]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/optimize.cpp
0d428149dcc667ca6a7990c5d760a9ca22ea9aee
[libs/gl.git] / source / glsl / optimize.cpp
1 #include <msp/core/raii.h>
2 #include <msp/strings/format.h>
3 #include "optimize.h"
4
5 using namespace std;
6
7 namespace Msp {
8 namespace GL {
9 namespace SL {
10
11 InlineableFunctionLocator::InlineableFunctionLocator():
12         current_function(0),
13         return_count(0)
14 { }
15
16 void InlineableFunctionLocator::visit(FunctionCall &call)
17 {
18         FunctionDeclaration *def = call.declaration;
19         if(def)
20                 def = def->definition;
21
22         if(def)
23         {
24                 unsigned &count = refcounts[def];
25                 ++count;
26                 /* Don't inline functions which are called more than once or are called
27                 recursively. */
28                 if(count>1 || def==current_function)
29                         inlineable.erase(def);
30         }
31
32         TraversingVisitor::visit(call);
33 }
34
35 void InlineableFunctionLocator::visit(FunctionDeclaration &func)
36 {
37         unsigned &count = refcounts[func.definition];
38         if(count<=1 && func.parameters.empty())
39                 inlineable.insert(func.definition);
40
41         SetForScope<FunctionDeclaration *> set(current_function, &func);
42         return_count = 0;
43         TraversingVisitor::visit(func);
44 }
45
46 void InlineableFunctionLocator::visit(Conditional &cond)
47 {
48         TraversingVisitor::visit(cond);
49         inlineable.erase(current_function);
50 }
51
52 void InlineableFunctionLocator::visit(Iteration &iter)
53 {
54         TraversingVisitor::visit(iter);
55         inlineable.erase(current_function);
56 }
57
58 void InlineableFunctionLocator::visit(Return &ret)
59 {
60         TraversingVisitor::visit(ret);
61         if(return_count)
62                 inlineable.erase(current_function);
63         ++return_count;
64 }
65
66
67 InlineContentInjector::InlineContentInjector():
68         source_func(0),
69         pass(DEPENDS)
70 { }
71
72 const string &InlineContentInjector::apply(Stage &stage, FunctionDeclaration &target_func, Block &tgt_blk, const NodeList<Statement>::iterator &ins_pt, FunctionDeclaration &src)
73 {
74         source_func = &src;
75
76         // Collect all declarations the inlined function depends on.
77         pass = DEPENDS;
78         source_func->visit(*this);
79
80         /* Populate referenced_names from the target function so we can rename
81         variables from the inlined function that would conflict. */
82         pass = REFERENCED;
83         target_func.visit(*this);
84
85         /* Inline and rename passes must be interleaved so used variable names are
86         known when inlining the return statement. */
87         pass = INLINE;
88         staging_block.parent = &tgt_blk;
89         staging_block.variables.clear();
90         remap_prefix = source_func->name;
91
92         for(NodeList<Statement>::iterator i=src.body.body.begin(); i!=src.body.body.end(); ++i)
93         {
94                 r_inlined_statement = 0;
95                 (*i)->visit(*this);
96                 if(!r_inlined_statement)
97                         r_inlined_statement = (*i)->clone();
98
99                 SetForScope<Pass> set_pass(pass, RENAME);
100                 r_inlined_statement->visit(*this);
101
102                 staging_block.body.push_back(0);
103                 staging_block.body.back() = r_inlined_statement;
104         }
105
106         /* Now collect names from the staging block.  Local variables that would
107         have conflicted with the target function were renamed earlier. */
108         pass = REFERENCED;
109         referenced_names.clear();
110         staging_block.variables.clear();
111         staging_block.visit(*this);
112
113         /* Rename variables in the target function so they don't interfere with
114         global identifiers used by the source function. */
115         pass = RENAME;
116         staging_block.parent = source_func->body.parent;
117         remap_prefix = target_func.name;
118         target_func.visit(*this);
119
120         tgt_blk.body.splice(ins_pt, staging_block.body);
121
122         NodeReorderer().apply(stage, target_func, dependencies);
123
124         return r_result_name;
125 }
126
127 void InlineContentInjector::visit(VariableReference &var)
128 {
129         if(pass==RENAME)
130         {
131                 map<string, VariableDeclaration *>::const_iterator i = staging_block.variables.find(var.name);
132                 if(i!=staging_block.variables.end())
133                         var.name = i->second->name;
134         }
135         else if(pass==DEPENDS && var.declaration)
136         {
137                 dependencies.insert(var.declaration);
138                 var.declaration->visit(*this);
139         }
140         else if(pass==REFERENCED)
141                 referenced_names.insert(var.name);
142 }
143
144 void InlineContentInjector::visit(InterfaceBlockReference &iface)
145 {
146         if(pass==DEPENDS && iface.declaration)
147         {
148                 dependencies.insert(iface.declaration);
149                 iface.declaration->visit(*this);
150         }
151         else if(pass==REFERENCED)
152                 referenced_names.insert(iface.name);
153 }
154
155 void InlineContentInjector::visit(FunctionCall &call)
156 {
157         if(pass==DEPENDS && call.declaration)
158                 dependencies.insert(call.declaration);
159         else if(pass==REFERENCED)
160                 referenced_names.insert(call.name);
161         TraversingVisitor::visit(call);
162 }
163
164 void InlineContentInjector::visit(VariableDeclaration &var)
165 {
166         TraversingVisitor::visit(var);
167
168         if(pass==RENAME)
169         {
170                 staging_block.variables[var.name] = &var;
171                 if(referenced_names.count(var.name))
172                 {
173                         string mapped_name = get_unused_variable_name(staging_block, var.name, remap_prefix);
174                         if(mapped_name!=var.name)
175                         {
176                                 staging_block.variables[mapped_name] = &var;
177                                 var.name = mapped_name;
178                         }
179                 }
180         }
181         else if(pass==DEPENDS && var.type_declaration)
182         {
183                 dependencies.insert(var.type_declaration);
184                 var.type_declaration->visit(*this);
185         }
186         else if(pass==REFERENCED)
187                 referenced_names.insert(var.type);
188 }
189
190 void InlineContentInjector::visit(Return &ret)
191 {
192         TraversingVisitor::visit(ret);
193
194         if(pass==INLINE && ret.expression)
195         {
196                 // Create a new variable to hold the return value of the inlined function.
197                 r_result_name = get_unused_variable_name(staging_block, "_return", source_func->name);
198                 RefPtr<VariableDeclaration> var = new VariableDeclaration;
199                 var->source = ret.source;
200                 var->line = ret.line;
201                 var->type = source_func->return_type;
202                 var->name = r_result_name;
203                 var->init_expression = ret.expression->clone();
204                 r_inlined_statement = var;
205         }
206 }
207
208
209 FunctionInliner::FunctionInliner():
210         current_function(0),
211         r_any_inlined(false),
212         r_inlined_here(false)
213 { }
214
215 bool FunctionInliner::apply(Stage &s)
216 {
217         stage = &s;
218         inlineable = InlineableFunctionLocator().apply(s);
219         r_any_inlined = false;
220         s.content.visit(*this);
221         return r_any_inlined;
222 }
223
224 void FunctionInliner::visit(RefPtr<Expression> &ptr)
225 {
226         r_inline_result = 0;
227         ptr->visit(*this);
228         if(r_inline_result)
229         {
230                 ptr = r_inline_result;
231                 r_any_inlined = true;
232         }
233         r_inline_result = 0;
234 }
235
236 void FunctionInliner::visit(Block &block)
237 {
238         SetForScope<Block *> set_block(current_block, &block);
239         SetForScope<NodeList<Statement>::iterator> save_insert_point(insert_point, block.body.begin());
240         for(NodeList<Statement>::iterator i=block.body.begin(); (!r_inlined_here && i!=block.body.end()); ++i)
241         {
242                 insert_point = i;
243                 (*i)->visit(*this);
244         }
245 }
246
247 void FunctionInliner::visit(FunctionCall &call)
248 {
249         if(r_inlined_here)
250                 return;
251
252         for(NodeArray<Expression>::iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
253                 visit(*i);
254
255         FunctionDeclaration *def = call.declaration;
256         if(def)
257                 def = def->definition;
258
259         if(def && inlineable.count(def))
260         {
261                 string result_name = InlineContentInjector().apply(*stage, *current_function, *current_block, insert_point, *def);
262
263                 // This will later get removed by UnusedVariableRemover.
264                 if(result_name.empty())
265                         result_name = "msp_unused_from_inline";
266
267                 RefPtr<VariableReference> ref = new VariableReference;
268                 ref->name = result_name;
269                 r_inline_result = ref;
270
271                 /* Inlined variables need to be resolved before this function can be
272                 inlined further. */
273                 inlineable.erase(current_function);
274                 r_inlined_here = true;
275         }
276 }
277
278 void FunctionInliner::visit(FunctionDeclaration &func)
279 {
280         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
281         TraversingVisitor::visit(func);
282         r_inlined_here = false;
283 }
284
285 void FunctionInliner::visit(Iteration &iter)
286 {
287         /* Visit the initialization statement before entering the loop body so the
288         inlined statements get inserted outside. */
289         if(iter.init_statement)
290                 iter.init_statement->visit(*this);
291
292         SetForScope<Block *> set_block(current_block, &iter.body);
293         /* Skip the condition and loop expression parts because they're not properly
294         inside the body block.  Inlining anything into them will require a more
295         comprehensive transformation. */
296         iter.body.visit(*this);
297 }
298
299
300 ExpressionInliner::ExpressionInfo::ExpressionInfo():
301         expression(0),
302         assign_scope(0),
303         inline_point(0),
304         trivial(false),
305         available(true)
306 { }
307
308
309 ExpressionInliner::ExpressionInliner():
310         r_ref_info(0),
311         r_any_inlined(false),
312         r_trivial(false),
313         mutating(false),
314         iteration_init(false),
315         iteration_body(0),
316         r_oper(0)
317 { }
318
319 bool ExpressionInliner::apply(Stage &s)
320 {
321         s.content.visit(*this);
322         return r_any_inlined;
323 }
324
325 void ExpressionInliner::inline_expression(Expression &expr, RefPtr<Expression> &ptr)
326 {
327         ptr = expr.clone();
328         r_any_inlined = true;
329 }
330
331 void ExpressionInliner::visit(Block &block)
332 {
333         TraversingVisitor::visit(block);
334
335         for(map<string, VariableDeclaration *>::iterator i=block.variables.begin(); i!=block.variables.end(); ++i)
336         {
337                 map<Assignment::Target, ExpressionInfo>::iterator j = expressions.lower_bound(i->second);
338                 for(; (j!=expressions.end() && j->first.declaration==i->second); )
339                 {
340                         if(j->second.expression && j->second.inline_point)
341                                 inline_expression(*j->second.expression, *j->second.inline_point);
342
343                         expressions.erase(j++);
344                 }
345         }
346
347         /* Expressions assigned in this block may depend on local variables of the
348         block.  If this is a conditionally executed block, the assignments might not
349         always happen.  Mark the expressions as not available to any outer blocks. */
350         for(map<Assignment::Target, ExpressionInfo>::iterator i=expressions.begin(); i!=expressions.end(); ++i)
351                 if(i->second.assign_scope==&block)
352                         i->second.available = false;
353 }
354
355 void ExpressionInliner::visit(RefPtr<Expression> &expr)
356 {
357         r_ref_info = 0;
358         expr->visit(*this);
359         if(r_ref_info && r_ref_info->expression && r_ref_info->available)
360         {
361                 if(iteration_body && !r_ref_info->trivial)
362                 {
363                         /* Don't inline non-trivial expressions which were assigned outside
364                         an iteration statement.  The iteration may run multiple times, which
365                         would cause the expression to also be evaluated multiple times. */
366                         Block *i = r_ref_info->assign_scope;
367                         for(; (i && i!=iteration_body); i=i->parent) ;
368                         if(!i)
369                                 return;
370                 }
371
372                 if(r_ref_info->trivial)
373                         inline_expression(*r_ref_info->expression, expr);
374                 else
375                         /* Record the inline point for a non-trivial expression but don't
376                         inline it yet.  It might turn out it shouldn't be inlined after all. */
377                         r_ref_info->inline_point = &expr;
378         }
379         r_oper = expr->oper;
380         r_ref_info = 0;
381 }
382
383 void ExpressionInliner::visit(VariableReference &var)
384 {
385         if(var.declaration)
386         {
387                 map<Assignment::Target, ExpressionInfo>::iterator i = expressions.find(var.declaration);
388                 if(i!=expressions.end())
389                 {
390                         /* If a non-trivial expression is referenced multiple times, don't
391                         inline it. */
392                         if(i->second.inline_point && !i->second.trivial)
393                                 i->second.expression = 0;
394                         /* Mutating expressions are analogous to self-referencing assignments
395                         and prevent inlining. */
396                         if(mutating)
397                                 i->second.expression = 0;
398                         r_ref_info = &i->second;
399                 }
400         }
401 }
402
403 void ExpressionInliner::visit(MemberAccess &memacc)
404 {
405         visit(memacc.left);
406         r_trivial = false;
407 }
408
409 void ExpressionInliner::visit(Swizzle &swizzle)
410 {
411         visit(swizzle.left);
412         r_trivial = false;
413 }
414
415 void ExpressionInliner::visit(UnaryExpression &unary)
416 {
417         SetFlag set_target(mutating, mutating || unary.oper->token[1]=='+' || unary.oper->token[1]=='-');
418         visit(unary.expression);
419         r_trivial = false;
420 }
421
422 void ExpressionInliner::visit(BinaryExpression &binary)
423 {
424         visit(binary.left);
425         {
426                 SetFlag clear_target(mutating, false);
427                 visit(binary.right);
428         }
429         r_trivial = false;
430 }
431
432 void ExpressionInliner::visit(Assignment &assign)
433 {
434         {
435                 SetFlag set_target(mutating);
436                 visit(assign.left);
437         }
438         r_oper = 0;
439         visit(assign.right);
440
441         map<Assignment::Target, ExpressionInfo>::iterator i = expressions.find(assign.target);
442         if(i!=expressions.end())
443         {
444                 /* Self-referencing assignments can't be inlined without additional
445                 work.  Just clear any previous expression. */
446                 i->second.expression = (assign.self_referencing ? 0 : assign.right.get());
447                 i->second.assign_scope = current_block;
448                 i->second.inline_point = 0;
449                 i->second.available = true;
450         }
451
452         r_trivial = false;
453 }
454
455 void ExpressionInliner::visit(TernaryExpression &ternary)
456 {
457         visit(ternary.condition);
458         visit(ternary.true_expr);
459         visit(ternary.false_expr);
460         r_trivial = false;
461 }
462
463 void ExpressionInliner::visit(FunctionCall &call)
464 {
465         TraversingVisitor::visit(call);
466         r_trivial = false;
467 }
468
469 void ExpressionInliner::visit(VariableDeclaration &var)
470 {
471         r_oper = 0;
472         r_trivial = true;
473         TraversingVisitor::visit(var);
474
475         bool constant = var.constant;
476         if(constant && var.layout)
477         {
478                 for(vector<Layout::Qualifier>::const_iterator i=var.layout->qualifiers.begin(); (constant && i!=var.layout->qualifiers.end()); ++i)
479                         constant = (i->name!="constant_id");
480         }
481
482         /* Only inline global variables if they're constant and have trivial
483         initializers.  Non-constant variables could change in ways which are hard to
484         analyze and non-trivial expressions could be expensive to inline.  */
485         if((current_block->parent || (constant && r_trivial)) && var.interface.empty())
486         {
487                 ExpressionInfo &info = expressions[&var];
488                 /* Assume variables declared in an iteration initialization statement
489                 will have their values change throughout the iteration. */
490                 info.expression = (iteration_init ? 0 : var.init_expression.get());
491                 info.assign_scope = current_block;
492                 info.trivial = r_trivial;
493         }
494 }
495
496 void ExpressionInliner::visit(Iteration &iter)
497 {
498         SetForScope<Block *> set_block(current_block, &iter.body);
499         if(iter.init_statement)
500         {
501                 SetFlag set_init(iteration_init);
502                 iter.init_statement->visit(*this);
503         }
504
505         SetForScope<Block *> set_body(iteration_body, &iter.body);
506         if(iter.condition)
507                 visit(iter.condition);
508         iter.body.visit(*this);
509         if(iter.loop_expression)
510                 visit(iter.loop_expression);
511 }
512
513
514 BasicTypeDeclaration::Kind ConstantFolder::get_value_kind(const Variant &value)
515 {
516         if(value.check_type<bool>())
517                 return BasicTypeDeclaration::BOOL;
518         else if(value.check_type<int>())
519                 return BasicTypeDeclaration::INT;
520         else if(value.check_type<float>())
521                 return BasicTypeDeclaration::FLOAT;
522         else
523                 return BasicTypeDeclaration::VOID;
524 }
525
526 template<typename T>
527 T ConstantFolder::evaluate_logical(char oper, T left, T right)
528 {
529         switch(oper)
530         {
531         case '&': return left&right;
532         case '|': return left|right;
533         case '^': return left^right;
534         default: return T();
535         }
536 }
537
538 template<typename T>
539 bool ConstantFolder::evaluate_relation(const char *oper, T left, T right)
540 {
541         switch(oper[0]|oper[1])
542         {
543         case '<': return left<right;
544         case '<'|'=': return left<=right;
545         case '>': return left>right;
546         case '>'|'=': return left>=right;
547         default: return false;
548         }
549 }
550
551 template<typename T>
552 T ConstantFolder::evaluate_arithmetic(char oper, T left, T right)
553 {
554         switch(oper)
555         {
556         case '+': return left+right;
557         case '-': return left-right;
558         case '*': return left*right;
559         case '/': return left/right;
560         default: return T();
561         }
562 }
563
564 void ConstantFolder::set_result(const Variant &value, bool literal)
565 {
566         r_constant_value = value;
567         r_constant = true;
568         r_literal = literal;
569 }
570
571 void ConstantFolder::visit(RefPtr<Expression> &expr)
572 {
573         r_constant_value = Variant();
574         r_constant = false;
575         r_literal = false;
576         r_uses_iter_var = false;
577         expr->visit(*this);
578         /* Don't replace literals since they'd only be replaced with an identical
579         literal.  Also skip anything that uses an iteration variable, but pass on
580         the result so the Iteration visiting function can handle it. */
581         if(!r_constant || r_literal || r_uses_iter_var)
582                 return;
583
584         BasicTypeDeclaration::Kind kind = get_value_kind(r_constant_value);
585         if(kind==BasicTypeDeclaration::VOID)
586         {
587                 r_constant = false;
588                 return;
589         }
590
591         RefPtr<Literal> literal = new Literal;
592         if(kind==BasicTypeDeclaration::BOOL)
593                 literal->token = (r_constant_value.value<bool>() ? "true" : "false");
594         else if(kind==BasicTypeDeclaration::INT)
595                 literal->token = lexical_cast<string>(r_constant_value.value<int>());
596         else if(kind==BasicTypeDeclaration::FLOAT)
597                 literal->token = lexical_cast<string>(r_constant_value.value<float>());
598         literal->value = r_constant_value;
599         expr = literal;
600 }
601
602 void ConstantFolder::visit(Literal &literal)
603 {
604         set_result(literal.value, true);
605 }
606
607 void ConstantFolder::visit(VariableReference &var)
608 {
609         /* If an iteration variable is initialized with a constant value, return
610         that value here for the purpose of evaluating the loop condition for the
611         first iteration. */
612         if(var.declaration==iteration_var)
613         {
614                 set_result(iter_init_value);
615                 r_uses_iter_var = true;
616         }
617 }
618
619 void ConstantFolder::visit(MemberAccess &memacc)
620 {
621         TraversingVisitor::visit(memacc);
622         r_constant = false;
623 }
624
625 void ConstantFolder::visit(Swizzle &swizzle)
626 {
627         TraversingVisitor::visit(swizzle);
628         r_constant = false;
629 }
630
631 void ConstantFolder::visit(UnaryExpression &unary)
632 {
633         TraversingVisitor::visit(unary);
634         bool can_fold = r_constant;
635         r_constant = false;
636         if(!can_fold)
637                 return;
638
639         BasicTypeDeclaration::Kind kind = get_value_kind(r_constant_value);
640
641         char oper = unary.oper->token[0];
642         char oper2 = unary.oper->token[1];
643         if(oper=='!')
644         {
645                 if(kind==BasicTypeDeclaration::BOOL)
646                         set_result(!r_constant_value.value<bool>());
647         }
648         else if(oper=='~')
649         {
650                 if(kind==BasicTypeDeclaration::INT)
651                         set_result(~r_constant_value.value<int>());
652         }
653         else if(oper=='-' && !oper2)
654         {
655                 if(kind==BasicTypeDeclaration::INT)
656                         set_result(-r_constant_value.value<int>());
657                 else if(kind==BasicTypeDeclaration::FLOAT)
658                         set_result(-r_constant_value.value<float>());
659         }
660 }
661
662 void ConstantFolder::visit(BinaryExpression &binary)
663 {
664         visit(binary.left);
665         bool left_constant = r_constant;
666         bool left_iter_var = r_uses_iter_var;
667         Variant left_value = r_constant_value;
668         visit(binary.right);
669         if(left_iter_var)
670                 r_uses_iter_var = true;
671
672         bool can_fold = (left_constant && r_constant);
673         r_constant = false;
674         if(!can_fold)
675                 return;
676
677         BasicTypeDeclaration::Kind left_kind = get_value_kind(left_value);
678         BasicTypeDeclaration::Kind right_kind = get_value_kind(r_constant_value);
679         // Currently only expressions with both sides of equal types are handled.
680         if(left_kind!=right_kind)
681                 return;
682
683         char oper = binary.oper->token[0];
684         char oper2 = binary.oper->token[1];
685         if(oper=='&' || oper=='|' || oper=='^')
686         {
687                 if(oper2==oper && left_kind==BasicTypeDeclaration::BOOL)
688                         set_result(evaluate_logical(oper, left_value.value<bool>(), r_constant_value.value<bool>()));
689                 else if(!oper2 && left_kind==BasicTypeDeclaration::INT)
690                         set_result(evaluate_logical(oper, left_value.value<int>(), r_constant_value.value<int>()));
691         }
692         else if((oper=='<' || oper=='>') && oper2!=oper)
693         {
694                 if(left_kind==BasicTypeDeclaration::INT)
695                         set_result(evaluate_relation(binary.oper->token, left_value.value<int>(), r_constant_value.value<int>()));
696                 else if(left_kind==BasicTypeDeclaration::FLOAT)
697                         set_result(evaluate_relation(binary.oper->token, left_value.value<float>(), r_constant_value.value<float>()));
698         }
699         else if((oper=='=' || oper=='!') && oper2=='=')
700         {
701                 if(left_kind==BasicTypeDeclaration::INT)
702                         set_result((left_value.value<int>()==r_constant_value.value<int>()) == (oper=='='));
703                 if(left_kind==BasicTypeDeclaration::FLOAT)
704                         set_result((left_value.value<float>()==r_constant_value.value<float>()) == (oper=='='));
705         }
706         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
707         {
708                 if(left_kind==BasicTypeDeclaration::INT)
709                         set_result(evaluate_arithmetic(oper, left_value.value<int>(), r_constant_value.value<int>()));
710                 else if(left_kind==BasicTypeDeclaration::FLOAT)
711                         set_result(evaluate_arithmetic(oper, left_value.value<float>(), r_constant_value.value<float>()));
712         }
713         else if(oper=='%' || ((oper=='<' || oper=='>') && oper2==oper))
714         {
715                 if(left_kind!=BasicTypeDeclaration::INT)
716                         return;
717
718                 if(oper=='%')
719                         set_result(left_value.value<int>()%r_constant_value.value<int>());
720                 else if(oper=='<')
721                         set_result(left_value.value<int>()<<r_constant_value.value<int>());
722                 else if(oper=='>')
723                         set_result(left_value.value<int>()>>r_constant_value.value<int>());
724         }
725 }
726
727 void ConstantFolder::visit(Assignment &assign)
728 {
729         TraversingVisitor::visit(assign);
730         r_constant = false;
731 }
732
733 void ConstantFolder::visit(TernaryExpression &ternary)
734 {
735         TraversingVisitor::visit(ternary);
736         r_constant = false;
737 }
738
739 void ConstantFolder::visit(FunctionCall &call)
740 {
741         TraversingVisitor::visit(call);
742         r_constant = false;
743 }
744
745 void ConstantFolder::visit(VariableDeclaration &var)
746 {
747         if(iteration_init && var.init_expression)
748         {
749                 visit(var.init_expression);
750                 if(r_constant)
751                 {
752                         /* Record the value of a constant initialization expression of an
753                         iteration, so it can be used to evaluate the loop condition. */
754                         iteration_var = &var;
755                         iter_init_value = r_constant_value;
756                 }
757         }
758         else
759                 TraversingVisitor::visit(var);
760 }
761
762 void ConstantFolder::visit(Iteration &iter)
763 {
764         SetForScope<Block *> set_block(current_block, &iter.body);
765
766         /* The iteration variable is not normally inlined into expressions, so we
767         process it specially here.  If the initial value causes the loop condition
768         to evaluate to false, then the expression can be folded. */
769         iteration_var = 0;
770         if(iter.init_statement)
771         {
772                 SetFlag set_init(iteration_init);
773                 iter.init_statement->visit(*this);
774         }
775
776         if(iter.condition)
777         {
778                 visit(iter.condition);
779                 if(r_constant && r_constant_value.check_type<bool>() && !r_constant_value.value<bool>())
780                 {
781                         RefPtr<Literal> literal = new Literal;
782                         literal->token = "false";
783                         literal->value = r_constant_value;
784                         iter.condition = literal;
785                 }
786         }
787         iteration_var = 0;
788
789         iter.body.visit(*this);
790         if(iter.loop_expression)
791                 visit(iter.loop_expression);
792 }
793
794
795 void ConstantConditionEliminator::apply(Stage &stage)
796 {
797         stage.content.visit(*this);
798         NodeRemover().apply(stage, nodes_to_remove);
799 }
800
801 ConstantConditionEliminator::ConstantStatus ConstantConditionEliminator::check_constant_condition(const Expression &expr)
802 {
803         if(const Literal *literal = dynamic_cast<const Literal *>(&expr))
804                 if(literal->value.check_type<bool>())
805                         return (literal->value.value<bool>() ? CONSTANT_TRUE : CONSTANT_FALSE);
806         return NOT_CONSTANT;
807 }
808
809 void ConstantConditionEliminator::visit(Block &block)
810 {
811         SetForScope<Block *> set_block(current_block, &block);
812         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
813         {
814                 insert_point = i;
815                 (*i)->visit(*this);
816         }
817 }
818
819 void ConstantConditionEliminator::visit(RefPtr<Expression> &expr)
820 {
821         r_ternary_result = 0;
822         expr->visit(*this);
823         if(r_ternary_result)
824                 expr = r_ternary_result;
825         r_ternary_result = 0;
826 }
827
828 void ConstantConditionEliminator::visit(TernaryExpression &ternary)
829 {
830         ConstantStatus result = check_constant_condition(*ternary.condition);
831         if(result!=NOT_CONSTANT)
832                 r_ternary_result = (result==CONSTANT_TRUE ? ternary.true_expr : ternary.false_expr);
833         else
834                 r_ternary_result = 0;
835 }
836
837 void ConstantConditionEliminator::visit(Conditional &cond)
838 {
839         ConstantStatus result = check_constant_condition(*cond.condition);
840         if(result!=NOT_CONSTANT)
841         {
842                 Block &block = (result==CONSTANT_TRUE ? cond.body : cond.else_body);
843                 // TODO should check variable names for conflicts.  Potentially reuse InlineContentInjector?
844                 current_block->body.splice(insert_point, block.body);
845                 nodes_to_remove.insert(&cond);
846                 return;
847         }
848
849         TraversingVisitor::visit(cond);
850 }
851
852 void ConstantConditionEliminator::visit(Iteration &iter)
853 {
854         if(iter.condition)
855         {
856                 ConstantStatus result = check_constant_condition(*iter.condition);
857                 if(result==CONSTANT_FALSE)
858                 {
859                         nodes_to_remove.insert(&iter);
860                         return;
861                 }
862         }
863
864         TraversingVisitor::visit(iter);
865 }
866
867
868 bool UnusedTypeRemover::apply(Stage &stage)
869 {
870         stage.content.visit(*this);
871         NodeRemover().apply(stage, unused_nodes);
872         return !unused_nodes.empty();
873 }
874
875 void UnusedTypeRemover::visit(Literal &literal)
876 {
877         unused_nodes.erase(literal.type);
878 }
879
880 void UnusedTypeRemover::visit(UnaryExpression &unary)
881 {
882         unused_nodes.erase(unary.type);
883         TraversingVisitor::visit(unary);
884 }
885
886 void UnusedTypeRemover::visit(BinaryExpression &binary)
887 {
888         unused_nodes.erase(binary.type);
889         TraversingVisitor::visit(binary);
890 }
891
892 void UnusedTypeRemover::visit(TernaryExpression &ternary)
893 {
894         unused_nodes.erase(ternary.type);
895         TraversingVisitor::visit(ternary);
896 }
897
898 void UnusedTypeRemover::visit(FunctionCall &call)
899 {
900         unused_nodes.erase(call.type);
901         TraversingVisitor::visit(call);
902 }
903
904 void UnusedTypeRemover::visit(BasicTypeDeclaration &type)
905 {
906         if(type.base_type)
907                 unused_nodes.erase(type.base_type);
908         unused_nodes.insert(&type);
909 }
910
911 void UnusedTypeRemover::visit(ImageTypeDeclaration &type)
912 {
913         if(type.base_type)
914                 unused_nodes.erase(type.base_type);
915         unused_nodes.insert(&type);
916 }
917
918 void UnusedTypeRemover::visit(StructDeclaration &strct)
919 {
920         unused_nodes.insert(&strct);
921         TraversingVisitor::visit(strct);
922 }
923
924 void UnusedTypeRemover::visit(VariableDeclaration &var)
925 {
926         unused_nodes.erase(var.type_declaration);
927 }
928
929 void UnusedTypeRemover::visit(InterfaceBlock &iface)
930 {
931         unused_nodes.erase(iface.type_declaration);
932 }
933
934 void UnusedTypeRemover::visit(FunctionDeclaration &func)
935 {
936         unused_nodes.erase(func.return_type_declaration);
937         TraversingVisitor::visit(func);
938 }
939
940
941 UnusedVariableRemover::UnusedVariableRemover():
942         stage(0),
943         interface_block(0),
944         r_assignment(0),
945         assignment_target(false),
946         r_side_effects(false)
947 { }
948
949 bool UnusedVariableRemover::apply(Stage &s)
950 {
951         stage = &s;
952         s.content.visit(*this);
953
954         for(list<AssignmentInfo>::const_iterator i=assignments.begin(); i!=assignments.end(); ++i)
955                 if(i->used_by.empty())
956                         unused_nodes.insert(i->node);
957
958         for(map<string, InterfaceBlock *>::const_iterator i=s.interface_blocks.begin(); i!=s.interface_blocks.end(); ++i)
959                 if(i->second->instance_name.empty())
960                         unused_nodes.insert(i->second);
961
962         for(BlockVariableMap::const_iterator i=variables.begin(); i!=variables.end(); ++i)
963         {
964                 if(i->second.output)
965                 {
966                         /* The last visible assignments of output variables are used by the
967                         next stage or the API. */
968                         for(vector<AssignmentInfo *>::const_iterator j=i->second.assignments.begin(); j!=i->second.assignments.end(); ++j)
969                                 unused_nodes.erase((*j)->node);
970                 }
971
972                 if(!i->second.output && !i->second.referenced)
973                 {
974                         // Don't remove variables from inside interface blocks.
975                         if(!i->second.interface_block)
976                                 unused_nodes.insert(i->first);
977                 }
978                 else if(i->second.interface_block)
979                         // Interface blocks are kept if even one member is used.
980                         unused_nodes.erase(i->second.interface_block);
981         }
982
983         NodeRemover().apply(s, unused_nodes);
984
985         return !unused_nodes.empty();
986 }
987
988 void UnusedVariableRemover::referenced(const Assignment::Target &target, Node &node)
989 {
990         VariableInfo &var_info = variables[target.declaration];
991         var_info.referenced = true;
992         if(!assignment_target)
993         {
994                 for(vector<AssignmentInfo *>::const_iterator i=var_info.assignments.begin(); i!=var_info.assignments.end(); ++i)
995                         (*i)->used_by.push_back(&node);
996         }
997 }
998
999 void UnusedVariableRemover::visit(VariableReference &var)
1000 {
1001         referenced(var.declaration, var);
1002 }
1003
1004 void UnusedVariableRemover::visit(InterfaceBlockReference &iface)
1005 {
1006         referenced(iface.declaration, iface);
1007 }
1008
1009 void UnusedVariableRemover::visit(UnaryExpression &unary)
1010 {
1011         TraversingVisitor::visit(unary);
1012         if(unary.oper->token[1]=='+' || unary.oper->token[1]=='-')
1013                 r_side_effects = true;
1014 }
1015
1016 void UnusedVariableRemover::visit(BinaryExpression &binary)
1017 {
1018         if(binary.oper->token[0]=='[')
1019         {
1020                 binary.left->visit(*this);
1021                 SetFlag set(assignment_target, false);
1022                 binary.right->visit(*this);
1023         }
1024         else
1025                 TraversingVisitor::visit(binary);
1026 }
1027
1028 void UnusedVariableRemover::visit(Assignment &assign)
1029 {
1030         {
1031                 SetFlag set(assignment_target, (assign.oper->token[0]=='='));
1032                 assign.left->visit(*this);
1033         }
1034         assign.right->visit(*this);
1035         r_assignment = &assign;
1036         r_side_effects = true;
1037 }
1038
1039 void UnusedVariableRemover::visit(FunctionCall &call)
1040 {
1041         TraversingVisitor::visit(call);
1042         /* Treat function calls as having side effects so expression statements
1043         consisting of nothing but a function call won't be optimized away. */
1044         r_side_effects = true;
1045
1046         if(stage->type==Stage::GEOMETRY && call.name=="EmitVertex")
1047         {
1048                 for(map<Statement *, VariableInfo>::const_iterator i=variables.begin(); i!=variables.end(); ++i)
1049                         if(i->second.output)
1050                                 referenced(i->first, call);
1051         }
1052 }
1053
1054 void UnusedVariableRemover::record_assignment(const Assignment::Target &target, Node &node)
1055 {
1056         assignments.push_back(AssignmentInfo());
1057         AssignmentInfo &assign_info = assignments.back();
1058         assign_info.node = &node;
1059         assign_info.target = target;
1060
1061         /* An assignment to the target hides any assignments to the same target or
1062         its subfields. */
1063         VariableInfo &var_info = variables[target.declaration];
1064         for(unsigned i=0; i<var_info.assignments.size(); ++i)
1065         {
1066                 const Assignment::Target &t = var_info.assignments[i]->target;
1067
1068                 bool subfield = (t.chain_len>=target.chain_len);
1069                 for(unsigned j=0; (subfield && j<target.chain_len); ++j)
1070                         subfield = (t.chain[j]==target.chain[j]);
1071
1072                 if(subfield)
1073                         var_info.assignments.erase(var_info.assignments.begin()+i);
1074                 else
1075                         ++i;
1076         }
1077
1078         var_info.assignments.push_back(&assign_info);
1079 }
1080
1081 void UnusedVariableRemover::visit(ExpressionStatement &expr)
1082 {
1083         r_assignment = 0;
1084         r_side_effects = false;
1085         TraversingVisitor::visit(expr);
1086         if(r_assignment && r_assignment->target.declaration)
1087                 record_assignment(r_assignment->target, expr);
1088         if(!r_side_effects)
1089                 unused_nodes.insert(&expr);
1090 }
1091
1092 void UnusedVariableRemover::visit(VariableDeclaration &var)
1093 {
1094         VariableInfo &var_info = variables[&var];
1095         var_info.interface_block = interface_block;
1096
1097         /* Mark variables as output if they're used by the next stage or the
1098         graphics API. */
1099         if(interface_block)
1100                 var_info.output = (interface_block->interface=="out" && (interface_block->linked_block || !interface_block->name.compare(0, 3, "gl_")));
1101         else
1102                 var_info.output = (var.interface=="out" && (stage->type==Stage::FRAGMENT || var.linked_declaration || !var.name.compare(0, 3, "gl_")));
1103
1104         if(var.init_expression)
1105         {
1106                 var_info.initialized = true;
1107                 record_assignment(&var, *var.init_expression);
1108         }
1109         TraversingVisitor::visit(var);
1110 }
1111
1112 void UnusedVariableRemover::visit(InterfaceBlock &iface)
1113 {
1114         if(iface.instance_name.empty())
1115         {
1116                 SetForScope<InterfaceBlock *> set_block(interface_block, &iface);
1117                 iface.struct_declaration->members.visit(*this);
1118         }
1119         else
1120         {
1121                 VariableInfo &var_info = variables[&iface];
1122                 var_info.output = (iface.interface=="out" && (iface.linked_block || !iface.name.compare(0, 3, "gl_")));
1123         }
1124 }
1125
1126 void UnusedVariableRemover::merge_variables(const BlockVariableMap &other_vars)
1127 {
1128         for(BlockVariableMap::const_iterator i=other_vars.begin(); i!=other_vars.end(); ++i)
1129         {
1130                 BlockVariableMap::iterator j = variables.find(i->first);
1131                 if(j!=variables.end())
1132                 {
1133                         /* The merged blocks started as copies of each other so any common
1134                         assignments must be in the beginning. */
1135                         unsigned k = 0;
1136                         for(; (k<i->second.assignments.size() && k<j->second.assignments.size()); ++k)
1137                                 if(i->second.assignments[k]!=j->second.assignments[k])
1138                                         break;
1139
1140                         // Remaining assignments are unique to each block; merge them.
1141                         j->second.assignments.insert(j->second.assignments.end(), i->second.assignments.begin()+k, i->second.assignments.end());
1142                         j->second.referenced |= i->second.referenced;
1143                 }
1144                 else
1145                         variables.insert(*i);
1146         }
1147 }
1148
1149 void UnusedVariableRemover::visit(FunctionDeclaration &func)
1150 {
1151         if(func.body.body.empty())
1152                 return;
1153
1154         BlockVariableMap saved_vars = variables;
1155         // Assignments from other functions should not be visible.
1156         for(BlockVariableMap::iterator i=variables.begin(); i!=variables.end(); ++i)
1157                 i->second.assignments.resize(i->second.initialized);
1158         TraversingVisitor::visit(func);
1159         swap(variables, saved_vars);
1160         merge_variables(saved_vars);
1161
1162         /* Always treat function parameters as referenced.  Removing unused
1163         parameters is not currently supported. */
1164         for(NodeArray<VariableDeclaration>::iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
1165         {
1166                 BlockVariableMap::iterator j = variables.find(i->get());
1167                 if(j!=variables.end())
1168                         j->second.referenced = true;
1169         }
1170 }
1171
1172 void UnusedVariableRemover::visit(Conditional &cond)
1173 {
1174         cond.condition->visit(*this);
1175         BlockVariableMap saved_vars = variables;
1176         cond.body.visit(*this);
1177         swap(saved_vars, variables);
1178         cond.else_body.visit(*this);
1179
1180         /* Visible assignments after the conditional is the union of those visible
1181         at the end of the if and else blocks.  If there was no else block, then it's
1182         the union of the if block and the state before it. */
1183         merge_variables(saved_vars);
1184 }
1185
1186 void UnusedVariableRemover::visit(Iteration &iter)
1187 {
1188         BlockVariableMap saved_vars = variables;
1189         TraversingVisitor::visit(iter);
1190
1191         /* Merge assignments from the iteration, without clearing previous state.
1192         Further analysis is needed to determine which parts of the iteration body
1193         are always executed, if any. */
1194         merge_variables(saved_vars);
1195 }
1196
1197
1198 bool UnusedFunctionRemover::apply(Stage &stage)
1199 {
1200         stage.content.visit(*this);
1201         NodeRemover().apply(stage, unused_nodes);
1202         return !unused_nodes.empty();
1203 }
1204
1205 void UnusedFunctionRemover::visit(FunctionCall &call)
1206 {
1207         TraversingVisitor::visit(call);
1208
1209         unused_nodes.erase(call.declaration);
1210         if(call.declaration && call.declaration->definition!=call.declaration)
1211                 used_definitions.insert(call.declaration->definition);
1212 }
1213
1214 void UnusedFunctionRemover::visit(FunctionDeclaration &func)
1215 {
1216         TraversingVisitor::visit(func);
1217
1218         if((func.name!="main" || func.body.body.empty()) && !used_definitions.count(&func))
1219                 unused_nodes.insert(&func);
1220 }
1221
1222 } // namespace SL
1223 } // namespace GL
1224 } // namespace Msp