]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/generate.cpp
29d52364c1506f1dc151aad780c584dc13b0eab6
[libs/gl.git] / source / glsl / generate.cpp
1 #include <algorithm>
2 #include <msp/core/hash.h>
3 #include <msp/core/raii.h>
4 #include <msp/strings/lexicalcast.h>
5 #include <msp/strings/utils.h>
6 #include "builtin.h"
7 #include "generate.h"
8
9 using namespace std;
10
11 namespace Msp {
12 namespace GL {
13 namespace SL {
14
15 void DeclarationCombiner::apply(Stage &stage)
16 {
17         stage.content.visit(*this);
18         NodeRemover().apply(stage, nodes_to_remove);
19 }
20
21 void DeclarationCombiner::visit(Block &block)
22 {
23         if(current_block)
24                 return;
25
26         TraversingVisitor::visit(block);
27 }
28
29 void DeclarationCombiner::visit(VariableDeclaration &var)
30 {
31         VariableDeclaration *&ptr = variables[var.name];
32         if(ptr)
33         {
34                 ptr->type = var.type;
35                 if(var.init_expression)
36                         ptr->init_expression = var.init_expression;
37                 if(var.layout)
38                 {
39                         if(ptr->layout)
40                         {
41                                 for(vector<Layout::Qualifier>::iterator i=var.layout->qualifiers.begin(); i!=var.layout->qualifiers.end(); ++i)
42                                 {
43                                         bool found = false;
44                                         for(vector<Layout::Qualifier>::iterator j=ptr->layout->qualifiers.begin(); (!found && j!=ptr->layout->qualifiers.end()); ++j)
45                                                 if(j->name==i->name)
46                                                 {
47                                                         j->has_value = i->value;
48                                                         j->value = i->value;
49                                                         found = true;
50                                                 }
51
52                                         if(!found)
53                                                 ptr->layout->qualifiers.push_back(*i);
54                                 }
55                         }
56                         else
57                                 ptr->layout = var.layout;
58                 }
59                 nodes_to_remove.insert(&var);
60         }
61         else
62                 ptr = &var;
63 }
64
65
66 ConstantSpecializer::ConstantSpecializer():
67         values(0)
68 { }
69
70 void ConstantSpecializer::apply(Stage &stage, const map<string, int> *v)
71 {
72         values = v;
73         stage.content.visit(*this);
74 }
75
76 void ConstantSpecializer::visit(VariableDeclaration &var)
77 {
78         bool specializable = false;
79         if(var.layout)
80         {
81                 vector<Layout::Qualifier> &qualifiers = var.layout->qualifiers;
82                 for(vector<Layout::Qualifier>::iterator i=qualifiers.begin(); i!=qualifiers.end(); ++i)
83                         if(i->name=="constant_id")
84                         {
85                                 specializable = true;
86                                 if(values)
87                                         qualifiers.erase(i);
88                                 else if(i->value==-1)
89                                         i->value = hash32(var.name)&0x7FFFFFFF;
90                                 break;
91                         }
92
93                 if(qualifiers.empty())
94                         var.layout = 0;
95         }
96
97         if(specializable && values)
98         {
99                 map<string, int>::const_iterator i = values->find(var.name);
100                 if(i!=values->end())
101                 {
102                         RefPtr<Literal> literal = new Literal;
103                         if(var.type=="bool")
104                         {
105                                 literal->token = (i->second ? "true" : "false");
106                                 literal->value = static_cast<bool>(i->second);
107                         }
108                         else if(var.type=="int")
109                         {
110                                 literal->token = lexical_cast<string>(i->second);
111                                 literal->value = i->second;
112                         }
113                         var.init_expression = literal;
114                 }
115         }
116 }
117
118
119 void BlockHierarchyResolver::enter(Block &block)
120 {
121         r_any_resolved |= (current_block!=block.parent);
122         block.parent = current_block;
123 }
124
125
126 TypeResolver::TypeResolver():
127         stage(0),
128         iface_block(0),
129         r_any_resolved(false)
130 { }
131
132 bool TypeResolver::apply(Stage &s)
133 {
134         stage = &s;
135         s.types.clear();
136         r_any_resolved = false;
137         s.content.visit(*this);
138         return r_any_resolved;
139 }
140
141 TypeDeclaration *TypeResolver::get_or_create_array_type(TypeDeclaration &type)
142 {
143         map<TypeDeclaration *, TypeDeclaration *>::iterator i = array_types.find(&type);
144         if(i!=array_types.end())
145                 return i->second;
146
147         BasicTypeDeclaration *array = new BasicTypeDeclaration;
148         array->source = BUILTIN_SOURCE;
149         array->name = type.name+"[]";
150         array->kind = BasicTypeDeclaration::ARRAY;
151         array->base = type.name;
152         array->base_type = &type;
153         stage->content.body.insert(type_insert_point, array);
154         array_types[&type] = array;
155         return array;
156 }
157
158 void TypeResolver::resolve_type(TypeDeclaration *&type, const string &name, bool array)
159 {
160         TypeDeclaration *resolved = 0;
161         map<string, TypeDeclaration *>::iterator i = stage->types.find(name);
162         if(i!=stage->types.end())
163         {
164                 map<TypeDeclaration *, TypeDeclaration *>::iterator j = alias_map.find(i->second);
165                 resolved = (j!=alias_map.end() ? j->second : i->second);
166         }
167
168         if(resolved && array)
169                 resolved = get_or_create_array_type(*resolved);
170
171         r_any_resolved |= (resolved!=type);
172         type=resolved;
173 }
174
175 void TypeResolver::visit(Block &block)
176 {
177         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
178         {
179                 if(!block.parent)
180                         type_insert_point = i;
181                 (*i)->visit(*this);
182         }
183 }
184
185 void TypeResolver::visit(BasicTypeDeclaration &type)
186 {
187         resolve_type(type.base_type, type.base, false);
188
189         if(type.kind==BasicTypeDeclaration::VECTOR && type.base_type)
190                 if(BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type))
191                         if(basic_base->kind==BasicTypeDeclaration::VECTOR)
192                         {
193                                 type.kind = BasicTypeDeclaration::MATRIX;
194                                 type.size |= basic_base->size<<16;
195                         }
196
197         if(type.kind==BasicTypeDeclaration::ALIAS && type.base_type)
198                 alias_map[&type] = type.base_type;
199         else if(type.kind==BasicTypeDeclaration::ARRAY && type.base_type)
200                 array_types[type.base_type] = &type;
201
202         stage->types.insert(make_pair(type.name, &type));
203 }
204
205 void TypeResolver::visit(ImageTypeDeclaration &type)
206 {
207         resolve_type(type.base_type, type.base, false);
208         stage->types.insert(make_pair(type.name, &type));
209 }
210
211 void TypeResolver::visit(StructDeclaration &strct)
212 {
213         stage->types.insert(make_pair(strct.name, &strct));
214         TraversingVisitor::visit(strct);
215 }
216
217 void TypeResolver::visit(VariableDeclaration &var)
218 {
219         resolve_type(var.type_declaration, var.type, var.array);
220         if(iface_block && var.interface==iface_block->interface)
221                 var.interface.clear();
222 }
223
224 void TypeResolver::visit(InterfaceBlock &iface)
225 {
226         if(iface.members)
227         {
228                 SetForScope<InterfaceBlock *> set_iface(iface_block, &iface);
229                 iface.members->visit(*this);
230
231                 StructDeclaration *strct = new StructDeclaration;
232                 strct->source = INTERNAL_SOURCE;
233                 strct->name = format("_%s_%s", iface.interface, iface.name);
234                 strct->members.body.splice(strct->members.body.begin(), iface.members->body);
235                 stage->content.body.insert(type_insert_point, strct);
236                 stage->types.insert(make_pair(strct->name, strct));
237
238                 iface.members = 0;
239                 strct->interface_block = &iface;
240                 iface.struct_declaration = strct;
241         }
242
243         TypeDeclaration *type = iface.struct_declaration;
244         if(type && iface.array)
245                 type = get_or_create_array_type(*type);
246         r_any_resolved = (type!=iface.type_declaration);
247         iface.type_declaration = type;
248 }
249
250 void TypeResolver::visit(FunctionDeclaration &func)
251 {
252         resolve_type(func.return_type_declaration, func.return_type, false);
253         TraversingVisitor::visit(func);
254 }
255
256
257 VariableResolver::VariableResolver():
258         stage(0),
259         r_any_resolved(false),
260         record_target(false),
261         r_self_referencing(false)
262 { }
263
264 bool VariableResolver::apply(Stage &s)
265 {
266         stage = &s;
267         s.interface_blocks.clear();
268         r_any_resolved = false;
269         s.content.visit(*this);
270         return r_any_resolved;
271 }
272
273 void VariableResolver::enter(Block &block)
274 {
275         block.variables.clear();
276 }
277
278 void VariableResolver::visit(RefPtr<Expression> &expr)
279 {
280         r_replacement_expr = 0;
281         expr->visit(*this);
282         if(r_replacement_expr)
283         {
284                 expr = r_replacement_expr;
285                 /* Don't record assignment target when doing a replacement, because chain
286                 information won't be correct. */
287                 r_assignment_target.declaration = 0;
288                 r_any_resolved = true;
289         }
290         r_replacement_expr = 0;
291 }
292
293 void VariableResolver::check_assignment_target(Statement *declaration)
294 {
295         if(record_target)
296         {
297                 if(r_assignment_target.declaration)
298                 {
299                         /* More than one reference found in assignment target.  Unable to
300                         determine what the primary target is. */
301                         record_target = false;
302                         r_assignment_target.declaration = 0;
303                 }
304                 else
305                         r_assignment_target.declaration = declaration;
306         }
307         // TODO This check is overly broad and may prevent some optimizations.
308         else if(declaration && declaration==r_assignment_target.declaration)
309                 r_self_referencing = true;
310 }
311
312 void VariableResolver::visit(VariableReference &var)
313 {
314         VariableDeclaration *declaration = 0;
315
316         /* Look for variable declarations in the block hierarchy first.  Interface
317         blocks are always defined in the top level so we can't accidentally skip
318         one. */
319         for(Block *block=current_block; (!declaration && block); block=block->parent)
320         {
321                 map<string, VariableDeclaration *>::iterator i = block->variables.find(var.name);
322                 if(i!=block->variables.end())
323                         declaration = i->second;
324         }
325
326         if(!declaration)
327         {
328                 const map<string, InterfaceBlock *> &blocks = stage->interface_blocks;
329                 map<string, InterfaceBlock *>::const_iterator i = blocks.find("_"+var.name);
330                 if(i!=blocks.end())
331                 {
332                         /* The name refers to an interface block with an instance name rather
333                         than a variable.  Prepare a new syntax tree node accordingly. */
334                         InterfaceBlockReference *iface_ref = new InterfaceBlockReference;
335                         iface_ref->source = var.source;
336                         iface_ref->line = var.line;
337                         iface_ref->name = var.name;
338                         iface_ref->declaration = i->second;
339                         r_replacement_expr = iface_ref;
340                 }
341                 else
342                 {
343                         // Look for the variable in anonymous interface blocks.
344                         for(i=blocks.begin(); (!declaration && i!=blocks.end()); ++i)
345                                 if(i->second->instance_name.empty() && i->second->struct_declaration)
346                                 {
347                                         const map<string, VariableDeclaration *> &iface_vars = i->second->struct_declaration->members.variables;
348                                         map<string, VariableDeclaration *>::const_iterator j = iface_vars.find(var.name);
349                                         if(j!=iface_vars.end())
350                                                 declaration = j->second;
351                                 }
352                 }
353         }
354
355         r_any_resolved |= (declaration!=var.declaration);
356         var.declaration = declaration;
357
358         check_assignment_target(var.declaration);
359 }
360
361 void VariableResolver::visit(InterfaceBlockReference &iface)
362 {
363         map<string, InterfaceBlock *>::iterator i = stage->interface_blocks.find("_"+iface.name);
364         InterfaceBlock *declaration = (i!=stage->interface_blocks.end() ? i->second : 0);
365         r_any_resolved |= (declaration!=iface.declaration);
366         iface.declaration = declaration;
367
368         check_assignment_target(iface.declaration);
369 }
370
371 void VariableResolver::add_to_chain(Assignment::Target::ChainType type, unsigned index)
372 {
373         if(r_assignment_target.chain_len<7)
374                 r_assignment_target.chain[r_assignment_target.chain_len] = type | min<unsigned>(index, 0x3F);
375         ++r_assignment_target.chain_len;
376 }
377
378 void VariableResolver::visit(MemberAccess &memacc)
379 {
380         TraversingVisitor::visit(memacc);
381
382         VariableDeclaration *declaration = 0;
383         if(StructDeclaration *strct = dynamic_cast<StructDeclaration *>(memacc.left->type))
384         {
385                 map<string, VariableDeclaration *>::iterator i = strct->members.variables.find(memacc.member);
386                 if(i!=strct->members.variables.end())
387                 {
388                         declaration = i->second;
389
390                         if(record_target)
391                         {
392                                 unsigned index = 0;
393                                 for(NodeList<Statement>::const_iterator j=strct->members.body.begin(); (j!=strct->members.body.end() && j->get()!=i->second); ++j)
394                                         ++index;
395
396                                 add_to_chain(Assignment::Target::MEMBER, index);
397                         }
398                 }
399         }
400         else if(BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(memacc.left->type))
401         {
402                 bool scalar_swizzle = ((basic->kind==BasicTypeDeclaration::INT || basic->kind==BasicTypeDeclaration::FLOAT) && memacc.member.size()==1);
403                 bool vector_swizzle = (basic->kind==BasicTypeDeclaration::VECTOR && memacc.member.size()<=4);
404                 if(scalar_swizzle || vector_swizzle)
405                 {
406                         static const char component_names[] = { 'x', 'r', 's', 'y', 'g', 't', 'z', 'b', 'p', 'w', 'a', 'q' };
407
408                         bool ok = true;
409                         UInt8 components[4] = { };
410                         for(unsigned i=0; (ok && i<memacc.member.size()); ++i)
411                                 ok = ((components[i] = (find(component_names, component_names+12, memacc.member[i])-component_names)/3) < 4);
412
413                         if(ok)
414                         {
415                                 Swizzle *swizzle = new Swizzle;
416                                 swizzle->source = memacc.source;
417                                 swizzle->line = memacc.line;
418                                 swizzle->oper = memacc.oper;
419                                 swizzle->left = memacc.left;
420                                 swizzle->component_group = memacc.member;
421                                 swizzle->count = memacc.member.size();
422                                 copy(components, components+memacc.member.size(), swizzle->components);
423                                 r_replacement_expr = swizzle;
424                         }
425                 }
426         }
427
428         r_any_resolved |= (declaration!=memacc.declaration);
429         memacc.declaration = declaration;
430 }
431
432 void VariableResolver::visit(Swizzle &swizzle)
433 {
434         TraversingVisitor::visit(swizzle);
435
436         if(record_target)
437         {
438                 unsigned mask = 0;
439                 for(unsigned i=0; i<swizzle.count; ++i)
440                         mask |= 1<<swizzle.components[i];
441                 add_to_chain(Assignment::Target::SWIZZLE, mask);
442         }
443 }
444
445 void VariableResolver::visit(BinaryExpression &binary)
446 {
447         if(binary.oper->token[0]=='[')
448         {
449                 {
450                         /* The subscript expression is not a part of the primary assignment
451                         target. */
452                         SetFlag set(record_target, false);
453                         visit(binary.right);
454                 }
455                 visit(binary.left);
456
457                 if(record_target)
458                 {
459                         unsigned index = 0x3F;
460                         if(Literal *literal_subscript = dynamic_cast<Literal *>(binary.right.get()))
461                                 if(literal_subscript->value.check_type<int>())
462                                         index = literal_subscript->value.value<int>();
463                         add_to_chain(Assignment::Target::ARRAY, index);
464                 }
465         }
466         else
467                 TraversingVisitor::visit(binary);
468 }
469
470 void VariableResolver::visit(Assignment &assign)
471 {
472         {
473                 SetFlag set(record_target);
474                 r_assignment_target = Assignment::Target();
475                 visit(assign.left);
476                 r_any_resolved |= (r_assignment_target<assign.target || assign.target<r_assignment_target);
477                 assign.target = r_assignment_target;
478         }
479
480         r_self_referencing = false;
481         visit(assign.right);
482         assign.self_referencing = (r_self_referencing || assign.oper->token[0]!='=');
483 }
484
485 void VariableResolver::visit(VariableDeclaration &var)
486 {
487         TraversingVisitor::visit(var);
488         current_block->variables.insert(make_pair(var.name, &var));
489 }
490
491 void VariableResolver::visit(InterfaceBlock &iface)
492 {
493         /* Block names can be reused in different interfaces.  Prefix the name with
494         the first character of the interface to avoid conflicts. */
495         stage->interface_blocks.insert(make_pair(iface.interface+iface.name, &iface));
496         if(!iface.instance_name.empty())
497                 stage->interface_blocks.insert(make_pair("_"+iface.instance_name, &iface));
498
499         TraversingVisitor::visit(iface);
500 }
501
502
503 ExpressionResolver::ExpressionResolver():
504         stage(0),
505         r_any_resolved(false)
506 { }
507
508 bool ExpressionResolver::apply(Stage &s)
509 {
510         stage = &s;
511         r_any_resolved = false;
512         s.content.visit(*this);
513         return r_any_resolved;
514 }
515
516 bool ExpressionResolver::is_scalar(BasicTypeDeclaration &type)
517 {
518         return (type.kind==BasicTypeDeclaration::INT || type.kind==BasicTypeDeclaration::FLOAT);
519 }
520
521 bool ExpressionResolver::is_vector_or_matrix(BasicTypeDeclaration &type)
522 {
523         return (type.kind==BasicTypeDeclaration::VECTOR || type.kind==BasicTypeDeclaration::MATRIX);
524 }
525
526 BasicTypeDeclaration *ExpressionResolver::get_element_type(BasicTypeDeclaration &type)
527 {
528         if(is_vector_or_matrix(type) || type.kind==BasicTypeDeclaration::ARRAY)
529         {
530                 BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type);
531                 return (basic_base ? get_element_type(*basic_base) : 0);
532         }
533         else
534                 return &type;
535 }
536
537 bool ExpressionResolver::can_convert(BasicTypeDeclaration &from, BasicTypeDeclaration &to)
538 {
539         if(from.kind==BasicTypeDeclaration::INT && to.kind==BasicTypeDeclaration::FLOAT)
540                 return from.size<=to.size;
541         else if(from.kind!=to.kind)
542                 return false;
543         else if((from.kind==BasicTypeDeclaration::VECTOR || from.kind==BasicTypeDeclaration::MATRIX) && from.size==to.size)
544         {
545                 BasicTypeDeclaration *from_base = dynamic_cast<BasicTypeDeclaration *>(from.base_type);
546                 BasicTypeDeclaration *to_base = dynamic_cast<BasicTypeDeclaration *>(to.base_type);
547                 return (from_base && to_base && can_convert(*from_base, *to_base));
548         }
549         else
550                 return false;
551 }
552
553 ExpressionResolver::Compatibility ExpressionResolver::get_compatibility(BasicTypeDeclaration &left, BasicTypeDeclaration &right)
554 {
555         if(&left==&right)
556                 return SAME_TYPE;
557         else if(can_convert(left, right))
558                 return LEFT_CONVERTIBLE;
559         else if(can_convert(right, left))
560                 return RIGHT_CONVERTIBLE;
561         else
562                 return NOT_COMPATIBLE;
563 }
564
565 BasicTypeDeclaration *ExpressionResolver::find_type(BasicTypeDeclaration::Kind kind, unsigned size)
566 {
567         for(vector<BasicTypeDeclaration *>::const_iterator i=basic_types.begin(); i!=basic_types.end(); ++i)
568                 if((*i)->kind==kind && (*i)->size==size)
569                         return *i;
570         return 0;
571 }
572
573 BasicTypeDeclaration *ExpressionResolver::find_type(BasicTypeDeclaration &elem_type, BasicTypeDeclaration::Kind kind, unsigned size)
574 {
575         for(vector<BasicTypeDeclaration *>::const_iterator i=basic_types.begin(); i!=basic_types.end(); ++i)
576                 if(get_element_type(**i)==&elem_type && (*i)->kind==kind && (*i)->size==size)
577                         return *i;
578         return 0;
579 }
580
581 void ExpressionResolver::convert_to(RefPtr<Expression> &expr, BasicTypeDeclaration &type)
582 {
583         RefPtr<FunctionCall> call = new FunctionCall;
584         call->name = type.name;
585         call->constructor = true;
586         call->arguments.push_back(0);
587         call->arguments.back() = expr;
588         call->type = &type;
589         expr = call;
590 }
591
592 bool ExpressionResolver::convert_to_element(RefPtr<Expression> &expr, BasicTypeDeclaration &elem_type)
593 {
594         if(BasicTypeDeclaration *expr_type = dynamic_cast<BasicTypeDeclaration *>(expr->type))
595         {
596                 BasicTypeDeclaration *to_type = &elem_type;
597                 if(is_vector_or_matrix(*expr_type))
598                         to_type = find_type(elem_type, expr_type->kind, expr_type->size);
599                 if(to_type)
600                 {
601                         convert_to(expr, *to_type);
602                         return true;
603                 }
604         }
605
606         return false;
607 }
608
609 void ExpressionResolver::resolve(Expression &expr, TypeDeclaration *type, bool lvalue)
610 {
611         r_any_resolved |= (type!=expr.type || lvalue!=expr.lvalue);
612         expr.type = type;
613         expr.lvalue = lvalue;
614 }
615
616 void ExpressionResolver::visit(Literal &literal)
617 {
618         if(literal.value.check_type<bool>())
619                 resolve(literal, find_type(BasicTypeDeclaration::BOOL, 1), false);
620         else if(literal.value.check_type<int>())
621                 resolve(literal, find_type(BasicTypeDeclaration::INT, 32), false);
622         else if(literal.value.check_type<float>())
623                 resolve(literal, find_type(BasicTypeDeclaration::FLOAT, 32), false);
624 }
625
626 void ExpressionResolver::visit(ParenthesizedExpression &parexpr)
627 {
628         TraversingVisitor::visit(parexpr);
629         resolve(parexpr, parexpr.expression->type, parexpr.expression->lvalue);
630 }
631
632 void ExpressionResolver::visit(VariableReference &var)
633 {
634         if(var.declaration)
635                 resolve(var, var.declaration->type_declaration, true);
636 }
637
638 void ExpressionResolver::visit(InterfaceBlockReference &iface)
639 {
640         if(iface.declaration)
641                 resolve(iface, iface.declaration->type_declaration, true);
642 }
643
644 void ExpressionResolver::visit(MemberAccess &memacc)
645 {
646         TraversingVisitor::visit(memacc);
647
648         if(memacc.declaration)
649                 resolve(memacc, memacc.declaration->type_declaration, memacc.left->lvalue);
650 }
651
652 void ExpressionResolver::visit(Swizzle &swizzle)
653 {
654         TraversingVisitor::visit(swizzle);
655
656         if(BasicTypeDeclaration *left_basic = dynamic_cast<BasicTypeDeclaration *>(swizzle.left->type))
657         {
658                 BasicTypeDeclaration *left_elem = get_element_type(*left_basic);
659                 if(swizzle.count==1)
660                         resolve(swizzle, left_elem, swizzle.left->lvalue);
661                 else if(left_basic->kind==BasicTypeDeclaration::VECTOR && left_elem)
662                         resolve(swizzle, find_type(*left_elem, left_basic->kind, swizzle.count), swizzle.left->lvalue);
663         }
664 }
665
666 void ExpressionResolver::visit(UnaryExpression &unary)
667 {
668         TraversingVisitor::visit(unary);
669
670         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(unary.expression->type);
671         if(!basic)
672                 return;
673
674         char oper = unary.oper->token[0];
675         if(oper=='!')
676         {
677                 if(basic->kind!=BasicTypeDeclaration::BOOL)
678                         return;
679         }
680         else if(oper=='~')
681         {
682                 if(basic->kind!=BasicTypeDeclaration::INT)
683                         return;
684         }
685         else if(oper=='+' || oper=='-')
686         {
687                 BasicTypeDeclaration *elem = get_element_type(*basic);
688                 if(!elem || !is_scalar(*elem))
689                         return;
690         }
691         resolve(unary, basic, unary.expression->lvalue);
692 }
693
694 void ExpressionResolver::visit(BinaryExpression &binary, bool assign)
695 {
696         /* Binary operators are only defined for basic types (not for image or
697         structure types). */
698         BasicTypeDeclaration *basic_left = dynamic_cast<BasicTypeDeclaration *>(binary.left->type);
699         BasicTypeDeclaration *basic_right = dynamic_cast<BasicTypeDeclaration *>(binary.right->type);
700         if(!basic_left || !basic_right)
701                 return;
702
703         char oper = binary.oper->token[0];
704         if(oper=='[')
705         {
706                 /* Subscripting operates on vectors, matrices and arrays, and the right
707                 operand must be an integer. */
708                 if((!is_vector_or_matrix(*basic_left) && basic_left->kind!=BasicTypeDeclaration::ARRAY) || basic_right->kind!=BasicTypeDeclaration::INT)
709                         return;
710
711                 resolve(binary, basic_left->base_type, binary.left->lvalue);
712                 return;
713         }
714         else if(basic_left->kind==BasicTypeDeclaration::ARRAY || basic_right->kind==BasicTypeDeclaration::ARRAY)
715                 // No other binary operator can be used with arrays.
716                 return;
717
718         BasicTypeDeclaration *elem_left = get_element_type(*basic_left);
719         BasicTypeDeclaration *elem_right = get_element_type(*basic_right);
720         if(!elem_left || !elem_right)
721                 return;
722
723         Compatibility compat = get_compatibility(*basic_left, *basic_right);
724         Compatibility elem_compat = get_compatibility(*elem_left, *elem_right);
725         if(elem_compat==NOT_COMPATIBLE)
726                 return;
727         if(assign && (compat==LEFT_CONVERTIBLE || elem_compat==LEFT_CONVERTIBLE))
728                 return;
729
730         TypeDeclaration *type = 0;
731         char oper2 = binary.oper->token[1];
732         if((oper=='<' && oper2!='<') || (oper=='>' && oper2!='>'))
733         {
734                 /* Relational operators compare two scalar integer or floating-point
735                 values. */
736                 if(!is_scalar(*elem_left) || !is_scalar(*elem_right) || compat==NOT_COMPATIBLE)
737                         return;
738
739                 type = find_type(BasicTypeDeclaration::BOOL, 1);
740         }
741         else if((oper=='=' || oper=='!') && oper2=='=')
742         {
743                 // Equality comparison can be done on any compatible types.
744                 if(compat==NOT_COMPATIBLE)
745                         return;
746
747                 type = find_type(BasicTypeDeclaration::BOOL, 1);
748         }
749         else if(oper2=='&' || oper2=='|' || oper2=='^')
750         {
751                 // Logical operators can only be applied to booleans.
752                 if(basic_left->kind!=BasicTypeDeclaration::BOOL || basic_right->kind!=BasicTypeDeclaration::BOOL)
753                         return;
754
755                 type = basic_left;
756         }
757         else if((oper=='&' || oper=='|' || oper=='^' || oper=='%') && !oper2)
758         {
759                 // Bitwise operators and modulo can only be applied to integers.
760                 if(basic_left->kind!=BasicTypeDeclaration::INT || basic_right->kind!=BasicTypeDeclaration::INT)
761                         return;
762
763                 type = (compat==LEFT_CONVERTIBLE ? basic_right : basic_left);
764         }
765         else if((oper=='<' || oper=='>') && oper2==oper)
766         {
767                 // Shifts apply to integer scalars and vectors, with some restrictions.
768                 if(elem_left->kind!=BasicTypeDeclaration::INT || elem_right->kind!=BasicTypeDeclaration::INT)
769                         return;
770                 unsigned left_size = (basic_left->kind==BasicTypeDeclaration::INT ? 1 : basic_left->kind==BasicTypeDeclaration::VECTOR ? basic_left->size : 0);
771                 unsigned right_size = (basic_right->kind==BasicTypeDeclaration::INT ? 1 : basic_right->kind==BasicTypeDeclaration::VECTOR ? basic_right->size : 0);
772                 if(!left_size || (left_size==1 && right_size!=1) || (left_size>1 && right_size!=1 && right_size!=left_size))
773                         return;
774
775                 type = basic_left;
776                 // Don't perform conversion even if the operands are of different sizes.
777                 compat = SAME_TYPE;
778         }
779         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
780         {
781                 // Arithmetic operators require scalar elements.
782                 if(!is_scalar(*elem_left) || !is_scalar(*elem_right))
783                         return;
784
785                 if(oper=='*' && is_vector_or_matrix(*basic_left) && is_vector_or_matrix(*basic_right) &&
786                         (basic_left->kind==BasicTypeDeclaration::MATRIX || basic_right->kind==BasicTypeDeclaration::MATRIX))
787                 {
788                         /* Multiplication has special rules when at least one operand is a
789                         matrix and the other is a vector or a matrix. */
790                         unsigned left_columns = basic_left->size&0xFFFF;
791                         unsigned right_rows = basic_right->size;
792                         if(basic_right->kind==BasicTypeDeclaration::MATRIX)
793                                 right_rows >>= 16;
794                         if(left_columns!=right_rows)
795                                 return;
796
797                         BasicTypeDeclaration *elem_result = (elem_compat==LEFT_CONVERTIBLE ? elem_right : elem_left);
798
799                         if(basic_left->kind==BasicTypeDeclaration::VECTOR)
800                                 type = find_type(*elem_result, BasicTypeDeclaration::VECTOR, basic_right->size&0xFFFF);
801                         else if(basic_right->kind==BasicTypeDeclaration::VECTOR)
802                                 type = find_type(*elem_result, BasicTypeDeclaration::VECTOR, basic_left->size>>16);
803                         else
804                                 type = find_type(*elem_result, BasicTypeDeclaration::MATRIX, (basic_left->size&0xFFFF0000)|(basic_right->size&0xFFFF));
805                 }
806                 else if(compat==NOT_COMPATIBLE)
807                 {
808                         // Arithmetic between scalars and matrices or vectors is supported.
809                         if(is_scalar(*basic_left) && is_vector_or_matrix(*basic_right))
810                                 type = (elem_compat==RIGHT_CONVERTIBLE ? find_type(*elem_left, basic_right->kind, basic_right->size) : basic_right);
811                         else if(is_vector_or_matrix(*basic_left) && is_scalar(*basic_right))
812                                 type = (elem_compat==LEFT_CONVERTIBLE ? find_type(*elem_right, basic_left->kind, basic_left->size) : basic_left);
813                         else
814                                 return;
815                 }
816                 else if(compat==LEFT_CONVERTIBLE)
817                         type = basic_right;
818                 else
819                         type = basic_left;
820         }
821         else
822                 return;
823
824         if(assign && type!=basic_left)
825                 return;
826
827         bool converted = true;
828         if(compat==LEFT_CONVERTIBLE)
829                 convert_to(binary.left, *basic_right);
830         else if(compat==RIGHT_CONVERTIBLE)
831                 convert_to(binary.right, *basic_left);
832         else if(elem_compat==LEFT_CONVERTIBLE)
833                 converted = convert_to_element(binary.left, *elem_right);
834         else if(elem_compat==RIGHT_CONVERTIBLE)
835                 converted = convert_to_element(binary.right, *elem_left);
836
837         if(!converted)
838                 type = 0;
839
840         resolve(binary, type, assign);
841 }
842
843 void ExpressionResolver::visit(BinaryExpression &binary)
844 {
845         TraversingVisitor::visit(binary);
846         visit(binary, false);
847 }
848
849 void ExpressionResolver::visit(Assignment &assign)
850 {
851         TraversingVisitor::visit(assign);
852
853         if(assign.oper->token[0]!='=')
854                 return visit(assign, true);
855         else if(assign.left->type!=assign.right->type)
856         {
857                 BasicTypeDeclaration *basic_left = dynamic_cast<BasicTypeDeclaration *>(assign.left->type);
858                 BasicTypeDeclaration *basic_right = dynamic_cast<BasicTypeDeclaration *>(assign.right->type);
859                 if(!basic_left || !basic_right)
860                         return;
861
862                 Compatibility compat = get_compatibility(*basic_left, *basic_right);
863                 if(compat==RIGHT_CONVERTIBLE)
864                         convert_to(assign.right, *basic_left);
865                 else if(compat!=SAME_TYPE)
866                         return;
867         }
868
869         resolve(assign, assign.left->type, true);
870 }
871
872 void ExpressionResolver::visit(FunctionCall &call)
873 {
874         TraversingVisitor::visit(call);
875
876         TypeDeclaration *type = 0;
877         if(call.declaration)
878                 type = call.declaration->return_type_declaration;
879         else if(call.constructor)
880         {
881                 map<string, TypeDeclaration *>::const_iterator i=stage->types.find(call.name);
882                 type = (i!=stage->types.end() ? i->second : 0);
883         }
884         resolve(call, type, false);
885 }
886
887 void ExpressionResolver::visit(BasicTypeDeclaration &type)
888 {
889         basic_types.push_back(&type);
890 }
891
892 void ExpressionResolver::visit(VariableDeclaration &var)
893 {
894         TraversingVisitor::visit(var);
895         if(!var.init_expression)
896                 return;
897
898         BasicTypeDeclaration *var_basic = dynamic_cast<BasicTypeDeclaration *>(var.type_declaration);
899         BasicTypeDeclaration *init_basic = dynamic_cast<BasicTypeDeclaration *>(var.init_expression->type);
900         if(!var_basic || !init_basic)
901                 return;
902
903         Compatibility compat = get_compatibility(*var_basic, *init_basic);
904         if(compat==RIGHT_CONVERTIBLE)
905                 convert_to(var.init_expression, *var_basic);
906 }
907
908
909 bool FunctionResolver::apply(Stage &s)
910 {
911         stage = &s;
912         s.functions.clear();
913         r_any_resolved = false;
914         s.content.visit(*this);
915         return r_any_resolved;
916 }
917
918 void FunctionResolver::visit(FunctionCall &call)
919 {
920         string arg_types;
921         bool has_signature = true;
922         for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); (has_signature && i!=call.arguments.end()); ++i)
923         {
924                 if((*i)->type)
925                         append(arg_types, ",", (*i)->type->name);
926                 else
927                         has_signature = false;
928         }
929
930         FunctionDeclaration *declaration = 0;
931         if(has_signature)
932         {
933                 map<string, FunctionDeclaration *>::iterator i = stage->functions.find(format("%s(%s)", call.name, arg_types));
934                 declaration = (i!=stage->functions.end() ? i->second : 0);
935         }
936         r_any_resolved |= (declaration!=call.declaration);
937         call.declaration = declaration;
938
939         TraversingVisitor::visit(call);
940 }
941
942 void FunctionResolver::visit(FunctionDeclaration &func)
943 {
944         if(func.signature.empty())
945         {
946                 string param_types;
947                 for(NodeArray<VariableDeclaration>::const_iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
948                 {
949                         if((*i)->type_declaration)
950                                 append(param_types, ",", (*i)->type_declaration->name);
951                         else
952                                 return;
953                 }
954                 func.signature = format("(%s)", param_types);
955                 r_any_resolved = true;
956         }
957
958         string key = func.name+func.signature;
959         FunctionDeclaration *&stage_decl = stage->functions[key];
960         vector<FunctionDeclaration *> &decls = declarations[key];
961         if(func.definition==&func)
962         {
963                 stage_decl = &func;
964
965                 // Set all previous declarations to use this definition.
966                 for(vector<FunctionDeclaration *>::iterator i=decls.begin(); i!=decls.end(); ++i)
967                 {
968                         r_any_resolved |= (func.definition!=(*i)->definition);
969                         (*i)->definition = func.definition;
970                         (*i)->body.body.clear();
971                 }
972         }
973         else
974         {
975                 FunctionDeclaration *definition = (stage_decl ? stage_decl->definition : 0);
976                 r_any_resolved |= (definition!=func.definition);
977                 func.definition = definition;
978
979                 if(!stage_decl)
980                         stage_decl = &func;
981         }
982         decls.push_back(&func);
983
984         TraversingVisitor::visit(func);
985 }
986
987
988 InterfaceGenerator::InterfaceGenerator():
989         stage(0),
990         function_scope(false),
991         copy_block(false),
992         iface_target_block(0)
993 { }
994
995 string InterfaceGenerator::get_out_prefix(Stage::Type type)
996 {
997         if(type==Stage::VERTEX)
998                 return "_vs_out_";
999         else if(type==Stage::GEOMETRY)
1000                 return "_gs_out_";
1001         else
1002                 return string();
1003 }
1004
1005 void InterfaceGenerator::apply(Stage &s)
1006 {
1007         stage = &s;
1008         iface_target_block = &stage->content;
1009         if(stage->previous)
1010                 in_prefix = get_out_prefix(stage->previous->type);
1011         out_prefix = get_out_prefix(stage->type);
1012         s.content.visit(*this);
1013         NodeRemover().apply(s, nodes_to_remove);
1014 }
1015
1016 void InterfaceGenerator::visit(Block &block)
1017 {
1018         SetForScope<Block *> set_block(current_block, &block);
1019         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
1020         {
1021                 assignment_insert_point = i;
1022                 if(&block==&stage->content)
1023                         iface_insert_point = i;
1024
1025                 (*i)->visit(*this);
1026         }
1027 }
1028
1029 string InterfaceGenerator::change_prefix(const string &name, const string &prefix) const
1030 {
1031         unsigned offset = (name.compare(0, in_prefix.size(), in_prefix) ? 0 : in_prefix.size());
1032         return prefix+name.substr(offset);
1033 }
1034
1035 VariableDeclaration *InterfaceGenerator::generate_interface(VariableDeclaration &var, const string &iface, const string &name)
1036 {
1037         if(stage->content.variables.count(name))
1038                 return 0;
1039
1040         VariableDeclaration* iface_var = new VariableDeclaration;
1041         iface_var->sampling = var.sampling;
1042         iface_var->interface = iface;
1043         iface_var->type = var.type;
1044         iface_var->name = name;
1045         /* Geometry shader inputs are always arrays.  But if we're bringing in an
1046         entire block, the array is on the block and not individual variables. */
1047         if(stage->type==Stage::GEOMETRY && !copy_block)
1048                 iface_var->array = ((var.array && var.interface!="in") || iface=="in");
1049         else
1050                 iface_var->array = var.array;
1051         if(iface_var->array)
1052                 iface_var->array_size = var.array_size;
1053         if(iface=="in")
1054         {
1055                 iface_var->layout = var.layout;
1056                 iface_var->linked_declaration = &var;
1057                 var.linked_declaration = iface_var;
1058         }
1059
1060         iface_target_block->body.insert(iface_insert_point, iface_var);
1061         iface_target_block->variables.insert(make_pair(name, iface_var));
1062
1063         return iface_var;
1064 }
1065
1066 InterfaceBlock *InterfaceGenerator::generate_interface(InterfaceBlock &out_block)
1067 {
1068         if(stage->interface_blocks.count("in"+out_block.name))
1069                 return 0;
1070
1071         InterfaceBlock *in_block = new InterfaceBlock;
1072         in_block->interface = "in";
1073         in_block->name = out_block.name;
1074         in_block->members = new Block;
1075         in_block->instance_name = out_block.instance_name;
1076         if(stage->type==Stage::GEOMETRY)
1077                 in_block->array = true;
1078         else
1079                 in_block->array = out_block.array;
1080         in_block->linked_block = &out_block;
1081         out_block.linked_block = in_block;
1082
1083         {
1084                 SetFlag set_copy(copy_block, true);
1085                 SetForScope<Block *> set_target(iface_target_block, in_block->members.get());
1086                 SetForScope<NodeList<Statement>::iterator> set_ins_pt(iface_insert_point, in_block->members->body.end());
1087                 if(out_block.struct_declaration)
1088                         out_block.struct_declaration->members.visit(*this);
1089                 else if(out_block.members)
1090                         out_block.members->visit(*this);
1091         }
1092
1093         iface_target_block->body.insert(iface_insert_point, in_block);
1094         stage->interface_blocks.insert(make_pair("in"+in_block->name, in_block));
1095         if(!in_block->instance_name.empty())
1096                 stage->interface_blocks.insert(make_pair("_"+in_block->instance_name, in_block));
1097
1098         SetFlag set_scope(function_scope, false);
1099         SetForScope<Block *> set_block(current_block, &stage->content);
1100         in_block->visit(*this);
1101
1102         return in_block;
1103 }
1104
1105 ExpressionStatement &InterfaceGenerator::insert_assignment(const string &left, Expression *right)
1106 {
1107         Assignment *assign = new Assignment;
1108         VariableReference *ref = new VariableReference;
1109         ref->name = left;
1110         assign->left = ref;
1111         assign->oper = &Operator::get_operator("=", Operator::BINARY);
1112         assign->right = right;
1113
1114         ExpressionStatement *stmt = new ExpressionStatement;
1115         stmt->expression = assign;
1116         current_block->body.insert(assignment_insert_point, stmt);
1117         stmt->visit(*this);
1118
1119         return *stmt;
1120 }
1121
1122 void InterfaceGenerator::visit(VariableReference &var)
1123 {
1124         if(var.declaration || !stage->previous)
1125                 return;
1126         /* Don't pull a variable from previous stage if we just generated an output
1127         interface in this stage */
1128         if(stage->content.variables.count(var.name))
1129                 return;
1130
1131         const map<string, VariableDeclaration *> &prev_vars = stage->previous->content.variables;
1132         map<string, VariableDeclaration *>::const_iterator i = prev_vars.find(var.name);
1133         if(i==prev_vars.end() || i->second->interface!="out")
1134                 i = prev_vars.find(in_prefix+var.name);
1135         if(i!=prev_vars.end() && i->second->interface=="out")
1136         {
1137                 generate_interface(*i->second, "in", i->second->name);
1138                 var.name = i->second->name;
1139                 return;
1140         }
1141
1142         const map<string, InterfaceBlock *> &prev_blocks = stage->previous->interface_blocks;
1143         map<string, InterfaceBlock *>::const_iterator j = prev_blocks.find("_"+var.name);
1144         if(j!=prev_blocks.end() && j->second->interface=="out")
1145         {
1146                 generate_interface(*j->second);
1147                 /* Let VariableResolver convert the variable reference into an interface
1148                 block reference. */
1149                 return;
1150         }
1151
1152         for(j=prev_blocks.begin(); j!=prev_blocks.end(); ++j)
1153                 if(j->second->instance_name.empty() && j->second->struct_declaration)
1154                 {
1155                         const map<string, VariableDeclaration *> &iface_vars = j->second->struct_declaration->members.variables;
1156                         i = iface_vars.find(var.name);
1157                         if(i!=iface_vars.end())
1158                         {
1159                                 generate_interface(*j->second);
1160                                 return;
1161                         }
1162                 }
1163 }
1164
1165 void InterfaceGenerator::visit(VariableDeclaration &var)
1166 {
1167         if(copy_block)
1168                 generate_interface(var, "in", var.name);
1169         else if(var.interface=="out")
1170         {
1171                 /* For output variables in function scope, generate a global interface
1172                 and replace the local declaration with an assignment. */
1173                 VariableDeclaration *out_var = 0;
1174                 if(function_scope && (out_var=generate_interface(var, "out", var.name)))
1175                 {
1176                         out_var->source = var.source;
1177                         out_var->line = var.line;
1178                         nodes_to_remove.insert(&var);
1179                         if(var.init_expression)
1180                         {
1181                                 ExpressionStatement &stmt = insert_assignment(var.name, var.init_expression->clone());
1182                                 stmt.source = var.source;
1183                                 stmt.line = var.line;
1184                                 return;
1185                         }
1186                 }
1187         }
1188         else if(var.interface=="in")
1189         {
1190                 /* Try to link input variables in global scope with output variables from
1191                 previous stage. */
1192                 if(current_block==&stage->content && !var.linked_declaration && stage->previous)
1193                 {
1194                         const map<string, VariableDeclaration *> &prev_vars = stage->previous->content.variables;
1195                         map<string, VariableDeclaration *>::const_iterator i = prev_vars.find(var.name);
1196                         if(i!=prev_vars.end() && i->second->interface=="out")
1197                         {
1198                                 var.linked_declaration = i->second;
1199                                 i->second->linked_declaration = &var;
1200                         }
1201                 }
1202         }
1203
1204         TraversingVisitor::visit(var);
1205 }
1206
1207 void InterfaceGenerator::visit(InterfaceBlock &iface)
1208 {
1209         if(iface.interface=="in")
1210         {
1211                 /* Try to link input blocks with output blocks sharing the same block
1212                 name from previous stage. */
1213                 if(!iface.linked_block && stage->previous)
1214                 {
1215                         const map<string, InterfaceBlock *> &prev_blocks = stage->previous->interface_blocks;
1216                         map<string, InterfaceBlock *>::const_iterator i = prev_blocks.find("out"+iface.name);
1217                         if(i!=prev_blocks.end())
1218                         {
1219                                 iface.linked_block = i->second;
1220                                 i->second->linked_block = &iface;
1221                         }
1222                 }
1223         }
1224
1225         TraversingVisitor::visit(iface);
1226 }
1227
1228 void InterfaceGenerator::visit(FunctionDeclaration &func)
1229 {
1230         SetFlag set_scope(function_scope, true);
1231         // Skip parameters because they're not useful here
1232         func.body.visit(*this);
1233 }
1234
1235 void InterfaceGenerator::visit(Passthrough &pass)
1236 {
1237         vector<VariableDeclaration *> pass_vars;
1238
1239         // Pass through all input variables of this stage.
1240         for(map<string, VariableDeclaration *>::const_iterator i=stage->content.variables.begin(); i!=stage->content.variables.end(); ++i)
1241                 if(i->second->interface=="in")
1242                         pass_vars.push_back(i->second);
1243
1244         if(stage->previous)
1245         {
1246                 const map<string, VariableDeclaration *> &prev_vars = stage->previous->content.variables;
1247                 for(map<string, VariableDeclaration *>::const_iterator i=prev_vars.begin(); i!=prev_vars.end(); ++i)
1248                 {
1249                         if(i->second->interface!="out")
1250                                 continue;
1251
1252                         /* Pass through output variables from the previous stage, but only
1253                         those which are not already linked to an input here. */
1254                         if(!i->second->linked_declaration && generate_interface(*i->second, "in", i->second->name))
1255                                 pass_vars.push_back(i->second);
1256                 }
1257         }
1258
1259         if(stage->type==Stage::GEOMETRY)
1260         {
1261                 /* Special case for geometry shader: copy gl_Position from input to
1262                 output. */
1263                 InterfaceBlockReference *ref = new InterfaceBlockReference;
1264                 ref->name = "gl_in";
1265
1266                 BinaryExpression *subscript = new BinaryExpression;
1267                 subscript->left = ref;
1268                 subscript->oper = &Operator::get_operator("[", Operator::BINARY);
1269                 subscript->right = pass.subscript;
1270
1271                 MemberAccess *memacc = new MemberAccess;
1272                 memacc->left = subscript;
1273                 memacc->member = "gl_Position";
1274
1275                 insert_assignment("gl_Position", memacc);
1276         }
1277
1278         for(vector<VariableDeclaration *>::const_iterator i=pass_vars.begin(); i!=pass_vars.end(); ++i)
1279         {
1280                 string out_name = change_prefix((*i)->name, out_prefix);
1281                 generate_interface(**i, "out", out_name);
1282
1283                 VariableReference *ref = new VariableReference;
1284                 ref->name = (*i)->name;
1285                 if(pass.subscript)
1286                 {
1287                         BinaryExpression *subscript = new BinaryExpression;
1288                         subscript->left = ref;
1289                         subscript->oper = &Operator::get_operator("[", Operator::BINARY);
1290                         subscript->right = pass.subscript;
1291                         insert_assignment(out_name, subscript);
1292                 }
1293                 else
1294                         insert_assignment(out_name, ref);
1295         }
1296
1297         nodes_to_remove.insert(&pass);
1298 }
1299
1300 } // namespace SL
1301 } // namespace GL
1302 } // namespace Msp