]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/optimize.cpp
Fix function inlining regressions
[libs/gl.git] / source / glsl / optimize.cpp
1 #include <msp/core/raii.h>
2 #include <msp/strings/format.h>
3 #include "optimize.h"
4
5 using namespace std;
6
7 namespace Msp {
8 namespace GL {
9 namespace SL {
10
11 InlineableFunctionLocator::InlineableFunctionLocator():
12         current_function(0),
13         return_count(0)
14 { }
15
16 void InlineableFunctionLocator::visit(FunctionCall &call)
17 {
18         FunctionDeclaration *def = call.declaration;
19         if(def)
20                 def = def->definition;
21
22         if(def)
23         {
24                 unsigned &count = refcounts[def];
25                 ++count;
26                 /* Don't inline functions which are called more than once or are called
27                 recursively. */
28                 if(count>1 || def==current_function)
29                         inlineable.erase(def);
30         }
31
32         TraversingVisitor::visit(call);
33 }
34
35 void InlineableFunctionLocator::visit(FunctionDeclaration &func)
36 {
37         bool has_out_params = false;
38         for(NodeArray<VariableDeclaration>::const_iterator i=func.parameters.begin(); (!has_out_params && i!=func.parameters.end()); ++i)
39                 has_out_params = ((*i)->interface=="out");
40
41         unsigned &count = refcounts[func.definition];
42         if(count<=1 && !has_out_params)
43                 inlineable.insert(func.definition);
44
45         SetForScope<FunctionDeclaration *> set(current_function, &func);
46         return_count = 0;
47         TraversingVisitor::visit(func);
48 }
49
50 void InlineableFunctionLocator::visit(Conditional &cond)
51 {
52         TraversingVisitor::visit(cond);
53         inlineable.erase(current_function);
54 }
55
56 void InlineableFunctionLocator::visit(Iteration &iter)
57 {
58         TraversingVisitor::visit(iter);
59         inlineable.erase(current_function);
60 }
61
62 void InlineableFunctionLocator::visit(Return &ret)
63 {
64         TraversingVisitor::visit(ret);
65         if(return_count)
66                 inlineable.erase(current_function);
67         ++return_count;
68 }
69
70
71 InlineContentInjector::InlineContentInjector():
72         source_func(0),
73         pass(DEPENDS)
74 { }
75
76 const string &InlineContentInjector::apply(Stage &stage, FunctionDeclaration &target_func, Block &tgt_blk, const NodeList<Statement>::iterator &ins_pt, FunctionCall &call)
77 {
78         source_func = call.declaration->definition;
79
80         // Collect all declarations the inlined function depends on.
81         pass = DEPENDS;
82         source_func->visit(*this);
83
84         /* Populate referenced_names from the target function so we can rename
85         variables from the inlined function that would conflict. */
86         pass = REFERENCED;
87         target_func.visit(*this);
88
89         /* Inline and rename passes must be interleaved so used variable names are
90         known when inlining the return statement. */
91         pass = INLINE;
92         staging_block.parent = &tgt_blk;
93         staging_block.variables.clear();
94
95         std::vector<RefPtr<VariableDeclaration> > params;
96         params.reserve(source_func->parameters.size());
97         for(NodeArray<VariableDeclaration>::iterator i=source_func->parameters.begin(); i!=source_func->parameters.end(); ++i)
98         {
99                 RefPtr<VariableDeclaration> var = (*i)->clone();
100                 var->interface.clear();
101
102                 SetForScope<Pass> set_pass(pass, RENAME);
103                 var->visit(*this);
104
105                 staging_block.body.push_back_nocopy(var);
106                 params.push_back(var);
107         }
108
109         for(NodeList<Statement>::iterator i=source_func->body.body.begin(); i!=source_func->body.body.end(); ++i)
110         {
111                 r_inlined_statement = 0;
112                 (*i)->visit(*this);
113                 if(!r_inlined_statement)
114                         r_inlined_statement = (*i)->clone();
115
116                 SetForScope<Pass> set_pass(pass, RENAME);
117                 r_inlined_statement->visit(*this);
118
119                 staging_block.body.push_back_nocopy(r_inlined_statement);
120         }
121
122         /* Now collect names from the staging block.  Local variables that would
123         have conflicted with the target function were renamed earlier. */
124         pass = REFERENCED;
125         referenced_names.clear();
126         staging_block.variables.clear();
127         staging_block.visit(*this);
128
129         /* Rename variables in the target function so they don't interfere with
130         global identifiers used by the source function. */
131         pass = RENAME;
132         staging_block.parent = source_func->body.parent;
133         target_func.visit(*this);
134
135         // Put the argument expressions in place after all renaming has been done.
136         for(unsigned i=0; i<source_func->parameters.size(); ++i)
137                 params[i]->init_expression = call.arguments[i]->clone();
138
139         tgt_blk.body.splice(ins_pt, staging_block.body);
140
141         NodeReorderer().apply(stage, target_func, dependencies);
142
143         return r_result_name;
144 }
145
146 void InlineContentInjector::visit(VariableReference &var)
147 {
148         if(pass==RENAME)
149         {
150                 map<string, VariableDeclaration *>::const_iterator i = staging_block.variables.find(var.name);
151                 if(i!=staging_block.variables.end())
152                         var.name = i->second->name;
153         }
154         else if(pass==DEPENDS && var.declaration)
155         {
156                 dependencies.insert(var.declaration);
157                 var.declaration->visit(*this);
158         }
159         else if(pass==REFERENCED)
160                 referenced_names.insert(var.name);
161 }
162
163 void InlineContentInjector::visit(InterfaceBlockReference &iface)
164 {
165         if(pass==DEPENDS && iface.declaration)
166         {
167                 dependencies.insert(iface.declaration);
168                 iface.declaration->visit(*this);
169         }
170         else if(pass==REFERENCED)
171                 referenced_names.insert(iface.name);
172 }
173
174 void InlineContentInjector::visit(FunctionCall &call)
175 {
176         if(pass==DEPENDS && call.declaration)
177                 dependencies.insert(call.declaration);
178         else if(pass==REFERENCED)
179                 referenced_names.insert(call.name);
180         TraversingVisitor::visit(call);
181 }
182
183 void InlineContentInjector::visit(VariableDeclaration &var)
184 {
185         TraversingVisitor::visit(var);
186
187         if(pass==RENAME)
188         {
189                 /* Check against conflicts with the other context as well as variables
190                 already renamed here. */
191                 bool conflict = (staging_block.variables.count(var.name) || referenced_names.count(var.name));
192                 staging_block.variables[var.name] = &var;
193                 if(conflict)
194                 {
195                         string mapped_name = get_unused_variable_name(staging_block, var.name);
196                         if(mapped_name!=var.name)
197                         {
198                                 staging_block.variables[mapped_name] = &var;
199                                 var.name = mapped_name;
200                         }
201                 }
202         }
203         else if(pass==DEPENDS && var.type_declaration)
204         {
205                 dependencies.insert(var.type_declaration);
206                 var.type_declaration->visit(*this);
207         }
208         else if(pass==REFERENCED)
209                 referenced_names.insert(var.type);
210 }
211
212 void InlineContentInjector::visit(Return &ret)
213 {
214         TraversingVisitor::visit(ret);
215
216         if(pass==INLINE && ret.expression)
217         {
218                 // Create a new variable to hold the return value of the inlined function.
219                 r_result_name = get_unused_variable_name(staging_block, "_return");
220                 RefPtr<VariableDeclaration> var = new VariableDeclaration;
221                 var->source = ret.source;
222                 var->line = ret.line;
223                 var->type = source_func->return_type;
224                 var->name = r_result_name;
225                 var->init_expression = ret.expression->clone();
226                 r_inlined_statement = var;
227         }
228 }
229
230
231 FunctionInliner::FunctionInliner():
232         current_function(0),
233         r_any_inlined(false),
234         r_inlined_here(false)
235 { }
236
237 bool FunctionInliner::apply(Stage &s)
238 {
239         stage = &s;
240         inlineable = InlineableFunctionLocator().apply(s);
241         r_any_inlined = false;
242         s.content.visit(*this);
243         return r_any_inlined;
244 }
245
246 void FunctionInliner::visit(RefPtr<Expression> &ptr)
247 {
248         r_inline_result = 0;
249         ptr->visit(*this);
250         if(r_inline_result)
251         {
252                 ptr = r_inline_result;
253                 r_any_inlined = true;
254         }
255         r_inline_result = 0;
256 }
257
258 void FunctionInliner::visit(Block &block)
259 {
260         SetForScope<Block *> set_block(current_block, &block);
261         SetForScope<NodeList<Statement>::iterator> save_insert_point(insert_point, block.body.begin());
262         for(NodeList<Statement>::iterator i=block.body.begin(); (!r_inlined_here && i!=block.body.end()); ++i)
263         {
264                 insert_point = i;
265                 (*i)->visit(*this);
266         }
267 }
268
269 void FunctionInliner::visit(FunctionCall &call)
270 {
271         for(NodeArray<Expression>::iterator i=call.arguments.begin(); (!r_inlined_here && i!=call.arguments.end()); ++i)
272                 visit(*i);
273
274         if(r_inlined_here)
275                 return;
276
277         FunctionDeclaration *def = call.declaration;
278         if(def)
279                 def = def->definition;
280
281         if(def && inlineable.count(def))
282         {
283                 string result_name = InlineContentInjector().apply(*stage, *current_function, *current_block, insert_point, call);
284
285                 // This will later get removed by UnusedVariableRemover.
286                 if(result_name.empty())
287                         result_name = "_msp_unused_from_inline";
288
289                 RefPtr<VariableReference> ref = new VariableReference;
290                 ref->name = result_name;
291                 r_inline_result = ref;
292
293                 /* Inlined variables need to be resolved before this function can be
294                 inlined further. */
295                 inlineable.erase(current_function);
296                 r_inlined_here = true;
297         }
298 }
299
300 void FunctionInliner::visit(FunctionDeclaration &func)
301 {
302         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
303         TraversingVisitor::visit(func);
304         r_inlined_here = false;
305 }
306
307 void FunctionInliner::visit(Iteration &iter)
308 {
309         /* Visit the initialization statement before entering the loop body so the
310         inlined statements get inserted outside. */
311         if(iter.init_statement)
312                 iter.init_statement->visit(*this);
313
314         SetForScope<Block *> set_block(current_block, &iter.body);
315         /* Skip the condition and loop expression parts because they're not properly
316         inside the body block.  Inlining anything into them will require a more
317         comprehensive transformation. */
318         iter.body.visit(*this);
319 }
320
321
322 ExpressionInliner::ExpressionInfo::ExpressionInfo():
323         expression(0),
324         assign_scope(0),
325         inline_point(0),
326         trivial(false),
327         available(true)
328 { }
329
330
331 ExpressionInliner::ExpressionInliner():
332         r_ref_info(0),
333         r_any_inlined(false),
334         r_trivial(false),
335         mutating(false),
336         iteration_init(false),
337         iteration_body(0),
338         r_oper(0)
339 { }
340
341 bool ExpressionInliner::apply(Stage &s)
342 {
343         s.content.visit(*this);
344         return r_any_inlined;
345 }
346
347 void ExpressionInliner::inline_expression(Expression &expr, RefPtr<Expression> &ptr)
348 {
349         ptr = expr.clone();
350         r_any_inlined = true;
351 }
352
353 void ExpressionInliner::visit(Block &block)
354 {
355         TraversingVisitor::visit(block);
356
357         for(map<string, VariableDeclaration *>::iterator i=block.variables.begin(); i!=block.variables.end(); ++i)
358         {
359                 map<Assignment::Target, ExpressionInfo>::iterator j = expressions.lower_bound(i->second);
360                 for(; (j!=expressions.end() && j->first.declaration==i->second); )
361                 {
362                         if(j->second.expression && j->second.inline_point)
363                                 inline_expression(*j->second.expression, *j->second.inline_point);
364
365                         expressions.erase(j++);
366                 }
367         }
368
369         /* Expressions assigned in this block may depend on local variables of the
370         block.  If this is a conditionally executed block, the assignments might not
371         always happen.  Mark the expressions as not available to any outer blocks. */
372         for(map<Assignment::Target, ExpressionInfo>::iterator i=expressions.begin(); i!=expressions.end(); ++i)
373                 if(i->second.assign_scope==&block)
374                         i->second.available = false;
375 }
376
377 void ExpressionInliner::visit(RefPtr<Expression> &expr)
378 {
379         r_ref_info = 0;
380         expr->visit(*this);
381         if(r_ref_info && r_ref_info->expression && r_ref_info->available)
382         {
383                 if(iteration_body && !r_ref_info->trivial)
384                 {
385                         /* Don't inline non-trivial expressions which were assigned outside
386                         an iteration statement.  The iteration may run multiple times, which
387                         would cause the expression to also be evaluated multiple times. */
388                         Block *i = r_ref_info->assign_scope;
389                         for(; (i && i!=iteration_body); i=i->parent) ;
390                         if(!i)
391                                 return;
392                 }
393
394                 if(r_ref_info->trivial)
395                         inline_expression(*r_ref_info->expression, expr);
396                 else
397                         /* Record the inline point for a non-trivial expression but don't
398                         inline it yet.  It might turn out it shouldn't be inlined after all. */
399                         r_ref_info->inline_point = &expr;
400         }
401         r_oper = expr->oper;
402         r_ref_info = 0;
403 }
404
405 void ExpressionInliner::visit(VariableReference &var)
406 {
407         if(var.declaration)
408         {
409                 map<Assignment::Target, ExpressionInfo>::iterator i = expressions.find(var.declaration);
410                 if(i!=expressions.end())
411                 {
412                         /* If a non-trivial expression is referenced multiple times, don't
413                         inline it. */
414                         if(i->second.inline_point && !i->second.trivial)
415                                 i->second.expression = 0;
416                         /* Mutating expressions are analogous to self-referencing assignments
417                         and prevent inlining. */
418                         if(mutating)
419                                 i->second.expression = 0;
420                         r_ref_info = &i->second;
421                 }
422         }
423 }
424
425 void ExpressionInliner::visit(MemberAccess &memacc)
426 {
427         visit(memacc.left);
428         r_trivial = false;
429 }
430
431 void ExpressionInliner::visit(Swizzle &swizzle)
432 {
433         visit(swizzle.left);
434         r_trivial = false;
435 }
436
437 void ExpressionInliner::visit(UnaryExpression &unary)
438 {
439         SetFlag set_target(mutating, mutating || unary.oper->token[1]=='+' || unary.oper->token[1]=='-');
440         visit(unary.expression);
441         r_trivial = false;
442 }
443
444 void ExpressionInliner::visit(BinaryExpression &binary)
445 {
446         visit(binary.left);
447         {
448                 SetFlag clear_target(mutating, false);
449                 visit(binary.right);
450         }
451         r_trivial = false;
452 }
453
454 void ExpressionInliner::visit(Assignment &assign)
455 {
456         {
457                 SetFlag set_target(mutating);
458                 visit(assign.left);
459         }
460         r_oper = 0;
461         visit(assign.right);
462
463         map<Assignment::Target, ExpressionInfo>::iterator i = expressions.find(assign.target);
464         if(i!=expressions.end())
465         {
466                 /* Self-referencing assignments can't be inlined without additional
467                 work.  Just clear any previous expression. */
468                 i->second.expression = (assign.self_referencing ? 0 : assign.right.get());
469                 i->second.assign_scope = current_block;
470                 i->second.inline_point = 0;
471                 i->second.available = true;
472         }
473
474         r_trivial = false;
475 }
476
477 void ExpressionInliner::visit(TernaryExpression &ternary)
478 {
479         visit(ternary.condition);
480         visit(ternary.true_expr);
481         visit(ternary.false_expr);
482         r_trivial = false;
483 }
484
485 void ExpressionInliner::visit(FunctionCall &call)
486 {
487         TraversingVisitor::visit(call);
488         r_trivial = false;
489 }
490
491 void ExpressionInliner::visit(VariableDeclaration &var)
492 {
493         r_oper = 0;
494         r_trivial = true;
495         TraversingVisitor::visit(var);
496
497         bool constant = var.constant;
498         if(constant && var.layout)
499         {
500                 for(vector<Layout::Qualifier>::const_iterator i=var.layout->qualifiers.begin(); (constant && i!=var.layout->qualifiers.end()); ++i)
501                         constant = (i->name!="constant_id");
502         }
503
504         /* Only inline global variables if they're constant and have trivial
505         initializers.  Non-constant variables could change in ways which are hard to
506         analyze and non-trivial expressions could be expensive to inline.  */
507         if((current_block->parent || (constant && r_trivial)) && var.interface.empty())
508         {
509                 ExpressionInfo &info = expressions[&var];
510                 /* Assume variables declared in an iteration initialization statement
511                 will have their values change throughout the iteration. */
512                 info.expression = (iteration_init ? 0 : var.init_expression.get());
513                 info.assign_scope = current_block;
514                 info.trivial = r_trivial;
515         }
516 }
517
518 void ExpressionInliner::visit(Iteration &iter)
519 {
520         SetForScope<Block *> set_block(current_block, &iter.body);
521         if(iter.init_statement)
522         {
523                 SetFlag set_init(iteration_init);
524                 iter.init_statement->visit(*this);
525         }
526
527         SetForScope<Block *> set_body(iteration_body, &iter.body);
528         if(iter.condition)
529                 visit(iter.condition);
530         iter.body.visit(*this);
531         if(iter.loop_expression)
532                 visit(iter.loop_expression);
533 }
534
535
536 BasicTypeDeclaration::Kind ConstantFolder::get_value_kind(const Variant &value)
537 {
538         if(value.check_type<bool>())
539                 return BasicTypeDeclaration::BOOL;
540         else if(value.check_type<int>())
541                 return BasicTypeDeclaration::INT;
542         else if(value.check_type<float>())
543                 return BasicTypeDeclaration::FLOAT;
544         else
545                 return BasicTypeDeclaration::VOID;
546 }
547
548 template<typename T>
549 T ConstantFolder::evaluate_logical(char oper, T left, T right)
550 {
551         switch(oper)
552         {
553         case '&': return left&right;
554         case '|': return left|right;
555         case '^': return left^right;
556         default: return T();
557         }
558 }
559
560 template<typename T>
561 bool ConstantFolder::evaluate_relation(const char *oper, T left, T right)
562 {
563         switch(oper[0]|oper[1])
564         {
565         case '<': return left<right;
566         case '<'|'=': return left<=right;
567         case '>': return left>right;
568         case '>'|'=': return left>=right;
569         default: return false;
570         }
571 }
572
573 template<typename T>
574 T ConstantFolder::evaluate_arithmetic(char oper, T left, T right)
575 {
576         switch(oper)
577         {
578         case '+': return left+right;
579         case '-': return left-right;
580         case '*': return left*right;
581         case '/': return left/right;
582         default: return T();
583         }
584 }
585
586 void ConstantFolder::set_result(const Variant &value, bool literal)
587 {
588         r_constant_value = value;
589         r_constant = true;
590         r_literal = literal;
591 }
592
593 void ConstantFolder::visit(RefPtr<Expression> &expr)
594 {
595         r_constant_value = Variant();
596         r_constant = false;
597         r_literal = false;
598         r_uses_iter_var = false;
599         expr->visit(*this);
600         /* Don't replace literals since they'd only be replaced with an identical
601         literal.  Also skip anything that uses an iteration variable, but pass on
602         the result so the Iteration visiting function can handle it. */
603         if(!r_constant || r_literal || r_uses_iter_var)
604                 return;
605
606         BasicTypeDeclaration::Kind kind = get_value_kind(r_constant_value);
607         if(kind==BasicTypeDeclaration::VOID)
608         {
609                 r_constant = false;
610                 return;
611         }
612
613         RefPtr<Literal> literal = new Literal;
614         if(kind==BasicTypeDeclaration::BOOL)
615                 literal->token = (r_constant_value.value<bool>() ? "true" : "false");
616         else if(kind==BasicTypeDeclaration::INT)
617                 literal->token = lexical_cast<string>(r_constant_value.value<int>());
618         else if(kind==BasicTypeDeclaration::FLOAT)
619                 literal->token = lexical_cast<string>(r_constant_value.value<float>());
620         literal->value = r_constant_value;
621         expr = literal;
622 }
623
624 void ConstantFolder::visit(Literal &literal)
625 {
626         set_result(literal.value, true);
627 }
628
629 void ConstantFolder::visit(VariableReference &var)
630 {
631         /* If an iteration variable is initialized with a constant value, return
632         that value here for the purpose of evaluating the loop condition for the
633         first iteration. */
634         if(var.declaration==iteration_var)
635         {
636                 set_result(iter_init_value);
637                 r_uses_iter_var = true;
638         }
639 }
640
641 void ConstantFolder::visit(MemberAccess &memacc)
642 {
643         TraversingVisitor::visit(memacc);
644         r_constant = false;
645 }
646
647 void ConstantFolder::visit(Swizzle &swizzle)
648 {
649         TraversingVisitor::visit(swizzle);
650         r_constant = false;
651 }
652
653 void ConstantFolder::visit(UnaryExpression &unary)
654 {
655         TraversingVisitor::visit(unary);
656         bool can_fold = r_constant;
657         r_constant = false;
658         if(!can_fold)
659                 return;
660
661         BasicTypeDeclaration::Kind kind = get_value_kind(r_constant_value);
662
663         char oper = unary.oper->token[0];
664         char oper2 = unary.oper->token[1];
665         if(oper=='!')
666         {
667                 if(kind==BasicTypeDeclaration::BOOL)
668                         set_result(!r_constant_value.value<bool>());
669         }
670         else if(oper=='~')
671         {
672                 if(kind==BasicTypeDeclaration::INT)
673                         set_result(~r_constant_value.value<int>());
674         }
675         else if(oper=='-' && !oper2)
676         {
677                 if(kind==BasicTypeDeclaration::INT)
678                         set_result(-r_constant_value.value<int>());
679                 else if(kind==BasicTypeDeclaration::FLOAT)
680                         set_result(-r_constant_value.value<float>());
681         }
682 }
683
684 void ConstantFolder::visit(BinaryExpression &binary)
685 {
686         visit(binary.left);
687         bool left_constant = r_constant;
688         bool left_iter_var = r_uses_iter_var;
689         Variant left_value = r_constant_value;
690         visit(binary.right);
691         if(left_iter_var)
692                 r_uses_iter_var = true;
693
694         bool can_fold = (left_constant && r_constant);
695         r_constant = false;
696         if(!can_fold)
697                 return;
698
699         BasicTypeDeclaration::Kind left_kind = get_value_kind(left_value);
700         BasicTypeDeclaration::Kind right_kind = get_value_kind(r_constant_value);
701         // Currently only expressions with both sides of equal types are handled.
702         if(left_kind!=right_kind)
703                 return;
704
705         char oper = binary.oper->token[0];
706         char oper2 = binary.oper->token[1];
707         if(oper=='&' || oper=='|' || oper=='^')
708         {
709                 if(oper2==oper && left_kind==BasicTypeDeclaration::BOOL)
710                         set_result(evaluate_logical(oper, left_value.value<bool>(), r_constant_value.value<bool>()));
711                 else if(!oper2 && left_kind==BasicTypeDeclaration::INT)
712                         set_result(evaluate_logical(oper, left_value.value<int>(), r_constant_value.value<int>()));
713         }
714         else if((oper=='<' || oper=='>') && oper2!=oper)
715         {
716                 if(left_kind==BasicTypeDeclaration::INT)
717                         set_result(evaluate_relation(binary.oper->token, left_value.value<int>(), r_constant_value.value<int>()));
718                 else if(left_kind==BasicTypeDeclaration::FLOAT)
719                         set_result(evaluate_relation(binary.oper->token, left_value.value<float>(), r_constant_value.value<float>()));
720         }
721         else if((oper=='=' || oper=='!') && oper2=='=')
722         {
723                 if(left_kind==BasicTypeDeclaration::INT)
724                         set_result((left_value.value<int>()==r_constant_value.value<int>()) == (oper=='='));
725                 if(left_kind==BasicTypeDeclaration::FLOAT)
726                         set_result((left_value.value<float>()==r_constant_value.value<float>()) == (oper=='='));
727         }
728         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
729         {
730                 if(left_kind==BasicTypeDeclaration::INT)
731                         set_result(evaluate_arithmetic(oper, left_value.value<int>(), r_constant_value.value<int>()));
732                 else if(left_kind==BasicTypeDeclaration::FLOAT)
733                         set_result(evaluate_arithmetic(oper, left_value.value<float>(), r_constant_value.value<float>()));
734         }
735         else if(oper=='%' || ((oper=='<' || oper=='>') && oper2==oper))
736         {
737                 if(left_kind!=BasicTypeDeclaration::INT)
738                         return;
739
740                 if(oper=='%')
741                         set_result(left_value.value<int>()%r_constant_value.value<int>());
742                 else if(oper=='<')
743                         set_result(left_value.value<int>()<<r_constant_value.value<int>());
744                 else if(oper=='>')
745                         set_result(left_value.value<int>()>>r_constant_value.value<int>());
746         }
747 }
748
749 void ConstantFolder::visit(Assignment &assign)
750 {
751         TraversingVisitor::visit(assign);
752         r_constant = false;
753 }
754
755 void ConstantFolder::visit(TernaryExpression &ternary)
756 {
757         TraversingVisitor::visit(ternary);
758         r_constant = false;
759 }
760
761 void ConstantFolder::visit(FunctionCall &call)
762 {
763         TraversingVisitor::visit(call);
764         r_constant = false;
765 }
766
767 void ConstantFolder::visit(VariableDeclaration &var)
768 {
769         if(iteration_init && var.init_expression)
770         {
771                 visit(var.init_expression);
772                 if(r_constant)
773                 {
774                         /* Record the value of a constant initialization expression of an
775                         iteration, so it can be used to evaluate the loop condition. */
776                         iteration_var = &var;
777                         iter_init_value = r_constant_value;
778                 }
779         }
780         else
781                 TraversingVisitor::visit(var);
782 }
783
784 void ConstantFolder::visit(Iteration &iter)
785 {
786         SetForScope<Block *> set_block(current_block, &iter.body);
787
788         /* The iteration variable is not normally inlined into expressions, so we
789         process it specially here.  If the initial value causes the loop condition
790         to evaluate to false, then the expression can be folded. */
791         iteration_var = 0;
792         if(iter.init_statement)
793         {
794                 SetFlag set_init(iteration_init);
795                 iter.init_statement->visit(*this);
796         }
797
798         if(iter.condition)
799         {
800                 visit(iter.condition);
801                 if(r_constant && r_constant_value.check_type<bool>() && !r_constant_value.value<bool>())
802                 {
803                         RefPtr<Literal> literal = new Literal;
804                         literal->token = "false";
805                         literal->value = r_constant_value;
806                         iter.condition = literal;
807                 }
808         }
809         iteration_var = 0;
810
811         iter.body.visit(*this);
812         if(iter.loop_expression)
813                 visit(iter.loop_expression);
814 }
815
816
817 void ConstantConditionEliminator::apply(Stage &stage)
818 {
819         stage.content.visit(*this);
820         NodeRemover().apply(stage, nodes_to_remove);
821 }
822
823 ConstantConditionEliminator::ConstantStatus ConstantConditionEliminator::check_constant_condition(const Expression &expr)
824 {
825         if(const Literal *literal = dynamic_cast<const Literal *>(&expr))
826                 if(literal->value.check_type<bool>())
827                         return (literal->value.value<bool>() ? CONSTANT_TRUE : CONSTANT_FALSE);
828         return NOT_CONSTANT;
829 }
830
831 void ConstantConditionEliminator::visit(Block &block)
832 {
833         SetForScope<Block *> set_block(current_block, &block);
834         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
835         {
836                 insert_point = i;
837                 (*i)->visit(*this);
838         }
839 }
840
841 void ConstantConditionEliminator::visit(RefPtr<Expression> &expr)
842 {
843         r_ternary_result = 0;
844         expr->visit(*this);
845         if(r_ternary_result)
846                 expr = r_ternary_result;
847         r_ternary_result = 0;
848 }
849
850 void ConstantConditionEliminator::visit(TernaryExpression &ternary)
851 {
852         ConstantStatus result = check_constant_condition(*ternary.condition);
853         if(result!=NOT_CONSTANT)
854                 r_ternary_result = (result==CONSTANT_TRUE ? ternary.true_expr : ternary.false_expr);
855         else
856                 r_ternary_result = 0;
857 }
858
859 void ConstantConditionEliminator::visit(Conditional &cond)
860 {
861         ConstantStatus result = check_constant_condition(*cond.condition);
862         if(result!=NOT_CONSTANT)
863         {
864                 Block &block = (result==CONSTANT_TRUE ? cond.body : cond.else_body);
865                 // TODO should check variable names for conflicts.  Potentially reuse InlineContentInjector?
866                 current_block->body.splice(insert_point, block.body);
867                 nodes_to_remove.insert(&cond);
868                 return;
869         }
870
871         TraversingVisitor::visit(cond);
872 }
873
874 void ConstantConditionEliminator::visit(Iteration &iter)
875 {
876         if(iter.condition)
877         {
878                 ConstantStatus result = check_constant_condition(*iter.condition);
879                 if(result==CONSTANT_FALSE)
880                 {
881                         nodes_to_remove.insert(&iter);
882                         return;
883                 }
884         }
885
886         TraversingVisitor::visit(iter);
887 }
888
889
890 bool UnusedTypeRemover::apply(Stage &stage)
891 {
892         stage.content.visit(*this);
893         NodeRemover().apply(stage, unused_nodes);
894         return !unused_nodes.empty();
895 }
896
897 void UnusedTypeRemover::visit(Literal &literal)
898 {
899         unused_nodes.erase(literal.type);
900 }
901
902 void UnusedTypeRemover::visit(UnaryExpression &unary)
903 {
904         unused_nodes.erase(unary.type);
905         TraversingVisitor::visit(unary);
906 }
907
908 void UnusedTypeRemover::visit(BinaryExpression &binary)
909 {
910         unused_nodes.erase(binary.type);
911         TraversingVisitor::visit(binary);
912 }
913
914 void UnusedTypeRemover::visit(TernaryExpression &ternary)
915 {
916         unused_nodes.erase(ternary.type);
917         TraversingVisitor::visit(ternary);
918 }
919
920 void UnusedTypeRemover::visit(FunctionCall &call)
921 {
922         unused_nodes.erase(call.type);
923         TraversingVisitor::visit(call);
924 }
925
926 void UnusedTypeRemover::visit(BasicTypeDeclaration &type)
927 {
928         if(type.base_type)
929                 unused_nodes.erase(type.base_type);
930         unused_nodes.insert(&type);
931 }
932
933 void UnusedTypeRemover::visit(ImageTypeDeclaration &type)
934 {
935         if(type.base_type)
936                 unused_nodes.erase(type.base_type);
937         unused_nodes.insert(&type);
938 }
939
940 void UnusedTypeRemover::visit(StructDeclaration &strct)
941 {
942         unused_nodes.insert(&strct);
943         TraversingVisitor::visit(strct);
944 }
945
946 void UnusedTypeRemover::visit(VariableDeclaration &var)
947 {
948         unused_nodes.erase(var.type_declaration);
949 }
950
951 void UnusedTypeRemover::visit(InterfaceBlock &iface)
952 {
953         unused_nodes.erase(iface.type_declaration);
954 }
955
956 void UnusedTypeRemover::visit(FunctionDeclaration &func)
957 {
958         unused_nodes.erase(func.return_type_declaration);
959         TraversingVisitor::visit(func);
960 }
961
962
963 UnusedVariableRemover::UnusedVariableRemover():
964         stage(0),
965         interface_block(0),
966         r_assignment(0),
967         assignment_target(false),
968         r_side_effects(false)
969 { }
970
971 bool UnusedVariableRemover::apply(Stage &s)
972 {
973         stage = &s;
974         s.content.visit(*this);
975
976         for(list<AssignmentInfo>::const_iterator i=assignments.begin(); i!=assignments.end(); ++i)
977                 if(i->used_by.empty())
978                         unused_nodes.insert(i->node);
979
980         for(map<string, InterfaceBlock *>::const_iterator i=s.interface_blocks.begin(); i!=s.interface_blocks.end(); ++i)
981                 if(i->second->instance_name.empty())
982                         unused_nodes.insert(i->second);
983
984         for(BlockVariableMap::const_iterator i=variables.begin(); i!=variables.end(); ++i)
985         {
986                 if(i->second.output)
987                 {
988                         /* The last visible assignments of output variables are used by the
989                         next stage or the API. */
990                         for(vector<AssignmentInfo *>::const_iterator j=i->second.assignments.begin(); j!=i->second.assignments.end(); ++j)
991                                 unused_nodes.erase((*j)->node);
992                 }
993
994                 if(!i->second.output && !i->second.referenced)
995                 {
996                         // Don't remove variables from inside interface blocks.
997                         if(!i->second.interface_block)
998                                 unused_nodes.insert(i->first);
999                 }
1000                 else if(i->second.interface_block)
1001                         // Interface blocks are kept if even one member is used.
1002                         unused_nodes.erase(i->second.interface_block);
1003         }
1004
1005         NodeRemover().apply(s, unused_nodes);
1006
1007         return !unused_nodes.empty();
1008 }
1009
1010 void UnusedVariableRemover::referenced(const Assignment::Target &target, Node &node)
1011 {
1012         VariableInfo &var_info = variables[target.declaration];
1013         var_info.referenced = true;
1014         if(!assignment_target)
1015         {
1016                 for(vector<AssignmentInfo *>::const_iterator i=var_info.assignments.begin(); i!=var_info.assignments.end(); ++i)
1017                         (*i)->used_by.push_back(&node);
1018         }
1019 }
1020
1021 void UnusedVariableRemover::visit(VariableReference &var)
1022 {
1023         referenced(var.declaration, var);
1024 }
1025
1026 void UnusedVariableRemover::visit(InterfaceBlockReference &iface)
1027 {
1028         referenced(iface.declaration, iface);
1029 }
1030
1031 void UnusedVariableRemover::visit(UnaryExpression &unary)
1032 {
1033         TraversingVisitor::visit(unary);
1034         if(unary.oper->token[1]=='+' || unary.oper->token[1]=='-')
1035                 r_side_effects = true;
1036 }
1037
1038 void UnusedVariableRemover::visit(BinaryExpression &binary)
1039 {
1040         if(binary.oper->token[0]=='[')
1041         {
1042                 binary.left->visit(*this);
1043                 SetFlag set(assignment_target, false);
1044                 binary.right->visit(*this);
1045         }
1046         else
1047                 TraversingVisitor::visit(binary);
1048 }
1049
1050 void UnusedVariableRemover::visit(Assignment &assign)
1051 {
1052         {
1053                 SetFlag set(assignment_target, (assign.oper->token[0]=='='));
1054                 assign.left->visit(*this);
1055         }
1056         assign.right->visit(*this);
1057         r_assignment = &assign;
1058         r_side_effects = true;
1059 }
1060
1061 void UnusedVariableRemover::visit(FunctionCall &call)
1062 {
1063         TraversingVisitor::visit(call);
1064         /* Treat function calls as having side effects so expression statements
1065         consisting of nothing but a function call won't be optimized away. */
1066         r_side_effects = true;
1067
1068         if(stage->type==Stage::GEOMETRY && call.name=="EmitVertex")
1069         {
1070                 for(map<Statement *, VariableInfo>::const_iterator i=variables.begin(); i!=variables.end(); ++i)
1071                         if(i->second.output)
1072                                 referenced(i->first, call);
1073         }
1074 }
1075
1076 void UnusedVariableRemover::record_assignment(const Assignment::Target &target, Node &node)
1077 {
1078         assignments.push_back(AssignmentInfo());
1079         AssignmentInfo &assign_info = assignments.back();
1080         assign_info.node = &node;
1081         assign_info.target = target;
1082
1083         /* An assignment to the target hides any assignments to the same target or
1084         its subfields. */
1085         VariableInfo &var_info = variables[target.declaration];
1086         for(unsigned i=0; i<var_info.assignments.size(); ++i)
1087         {
1088                 const Assignment::Target &t = var_info.assignments[i]->target;
1089
1090                 bool subfield = (t.chain_len>=target.chain_len);
1091                 for(unsigned j=0; (subfield && j<target.chain_len); ++j)
1092                         subfield = (t.chain[j]==target.chain[j]);
1093
1094                 if(subfield)
1095                         var_info.assignments.erase(var_info.assignments.begin()+i);
1096                 else
1097                         ++i;
1098         }
1099
1100         var_info.assignments.push_back(&assign_info);
1101 }
1102
1103 void UnusedVariableRemover::visit(ExpressionStatement &expr)
1104 {
1105         r_assignment = 0;
1106         r_side_effects = false;
1107         TraversingVisitor::visit(expr);
1108         if(r_assignment && r_assignment->target.declaration)
1109                 record_assignment(r_assignment->target, expr);
1110         if(!r_side_effects)
1111                 unused_nodes.insert(&expr);
1112 }
1113
1114 void UnusedVariableRemover::visit(VariableDeclaration &var)
1115 {
1116         VariableInfo &var_info = variables[&var];
1117         var_info.interface_block = interface_block;
1118
1119         /* Mark variables as output if they're used by the next stage or the
1120         graphics API. */
1121         if(interface_block)
1122                 var_info.output = (interface_block->interface=="out" && (interface_block->linked_block || !interface_block->block_name.compare(0, 3, "gl_")));
1123         else
1124                 var_info.output = (var.interface=="out" && (stage->type==Stage::FRAGMENT || var.linked_declaration || !var.name.compare(0, 3, "gl_")));
1125
1126         if(var.init_expression)
1127         {
1128                 var_info.initialized = true;
1129                 record_assignment(&var, *var.init_expression);
1130         }
1131         TraversingVisitor::visit(var);
1132 }
1133
1134 void UnusedVariableRemover::visit(InterfaceBlock &iface)
1135 {
1136         if(iface.instance_name.empty())
1137         {
1138                 SetForScope<InterfaceBlock *> set_block(interface_block, &iface);
1139                 iface.struct_declaration->members.visit(*this);
1140         }
1141         else
1142         {
1143                 VariableInfo &var_info = variables[&iface];
1144                 var_info.output = (iface.interface=="out" && (iface.linked_block || !iface.block_name.compare(0, 3, "gl_")));
1145         }
1146 }
1147
1148 void UnusedVariableRemover::merge_variables(const BlockVariableMap &other_vars)
1149 {
1150         for(BlockVariableMap::const_iterator i=other_vars.begin(); i!=other_vars.end(); ++i)
1151         {
1152                 BlockVariableMap::iterator j = variables.find(i->first);
1153                 if(j!=variables.end())
1154                 {
1155                         /* The merged blocks started as copies of each other so any common
1156                         assignments must be in the beginning. */
1157                         unsigned k = 0;
1158                         for(; (k<i->second.assignments.size() && k<j->second.assignments.size()); ++k)
1159                                 if(i->second.assignments[k]!=j->second.assignments[k])
1160                                         break;
1161
1162                         // Remaining assignments are unique to each block; merge them.
1163                         j->second.assignments.insert(j->second.assignments.end(), i->second.assignments.begin()+k, i->second.assignments.end());
1164                         j->second.referenced |= i->second.referenced;
1165                 }
1166                 else
1167                         variables.insert(*i);
1168         }
1169 }
1170
1171 void UnusedVariableRemover::visit(FunctionDeclaration &func)
1172 {
1173         if(func.body.body.empty())
1174                 return;
1175
1176         BlockVariableMap saved_vars = variables;
1177         // Assignments from other functions should not be visible.
1178         for(BlockVariableMap::iterator i=variables.begin(); i!=variables.end(); ++i)
1179                 i->second.assignments.resize(i->second.initialized);
1180         TraversingVisitor::visit(func);
1181         swap(variables, saved_vars);
1182         merge_variables(saved_vars);
1183
1184         /* Always treat function parameters as referenced.  Removing unused
1185         parameters is not currently supported. */
1186         for(NodeArray<VariableDeclaration>::iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
1187         {
1188                 BlockVariableMap::iterator j = variables.find(i->get());
1189                 if(j!=variables.end())
1190                         j->second.referenced = true;
1191         }
1192 }
1193
1194 void UnusedVariableRemover::visit(Conditional &cond)
1195 {
1196         cond.condition->visit(*this);
1197         BlockVariableMap saved_vars = variables;
1198         cond.body.visit(*this);
1199         swap(saved_vars, variables);
1200         cond.else_body.visit(*this);
1201
1202         /* Visible assignments after the conditional is the union of those visible
1203         at the end of the if and else blocks.  If there was no else block, then it's
1204         the union of the if block and the state before it. */
1205         merge_variables(saved_vars);
1206 }
1207
1208 void UnusedVariableRemover::visit(Iteration &iter)
1209 {
1210         BlockVariableMap saved_vars = variables;
1211         TraversingVisitor::visit(iter);
1212
1213         /* Merge assignments from the iteration, without clearing previous state.
1214         Further analysis is needed to determine which parts of the iteration body
1215         are always executed, if any. */
1216         merge_variables(saved_vars);
1217 }
1218
1219
1220 bool UnusedFunctionRemover::apply(Stage &stage)
1221 {
1222         stage.content.visit(*this);
1223         NodeRemover().apply(stage, unused_nodes);
1224         return !unused_nodes.empty();
1225 }
1226
1227 void UnusedFunctionRemover::visit(FunctionCall &call)
1228 {
1229         TraversingVisitor::visit(call);
1230
1231         unused_nodes.erase(call.declaration);
1232         if(call.declaration && call.declaration->definition!=call.declaration)
1233                 used_definitions.insert(call.declaration->definition);
1234 }
1235
1236 void UnusedFunctionRemover::visit(FunctionDeclaration &func)
1237 {
1238         TraversingVisitor::visit(func);
1239
1240         if((func.name!="main" || func.body.body.empty()) && !used_definitions.count(&func))
1241                 unused_nodes.insert(&func);
1242 }
1243
1244 } // namespace SL
1245 } // namespace GL
1246 } // namespace Msp