]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/optimize.cpp
Make ConstantFolder check the type of the value directly
[libs/gl.git] / source / glsl / optimize.cpp
1 #include <msp/core/raii.h>
2 #include <msp/strings/format.h>
3 #include "optimize.h"
4 #include "reflect.h"
5
6 using namespace std;
7
8 namespace Msp {
9 namespace GL {
10 namespace SL {
11
12 ConstantSpecializer::ConstantSpecializer():
13         values(0)
14 { }
15
16 void ConstantSpecializer::apply(Stage &stage, const map<string, int> &v)
17 {
18         values = &v;
19         stage.content.visit(*this);
20 }
21
22 void ConstantSpecializer::visit(VariableDeclaration &var)
23 {
24         bool specializable = false;
25         if(var.layout)
26         {
27                 vector<Layout::Qualifier> &qualifiers = var.layout->qualifiers;
28                 for(vector<Layout::Qualifier>::iterator i=qualifiers.begin(); (!specializable && i!=qualifiers.end()); ++i)
29                         if(i->name=="constant_id")
30                         {
31                                 specializable = true;
32                                 qualifiers.erase(i);
33                         }
34
35                 if(qualifiers.empty())
36                         var.layout = 0;
37         }
38
39         if(specializable)
40         {
41                 map<string, int>::const_iterator i = values->find(var.name);
42                 if(i!=values->end())
43                 {
44                         RefPtr<Literal> literal = new Literal;
45                         if(var.type=="bool")
46                         {
47                                 literal->token = (i->second ? "true" : "false");
48                                 literal->value = static_cast<bool>(i->second);
49                         }
50                         else if(var.type=="int")
51                         {
52                                 literal->token = lexical_cast<string>(i->second);
53                                 literal->value = i->second;
54                         }
55                         var.init_expression = literal;
56                 }
57         }
58 }
59
60
61 InlineableFunctionLocator::InlineableFunctionLocator():
62         current_function(0),
63         return_count(0)
64 { }
65
66 void InlineableFunctionLocator::visit(FunctionCall &call)
67 {
68         FunctionDeclaration *def = call.declaration;
69         if(def)
70                 def = def->definition;
71
72         if(def)
73         {
74                 unsigned &count = refcounts[def];
75                 ++count;
76                 /* Don't inline functions which are called more than once or are called
77                 recursively. */
78                 if((count>1 && def->source!=BUILTIN_SOURCE) || def==current_function)
79                         inlineable.erase(def);
80         }
81
82         TraversingVisitor::visit(call);
83 }
84
85 void InlineableFunctionLocator::visit(FunctionDeclaration &func)
86 {
87         bool has_out_params = false;
88         for(NodeArray<VariableDeclaration>::const_iterator i=func.parameters.begin(); (!has_out_params && i!=func.parameters.end()); ++i)
89                 has_out_params = ((*i)->interface=="out");
90
91         unsigned &count = refcounts[func.definition];
92         if((count<=1 || func.source==BUILTIN_SOURCE) && !has_out_params)
93                 inlineable.insert(func.definition);
94
95         SetForScope<FunctionDeclaration *> set(current_function, &func);
96         return_count = 0;
97         TraversingVisitor::visit(func);
98 }
99
100 void InlineableFunctionLocator::visit(Conditional &cond)
101 {
102         TraversingVisitor::visit(cond);
103         inlineable.erase(current_function);
104 }
105
106 void InlineableFunctionLocator::visit(Iteration &iter)
107 {
108         TraversingVisitor::visit(iter);
109         inlineable.erase(current_function);
110 }
111
112 void InlineableFunctionLocator::visit(Return &ret)
113 {
114         TraversingVisitor::visit(ret);
115         if(return_count)
116                 inlineable.erase(current_function);
117         ++return_count;
118 }
119
120
121 InlineContentInjector::InlineContentInjector():
122         source_func(0),
123         pass(REFERENCED)
124 { }
125
126 string InlineContentInjector::apply(Stage &stage, FunctionDeclaration &target_func, Block &tgt_blk, const NodeList<Statement>::iterator &ins_pt, FunctionCall &call)
127 {
128         source_func = call.declaration->definition;
129
130         /* Populate referenced_names from the target function so we can rename
131         variables from the inlined function that would conflict. */
132         pass = REFERENCED;
133         target_func.visit(*this);
134
135         /* Inline and rename passes must be interleaved so used variable names are
136         known when inlining the return statement. */
137         pass = INLINE;
138         staging_block.parent = &tgt_blk;
139         staging_block.variables.clear();
140
141         vector<RefPtr<VariableDeclaration> > params;
142         params.reserve(source_func->parameters.size());
143         for(NodeArray<VariableDeclaration>::iterator i=source_func->parameters.begin(); i!=source_func->parameters.end(); ++i)
144         {
145                 RefPtr<VariableDeclaration> var = (*i)->clone();
146                 var->interface.clear();
147
148                 SetForScope<Pass> set_pass(pass, RENAME);
149                 var->visit(*this);
150
151                 staging_block.body.push_back_nocopy(var);
152                 params.push_back(var);
153         }
154
155         for(NodeList<Statement>::iterator i=source_func->body.body.begin(); i!=source_func->body.body.end(); ++i)
156         {
157                 r_inlined_statement = 0;
158                 (*i)->visit(*this);
159                 if(!r_inlined_statement)
160                         r_inlined_statement = (*i)->clone();
161
162                 SetForScope<Pass> set_pass(pass, RENAME);
163                 r_inlined_statement->visit(*this);
164
165                 staging_block.body.push_back_nocopy(r_inlined_statement);
166         }
167
168         /* Now collect names from the staging block.  Local variables that would
169         have conflicted with the target function were renamed earlier. */
170         pass = REFERENCED;
171         referenced_names.clear();
172         staging_block.variables.clear();
173         staging_block.visit(*this);
174
175         /* Rename variables in the target function so they don't interfere with
176         global identifiers used by the source function. */
177         pass = RENAME;
178         staging_block.parent = source_func->body.parent;
179         target_func.visit(*this);
180
181         // Put the argument expressions in place after all renaming has been done.
182         for(unsigned i=0; i<source_func->parameters.size(); ++i)
183                 params[i]->init_expression = call.arguments[i]->clone();
184
185         tgt_blk.body.splice(ins_pt, staging_block.body);
186
187         NodeReorderer().apply(stage, target_func, DependencyCollector().apply(*source_func));
188
189         return r_result_name;
190 }
191
192 void InlineContentInjector::visit(VariableReference &var)
193 {
194         if(pass==RENAME)
195         {
196                 map<string, VariableDeclaration *>::const_iterator i = staging_block.variables.find(var.name);
197                 if(i!=staging_block.variables.end())
198                         var.name = i->second->name;
199         }
200         else if(pass==REFERENCED)
201                 referenced_names.insert(var.name);
202 }
203
204 void InlineContentInjector::visit(InterfaceBlockReference &iface)
205 {
206         if(pass==REFERENCED)
207                 referenced_names.insert(iface.name);
208 }
209
210 void InlineContentInjector::visit(FunctionCall &call)
211 {
212         if(pass==REFERENCED)
213                 referenced_names.insert(call.name);
214         TraversingVisitor::visit(call);
215 }
216
217 void InlineContentInjector::visit(VariableDeclaration &var)
218 {
219         TraversingVisitor::visit(var);
220
221         if(pass==RENAME)
222         {
223                 /* Check against conflicts with the other context as well as variables
224                 already renamed here. */
225                 bool conflict = (staging_block.variables.count(var.name) || referenced_names.count(var.name));
226                 staging_block.variables[var.name] = &var;
227                 if(conflict)
228                 {
229                         string mapped_name = get_unused_variable_name(staging_block, var.name);
230                         if(mapped_name!=var.name)
231                         {
232                                 staging_block.variables[mapped_name] = &var;
233                                 var.name = mapped_name;
234                         }
235                 }
236         }
237         else if(pass==REFERENCED)
238                 referenced_names.insert(var.type);
239 }
240
241 void InlineContentInjector::visit(Return &ret)
242 {
243         TraversingVisitor::visit(ret);
244
245         if(pass==INLINE && ret.expression)
246         {
247                 // Create a new variable to hold the return value of the inlined function.
248                 r_result_name = get_unused_variable_name(staging_block, "_return");
249                 RefPtr<VariableDeclaration> var = new VariableDeclaration;
250                 var->source = ret.source;
251                 var->line = ret.line;
252                 var->type = source_func->return_type;
253                 var->name = r_result_name;
254                 var->init_expression = ret.expression->clone();
255                 r_inlined_statement = var;
256         }
257 }
258
259
260 FunctionInliner::FunctionInliner():
261         current_function(0),
262         r_any_inlined(false),
263         r_inlined_here(false)
264 { }
265
266 bool FunctionInliner::apply(Stage &s)
267 {
268         stage = &s;
269         inlineable = InlineableFunctionLocator().apply(s);
270         r_any_inlined = false;
271         s.content.visit(*this);
272         return r_any_inlined;
273 }
274
275 void FunctionInliner::visit(RefPtr<Expression> &ptr)
276 {
277         r_inline_result = 0;
278         ptr->visit(*this);
279         if(r_inline_result)
280         {
281                 ptr = r_inline_result;
282                 r_any_inlined = true;
283         }
284         r_inline_result = 0;
285 }
286
287 void FunctionInliner::visit(Block &block)
288 {
289         SetForScope<Block *> set_block(current_block, &block);
290         SetForScope<NodeList<Statement>::iterator> save_insert_point(insert_point, block.body.begin());
291         for(NodeList<Statement>::iterator i=block.body.begin(); (!r_inlined_here && i!=block.body.end()); ++i)
292         {
293                 insert_point = i;
294                 (*i)->visit(*this);
295         }
296 }
297
298 void FunctionInliner::visit(FunctionCall &call)
299 {
300         for(NodeArray<Expression>::iterator i=call.arguments.begin(); (!r_inlined_here && i!=call.arguments.end()); ++i)
301                 visit(*i);
302
303         if(r_inlined_here)
304                 return;
305
306         FunctionDeclaration *def = call.declaration;
307         if(def)
308                 def = def->definition;
309
310         if(def && inlineable.count(def))
311         {
312                 string result_name = InlineContentInjector().apply(*stage, *current_function, *current_block, insert_point, call);
313
314                 // This will later get removed by UnusedVariableRemover.
315                 if(result_name.empty())
316                         result_name = "_msp_unused_from_inline";
317
318                 RefPtr<VariableReference> ref = new VariableReference;
319                 ref->name = result_name;
320                 r_inline_result = ref;
321
322                 /* Inlined variables need to be resolved before this function can be
323                 inlined further. */
324                 inlineable.erase(current_function);
325                 r_inlined_here = true;
326         }
327 }
328
329 void FunctionInliner::visit(FunctionDeclaration &func)
330 {
331         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
332         TraversingVisitor::visit(func);
333         r_inlined_here = false;
334 }
335
336 void FunctionInliner::visit(Iteration &iter)
337 {
338         /* Visit the initialization statement before entering the loop body so the
339         inlined statements get inserted outside. */
340         if(iter.init_statement)
341                 iter.init_statement->visit(*this);
342
343         SetForScope<Block *> set_block(current_block, &iter.body);
344         /* Skip the condition and loop expression parts because they're not properly
345         inside the body block.  Inlining anything into them will require a more
346         comprehensive transformation. */
347         iter.body.visit(*this);
348 }
349
350
351 ExpressionInliner::ExpressionInliner():
352         r_ref_info(0),
353         r_any_inlined(false),
354         r_trivial(false),
355         mutating(false),
356         iteration_init(false),
357         iteration_body(0),
358         r_oper(0)
359 { }
360
361 bool ExpressionInliner::apply(Stage &s)
362 {
363         s.content.visit(*this);
364         return r_any_inlined;
365 }
366
367 void ExpressionInliner::inline_expression(Expression &expr, RefPtr<Expression> &ptr)
368 {
369         ptr = expr.clone();
370         r_any_inlined = true;
371 }
372
373 void ExpressionInliner::visit(Block &block)
374 {
375         TraversingVisitor::visit(block);
376
377         for(map<string, VariableDeclaration *>::iterator i=block.variables.begin(); i!=block.variables.end(); ++i)
378         {
379                 map<Assignment::Target, ExpressionInfo>::iterator j = expressions.lower_bound(i->second);
380                 for(; (j!=expressions.end() && j->first.declaration==i->second); )
381                 {
382                         if(j->second.expression && j->second.inline_point)
383                                 inline_expression(*j->second.expression, *j->second.inline_point);
384
385                         expressions.erase(j++);
386                 }
387         }
388
389         /* Expressions assigned in this block may depend on local variables of the
390         block.  If this is a conditionally executed block, the assignments might not
391         always happen.  Mark the expressions as not available to any outer blocks. */
392         for(map<Assignment::Target, ExpressionInfo>::iterator i=expressions.begin(); i!=expressions.end(); ++i)
393                 if(i->second.assign_scope==&block)
394                         i->second.available = false;
395 }
396
397 void ExpressionInliner::visit(RefPtr<Expression> &expr)
398 {
399         r_ref_info = 0;
400         expr->visit(*this);
401         if(r_ref_info && r_ref_info->expression && r_ref_info->available)
402         {
403                 if(iteration_body && !r_ref_info->trivial)
404                 {
405                         /* Don't inline non-trivial expressions which were assigned outside
406                         an iteration statement.  The iteration may run multiple times, which
407                         would cause the expression to also be evaluated multiple times. */
408                         Block *i = r_ref_info->assign_scope;
409                         for(; (i && i!=iteration_body); i=i->parent) ;
410                         if(!i)
411                                 return;
412                 }
413
414                 if(r_ref_info->trivial)
415                         inline_expression(*r_ref_info->expression, expr);
416                 else
417                         /* Record the inline point for a non-trivial expression but don't
418                         inline it yet.  It might turn out it shouldn't be inlined after all. */
419                         r_ref_info->inline_point = &expr;
420         }
421         r_oper = expr->oper;
422         r_ref_info = 0;
423 }
424
425 void ExpressionInliner::visit(VariableReference &var)
426 {
427         if(var.declaration)
428         {
429                 map<Assignment::Target, ExpressionInfo>::iterator i = expressions.find(var.declaration);
430                 if(i!=expressions.end())
431                 {
432                         /* If a non-trivial expression is referenced multiple times, don't
433                         inline it. */
434                         if(i->second.inline_point && !i->second.trivial)
435                                 i->second.expression = 0;
436                         /* Mutating expressions are analogous to self-referencing assignments
437                         and prevent inlining. */
438                         if(mutating)
439                                 i->second.expression = 0;
440                         r_ref_info = &i->second;
441                 }
442         }
443 }
444
445 void ExpressionInliner::visit(MemberAccess &memacc)
446 {
447         visit(memacc.left);
448         r_trivial = false;
449 }
450
451 void ExpressionInliner::visit(Swizzle &swizzle)
452 {
453         visit(swizzle.left);
454         r_trivial = false;
455 }
456
457 void ExpressionInliner::visit(UnaryExpression &unary)
458 {
459         SetFlag set_target(mutating, mutating || unary.oper->token[1]=='+' || unary.oper->token[1]=='-');
460         visit(unary.expression);
461         r_trivial = false;
462 }
463
464 void ExpressionInliner::visit(BinaryExpression &binary)
465 {
466         visit(binary.left);
467         {
468                 SetFlag clear_target(mutating, false);
469                 visit(binary.right);
470         }
471         r_trivial = false;
472 }
473
474 void ExpressionInliner::visit(Assignment &assign)
475 {
476         {
477                 SetFlag set_target(mutating);
478                 visit(assign.left);
479         }
480         r_oper = 0;
481         visit(assign.right);
482
483         map<Assignment::Target, ExpressionInfo>::iterator i = expressions.find(assign.target);
484         if(i!=expressions.end())
485         {
486                 /* Self-referencing assignments can't be inlined without additional
487                 work.  Just clear any previous expression. */
488                 i->second.expression = (assign.self_referencing ? 0 : assign.right.get());
489                 i->second.assign_scope = current_block;
490                 i->second.inline_point = 0;
491                 i->second.available = true;
492         }
493
494         r_trivial = false;
495 }
496
497 void ExpressionInliner::visit(TernaryExpression &ternary)
498 {
499         visit(ternary.condition);
500         visit(ternary.true_expr);
501         visit(ternary.false_expr);
502         r_trivial = false;
503 }
504
505 void ExpressionInliner::visit(FunctionCall &call)
506 {
507         TraversingVisitor::visit(call);
508         r_trivial = false;
509 }
510
511 void ExpressionInliner::visit(VariableDeclaration &var)
512 {
513         r_oper = 0;
514         r_trivial = true;
515         TraversingVisitor::visit(var);
516
517         bool constant = var.constant;
518         if(constant && var.layout)
519         {
520                 for(vector<Layout::Qualifier>::const_iterator i=var.layout->qualifiers.begin(); (constant && i!=var.layout->qualifiers.end()); ++i)
521                         constant = (i->name!="constant_id");
522         }
523
524         /* Only inline global variables if they're constant and have trivial
525         initializers.  Non-constant variables could change in ways which are hard to
526         analyze and non-trivial expressions could be expensive to inline.  */
527         if((current_block->parent || (constant && r_trivial)) && var.interface.empty())
528         {
529                 ExpressionInfo &info = expressions[&var];
530                 /* Assume variables declared in an iteration initialization statement
531                 will have their values change throughout the iteration. */
532                 info.expression = (iteration_init ? 0 : var.init_expression.get());
533                 info.assign_scope = current_block;
534                 info.trivial = r_trivial;
535         }
536 }
537
538 void ExpressionInliner::visit(Iteration &iter)
539 {
540         SetForScope<Block *> set_block(current_block, &iter.body);
541         if(iter.init_statement)
542         {
543                 SetFlag set_init(iteration_init);
544                 iter.init_statement->visit(*this);
545         }
546
547         SetForScope<Block *> set_body(iteration_body, &iter.body);
548         if(iter.condition)
549                 visit(iter.condition);
550         iter.body.visit(*this);
551         if(iter.loop_expression)
552                 visit(iter.loop_expression);
553 }
554
555
556 template<typename T>
557 T ConstantFolder::evaluate_logical(char oper, T left, T right)
558 {
559         switch(oper)
560         {
561         case '&': return left&right;
562         case '|': return left|right;
563         case '^': return left^right;
564         default: return T();
565         }
566 }
567
568 template<typename T>
569 bool ConstantFolder::evaluate_relation(const char *oper, T left, T right)
570 {
571         switch(oper[0]|oper[1])
572         {
573         case '<': return left<right;
574         case '<'|'=': return left<=right;
575         case '>': return left>right;
576         case '>'|'=': return left>=right;
577         default: return false;
578         }
579 }
580
581 template<typename T>
582 T ConstantFolder::evaluate_arithmetic(char oper, T left, T right)
583 {
584         switch(oper)
585         {
586         case '+': return left+right;
587         case '-': return left-right;
588         case '*': return left*right;
589         case '/': return left/right;
590         default: return T();
591         }
592 }
593
594 void ConstantFolder::set_result(const Variant &value, bool literal)
595 {
596         r_constant_value = value;
597         r_constant = true;
598         r_literal = literal;
599 }
600
601 void ConstantFolder::visit(RefPtr<Expression> &expr)
602 {
603         r_constant_value = Variant();
604         r_constant = false;
605         r_literal = false;
606         r_uses_iter_var = false;
607         expr->visit(*this);
608         /* Don't replace literals since they'd only be replaced with an identical
609         literal.  Also skip anything that uses an iteration variable, but pass on
610         the result so the Iteration visiting function can handle it. */
611         if(!r_constant || r_literal || r_uses_iter_var)
612                 return;
613
614         RefPtr<Literal> literal = new Literal;
615         if(r_constant_value.check_type<bool>())
616                 literal->token = (r_constant_value.value<bool>() ? "true" : "false");
617         else if(r_constant_value.check_type<int>())
618                 literal->token = lexical_cast<string>(r_constant_value.value<int>());
619         else if(r_constant_value.check_type<float>())
620                 literal->token = lexical_cast<string>(r_constant_value.value<float>());
621         else
622         {
623                 r_constant = false;
624                 return;
625         }
626         literal->value = r_constant_value;
627         expr = literal;
628 }
629
630 void ConstantFolder::visit(Literal &literal)
631 {
632         set_result(literal.value, true);
633 }
634
635 void ConstantFolder::visit(VariableReference &var)
636 {
637         /* If an iteration variable is initialized with a constant value, return
638         that value here for the purpose of evaluating the loop condition for the
639         first iteration. */
640         if(var.declaration==iteration_var)
641         {
642                 set_result(iter_init_value);
643                 r_uses_iter_var = true;
644         }
645 }
646
647 void ConstantFolder::visit(MemberAccess &memacc)
648 {
649         TraversingVisitor::visit(memacc);
650         r_constant = false;
651 }
652
653 void ConstantFolder::visit(Swizzle &swizzle)
654 {
655         TraversingVisitor::visit(swizzle);
656         r_constant = false;
657 }
658
659 void ConstantFolder::visit(UnaryExpression &unary)
660 {
661         TraversingVisitor::visit(unary);
662         bool can_fold = r_constant;
663         r_constant = false;
664         if(!can_fold)
665                 return;
666
667         char oper = unary.oper->token[0];
668         char oper2 = unary.oper->token[1];
669         if(oper=='!')
670         {
671                 if(r_constant_value.check_type<bool>())
672                         set_result(!r_constant_value.value<bool>());
673         }
674         else if(oper=='~')
675         {
676                 if(r_constant_value.check_type<int>())
677                         set_result(~r_constant_value.value<int>());
678                 else if(r_constant_value.check_type<unsigned>())
679                         set_result(~r_constant_value.value<unsigned>());
680         }
681         else if(oper=='-' && !oper2)
682         {
683                 if(r_constant_value.check_type<int>())
684                         set_result(-r_constant_value.value<int>());
685                 else if(r_constant_value.check_type<unsigned>())
686                         set_result(-r_constant_value.value<unsigned>());
687                 else if(r_constant_value.check_type<float>())
688                         set_result(-r_constant_value.value<float>());
689         }
690 }
691
692 void ConstantFolder::visit(BinaryExpression &binary)
693 {
694         visit(binary.left);
695         bool left_constant = r_constant;
696         bool left_iter_var = r_uses_iter_var;
697         Variant left_value = r_constant_value;
698         visit(binary.right);
699         if(left_iter_var)
700                 r_uses_iter_var = true;
701
702         bool can_fold = (left_constant && r_constant);
703         r_constant = false;
704         if(!can_fold)
705                 return;
706
707         // Currently only expressions with both sides of equal types are handled.
708         if(!left_value.check_same_type(r_constant_value))
709                 return;
710
711         char oper = binary.oper->token[0];
712         char oper2 = binary.oper->token[1];
713         if(oper=='&' || oper=='|' || oper=='^')
714         {
715                 if(oper2==oper && left_value.check_type<bool>())
716                         set_result(evaluate_logical(oper, left_value.value<bool>(), r_constant_value.value<bool>()));
717                 else if(!oper2 && left_value.check_type<int>())
718                         set_result(evaluate_logical(oper, left_value.value<int>(), r_constant_value.value<int>()));
719         }
720         else if((oper=='<' || oper=='>') && oper2!=oper)
721         {
722                 if(left_value.check_type<int>())
723                         set_result(evaluate_relation(binary.oper->token, left_value.value<int>(), r_constant_value.value<int>()));
724                 else if(left_value.check_type<float>())
725                         set_result(evaluate_relation(binary.oper->token, left_value.value<float>(), r_constant_value.value<float>()));
726         }
727         else if((oper=='=' || oper=='!') && oper2=='=')
728         {
729                 if(left_value.check_type<int>())
730                         set_result((left_value.value<int>()==r_constant_value.value<int>()) == (oper=='='));
731                 if(left_value.check_type<float>())
732                         set_result((left_value.value<float>()==r_constant_value.value<float>()) == (oper=='='));
733         }
734         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
735         {
736                 if(left_value.check_type<int>())
737                         set_result(evaluate_arithmetic(oper, left_value.value<int>(), r_constant_value.value<int>()));
738                 else if(left_value.check_type<float>())
739                         set_result(evaluate_arithmetic(oper, left_value.value<float>(), r_constant_value.value<float>()));
740         }
741         else if(oper=='%' || ((oper=='<' || oper=='>') && oper2==oper))
742         {
743                 if(!left_value.check_type<int>())
744                         return;
745
746                 if(oper=='%')
747                         set_result(left_value.value<int>()%r_constant_value.value<int>());
748                 else if(oper=='<')
749                         set_result(left_value.value<int>()<<r_constant_value.value<int>());
750                 else if(oper=='>')
751                         set_result(left_value.value<int>()>>r_constant_value.value<int>());
752         }
753 }
754
755 void ConstantFolder::visit(Assignment &assign)
756 {
757         TraversingVisitor::visit(assign);
758         r_constant = false;
759 }
760
761 void ConstantFolder::visit(TernaryExpression &ternary)
762 {
763         TraversingVisitor::visit(ternary);
764         r_constant = false;
765 }
766
767 void ConstantFolder::visit(FunctionCall &call)
768 {
769         TraversingVisitor::visit(call);
770         r_constant = false;
771 }
772
773 void ConstantFolder::visit(VariableDeclaration &var)
774 {
775         if(iteration_init && var.init_expression)
776         {
777                 visit(var.init_expression);
778                 if(r_constant)
779                 {
780                         /* Record the value of a constant initialization expression of an
781                         iteration, so it can be used to evaluate the loop condition. */
782                         iteration_var = &var;
783                         iter_init_value = r_constant_value;
784                 }
785         }
786         else
787                 TraversingVisitor::visit(var);
788 }
789
790 void ConstantFolder::visit(Iteration &iter)
791 {
792         SetForScope<Block *> set_block(current_block, &iter.body);
793
794         /* The iteration variable is not normally inlined into expressions, so we
795         process it specially here.  If the initial value causes the loop condition
796         to evaluate to false, then the expression can be folded. */
797         iteration_var = 0;
798         if(iter.init_statement)
799         {
800                 SetFlag set_init(iteration_init);
801                 iter.init_statement->visit(*this);
802         }
803
804         if(iter.condition)
805         {
806                 visit(iter.condition);
807                 if(r_constant && r_constant_value.check_type<bool>() && !r_constant_value.value<bool>())
808                 {
809                         RefPtr<Literal> literal = new Literal;
810                         literal->token = "false";
811                         literal->value = r_constant_value;
812                         iter.condition = literal;
813                 }
814         }
815         iteration_var = 0;
816
817         iter.body.visit(*this);
818         if(iter.loop_expression)
819                 visit(iter.loop_expression);
820 }
821
822
823 void ConstantConditionEliminator::apply(Stage &stage)
824 {
825         stage.content.visit(*this);
826         NodeRemover().apply(stage, nodes_to_remove);
827 }
828
829 ConstantConditionEliminator::ConstantStatus ConstantConditionEliminator::check_constant_condition(const Expression &expr)
830 {
831         if(const Literal *literal = dynamic_cast<const Literal *>(&expr))
832                 if(literal->value.check_type<bool>())
833                         return (literal->value.value<bool>() ? CONSTANT_TRUE : CONSTANT_FALSE);
834         return NOT_CONSTANT;
835 }
836
837 void ConstantConditionEliminator::visit(Block &block)
838 {
839         SetForScope<Block *> set_block(current_block, &block);
840         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
841         {
842                 insert_point = i;
843                 (*i)->visit(*this);
844         }
845 }
846
847 void ConstantConditionEliminator::visit(RefPtr<Expression> &expr)
848 {
849         r_ternary_result = 0;
850         expr->visit(*this);
851         if(r_ternary_result)
852                 expr = r_ternary_result;
853         r_ternary_result = 0;
854 }
855
856 void ConstantConditionEliminator::visit(TernaryExpression &ternary)
857 {
858         ConstantStatus result = check_constant_condition(*ternary.condition);
859         if(result!=NOT_CONSTANT)
860                 r_ternary_result = (result==CONSTANT_TRUE ? ternary.true_expr : ternary.false_expr);
861         else
862                 r_ternary_result = 0;
863 }
864
865 void ConstantConditionEliminator::visit(Conditional &cond)
866 {
867         ConstantStatus result = check_constant_condition(*cond.condition);
868         if(result!=NOT_CONSTANT)
869         {
870                 Block &block = (result==CONSTANT_TRUE ? cond.body : cond.else_body);
871                 // TODO should check variable names for conflicts.  Potentially reuse InlineContentInjector?
872                 current_block->body.splice(insert_point, block.body);
873                 nodes_to_remove.insert(&cond);
874                 return;
875         }
876
877         TraversingVisitor::visit(cond);
878 }
879
880 void ConstantConditionEliminator::visit(Iteration &iter)
881 {
882         if(iter.condition)
883         {
884                 ConstantStatus result = check_constant_condition(*iter.condition);
885                 if(result==CONSTANT_FALSE)
886                 {
887                         nodes_to_remove.insert(&iter);
888                         return;
889                 }
890         }
891
892         TraversingVisitor::visit(iter);
893 }
894
895
896 UnreachableCodeRemover::UnreachableCodeRemover():
897         reachable(true)
898 { }
899
900 bool UnreachableCodeRemover::apply(Stage &stage)
901 {
902         stage.content.visit(*this);
903         NodeRemover().apply(stage, unreachable_nodes);
904         return !unreachable_nodes.empty();
905 }
906
907 void UnreachableCodeRemover::visit(Block &block)
908 {
909         NodeList<Statement>::iterator i = block.body.begin();
910         for(; (reachable && i!=block.body.end()); ++i)
911                 (*i)->visit(*this);
912         for(; i!=block.body.end(); ++i)
913                 unreachable_nodes.insert(i->get());
914 }
915
916 void UnreachableCodeRemover::visit(FunctionDeclaration &func)
917 {
918         TraversingVisitor::visit(func);
919         reachable = true;
920 }
921
922 void UnreachableCodeRemover::visit(Conditional &cond)
923 {
924         cond.body.visit(*this);
925         bool reachable_if_true = reachable;
926         reachable = true;
927         cond.else_body.visit(*this);
928
929         reachable |= reachable_if_true;
930 }
931
932 void UnreachableCodeRemover::visit(Iteration &iter)
933 {
934         TraversingVisitor::visit(iter);
935
936         /* Always consider code after a loop reachable, since there's no checking
937         for whether the loop executes. */
938         reachable = true;
939 }
940
941
942 bool UnusedTypeRemover::apply(Stage &stage)
943 {
944         stage.content.visit(*this);
945         NodeRemover().apply(stage, unused_nodes);
946         return !unused_nodes.empty();
947 }
948
949 void UnusedTypeRemover::visit(RefPtr<Expression> &expr)
950 {
951         unused_nodes.erase(expr->type);
952         TraversingVisitor::visit(expr);
953 }
954
955 void UnusedTypeRemover::visit(BasicTypeDeclaration &type)
956 {
957         if(type.base_type)
958                 unused_nodes.erase(type.base_type);
959         unused_nodes.insert(&type);
960 }
961
962 void UnusedTypeRemover::visit(ImageTypeDeclaration &type)
963 {
964         if(type.base_type)
965                 unused_nodes.erase(type.base_type);
966         unused_nodes.insert(&type);
967 }
968
969 void UnusedTypeRemover::visit(StructDeclaration &strct)
970 {
971         unused_nodes.insert(&strct);
972         TraversingVisitor::visit(strct);
973 }
974
975 void UnusedTypeRemover::visit(VariableDeclaration &var)
976 {
977         unused_nodes.erase(var.type_declaration);
978         TraversingVisitor::visit(var);
979 }
980
981 void UnusedTypeRemover::visit(InterfaceBlock &iface)
982 {
983         unused_nodes.erase(iface.type_declaration);
984 }
985
986 void UnusedTypeRemover::visit(FunctionDeclaration &func)
987 {
988         unused_nodes.erase(func.return_type_declaration);
989         TraversingVisitor::visit(func);
990 }
991
992
993 UnusedVariableRemover::UnusedVariableRemover():
994         stage(0),
995         interface_block(0),
996         r_assignment(0),
997         assignment_target(false),
998         r_side_effects(false),
999         in_struct(false),
1000         composite_reference(false)
1001 { }
1002
1003 bool UnusedVariableRemover::apply(Stage &s)
1004 {
1005         stage = &s;
1006         s.content.visit(*this);
1007
1008         for(list<AssignmentInfo>::const_iterator i=assignments.begin(); i!=assignments.end(); ++i)
1009                 if(i->used_by.empty())
1010                         unused_nodes.insert(i->node);
1011
1012         for(BlockVariableMap::const_iterator i=variables.begin(); i!=variables.end(); ++i)
1013         {
1014                 if(i->second.output)
1015                 {
1016                         /* The last visible assignments of output variables are used by the
1017                         next stage or the API. */
1018                         for(vector<AssignmentInfo *>::const_iterator j=i->second.assignments.begin(); j!=i->second.assignments.end(); ++j)
1019                                 unused_nodes.erase((*j)->node);
1020                 }
1021
1022                 if(!i->second.output && !i->second.referenced)
1023                 {
1024                         // Don't remove variables from inside interface blocks.
1025                         if(!i->second.interface_block)
1026                                 unused_nodes.insert(i->first);
1027                 }
1028                 else if(i->second.interface_block)
1029                         // Interface blocks are kept if even one member is used.
1030                         unused_nodes.erase(i->second.interface_block);
1031         }
1032
1033         NodeRemover().apply(s, unused_nodes);
1034
1035         return !unused_nodes.empty();
1036 }
1037
1038 void UnusedVariableRemover::referenced(const Assignment::Target &target, Node &node)
1039 {
1040         VariableInfo &var_info = variables[target.declaration];
1041         var_info.referenced = true;
1042         if(!assignment_target)
1043         {
1044                 for(vector<AssignmentInfo *>::const_iterator i=var_info.assignments.begin(); i!=var_info.assignments.end(); ++i)
1045                 {
1046                         bool covered = true;
1047                         for(unsigned j=0; (covered && j<(*i)->target.chain_len && j<target.chain_len); ++j)
1048                         {
1049                                 Assignment::Target::ChainType type1 = static_cast<Assignment::Target::ChainType>((*i)->target.chain[j]&0xC0);
1050                                 Assignment::Target::ChainType type2 = static_cast<Assignment::Target::ChainType>(target.chain[j]&0xC0);
1051                                 if(type1==Assignment::Target::SWIZZLE || type2==Assignment::Target::SWIZZLE)
1052                                 {
1053                                         unsigned index1 = (*i)->target.chain[j]&0x3F;
1054                                         unsigned index2 = target.chain[j]&0x3F;
1055                                         if(type1==Assignment::Target::SWIZZLE && type2==Assignment::Target::SWIZZLE)
1056                                                 covered = index1&index2;
1057                                         else if(type1==Assignment::Target::ARRAY && index1<4)
1058                                                 covered = index2&(1<<index1);
1059                                         else if(type2==Assignment::Target::ARRAY && index2<4)
1060                                                 covered = index1&(1<<index2);
1061                                         /* If it's some other combination (shouldn't happen), leave
1062                                         covered as true */
1063                                 }
1064                                 else
1065                                         covered = ((*i)->target.chain[j]==target.chain[j]);
1066                         }
1067                         if(covered)
1068                                 (*i)->used_by.push_back(&node);
1069                 }
1070         }
1071 }
1072
1073 void UnusedVariableRemover::visit(VariableReference &var)
1074 {
1075         if(composite_reference)
1076                 r_reference.declaration = var.declaration;
1077         else
1078                 referenced(var.declaration, var);
1079 }
1080
1081 void UnusedVariableRemover::visit(InterfaceBlockReference &iface)
1082 {
1083         if(composite_reference)
1084                 r_reference.declaration = iface.declaration;
1085         else
1086                 referenced(iface.declaration, iface);
1087 }
1088
1089 void UnusedVariableRemover::visit_composite(Expression &expr)
1090 {
1091         if(!composite_reference)
1092                 r_reference = Assignment::Target();
1093
1094         SetFlag set_composite(composite_reference);
1095         expr.visit(*this);
1096 }
1097
1098 void UnusedVariableRemover::visit(MemberAccess &memacc)
1099 {
1100         visit_composite(*memacc.left);
1101
1102         add_to_chain(r_reference, Assignment::Target::MEMBER, memacc.index);
1103
1104         if(!composite_reference && r_reference.declaration)
1105                 referenced(r_reference, memacc);
1106 }
1107
1108 void UnusedVariableRemover::visit(Swizzle &swizzle)
1109 {
1110         visit_composite(*swizzle.left);
1111
1112         unsigned mask = 0;
1113         for(unsigned i=0; i<swizzle.count; ++i)
1114                 mask |= 1<<swizzle.components[i];
1115         add_to_chain(r_reference, Assignment::Target::SWIZZLE, mask);
1116
1117         if(!composite_reference && r_reference.declaration)
1118                 referenced(r_reference, swizzle);
1119 }
1120
1121 void UnusedVariableRemover::visit(UnaryExpression &unary)
1122 {
1123         TraversingVisitor::visit(unary);
1124         if(unary.oper->token[1]=='+' || unary.oper->token[1]=='-')
1125                 r_side_effects = true;
1126 }
1127
1128 void UnusedVariableRemover::visit(BinaryExpression &binary)
1129 {
1130         if(binary.oper->token[0]=='[')
1131         {
1132                 visit_composite(*binary.left);
1133
1134                 {
1135                         SetFlag clear_assignment(assignment_target, false);
1136                         SetFlag clear_composite(composite_reference, false);
1137                         binary.right->visit(*this);
1138                 }
1139
1140                 add_to_chain(r_reference, Assignment::Target::ARRAY, 0x3F);
1141
1142                 if(!composite_reference && r_reference.declaration)
1143                         referenced(r_reference, binary);
1144         }
1145         else
1146         {
1147                 SetFlag clear_composite(composite_reference, false);
1148                 TraversingVisitor::visit(binary);
1149         }
1150 }
1151
1152 void UnusedVariableRemover::visit(TernaryExpression &ternary)
1153 {
1154         SetFlag clear_composite(composite_reference, false);
1155         TraversingVisitor::visit(ternary);
1156 }
1157
1158 void UnusedVariableRemover::visit(Assignment &assign)
1159 {
1160         {
1161                 SetFlag set(assignment_target, (assign.oper->token[0]=='='));
1162                 assign.left->visit(*this);
1163         }
1164         assign.right->visit(*this);
1165         r_assignment = &assign;
1166         r_side_effects = true;
1167 }
1168
1169 void UnusedVariableRemover::visit(FunctionCall &call)
1170 {
1171         SetFlag clear_composite(composite_reference, false);
1172         TraversingVisitor::visit(call);
1173         /* Treat function calls as having side effects so expression statements
1174         consisting of nothing but a function call won't be optimized away. */
1175         r_side_effects = true;
1176
1177         if(stage->type==Stage::GEOMETRY && call.name=="EmitVertex")
1178         {
1179                 for(map<Statement *, VariableInfo>::const_iterator i=variables.begin(); i!=variables.end(); ++i)
1180                         if(i->second.output)
1181                                 referenced(i->first, call);
1182         }
1183 }
1184
1185 void UnusedVariableRemover::record_assignment(const Assignment::Target &target, Node &node)
1186 {
1187         assignments.push_back(AssignmentInfo());
1188         AssignmentInfo &assign_info = assignments.back();
1189         assign_info.node = &node;
1190         assign_info.target = target;
1191
1192         /* An assignment to the target hides any assignments to the same target or
1193         its subfields. */
1194         VariableInfo &var_info = variables[target.declaration];
1195         for(unsigned i=0; i<var_info.assignments.size(); ++i)
1196         {
1197                 const Assignment::Target &t = var_info.assignments[i]->target;
1198
1199                 bool subfield = (t.chain_len>=target.chain_len);
1200                 for(unsigned j=0; (subfield && j<target.chain_len); ++j)
1201                         subfield = (t.chain[j]==target.chain[j]);
1202
1203                 if(subfield)
1204                         var_info.assignments.erase(var_info.assignments.begin()+i);
1205                 else
1206                         ++i;
1207         }
1208
1209         var_info.assignments.push_back(&assign_info);
1210 }
1211
1212 void UnusedVariableRemover::visit(ExpressionStatement &expr)
1213 {
1214         r_assignment = 0;
1215         r_side_effects = false;
1216         TraversingVisitor::visit(expr);
1217         if(r_assignment && r_assignment->target.declaration)
1218                 record_assignment(r_assignment->target, expr);
1219         if(!r_side_effects)
1220                 unused_nodes.insert(&expr);
1221 }
1222
1223 void UnusedVariableRemover::visit(StructDeclaration &strct)
1224 {
1225         SetFlag set_struct(in_struct);
1226         TraversingVisitor::visit(strct);
1227 }
1228
1229 void UnusedVariableRemover::visit(VariableDeclaration &var)
1230 {
1231         TraversingVisitor::visit(var);
1232
1233         if(in_struct)
1234                 return;
1235
1236         VariableInfo &var_info = variables[&var];
1237         var_info.interface_block = interface_block;
1238
1239         /* Mark variables as output if they're used by the next stage or the
1240         graphics API. */
1241         if(interface_block)
1242                 var_info.output = (interface_block->interface=="out" && (interface_block->linked_block || !interface_block->block_name.compare(0, 3, "gl_")));
1243         else
1244                 var_info.output = (var.interface=="out" && (stage->type==Stage::FRAGMENT || var.linked_declaration || !var.name.compare(0, 3, "gl_")));
1245
1246         if(var.init_expression)
1247         {
1248                 var_info.initialized = true;
1249                 record_assignment(&var, *var.init_expression);
1250         }
1251 }
1252
1253 void UnusedVariableRemover::visit(InterfaceBlock &iface)
1254 {
1255         VariableInfo &var_info = variables[&iface];
1256         var_info.output = (iface.interface=="out" && (iface.linked_block || !iface.block_name.compare(0, 3, "gl_")));
1257 }
1258
1259 void UnusedVariableRemover::merge_variables(const BlockVariableMap &other_vars)
1260 {
1261         for(BlockVariableMap::const_iterator i=other_vars.begin(); i!=other_vars.end(); ++i)
1262         {
1263                 BlockVariableMap::iterator j = variables.find(i->first);
1264                 if(j!=variables.end())
1265                 {
1266                         /* The merged blocks started as copies of each other so any common
1267                         assignments must be in the beginning. */
1268                         unsigned k = 0;
1269                         for(; (k<i->second.assignments.size() && k<j->second.assignments.size()); ++k)
1270                                 if(i->second.assignments[k]!=j->second.assignments[k])
1271                                         break;
1272
1273                         // Remaining assignments are unique to each block; merge them.
1274                         j->second.assignments.insert(j->second.assignments.end(), i->second.assignments.begin()+k, i->second.assignments.end());
1275                         j->second.referenced |= i->second.referenced;
1276                 }
1277                 else
1278                         variables.insert(*i);
1279         }
1280 }
1281
1282 void UnusedVariableRemover::visit(FunctionDeclaration &func)
1283 {
1284         if(func.body.body.empty())
1285                 return;
1286
1287         BlockVariableMap saved_vars = variables;
1288         // Assignments from other functions should not be visible.
1289         for(BlockVariableMap::iterator i=variables.begin(); i!=variables.end(); ++i)
1290                 i->second.assignments.resize(i->second.initialized);
1291         TraversingVisitor::visit(func);
1292         swap(variables, saved_vars);
1293         merge_variables(saved_vars);
1294
1295         /* Always treat function parameters as referenced.  Removing unused
1296         parameters is not currently supported. */
1297         for(NodeArray<VariableDeclaration>::iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
1298         {
1299                 BlockVariableMap::iterator j = variables.find(i->get());
1300                 if(j!=variables.end())
1301                         j->second.referenced = true;
1302         }
1303 }
1304
1305 void UnusedVariableRemover::visit(Conditional &cond)
1306 {
1307         cond.condition->visit(*this);
1308         BlockVariableMap saved_vars = variables;
1309         cond.body.visit(*this);
1310         swap(saved_vars, variables);
1311         cond.else_body.visit(*this);
1312
1313         /* Visible assignments after the conditional is the union of those visible
1314         at the end of the if and else blocks.  If there was no else block, then it's
1315         the union of the if block and the state before it. */
1316         merge_variables(saved_vars);
1317 }
1318
1319 void UnusedVariableRemover::visit(Iteration &iter)
1320 {
1321         BlockVariableMap saved_vars = variables;
1322         TraversingVisitor::visit(iter);
1323
1324         /* Merge assignments from the iteration, without clearing previous state.
1325         Further analysis is needed to determine which parts of the iteration body
1326         are always executed, if any. */
1327         merge_variables(saved_vars);
1328 }
1329
1330
1331 bool UnusedFunctionRemover::apply(Stage &stage)
1332 {
1333         stage.content.visit(*this);
1334         NodeRemover().apply(stage, unused_nodes);
1335         return !unused_nodes.empty();
1336 }
1337
1338 void UnusedFunctionRemover::visit(FunctionCall &call)
1339 {
1340         TraversingVisitor::visit(call);
1341
1342         unused_nodes.erase(call.declaration);
1343         if(call.declaration && call.declaration->definition!=call.declaration)
1344                 used_definitions.insert(call.declaration->definition);
1345 }
1346
1347 void UnusedFunctionRemover::visit(FunctionDeclaration &func)
1348 {
1349         TraversingVisitor::visit(func);
1350
1351         if((func.name!="main" || func.body.body.empty()) && !used_definitions.count(&func))
1352                 unused_nodes.insert(&func);
1353 }
1354
1355 } // namespace SL
1356 } // namespace GL
1357 } // namespace Msp