]> git.tdb.fi Git - libs/gl.git/blob - optimize.cpp
1e974c9f01f8eee6d3eea5faabc00b20759ffd37
[libs/gl.git] / optimize.cpp
1 #include <msp/core/algorithm.h>
2 #include <msp/core/raii.h>
3 #include <msp/strings/format.h>
4 #include <msp/strings/utils.h>
5 #include "optimize.h"
6 #include "reflect.h"
7
8 using namespace std;
9
10 namespace Msp {
11 namespace GL {
12 namespace SL {
13
14 void ConstantSpecializer::apply(Stage &stage, const map<string, int> &v)
15 {
16         values = &v;
17         stage.content.visit(*this);
18 }
19
20 void ConstantSpecializer::visit(VariableDeclaration &var)
21 {
22         bool specializable = false;
23         if(var.layout)
24         {
25                 vector<Layout::Qualifier> &qualifiers = var.layout->qualifiers;
26                 auto i = find_member(qualifiers, string("constant_id"), &Layout::Qualifier::name);
27                 if(i!=qualifiers.end())
28                 {
29                         specializable = true;
30                         qualifiers.erase(i);
31                         if(qualifiers.empty())
32                                 var.layout = 0;
33                 }
34         }
35
36         if(specializable)
37         {
38                 auto i = values->find(var.name);
39                 if(i!=values->end())
40                 {
41                         RefPtr<Literal> literal = new Literal;
42                         if(var.type=="bool")
43                         {
44                                 literal->token = (i->second ? "true" : "false");
45                                 literal->value = static_cast<bool>(i->second);
46                         }
47                         else if(var.type=="int")
48                         {
49                                 literal->token = lexical_cast<string>(i->second);
50                                 literal->value = i->second;
51                         }
52                         var.init_expression = literal;
53                 }
54         }
55 }
56
57
58 void InlineableFunctionLocator::visit(FunctionCall &call)
59 {
60         FunctionDeclaration *def = call.declaration;
61         if(def)
62                 def = def->definition;
63
64         if(def)
65         {
66                 unsigned &count = refcounts[def];
67                 ++count;
68                 /* Don't inline functions which are called more than once or are called
69                 recursively. */
70                 if((count>1 && def->source!=BUILTIN_SOURCE) || def==current_function)
71                         inlineable.erase(def);
72         }
73
74         TraversingVisitor::visit(call);
75 }
76
77 void InlineableFunctionLocator::visit(FunctionDeclaration &func)
78 {
79         bool has_out_params = any_of(func.parameters.begin(), func.parameters.end(),
80                 [](const RefPtr<VariableDeclaration> &p){ return p->interface=="out"; });
81
82         unsigned &count = refcounts[func.definition];
83         if((count<=1 || func.source==BUILTIN_SOURCE) && !has_out_params)
84                 inlineable.insert(func.definition);
85
86         SetForScope<FunctionDeclaration *> set(current_function, &func);
87         return_count = 0;
88         TraversingVisitor::visit(func);
89 }
90
91 void InlineableFunctionLocator::visit(Conditional &cond)
92 {
93         TraversingVisitor::visit(cond);
94         inlineable.erase(current_function);
95 }
96
97 void InlineableFunctionLocator::visit(Iteration &iter)
98 {
99         TraversingVisitor::visit(iter);
100         inlineable.erase(current_function);
101 }
102
103 void InlineableFunctionLocator::visit(Return &ret)
104 {
105         TraversingVisitor::visit(ret);
106         if(return_count)
107                 inlineable.erase(current_function);
108         ++return_count;
109 }
110
111
112 string InlineContentInjector::apply(Stage &stage, FunctionDeclaration &target_func, Block &tgt_blk, const NodeList<Statement>::iterator &ins_pt, FunctionCall &call)
113 {
114         source_func = call.declaration->definition;
115
116         /* Populate referenced_names from the target function so we can rename
117         variables from the inlined function that would conflict. */
118         pass = REFERENCED;
119         target_func.visit(*this);
120
121         /* Inline and rename passes must be interleaved so used variable names are
122         known when inlining the return statement. */
123         pass = INLINE;
124         staging_block.parent = &tgt_blk;
125         staging_block.variables.clear();
126
127         vector<RefPtr<VariableDeclaration> > params;
128         params.reserve(source_func->parameters.size());
129         for(const RefPtr<VariableDeclaration> &p: source_func->parameters)
130         {
131                 RefPtr<VariableDeclaration> var = p->clone();
132                 var->interface.clear();
133
134                 SetForScope<Pass> set_pass(pass, RENAME);
135                 var->visit(*this);
136
137                 staging_block.body.push_back_nocopy(var);
138                 params.push_back(var);
139         }
140
141         for(const RefPtr<Statement> &s: source_func->body.body)
142         {
143                 r_inlined_statement = 0;
144                 s->visit(*this);
145                 if(!r_inlined_statement)
146                         r_inlined_statement = s->clone();
147
148                 SetForScope<Pass> set_pass(pass, RENAME);
149                 r_inlined_statement->visit(*this);
150
151                 staging_block.body.push_back_nocopy(r_inlined_statement);
152         }
153
154         /* Now collect names from the staging block.  Local variables that would
155         have conflicted with the target function were renamed earlier. */
156         pass = REFERENCED;
157         referenced_names.clear();
158         staging_block.variables.clear();
159         staging_block.visit(*this);
160
161         /* Rename variables in the target function so they don't interfere with
162         global identifiers used by the source function. */
163         pass = RENAME;
164         staging_block.parent = source_func->body.parent;
165         target_func.visit(*this);
166
167         // Put the argument expressions in place after all renaming has been done.
168         for(unsigned i=0; i<source_func->parameters.size(); ++i)
169                 params[i]->init_expression = call.arguments[i]->clone();
170
171         tgt_blk.body.splice(ins_pt, staging_block.body);
172
173         NodeReorderer().apply(stage, target_func, DependencyCollector().apply(*source_func));
174
175         return r_result_name;
176 }
177
178 void InlineContentInjector::visit(VariableReference &var)
179 {
180         if(pass==RENAME)
181         {
182                 auto i = staging_block.variables.find(var.name);
183                 if(i!=staging_block.variables.end())
184                         var.name = i->second->name;
185         }
186         else if(pass==REFERENCED)
187                 referenced_names.insert(var.name);
188 }
189
190 void InlineContentInjector::visit(InterfaceBlockReference &iface)
191 {
192         if(pass==REFERENCED)
193                 referenced_names.insert(iface.name);
194 }
195
196 void InlineContentInjector::visit(FunctionCall &call)
197 {
198         if(pass==REFERENCED)
199                 referenced_names.insert(call.name);
200         TraversingVisitor::visit(call);
201 }
202
203 void InlineContentInjector::visit(VariableDeclaration &var)
204 {
205         TraversingVisitor::visit(var);
206
207         if(pass==RENAME)
208         {
209                 /* Check against conflicts with the other context as well as variables
210                 already renamed here. */
211                 bool conflict = (staging_block.variables.count(var.name) || referenced_names.count(var.name));
212                 staging_block.variables[var.name] = &var;
213                 if(conflict)
214                 {
215                         string mapped_name = get_unused_variable_name(staging_block, var.name);
216                         if(mapped_name!=var.name)
217                         {
218                                 staging_block.variables[mapped_name] = &var;
219                                 var.name = mapped_name;
220                         }
221                 }
222         }
223         else if(pass==REFERENCED)
224                 referenced_names.insert(var.type);
225 }
226
227 void InlineContentInjector::visit(Return &ret)
228 {
229         TraversingVisitor::visit(ret);
230
231         if(pass==INLINE && ret.expression)
232         {
233                 // Create a new variable to hold the return value of the inlined function.
234                 r_result_name = get_unused_variable_name(staging_block, "_return");
235                 RefPtr<VariableDeclaration> var = new VariableDeclaration;
236                 var->source = ret.source;
237                 var->line = ret.line;
238                 var->type = source_func->return_type;
239                 var->name = r_result_name;
240                 var->init_expression = ret.expression->clone();
241                 r_inlined_statement = var;
242         }
243 }
244
245
246 bool FunctionInliner::apply(Stage &s)
247 {
248         stage = &s;
249         inlineable = InlineableFunctionLocator().apply(s);
250         r_any_inlined = false;
251         s.content.visit(*this);
252         return r_any_inlined;
253 }
254
255 void FunctionInliner::visit(RefPtr<Expression> &ptr)
256 {
257         r_inline_result = 0;
258         ptr->visit(*this);
259         if(r_inline_result)
260         {
261                 ptr = r_inline_result;
262                 r_any_inlined = true;
263         }
264         r_inline_result = 0;
265 }
266
267 void FunctionInliner::visit(Block &block)
268 {
269         SetForScope<Block *> set_block(current_block, &block);
270         SetForScope<NodeList<Statement>::iterator> save_insert_point(insert_point, block.body.begin());
271         for(auto i=block.body.begin(); (!r_inlined_here && i!=block.body.end()); ++i)
272         {
273                 insert_point = i;
274                 (*i)->visit(*this);
275         }
276 }
277
278 void FunctionInliner::visit(FunctionCall &call)
279 {
280         for(auto i=call.arguments.begin(); (!r_inlined_here && i!=call.arguments.end()); ++i)
281                 visit(*i);
282
283         if(r_inlined_here)
284                 return;
285
286         FunctionDeclaration *def = call.declaration;
287         if(def)
288                 def = def->definition;
289
290         if(def && inlineable.count(def))
291         {
292                 string result_name = InlineContentInjector().apply(*stage, *current_function, *current_block, insert_point, call);
293
294                 // This will later get removed by UnusedVariableRemover.
295                 if(result_name.empty())
296                         result_name = "_msp_unused_from_inline";
297
298                 RefPtr<VariableReference> ref = new VariableReference;
299                 ref->name = result_name;
300                 r_inline_result = ref;
301
302                 /* Inlined variables need to be resolved before this function can be
303                 inlined further. */
304                 inlineable.erase(current_function);
305                 r_inlined_here = true;
306         }
307 }
308
309 void FunctionInliner::visit(FunctionDeclaration &func)
310 {
311         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
312         TraversingVisitor::visit(func);
313         r_inlined_here = false;
314 }
315
316 void FunctionInliner::visit(Iteration &iter)
317 {
318         /* Visit the initialization statement before entering the loop body so the
319         inlined statements get inserted outside. */
320         if(iter.init_statement)
321                 iter.init_statement->visit(*this);
322
323         SetForScope<Block *> set_block(current_block, &iter.body);
324         /* Skip the condition and loop expression parts because they're not properly
325         inside the body block.  Inlining anything into them will require a more
326         comprehensive transformation. */
327         iter.body.visit(*this);
328 }
329
330
331 bool ExpressionInliner::apply(Stage &s)
332 {
333         s.content.visit(*this);
334
335         bool any_inlined = false;
336         for(ExpressionInfo &e: expressions)
337                 if(e.expression && (e.trivial || e.uses.size()==1))
338                 {
339                         for(ExpressionUse &u: e.uses)
340                                 if(!u.blocked)
341                                 {
342                                         *u.reference = e.expression->clone();
343                                         any_inlined = true;
344                                 }
345                 }
346
347         return any_inlined;
348 }
349
350 void ExpressionInliner::visit(RefPtr<Expression> &expr)
351 {
352         r_ref_info = 0;
353         expr->visit(*this);
354         if(r_ref_info && r_ref_info->expression)
355         {
356                 ExpressionUse use;
357                 use.reference = &expr;
358                 use.ref_scope = current_block;
359                 use.blocked = access_write;
360
361                 if(iteration_body && !r_ref_info->trivial)
362                 {
363                         /* Block inlining of non-trivial expressions assigned outside an
364                         iteration statement.  The iteration may run multiple times, which
365                         would cause the expression to also be evaluated multiple times. */
366                         for(Block *i=iteration_body->parent; (!use.blocked && i); i=i->parent)
367                                 use.blocked = (i==r_ref_info->assign_scope);
368                 }
369
370                 /* Block inlining assignments from from inner scopes.  The assignment may
371                 depend on local variables of that scope or may not always be executed. */
372                 for(Block *i=r_ref_info->assign_scope->parent; (!use.blocked && i); i=i->parent)
373                         use.blocked = (i==current_block);
374
375                 r_ref_info->uses.push_back(use);
376         }
377         r_oper = expr->oper;
378         r_ref_info = 0;
379 }
380
381 void ExpressionInliner::visit(VariableReference &var)
382 {
383         if(var.declaration && access_read)
384         {
385                 auto i = assignments.find(var.declaration);
386                 if(i!=assignments.end())
387                         r_ref_info = i->second;
388         }
389 }
390
391 void ExpressionInliner::visit(MemberAccess &memacc)
392 {
393         visit(memacc.left);
394         r_trivial = false;
395 }
396
397 void ExpressionInliner::visit(Swizzle &swizzle)
398 {
399         visit(swizzle.left);
400         r_trivial = false;
401 }
402
403 void ExpressionInliner::visit(UnaryExpression &unary)
404 {
405         SetFlag set_write(access_write, access_write || unary.oper->token[1]=='+' || unary.oper->token[1]=='-');
406         visit(unary.expression);
407         r_trivial = false;
408 }
409
410 void ExpressionInliner::visit(BinaryExpression &binary)
411 {
412         visit(binary.left);
413         {
414                 SetFlag clear_write(access_write, false);
415                 visit(binary.right);
416         }
417         r_trivial = false;
418 }
419
420 void ExpressionInliner::visit(Assignment &assign)
421 {
422         {
423                 SetFlag set_read(access_read, assign.oper->token[0]!='=');
424                 SetFlag set_write(access_write);
425                 visit(assign.left);
426         }
427         r_oper = 0;
428         r_trivial = true;
429         visit(assign.right);
430
431         auto i = assignments.find(assign.target);
432         if(i!=assignments.end())
433         {
434                 if(iteration_body && i->second->expression)
435                 {
436                         /* Block inlining into previous references within the iteration
437                         statement.  On iterations after the first they would refer to the
438                         assignment within the iteration. */
439                         for(ExpressionUse &u: i->second->uses)
440                                 for(Block *k=u.ref_scope; (!u.blocked && k); k=k->parent)
441                                         u.blocked = (k==iteration_body);
442                 }
443
444                 expressions.push_back(ExpressionInfo());
445                 ExpressionInfo &info = expressions.back();
446                 info.target = assign.target;
447                 // Self-referencing assignments can't be inlined without additional work.
448                 if(!assign.self_referencing)
449                         info.expression = assign.right;
450                 info.assign_scope = current_block;
451                 info.trivial = r_trivial;
452
453                 i->second = &info;
454         }
455
456         r_trivial = false;
457 }
458
459 void ExpressionInliner::visit(TernaryExpression &ternary)
460 {
461         visit(ternary.condition);
462         visit(ternary.true_expr);
463         visit(ternary.false_expr);
464         r_trivial = false;
465 }
466
467 void ExpressionInliner::visit(FunctionCall &call)
468 {
469         TraversingVisitor::visit(call);
470         r_trivial = false;
471 }
472
473 void ExpressionInliner::visit(VariableDeclaration &var)
474 {
475         r_oper = 0;
476         r_trivial = true;
477         TraversingVisitor::visit(var);
478
479         bool constant = var.constant;
480         if(constant && var.layout)
481         {
482                 constant = !any_of(var.layout->qualifiers.begin(), var.layout->qualifiers.end(),
483                         [](const Layout::Qualifier &q){ return q.name=="constant_id"; });
484         }
485
486         /* Only inline global variables if they're constant and have trivial
487         initializers.  Non-constant variables could change in ways which are hard to
488         analyze and non-trivial expressions could be expensive to inline.  */
489         if((current_block->parent || (constant && r_trivial)) && var.interface.empty())
490         {
491                 expressions.push_back(ExpressionInfo());
492                 ExpressionInfo &info = expressions.back();
493                 info.target = &var;
494                 /* Assume variables declared in an iteration initialization statement
495                 will have their values change throughout the iteration. */
496                 if(!iteration_init)
497                         info.expression = var.init_expression;
498                 info.assign_scope = current_block;
499                 info.trivial = r_trivial;
500
501                 assignments[&var] = &info;
502         }
503 }
504
505 void ExpressionInliner::visit(Iteration &iter)
506 {
507         SetForScope<Block *> set_block(current_block, &iter.body);
508         if(iter.init_statement)
509         {
510                 SetFlag set_init(iteration_init);
511                 iter.init_statement->visit(*this);
512         }
513
514         SetForScope<Block *> set_body(iteration_body, &iter.body);
515         if(iter.condition)
516                 visit(iter.condition);
517         iter.body.visit(*this);
518         if(iter.loop_expression)
519                 visit(iter.loop_expression);
520 }
521
522
523 bool AggregateDismantler::apply(Stage &stage)
524 {
525         stage.content.visit(*this);
526
527         bool any_dismantled = false;
528         for(const auto &kvp: aggregates)
529         {
530                 if(kvp.second.referenced || !kvp.second.members_referenced)
531                         continue;
532
533                 for(const AggregateMember &m: kvp.second.members)
534                 {
535                         string name;
536                         if(m.declaration)
537                                 name = format("%s_%s", kvp.second.declaration->name, m.declaration->name);
538                         else
539                                 name = format("%s_%d", kvp.second.declaration->name, m.index);
540
541                         VariableDeclaration *var = new VariableDeclaration;
542                         var->source = kvp.first->source;
543                         var->line = kvp.first->line;
544                         var->name = get_unused_variable_name(*kvp.second.decl_scope, name);
545                         /* XXX This is kind of brittle and depends on the array declaration's
546                         textual type not having brackets in it. */
547                         var->type = (m.declaration ? m.declaration : kvp.second.declaration)->type;
548                         if(m.initializer)
549                                 var->init_expression = m.initializer->clone();
550
551                         kvp.second.decl_scope->body.insert(kvp.second.insert_point, var);
552
553                         for(RefPtr<Expression> *r: m.references)
554                         {
555                                 VariableReference *ref = new VariableReference;
556                                 ref->name = var->name;
557                                 *r = ref;
558                         }
559
560                         any_dismantled = true;
561                 }
562         }
563
564         return any_dismantled;
565 }
566
567 void AggregateDismantler::visit(Block &block)
568 {
569         SetForScope<Block *> set_block(current_block, &block);
570         for(auto i=block.body.begin(); i!=block.body.end(); ++i)
571         {
572                 insert_point = i;
573                 (*i)->visit(*this);
574         }
575 }
576
577 void AggregateDismantler::visit(RefPtr<Expression> &expr)
578 {
579         r_aggregate_ref = 0;
580         expr->visit(*this);
581         if(r_aggregate_ref && r_reference.chain_len==1)
582         {
583                 if((r_reference.chain[0]&0x3F)!=0x3F)
584                 {
585                         r_aggregate_ref->members[r_reference.chain[0]&0x3F].references.push_back(&expr);
586                         r_aggregate_ref->members_referenced = true;
587                 }
588                 else
589                         /* If the accessed member is not known, mark the entire aggregate as
590                         referenced. */
591                         r_aggregate_ref->referenced = true;
592         }
593         r_aggregate_ref = 0;
594 }
595
596 void AggregateDismantler::visit(VariableReference &var)
597 {
598         if(composite_reference)
599                 r_reference.declaration = var.declaration;
600         else
601         {
602                 /* If an aggregate variable is referenced as a whole, it should not be
603                 dismantled. */
604                 auto i = aggregates.find(var.declaration);
605                 if(i!=aggregates.end())
606                         i->second.referenced = true;
607         }
608 }
609
610 void AggregateDismantler::visit_composite(RefPtr<Expression> &expr)
611 {
612         if(!composite_reference)
613                 r_reference = Assignment::Target();
614
615         SetFlag set_composite(composite_reference);
616         visit(expr);
617 }
618
619 void AggregateDismantler::visit(MemberAccess &memacc)
620 {
621         visit_composite(memacc.left);
622
623         add_to_chain(r_reference, Assignment::Target::MEMBER, memacc.index);
624
625         if(r_reference.declaration && r_reference.chain_len==1)
626         {
627                 auto i = aggregates.find(r_reference.declaration);
628                 r_aggregate_ref = (i!=aggregates.end() ? &i->second : 0);
629         }
630         else
631                 r_aggregate_ref = 0;
632 }
633
634 void AggregateDismantler::visit(BinaryExpression &binary)
635 {
636         if(binary.oper->token[0]=='[')
637         {
638                 visit_composite(binary.left);
639                 {
640                         SetFlag clear_composite(composite_reference, false);
641                         visit(binary.right);
642                 }
643
644                 unsigned index = 0x3F;
645                 if(Literal *literal_subscript = dynamic_cast<Literal *>(binary.right.get()))
646                         if(literal_subscript->value.check_type<int>())
647                                 index = literal_subscript->value.value<int>();
648                 add_to_chain(r_reference, Assignment::Target::ARRAY, index);
649
650                 if(r_reference.declaration && r_reference.chain_len==1)
651                 {
652                         auto i = aggregates.find(r_reference.declaration);
653                         r_aggregate_ref = (i!=aggregates.end() ? &i->second : 0);
654                 }
655                 else
656                         r_aggregate_ref = 0;
657         }
658         else
659         {
660                 SetFlag clear_composite(composite_reference, false);
661                 TraversingVisitor::visit(binary);
662         }
663 }
664
665 void AggregateDismantler::visit(VariableDeclaration &var)
666 {
667         TraversingVisitor::visit(var);
668
669         if(var.interface.empty())
670         {
671                 if(const StructDeclaration *strct = dynamic_cast<const StructDeclaration *>(var.type_declaration))
672                 {
673                         const FunctionCall *init_call = dynamic_cast<const FunctionCall *>(var.init_expression.get());
674                         if((init_call && init_call->constructor) || !var.init_expression)
675                         {
676
677                                 Aggregate &aggre = aggregates[&var];
678                                 aggre.declaration = &var;
679                                 aggre.decl_scope = current_block;
680                                 aggre.insert_point = insert_point;
681
682                                 unsigned i = 0;
683                                 for(const RefPtr<Statement> &s: strct->members.body)
684                                 {
685                                         if(const VariableDeclaration *mem_decl = dynamic_cast<const VariableDeclaration *>(s.get()))
686                                         {
687                                                 AggregateMember member;
688                                                 member.declaration = mem_decl;
689                                                 member.index = i;
690                                                 if(init_call)
691                                                         member.initializer = init_call->arguments[i];
692                                                 aggre.members.push_back(member);
693                                         }
694                                         ++i;
695                                 }
696                         }
697                 }
698                 else if(const Literal *literal_size = dynamic_cast<const Literal *>(var.array_size.get()))
699                 {
700                         if(literal_size->value.check_type<int>())
701                         {
702                                 Aggregate &aggre = aggregates[&var];
703                                 aggre.declaration = &var;
704                                 aggre.decl_scope = current_block;
705                                 aggre.insert_point = insert_point;
706
707                                 int size = literal_size->value.value<int>();
708                                 for(int i=0; i<size; ++i)
709                                 {
710                                         AggregateMember member;
711                                         member.index = i;
712                                         // Array initializers are not supported yet
713                                         aggre.members.push_back(member);
714                                 }
715                         }
716                 }
717         }
718 }
719
720 void AggregateDismantler::visit(FunctionDeclaration &func)
721 {
722         func.body.visit(*this);
723 }
724
725
726 template<typename T>
727 T ConstantFolder::evaluate_logical(char oper, T left, T right)
728 {
729         switch(oper)
730         {
731         case '&': return left&right;
732         case '|': return left|right;
733         case '^': return left^right;
734         default: return T();
735         }
736 }
737
738 template<typename T>
739 bool ConstantFolder::evaluate_relation(const char *oper, T left, T right)
740 {
741         switch(oper[0]|oper[1])
742         {
743         case '<': return left<right;
744         case '<'|'=': return left<=right;
745         case '>': return left>right;
746         case '>'|'=': return left>=right;
747         default: return false;
748         }
749 }
750
751 template<typename T>
752 T ConstantFolder::evaluate_arithmetic(char oper, T left, T right)
753 {
754         switch(oper)
755         {
756         case '+': return left+right;
757         case '-': return left-right;
758         case '*': return left*right;
759         case '/': return left/right;
760         default: return T();
761         }
762 }
763
764 template<typename T>
765 T ConstantFolder::evaluate_int_special_op(char oper, T left, T right)
766 {
767         switch(oper)
768         {
769         case '%': return left%right;
770         case '<': return left<<right;
771         case '>': return left>>right;
772         default: return T();
773         }
774 }
775
776 template<typename T>
777 void ConstantFolder::convert_to_result(const Variant &value)
778 {
779         if(value.check_type<bool>())
780                 set_result(static_cast<T>(value.value<bool>()));
781         else if(value.check_type<int>())
782                 set_result(static_cast<T>(value.value<int>()));
783         else if(value.check_type<unsigned>())
784                 set_result(static_cast<T>(value.value<unsigned>()));
785         else if(value.check_type<float>())
786                 set_result(static_cast<T>(value.value<float>()));
787 }
788
789 void ConstantFolder::set_result(const Variant &value, bool literal)
790 {
791         r_constant_value = value;
792         r_constant = true;
793         r_literal = literal;
794 }
795
796 void ConstantFolder::visit(RefPtr<Expression> &expr)
797 {
798         r_constant_value = Variant();
799         r_constant = false;
800         r_literal = false;
801         r_uses_iter_var = false;
802         expr->visit(*this);
803         /* Don't replace literals since they'd only be replaced with an identical
804         literal.  Also skip anything that uses an iteration variable, but pass on
805         the result so the Iteration visiting function can handle it. */
806         if(!r_constant || r_literal || r_uses_iter_var)
807                 return;
808
809         RefPtr<Literal> literal = new Literal;
810         if(r_constant_value.check_type<bool>())
811                 literal->token = (r_constant_value.value<bool>() ? "true" : "false");
812         else if(r_constant_value.check_type<int>())
813                 literal->token = lexical_cast<string>(r_constant_value.value<int>());
814         else if(r_constant_value.check_type<unsigned>())
815                 literal->token = lexical_cast<string>(r_constant_value.value<unsigned>())+"u";
816         else if(r_constant_value.check_type<float>())
817         {
818                 literal->token = lexical_cast<string>(r_constant_value.value<float>(), Fmt().precision(8));
819                 if(literal->token.find('.')==string::npos && literal->token.find('e')==string::npos)
820                         literal->token += ".0";
821         }
822         else
823         {
824                 r_constant = false;
825                 return;
826         }
827         literal->value = r_constant_value;
828         expr = literal;
829         r_any_folded = true;
830 }
831
832 void ConstantFolder::visit(Literal &literal)
833 {
834         set_result(literal.value, true);
835 }
836
837 void ConstantFolder::visit(VariableReference &var)
838 {
839         /* If an iteration variable is initialized with a constant value, return
840         that value here for the purpose of evaluating the loop condition for the
841         first iteration. */
842         if(var.declaration==iteration_var)
843         {
844                 set_result(iter_init_value);
845                 r_uses_iter_var = true;
846         }
847 }
848
849 void ConstantFolder::visit(MemberAccess &memacc)
850 {
851         TraversingVisitor::visit(memacc);
852         r_constant = false;
853 }
854
855 void ConstantFolder::visit(Swizzle &swizzle)
856 {
857         TraversingVisitor::visit(swizzle);
858         r_constant = false;
859 }
860
861 void ConstantFolder::visit(UnaryExpression &unary)
862 {
863         TraversingVisitor::visit(unary);
864         bool can_fold = r_constant;
865         r_constant = false;
866         if(!can_fold)
867                 return;
868
869         char oper = unary.oper->token[0];
870         char oper2 = unary.oper->token[1];
871         if(oper=='!')
872         {
873                 if(r_constant_value.check_type<bool>())
874                         set_result(!r_constant_value.value<bool>());
875         }
876         else if(oper=='~')
877         {
878                 if(r_constant_value.check_type<int>())
879                         set_result(~r_constant_value.value<int>());
880                 else if(r_constant_value.check_type<unsigned>())
881                         set_result(~r_constant_value.value<unsigned>());
882         }
883         else if(oper=='-' && !oper2)
884         {
885                 if(r_constant_value.check_type<int>())
886                         set_result(-r_constant_value.value<int>());
887                 else if(r_constant_value.check_type<unsigned>())
888                         set_result(-r_constant_value.value<unsigned>());
889                 else if(r_constant_value.check_type<float>())
890                         set_result(-r_constant_value.value<float>());
891         }
892 }
893
894 void ConstantFolder::visit(BinaryExpression &binary)
895 {
896         visit(binary.left);
897         bool left_constant = r_constant;
898         bool left_iter_var = r_uses_iter_var;
899         Variant left_value = r_constant_value;
900         visit(binary.right);
901         if(left_iter_var)
902                 r_uses_iter_var = true;
903
904         bool can_fold = (left_constant && r_constant);
905         r_constant = false;
906         if(!can_fold)
907                 return;
908
909         // Currently only expressions with both sides of equal types are handled.
910         if(!left_value.check_same_type(r_constant_value))
911                 return;
912
913         char oper = binary.oper->token[0];
914         char oper2 = binary.oper->token[1];
915         if(oper=='&' || oper=='|' || oper=='^')
916         {
917                 if(oper2==oper && left_value.check_type<bool>())
918                         set_result(evaluate_logical(oper, left_value.value<bool>(), r_constant_value.value<bool>()));
919                 else if(!oper2 && left_value.check_type<int>())
920                         set_result(evaluate_logical(oper, left_value.value<int>(), r_constant_value.value<int>()));
921                 else if(!oper2 && left_value.check_type<unsigned>())
922                         set_result(evaluate_logical(oper, left_value.value<unsigned>(), r_constant_value.value<unsigned>()));
923         }
924         else if((oper=='<' || oper=='>') && oper2!=oper)
925         {
926                 if(left_value.check_type<int>())
927                         set_result(evaluate_relation(binary.oper->token, left_value.value<int>(), r_constant_value.value<int>()));
928                 else if(left_value.check_type<unsigned>())
929                         set_result(evaluate_relation(binary.oper->token, left_value.value<unsigned>(), r_constant_value.value<unsigned>()));
930                 else if(left_value.check_type<float>())
931                         set_result(evaluate_relation(binary.oper->token, left_value.value<float>(), r_constant_value.value<float>()));
932         }
933         else if((oper=='=' || oper=='!') && oper2=='=')
934         {
935                 if(left_value.check_type<int>())
936                         set_result((left_value.value<int>()==r_constant_value.value<int>()) == (oper=='='));
937                 else if(left_value.check_type<unsigned>())
938                         set_result((left_value.value<unsigned>()==r_constant_value.value<unsigned>()) == (oper=='='));
939                 else if(left_value.check_type<float>())
940                         set_result((left_value.value<float>()==r_constant_value.value<float>()) == (oper=='='));
941         }
942         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
943         {
944                 if(left_value.check_type<int>())
945                         set_result(evaluate_arithmetic(oper, left_value.value<int>(), r_constant_value.value<int>()));
946                 else if(left_value.check_type<unsigned>())
947                         set_result(evaluate_arithmetic(oper, left_value.value<unsigned>(), r_constant_value.value<unsigned>()));
948                 else if(left_value.check_type<float>())
949                         set_result(evaluate_arithmetic(oper, left_value.value<float>(), r_constant_value.value<float>()));
950         }
951         else if(oper=='%' || ((oper=='<' || oper=='>') && oper2==oper))
952         {
953                 if(left_value.check_type<int>())
954                         set_result(evaluate_int_special_op(oper, left_value.value<int>(), r_constant_value.value<int>()));
955                 else if(left_value.check_type<unsigned>())
956                         set_result(evaluate_int_special_op(oper, left_value.value<unsigned>(), r_constant_value.value<unsigned>()));
957         }
958 }
959
960 void ConstantFolder::visit(Assignment &assign)
961 {
962         TraversingVisitor::visit(assign);
963         r_constant = false;
964 }
965
966 void ConstantFolder::visit(TernaryExpression &ternary)
967 {
968         TraversingVisitor::visit(ternary);
969         r_constant = false;
970 }
971
972 void ConstantFolder::visit(FunctionCall &call)
973 {
974         if(call.constructor && call.type && call.arguments.size()==1)
975         {
976                 const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(call.type);
977                 if(basic)
978                 {
979                         visit(call.arguments[0]);
980                         bool can_fold = r_constant;
981                         r_constant = false;
982                         if(!can_fold)
983                                 return;
984
985                         if(basic->kind==BasicTypeDeclaration::BOOL)
986                                 convert_to_result<bool>(r_constant_value);
987                         else if(basic->kind==BasicTypeDeclaration::INT && basic->size==32 && basic->sign)
988                                 convert_to_result<int>(r_constant_value);
989                         else if(basic->kind==BasicTypeDeclaration::INT && basic->size==32 && !basic->sign)
990                                 convert_to_result<unsigned>(r_constant_value);
991                         else if(basic->kind==BasicTypeDeclaration::FLOAT && basic->size==32)
992                                 convert_to_result<float>(r_constant_value);
993
994                         return;
995                 }
996         }
997
998         TraversingVisitor::visit(call);
999         r_constant = false;
1000 }
1001
1002 void ConstantFolder::visit(VariableDeclaration &var)
1003 {
1004         if(iteration_init && var.init_expression)
1005         {
1006                 visit(var.init_expression);
1007                 if(r_constant)
1008                 {
1009                         /* Record the value of a constant initialization expression of an
1010                         iteration, so it can be used to evaluate the loop condition. */
1011                         iteration_var = &var;
1012                         iter_init_value = r_constant_value;
1013                 }
1014         }
1015         else
1016                 TraversingVisitor::visit(var);
1017 }
1018
1019 void ConstantFolder::visit(Iteration &iter)
1020 {
1021         SetForScope<Block *> set_block(current_block, &iter.body);
1022
1023         /* The iteration variable is not normally inlined into expressions, so we
1024         process it specially here.  If the initial value causes the loop condition
1025         to evaluate to false, then the expression can be folded. */
1026         iteration_var = 0;
1027         if(iter.init_statement)
1028         {
1029                 SetFlag set_init(iteration_init);
1030                 iter.init_statement->visit(*this);
1031         }
1032
1033         if(iter.condition)
1034         {
1035                 visit(iter.condition);
1036                 if(r_constant && r_constant_value.check_type<bool>() && !r_constant_value.value<bool>())
1037                 {
1038                         RefPtr<Literal> literal = new Literal;
1039                         literal->token = "false";
1040                         literal->value = r_constant_value;
1041                         iter.condition = literal;
1042                 }
1043         }
1044         iteration_var = 0;
1045
1046         iter.body.visit(*this);
1047         if(iter.loop_expression)
1048                 visit(iter.loop_expression);
1049 }
1050
1051
1052 void ConstantConditionEliminator::apply(Stage &stage)
1053 {
1054         stage.content.visit(*this);
1055         NodeRemover().apply(stage, nodes_to_remove);
1056 }
1057
1058 ConstantConditionEliminator::ConstantStatus ConstantConditionEliminator::check_constant_condition(const Expression &expr)
1059 {
1060         if(const Literal *literal = dynamic_cast<const Literal *>(&expr))
1061                 if(literal->value.check_type<bool>())
1062                         return (literal->value.value<bool>() ? CONSTANT_TRUE : CONSTANT_FALSE);
1063         return NOT_CONSTANT;
1064 }
1065
1066 void ConstantConditionEliminator::visit(Block &block)
1067 {
1068         SetForScope<Block *> set_block(current_block, &block);
1069         for(auto i=block.body.begin(); i!=block.body.end(); ++i)
1070         {
1071                 insert_point = i;
1072                 (*i)->visit(*this);
1073         }
1074 }
1075
1076 void ConstantConditionEliminator::visit(RefPtr<Expression> &expr)
1077 {
1078         r_ternary_result = 0;
1079         expr->visit(*this);
1080         if(r_ternary_result)
1081                 expr = r_ternary_result;
1082         r_ternary_result = 0;
1083 }
1084
1085 void ConstantConditionEliminator::visit(UnaryExpression &unary)
1086 {
1087         if(unary.oper->token[1]=='+' || unary.oper->token[1]=='-')
1088                 if(const VariableReference *var = dynamic_cast<const VariableReference *>(unary.expression.get()))
1089                 {
1090                         auto i = current_block->variables.find(var->name);
1091                         r_external_side_effects = (i==current_block->variables.end() || i->second!=var->declaration);
1092                         return;
1093                 }
1094
1095         TraversingVisitor::visit(unary);
1096 }
1097
1098 void ConstantConditionEliminator::visit(Assignment &assign)
1099 {
1100         auto i = find_if(current_block->variables, [&assign](const pair<string, VariableDeclaration *> &kvp){ return kvp.second==assign.target.declaration; });
1101         if(i==current_block->variables.end())
1102                 r_external_side_effects = true;
1103         TraversingVisitor::visit(assign);
1104 }
1105
1106 void ConstantConditionEliminator::visit(TernaryExpression &ternary)
1107 {
1108         ConstantStatus result = check_constant_condition(*ternary.condition);
1109         if(result!=NOT_CONSTANT)
1110                 r_ternary_result = (result==CONSTANT_TRUE ? ternary.true_expr : ternary.false_expr);
1111         else
1112                 r_ternary_result = 0;
1113 }
1114
1115 void ConstantConditionEliminator::visit(FunctionCall &call)
1116 {
1117         r_external_side_effects = true;
1118         TraversingVisitor::visit(call);
1119 }
1120
1121 void ConstantConditionEliminator::visit(Conditional &cond)
1122 {
1123         ConstantStatus result = check_constant_condition(*cond.condition);
1124         if(result!=NOT_CONSTANT)
1125         {
1126                 Block &block = (result==CONSTANT_TRUE ? cond.body : cond.else_body);
1127                 // TODO should check variable names for conflicts.  Potentially reuse InlineContentInjector?
1128                 current_block->body.splice(insert_point, block.body);
1129                 nodes_to_remove.insert(&cond);
1130                 return;
1131         }
1132
1133         r_external_side_effects = false;
1134         TraversingVisitor::visit(cond);
1135
1136         if(cond.body.body.empty() && cond.else_body.body.empty() && !r_external_side_effects)
1137                 nodes_to_remove.insert(&cond);
1138 }
1139
1140 void ConstantConditionEliminator::visit(Iteration &iter)
1141 {
1142         if(iter.condition)
1143         {
1144                 ConstantStatus result = check_constant_condition(*iter.condition);
1145                 if(result==CONSTANT_FALSE)
1146                 {
1147                         nodes_to_remove.insert(&iter);
1148                         return;
1149                 }
1150         }
1151
1152         r_external_side_effects = false;
1153         TraversingVisitor::visit(iter);
1154         if(iter.body.body.empty() && !r_external_side_effects)
1155                 nodes_to_remove.insert(&iter);
1156 }
1157
1158
1159 bool UnreachableCodeRemover::apply(Stage &stage)
1160 {
1161         stage.content.visit(*this);
1162         NodeRemover().apply(stage, unreachable_nodes);
1163         return !unreachable_nodes.empty();
1164 }
1165
1166 void UnreachableCodeRemover::visit(Block &block)
1167 {
1168         auto i = block.body.begin();
1169         for(; (reachable && i!=block.body.end()); ++i)
1170                 (*i)->visit(*this);
1171         for(; i!=block.body.end(); ++i)
1172                 unreachable_nodes.insert(i->get());
1173 }
1174
1175 void UnreachableCodeRemover::visit(FunctionDeclaration &func)
1176 {
1177         TraversingVisitor::visit(func);
1178         reachable = true;
1179 }
1180
1181 void UnreachableCodeRemover::visit(Conditional &cond)
1182 {
1183         cond.body.visit(*this);
1184         bool reachable_if_true = reachable;
1185         reachable = true;
1186         cond.else_body.visit(*this);
1187
1188         reachable |= reachable_if_true;
1189 }
1190
1191 void UnreachableCodeRemover::visit(Iteration &iter)
1192 {
1193         TraversingVisitor::visit(iter);
1194
1195         /* Always consider code after a loop reachable, since there's no checking
1196         for whether the loop executes. */
1197         reachable = true;
1198 }
1199
1200
1201 bool UnusedTypeRemover::apply(Stage &stage)
1202 {
1203         stage.content.visit(*this);
1204         NodeRemover().apply(stage, unused_nodes);
1205         return !unused_nodes.empty();
1206 }
1207
1208 void UnusedTypeRemover::visit(RefPtr<Expression> &expr)
1209 {
1210         unused_nodes.erase(expr->type);
1211         TraversingVisitor::visit(expr);
1212 }
1213
1214 void UnusedTypeRemover::visit(BasicTypeDeclaration &type)
1215 {
1216         if(type.base_type)
1217                 unused_nodes.erase(type.base_type);
1218         unused_nodes.insert(&type);
1219 }
1220
1221 void UnusedTypeRemover::visit(ImageTypeDeclaration &type)
1222 {
1223         if(type.base_type)
1224                 unused_nodes.erase(type.base_type);
1225         unused_nodes.insert(&type);
1226 }
1227
1228 void UnusedTypeRemover::visit(StructDeclaration &strct)
1229 {
1230         unused_nodes.insert(&strct);
1231         TraversingVisitor::visit(strct);
1232 }
1233
1234 void UnusedTypeRemover::visit(VariableDeclaration &var)
1235 {
1236         unused_nodes.erase(var.type_declaration);
1237         TraversingVisitor::visit(var);
1238 }
1239
1240 void UnusedTypeRemover::visit(InterfaceBlock &iface)
1241 {
1242         unused_nodes.erase(iface.type_declaration);
1243 }
1244
1245 void UnusedTypeRemover::visit(FunctionDeclaration &func)
1246 {
1247         unused_nodes.erase(func.return_type_declaration);
1248         TraversingVisitor::visit(func);
1249 }
1250
1251
1252 bool UnusedVariableRemover::apply(Stage &s)
1253 {
1254         stage = &s;
1255         s.content.visit(*this);
1256
1257         for(const AssignmentInfo &a: assignments)
1258                 if(a.used_by.empty())
1259                         unused_nodes.insert(a.node);
1260
1261         for(const auto &kvp: variables)
1262         {
1263                 if(!kvp.second.referenced)
1264                         unused_nodes.insert(kvp.first);
1265                 else if(kvp.second.output)
1266                 {
1267                         /* The last visible assignments of output variables are used by the
1268                         next stage or the API. */
1269                         for(AssignmentInfo *a: kvp.second.assignments)
1270                                 unused_nodes.erase(a->node);
1271                 }
1272         }
1273
1274         NodeRemover().apply(s, unused_nodes);
1275
1276         return !unused_nodes.empty();
1277 }
1278
1279 void UnusedVariableRemover::referenced(const Assignment::Target &target, Node &node)
1280 {
1281         VariableInfo &var_info = variables[target.declaration];
1282         var_info.referenced = true;
1283         if(!assignment_target)
1284         {
1285                 bool loop_external = false;
1286                 for(AssignmentInfo *a: var_info.assignments)
1287                 {
1288                         bool covered = true;
1289                         for(unsigned j=0; (covered && j<a->target.chain_len && j<target.chain_len); ++j)
1290                         {
1291                                 Assignment::Target::ChainType type1 = static_cast<Assignment::Target::ChainType>(a->target.chain[j]&0xC0);
1292                                 Assignment::Target::ChainType type2 = static_cast<Assignment::Target::ChainType>(target.chain[j]&0xC0);
1293                                 unsigned index1 = a->target.chain[j]&0x3F;
1294                                 unsigned index2 = target.chain[j]&0x3F;
1295                                 if(type1==Assignment::Target::SWIZZLE || type2==Assignment::Target::SWIZZLE)
1296                                 {
1297                                         if(type1==Assignment::Target::SWIZZLE && type2==Assignment::Target::SWIZZLE)
1298                                                 covered = index1&index2;
1299                                         else if(type1==Assignment::Target::ARRAY && index1<4)
1300                                                 covered = index2&(1<<index1);
1301                                         else if(type2==Assignment::Target::ARRAY && index2<4)
1302                                                 covered = index1&(1<<index2);
1303                                         /* If it's some other combination (shouldn't happen), leave
1304                                         covered as true */
1305                                 }
1306                                 else
1307                                         covered = (type1==type2 && (index1==index2 || index1==0x3F || index2==0x3F));
1308                         }
1309
1310                         if(covered)
1311                         {
1312                                 a->used_by.push_back(&node);
1313                                 if(a->in_loop<in_loop)
1314                                         loop_external = true;
1315                         }
1316                 }
1317
1318                 if(loop_external)
1319                         loop_ext_refs.push_back(&node);
1320         }
1321 }
1322
1323 void UnusedVariableRemover::visit(VariableReference &var)
1324 {
1325         if(composite_reference)
1326                 r_reference.declaration = var.declaration;
1327         else
1328                 referenced(var.declaration, var);
1329 }
1330
1331 void UnusedVariableRemover::visit(InterfaceBlockReference &iface)
1332 {
1333         if(composite_reference)
1334                 r_reference.declaration = iface.declaration;
1335         else
1336                 referenced(iface.declaration, iface);
1337 }
1338
1339 void UnusedVariableRemover::visit_composite(Expression &expr)
1340 {
1341         if(!composite_reference)
1342                 r_reference = Assignment::Target();
1343
1344         SetFlag set_composite(composite_reference);
1345         expr.visit(*this);
1346 }
1347
1348 void UnusedVariableRemover::visit(MemberAccess &memacc)
1349 {
1350         visit_composite(*memacc.left);
1351
1352         add_to_chain(r_reference, Assignment::Target::MEMBER, memacc.index);
1353
1354         if(!composite_reference && r_reference.declaration)
1355                 referenced(r_reference, memacc);
1356 }
1357
1358 void UnusedVariableRemover::visit(Swizzle &swizzle)
1359 {
1360         visit_composite(*swizzle.left);
1361
1362         unsigned mask = 0;
1363         for(unsigned i=0; i<swizzle.count; ++i)
1364                 mask |= 1<<swizzle.components[i];
1365         add_to_chain(r_reference, Assignment::Target::SWIZZLE, mask);
1366
1367         if(!composite_reference && r_reference.declaration)
1368                 referenced(r_reference, swizzle);
1369 }
1370
1371 void UnusedVariableRemover::visit(UnaryExpression &unary)
1372 {
1373         TraversingVisitor::visit(unary);
1374         if(unary.oper->token[1]=='+' || unary.oper->token[1]=='-')
1375                 r_side_effects = true;
1376 }
1377
1378 void UnusedVariableRemover::visit(BinaryExpression &binary)
1379 {
1380         if(binary.oper->token[0]=='[')
1381         {
1382                 visit_composite(*binary.left);
1383
1384                 {
1385                         SetFlag clear_assignment(assignment_target, false);
1386                         SetFlag clear_composite(composite_reference, false);
1387                         binary.right->visit(*this);
1388                 }
1389
1390                 add_to_chain(r_reference, Assignment::Target::ARRAY, 0x3F);
1391
1392                 if(!composite_reference && r_reference.declaration)
1393                         referenced(r_reference, binary);
1394         }
1395         else
1396         {
1397                 SetFlag clear_composite(composite_reference, false);
1398                 TraversingVisitor::visit(binary);
1399         }
1400 }
1401
1402 void UnusedVariableRemover::visit(TernaryExpression &ternary)
1403 {
1404         SetFlag clear_composite(composite_reference, false);
1405         TraversingVisitor::visit(ternary);
1406 }
1407
1408 void UnusedVariableRemover::visit(Assignment &assign)
1409 {
1410         {
1411                 SetFlag set(assignment_target, (assign.oper->token[0]=='='));
1412                 assign.left->visit(*this);
1413         }
1414         assign.right->visit(*this);
1415         r_assignment = &assign;
1416         r_side_effects = true;
1417 }
1418
1419 void UnusedVariableRemover::visit(FunctionCall &call)
1420 {
1421         SetFlag clear_composite(composite_reference, false);
1422         TraversingVisitor::visit(call);
1423         /* Treat function calls as having side effects so expression statements
1424         consisting of nothing but a function call won't be optimized away. */
1425         r_side_effects = true;
1426
1427         if(stage->type==Stage::GEOMETRY && call.name=="EmitVertex")
1428         {
1429                 for(const auto &kvp: variables)
1430                         if(kvp.second.output)
1431                                 referenced(kvp.first, call);
1432         }
1433 }
1434
1435 void UnusedVariableRemover::record_assignment(const Assignment::Target &target, Node &node)
1436 {
1437         assignments.push_back(AssignmentInfo());
1438         AssignmentInfo &assign_info = assignments.back();
1439         assign_info.node = &node;
1440         assign_info.target = target;
1441         assign_info.in_loop = in_loop;
1442
1443         /* An assignment to the target hides any assignments to the same target or
1444         its subfields. */
1445         VariableInfo &var_info = variables[target.declaration];
1446         for(unsigned i=0; i<var_info.assignments.size(); )
1447         {
1448                 const Assignment::Target &t = var_info.assignments[i]->target;
1449
1450                 bool subfield = (t.chain_len>=target.chain_len);
1451                 for(unsigned j=0; (subfield && j<target.chain_len); ++j)
1452                         subfield = (t.chain[j]==target.chain[j]);
1453
1454                 if(subfield)
1455                         var_info.assignments.erase(var_info.assignments.begin()+i);
1456                 else
1457                         ++i;
1458         }
1459
1460         var_info.assignments.push_back(&assign_info);
1461 }
1462
1463 void UnusedVariableRemover::visit(ExpressionStatement &expr)
1464 {
1465         r_assignment = 0;
1466         r_side_effects = false;
1467         TraversingVisitor::visit(expr);
1468         if(r_assignment && r_assignment->target.declaration)
1469                 record_assignment(r_assignment->target, expr);
1470         if(!r_side_effects)
1471                 unused_nodes.insert(&expr);
1472 }
1473
1474 void UnusedVariableRemover::visit(StructDeclaration &strct)
1475 {
1476         SetFlag set_struct(in_struct);
1477         TraversingVisitor::visit(strct);
1478 }
1479
1480 void UnusedVariableRemover::visit(VariableDeclaration &var)
1481 {
1482         TraversingVisitor::visit(var);
1483
1484         if(in_struct)
1485                 return;
1486
1487         VariableInfo &var_info = variables[&var];
1488
1489         /* Mark variables as output if they're used by the next stage or the
1490         graphics API. */
1491         var_info.output = (var.interface=="out" && (stage->type==Stage::FRAGMENT || var.linked_declaration || !var.name.compare(0, 3, "gl_")));
1492
1493         // Linked outputs are automatically referenced.
1494         if(var_info.output && var.linked_declaration)
1495                 var_info.referenced = true;
1496
1497         if(var.init_expression)
1498         {
1499                 var_info.initialized = true;
1500                 record_assignment(&var, *var.init_expression);
1501         }
1502 }
1503
1504 void UnusedVariableRemover::visit(InterfaceBlock &iface)
1505 {
1506         VariableInfo &var_info = variables[&iface];
1507         var_info.output = (iface.interface=="out" && (iface.linked_block || !iface.block_name.compare(0, 3, "gl_")));
1508 }
1509
1510 void UnusedVariableRemover::merge_variables(const BlockVariableMap &other_vars)
1511 {
1512         for(const auto &kvp: other_vars)
1513         {
1514                 auto j = variables.find(kvp.first);
1515                 if(j!=variables.end())
1516                 {
1517                         /* The merged blocks started as copies of each other so any common
1518                         assignments must be in the beginning. */
1519                         unsigned k = 0;
1520                         for(; (k<kvp.second.assignments.size() && k<j->second.assignments.size()); ++k)
1521                                 if(kvp.second.assignments[k]!=j->second.assignments[k])
1522                                         break;
1523
1524                         // Remaining assignments are unique to each block; merge them.
1525                         j->second.assignments.insert(j->second.assignments.end(), kvp.second.assignments.begin()+k, kvp.second.assignments.end());
1526                         j->second.referenced |= kvp.second.referenced;
1527                 }
1528                 else
1529                         variables.insert(kvp);
1530         }
1531 }
1532
1533 void UnusedVariableRemover::visit(FunctionDeclaration &func)
1534 {
1535         if(func.body.body.empty())
1536                 return;
1537
1538         BlockVariableMap saved_vars = variables;
1539         // Assignments from other functions should not be visible.
1540         for(auto &kvp: variables)
1541                 kvp.second.assignments.resize(kvp.second.initialized);
1542         TraversingVisitor::visit(func);
1543         swap(variables, saved_vars);
1544         merge_variables(saved_vars);
1545
1546         /* Always treat function parameters as referenced.  Removing unused
1547         parameters is not currently supported. */
1548         for(const RefPtr<VariableDeclaration> &p: func.parameters)
1549         {
1550                 auto j = variables.find(p.get());
1551                 if(j!=variables.end())
1552                         j->second.referenced = true;
1553         }
1554 }
1555
1556 void UnusedVariableRemover::visit(Conditional &cond)
1557 {
1558         cond.condition->visit(*this);
1559         BlockVariableMap saved_vars = variables;
1560         cond.body.visit(*this);
1561         swap(saved_vars, variables);
1562         cond.else_body.visit(*this);
1563
1564         /* Visible assignments after the conditional is the union of those visible
1565         at the end of the if and else blocks.  If there was no else block, then it's
1566         the union of the if block and the state before it. */
1567         merge_variables(saved_vars);
1568 }
1569
1570 void UnusedVariableRemover::visit(Iteration &iter)
1571 {
1572         BlockVariableMap saved_vars = variables;
1573         vector<Node *> saved_refs;
1574         swap(loop_ext_refs, saved_refs);
1575         {
1576                 SetForScope<unsigned> set_loop(in_loop, in_loop+1);
1577                 TraversingVisitor::visit(iter);
1578         }
1579         swap(loop_ext_refs, saved_refs);
1580
1581         /* Visit the external references of the loop again to record assignments
1582         done in the loop as used. */
1583         for(Node *n: saved_refs)
1584                 n->visit(*this);
1585
1586         /* Merge assignments from the iteration, without clearing previous state.
1587         Further analysis is needed to determine which parts of the iteration body
1588         are always executed, if any. */
1589         merge_variables(saved_vars);
1590 }
1591
1592
1593 bool UnusedFunctionRemover::apply(Stage &stage)
1594 {
1595         stage.content.visit(*this);
1596         NodeRemover().apply(stage, unused_nodes);
1597         return !unused_nodes.empty();
1598 }
1599
1600 void UnusedFunctionRemover::visit(FunctionCall &call)
1601 {
1602         TraversingVisitor::visit(call);
1603
1604         unused_nodes.erase(call.declaration);
1605         if(call.declaration && call.declaration->definition!=call.declaration)
1606                 used_definitions.insert(call.declaration->definition);
1607 }
1608
1609 void UnusedFunctionRemover::visit(FunctionDeclaration &func)
1610 {
1611         TraversingVisitor::visit(func);
1612
1613         if((func.name!="main" || func.body.body.empty()) && !used_definitions.count(&func))
1614                 unused_nodes.insert(&func);
1615 }
1616
1617 } // namespace SL
1618 } // namespace GL
1619 } // namespace Msp