]> git.tdb.fi Git - libs/gl.git/blob - source/programcompiler.cpp
Reorder declarations in shaders
[libs/gl.git] / source / programcompiler.cpp
1 #include <msp/core/raii.h>
2 #include <msp/strings/format.h>
3 #include <msp/strings/utils.h>
4 #include "error.h"
5 #include "program.h"
6 #include "programcompiler.h"
7 #include "resources.h"
8 #include "shader.h"
9
10 using namespace std;
11
12 namespace {
13
14 const char builtins_src[] =
15         "////// vertex\n"
16         "out gl_PerVertex {\n"
17         "  vec4 gl_Position;\n"
18         "  float gl_ClipDistance[];\n"
19         "};"
20         "////// geometry\n"
21         "in gl_PerVertex {\n"
22         "  vec4 gl_Position;\n"
23         "  float gl_ClipDistance[];\n"
24         "} gl_in[];\n"
25         "out gl_PerVertex {\n"
26         "  vec4 gl_Position;\n"
27         "  float gl_ClipDistance[];\n"
28         "};\n";
29
30 }
31
32 namespace Msp {
33 namespace GL {
34
35 using namespace ProgramSyntax;
36
37 ProgramCompiler::ProgramCompiler():
38         resources(0),
39         module(0)
40 { }
41
42 void ProgramCompiler::compile(const string &source)
43 {
44         resources = 0;
45         module = &parser.parse(source);
46         process();
47 }
48
49 void ProgramCompiler::compile(IO::Base &io, Resources *res)
50 {
51         resources = res;
52         module = &parser.parse(io);
53         process();
54 }
55
56 void ProgramCompiler::add_shaders(Program &program)
57 {
58         if(!module)
59                 throw invalid_operation("ProgramCompiler::add_shaders");
60
61         for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
62         {
63                 if(i->type==VERTEX)
64                 {
65                         program.attach_shader_owned(new VertexShader(apply<Formatter>(*i)));
66                         for(map<string, unsigned>::iterator j=i->locations.begin(); j!=i->locations.end(); ++j)
67                                 program.bind_attribute(j->second, j->first);
68                 }
69                 else if(i->type==GEOMETRY)
70                         program.attach_shader_owned(new GeometryShader(apply<Formatter>(*i)));
71                 else if(i->type==FRAGMENT)
72                 {
73                         program.attach_shader_owned(new FragmentShader(apply<Formatter>(*i)));
74                         for(map<string, unsigned>::iterator j=i->locations.begin(); j!=i->locations.end(); ++j)
75                                 program.bind_fragment_data(j->second, j->first);
76                 }
77         }
78 }
79
80 Module *ProgramCompiler::create_builtins_module()
81 {
82         ProgramParser parser;
83         Module *module = new Module(parser.parse(builtins_src));
84         for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
85         {
86                 VariableResolver resolver;
87                 i->content.visit(resolver);
88                 for(map<string, VariableDeclaration *>::iterator j=i->content.variables.begin(); j!=i->content.variables.end(); ++j)
89                         j->second->linked_declaration = j->second;
90         }
91         return module;
92 }
93
94 Module &ProgramCompiler::get_builtins_module()
95 {
96         static RefPtr<Module> builtins_module = create_builtins_module();
97         return *builtins_module;
98 }
99
100 Stage *ProgramCompiler::get_builtins(StageType type)
101 {
102         Module &module = get_builtins_module();
103         for(list<Stage>::iterator i=module.stages.begin(); i!=module.stages.end(); ++i)
104                 if(i->type==type)
105                         return &*i;
106         return 0;
107 }
108
109 void ProgramCompiler::process()
110 {
111         list<Import *> imports = apply<NodeGatherer<Import> >(module->shared);
112         for(list<Import *>::iterator i=imports.end(); i!=imports.begin(); )
113                 import((*--i)->module);
114         apply<NodeRemover>(module->shared, set<Node *>(imports.begin(), imports.end()));
115
116         for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
117                 generate(*i);
118         for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); )
119         {
120                 if(optimize(*i))
121                         i = module->stages.begin();
122                 else
123                         ++i;
124         }
125 }
126
127 void ProgramCompiler::import(const string &name)
128 {
129         string fn = name+".glsl";
130         RefPtr<IO::Seekable> io = (resources ? resources->open_raw(fn) : Resources::get_builtins().open(fn));
131         if(!io)
132                 throw runtime_error(format("module %s not found", name));
133         ProgramParser import_parser;
134         Module &imported_module = import_parser.parse(*io);
135
136         inject_block(module->shared.content, imported_module.shared.content);
137         apply<DeclarationCombiner>(module->shared);
138         for(list<Stage>::iterator i=imported_module.stages.begin(); i!=imported_module.stages.end(); ++i)
139         {
140                 list<Stage>::iterator j;
141                 for(j=module->stages.begin(); (j!=module->stages.end() && j->type<i->type); ++j) ;
142                 if(j==module->stages.end() || j->type>i->type)
143                 {
144                         j = module->stages.insert(j, *i);
145                         list<Stage>::iterator k = j;
146                         if(++k!=module->stages.end())
147                                 k->previous = &*j;
148                         if(j!=module->stages.begin())
149                                 j->previous = &*--(k=j);
150                 }
151                 else
152                 {
153                         inject_block(j->content, i->content);
154                         apply<DeclarationCombiner>(*j);
155                 }
156         }
157 }
158
159 void ProgramCompiler::generate(Stage &stage)
160 {
161         inject_block(stage.content, module->shared.content);
162
163         apply<FunctionResolver>(stage);
164         apply<VariableResolver>(stage);
165         apply<InterfaceGenerator>(stage);
166         apply<VariableResolver>(stage);
167         apply<VariableRenamer>(stage);
168         apply<DeclarationReorderer>(stage);
169         apply<LegacyConverter>(stage);
170 }
171
172 bool ProgramCompiler::optimize(Stage &stage)
173 {
174         apply<ConstantConditionEliminator>(stage);
175
176         set<Node *> unused = apply<UnusedVariableLocator>(stage);
177         set<Node *> unused2 = apply<UnusedFunctionLocator>(stage);
178         unused.insert(unused2.begin(), unused2.end());
179         apply<NodeRemover>(stage, unused);
180
181         return !unused.empty();
182 }
183
184 void ProgramCompiler::inject_block(Block &target, const Block &source)
185 {
186         list<RefPtr<Node> >::iterator insert_point = target.body.begin();
187         for(list<RefPtr<Node> >::const_iterator i=source.body.begin(); i!=source.body.end(); ++i)
188                 target.body.insert(insert_point, (*i)->clone());
189 }
190
191 template<typename T>
192 typename T::ResultType ProgramCompiler::apply(Stage &stage)
193 {
194         T visitor;
195         visitor.apply(stage);
196         return visitor.get_result();
197 }
198
199 template<typename T, typename A>
200 typename T::ResultType ProgramCompiler::apply(Stage &stage, const A &arg)
201 {
202         T visitor(arg);
203         visitor.apply(stage);
204         return visitor.get_result();
205 }
206
207
208 ProgramCompiler::Visitor::Visitor():
209         stage(0)
210 { }
211
212 void ProgramCompiler::Visitor::apply(Stage &s)
213 {
214         SetForScope<Stage *> set(stage, &s);
215         stage->content.visit(*this);
216 }
217
218
219 ProgramCompiler::Formatter::Formatter():
220         indent(0),
221         parameter_list(false),
222         else_if(0)
223 { }
224
225 void ProgramCompiler::Formatter::apply(ProgramSyntax::Stage &s)
226 {
227         const Version &ver = s.required_version;
228         if(ver.major)
229                 formatted += format("#version %d%d\n", ver.major, ver.minor);
230         Visitor::apply(s);
231 }
232
233 void ProgramCompiler::Formatter::visit(Literal &literal)
234 {
235         formatted += literal.token;
236 }
237
238 void ProgramCompiler::Formatter::visit(ParenthesizedExpression &parexpr)
239 {
240         formatted += '(';
241         parexpr.expression->visit(*this);
242         formatted += ')';
243 }
244
245 void ProgramCompiler::Formatter::visit(VariableReference &var)
246 {
247         formatted += var.name;
248 }
249
250 void ProgramCompiler::Formatter::visit(MemberAccess &memacc)
251 {
252         memacc.left->visit(*this);
253         formatted += format(".%s", memacc.member);
254 }
255
256 void ProgramCompiler::Formatter::visit(UnaryExpression &unary)
257 {
258         if(unary.prefix)
259                 formatted += unary.oper;
260         unary.expression->visit(*this);
261         if(!unary.prefix)
262                 formatted += unary.oper;
263 }
264
265 void ProgramCompiler::Formatter::visit(BinaryExpression &binary)
266 {
267         binary.left->visit(*this);
268         formatted += binary.oper;
269         binary.right->visit(*this);
270         formatted += binary.after;
271 }
272
273 void ProgramCompiler::Formatter::visit(Assignment &assign)
274 {
275         assign.left->visit(*this);
276         formatted += format(" %s ", assign.oper);
277         assign.right->visit(*this);
278 }
279
280 void ProgramCompiler::Formatter::visit(FunctionCall &call)
281 {
282         formatted += format("%s(", call.name);
283         for(vector<RefPtr<Expression> >::iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
284         {
285                 if(i!=call.arguments.begin())
286                         formatted += ", ";
287                 (*i)->visit(*this);
288         }
289         formatted += ')';
290 }
291
292 void ProgramCompiler::Formatter::visit(ExpressionStatement &expr)
293 {
294         expr.expression->visit(*this);
295         formatted += ';';
296 }
297
298 void ProgramCompiler::Formatter::visit(Block &block)
299 {
300         if(else_if)
301                 --else_if;
302
303         unsigned brace_indent = indent;
304         bool use_braces = (block.use_braces || (indent && block.body.size()!=1));
305         if(use_braces)
306                 formatted += format("%s{\n", string(brace_indent*2, ' '));
307
308         SetForScope<unsigned> set(indent, indent+(indent>0 || use_braces));
309         string spaces(indent*2, ' ');
310         for(list<RefPtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); ++i)
311         {
312                 if(i!=block.body.begin())
313                         formatted += '\n';
314                 formatted += spaces;
315                 (*i)->visit(*this);
316                 else_if = 0;
317         }
318
319         if(use_braces)
320                 formatted += format("\n%s}", string(brace_indent*2, ' '));
321 }
322
323 void ProgramCompiler::Formatter::visit(Import &import)
324 {
325         formatted += format("import %s;", import.module);
326 }
327
328 void ProgramCompiler::Formatter::visit(Layout &layout)
329 {
330         formatted += "layout(";
331         for(vector<Layout::Qualifier>::const_iterator i=layout.qualifiers.begin(); i!=layout.qualifiers.end(); ++i)
332         {
333                 if(i!=layout.qualifiers.begin())
334                         formatted += ", ";
335                 formatted += i->identifier;
336                 if(!i->value.empty())
337                         formatted += format("=%s", i->value);
338         }
339         formatted += ')';
340 }
341
342 void ProgramCompiler::Formatter::visit(InterfaceLayout &layout)
343 {
344         layout.layout.visit(*this);
345         formatted += format(" %s;", layout.interface);
346 }
347
348 void ProgramCompiler::Formatter::visit(StructDeclaration &strct)
349 {
350         formatted += format("struct %s\n", strct.name);
351         strct.members.visit(*this);
352         formatted += ';';
353 }
354
355 void ProgramCompiler::Formatter::visit(VariableDeclaration &var)
356 {
357         if(var.layout)
358         {
359                 var.layout->visit(*this);
360                 formatted += ' ';
361         }
362         if(var.constant)
363                 formatted += "const ";
364         if(!var.sampling.empty())
365                 formatted += format("%s ", var.sampling);
366         if(!var.interface.empty() && var.interface!=block_interface)
367                 formatted += format("%s ", var.interface);
368         formatted += format("%s %s", var.type, var.name);
369         if(var.array)
370         {
371                 formatted += '[';
372                 if(var.array_size)
373                         var.array_size->visit(*this);
374                 formatted += ']';
375         }
376         if(var.init_expression)
377         {
378                 formatted += " = ";
379                 var.init_expression->visit(*this);
380         }
381         if(!parameter_list)
382                 formatted += ';';
383 }
384
385 void ProgramCompiler::Formatter::visit(InterfaceBlock &iface)
386 {
387         SetForScope<string> set(block_interface, iface.interface);
388         formatted += format("%s %s\n", iface.interface, iface.name);
389         iface.members.visit(*this);
390         formatted += ';';
391 }
392
393 void ProgramCompiler::Formatter::visit(FunctionDeclaration &func)
394 {
395         formatted += format("%s %s(", func.return_type, func.name);
396         for(vector<RefPtr<VariableDeclaration> >::iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
397         {
398                 if(i!=func.parameters.begin())
399                         formatted += ", ";
400                 SetFlag set(parameter_list);
401                 (*i)->visit(*this);
402         }
403         formatted += ')';
404         if(func.definition==&func)
405         {
406                 formatted += '\n';
407                 func.body.visit(*this);
408         }
409         else
410                 formatted += ';';
411 }
412
413 void ProgramCompiler::Formatter::visit(Conditional &cond)
414 {
415         if(else_if)
416                 formatted.replace(formatted.rfind('\n'), string::npos, 1, ' ');
417
418         indent -= else_if;
419
420         formatted += "if(";
421         cond.condition->visit(*this);
422         formatted += ")\n";
423
424         cond.body.visit(*this);
425         if(!cond.else_body.body.empty())
426         {
427                 formatted += format("\n%selse\n", string(indent*2, ' '));
428                 SetForScope<unsigned> set(else_if, 2);
429                 cond.else_body.visit(*this);
430         }
431 }
432
433 void ProgramCompiler::Formatter::visit(Iteration &iter)
434 {
435         formatted += "for(";
436         iter.init_statement->visit(*this);
437         formatted += ' ';
438         iter.condition->visit(*this);
439         formatted += "; ";
440         iter.loop_expression->visit(*this);
441         formatted += ")\n";
442         iter.body.visit(*this);
443 }
444
445 void ProgramCompiler::Formatter::visit(Return &ret)
446 {
447         formatted += "return ";
448         ret.expression->visit(*this);
449         formatted += ';';
450 }
451
452
453 ProgramCompiler::DeclarationCombiner::DeclarationCombiner():
454         toplevel(true),
455         remove_node(false)
456 { }
457
458 void ProgramCompiler::DeclarationCombiner::visit(Block &block)
459 {
460         if(!toplevel)
461                 return;
462
463         SetForScope<bool> set(toplevel, false);
464         for(list<RefPtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
465         {
466                 remove_node = false;
467                 (*i)->visit(*this);
468                 if(remove_node)
469                         block.body.erase(i++);
470                 else
471                         ++i;
472         }
473 }
474
475 void ProgramCompiler::DeclarationCombiner::visit(FunctionDeclaration &func)
476 {
477         vector<FunctionDeclaration *> &decls = functions[func.name];
478         if(func.definition)
479         {
480                 for(vector<FunctionDeclaration *>::iterator i=decls.begin(); i!=decls.end(); ++i)
481                 {
482                         (*i)->definition = func.definition;
483                         (*i)->body.body.clear();
484                 }
485         }
486         decls.push_back(&func);
487 }
488
489 void ProgramCompiler::DeclarationCombiner::visit(VariableDeclaration &var)
490 {
491         VariableDeclaration *&ptr = variables[var.name];
492         if(ptr)
493         {
494                 ptr->type = var.type;
495                 if(var.init_expression)
496                         ptr->init_expression = var.init_expression;
497                 remove_node = true;
498         }
499         else
500                 ptr = &var;
501 }
502
503
504 ProgramCompiler::VariableResolver::VariableResolver():
505         anonymous(false),
506         record_target(false),
507         assignment_target(0),
508         self_referencing(false)
509 { }
510
511 void ProgramCompiler::VariableResolver::apply(Stage &s)
512 {
513         SetForScope<Stage *> set(stage, &s);
514         Stage *builtins = get_builtins(stage->type);
515         if(builtins)
516                 blocks.push_back(&builtins->content);
517         stage->content.visit(*this);
518         if(builtins)
519                 blocks.pop_back();
520 }
521
522 void ProgramCompiler::VariableResolver::visit(Block &block)
523 {
524         blocks.push_back(&block);
525         block.variables.clear();
526         TraversingVisitor::visit(block);
527         blocks.pop_back();
528 }
529
530 void ProgramCompiler::VariableResolver::visit(VariableReference &var)
531 {
532         var.declaration = 0;
533         type = 0;
534         for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
535         {
536                 --i;
537                 map<string, VariableDeclaration *>::iterator j = (*i)->variables.find(var.name);
538                 if(j!=(*i)->variables.end())
539                 {
540                         var.declaration = j->second;
541                         type = j->second->type_declaration;
542                         break;
543                 }
544         }
545
546         if(record_target)
547         {
548                 if(assignment_target)
549                 {
550                         record_target = false;
551                         assignment_target = 0;
552                 }
553                 else
554                         assignment_target = var.declaration;
555         }
556         else if(var.declaration && var.declaration==assignment_target)
557                 self_referencing = true;
558 }
559
560 void ProgramCompiler::VariableResolver::visit(MemberAccess &memacc)
561 {
562         type = 0;
563         TraversingVisitor::visit(memacc);
564         memacc.declaration = 0;
565         if(type)
566         {
567                 map<string, VariableDeclaration *>::iterator i = type->members.variables.find(memacc.member);
568                 if(i!=type->members.variables.end())
569                 {
570                         memacc.declaration = i->second;
571                         type = i->second->type_declaration;
572                 }
573                 else
574                         type = 0;
575         }
576 }
577
578 void ProgramCompiler::VariableResolver::visit(BinaryExpression &binary)
579 {
580         if(binary.oper=="[")
581         {
582                 {
583                         SetForScope<bool> set(record_target, false);
584                         binary.right->visit(*this);
585                 }
586                 type = 0;
587                 binary.left->visit(*this);
588         }
589         else
590         {
591                 TraversingVisitor::visit(binary);
592                 type = 0;
593         }
594 }
595
596 void ProgramCompiler::VariableResolver::visit(Assignment &assign)
597 {
598         {
599                 SetFlag set(record_target);
600                 assignment_target = 0;
601                 assign.left->visit(*this);
602         }
603
604         self_referencing = false;
605         assign.right->visit(*this);
606
607         assign.self_referencing = (self_referencing || assign.oper!="=");
608         assign.target_declaration = assignment_target;
609 }
610
611 void ProgramCompiler::VariableResolver::visit(StructDeclaration &strct)
612 {
613         TraversingVisitor::visit(strct);
614         blocks.back()->types[strct.name] = &strct;
615 }
616
617 void ProgramCompiler::VariableResolver::visit(VariableDeclaration &var)
618 {
619         for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
620         {
621                 --i;
622                 map<string, StructDeclaration *>::iterator j = (*i)->types.find(var.type);
623                 if(j!=(*i)->types.end())
624                         var.type_declaration = j->second;
625         }
626
627         if(!block_interface.empty() && var.interface.empty())
628                 var.interface = block_interface;
629
630         TraversingVisitor::visit(var);
631         blocks.back()->variables[var.name] = &var;
632         if(anonymous && blocks.size()>1)
633                 blocks[blocks.size()-2]->variables[var.name] = &var;
634 }
635
636 void ProgramCompiler::VariableResolver::visit(InterfaceBlock &iface)
637 {
638         SetFlag set(anonymous);
639         SetForScope<string> set2(block_interface, iface.interface);
640         TraversingVisitor::visit(iface);
641 }
642
643
644 void ProgramCompiler::FunctionResolver::visit(FunctionCall &call)
645 {
646         map<string, vector<FunctionDeclaration *> >::iterator i = functions.find(call.name);
647         if(i!=functions.end())
648                 call.declaration = i->second.back();
649
650         TraversingVisitor::visit(call);
651 }
652
653 void ProgramCompiler::FunctionResolver::visit(FunctionDeclaration &func)
654 {
655         vector<FunctionDeclaration *> &decls = functions[func.name];
656         if(func.definition)
657         {
658                 for(vector<FunctionDeclaration *>::iterator i=decls.begin(); i!=decls.end(); ++i)
659                         (*i)->definition = func.definition;
660                 decls.clear();
661                 decls.push_back(&func);
662         }
663         else if(!decls.empty() && decls.back()->definition)
664                 func.definition = decls.back()->definition;
665         else
666                 decls.push_back(&func);
667
668         TraversingVisitor::visit(func);
669 }
670
671
672 ProgramCompiler::BlockModifier::BlockModifier():
673         remove_node(false)
674 { }
675
676 void ProgramCompiler::BlockModifier::flatten_block(Block &block)
677 {
678         insert_nodes.insert(insert_nodes.end(), block.body.begin(), block.body.end());
679         remove_node = true;
680 }
681
682 void ProgramCompiler::BlockModifier::apply_and_increment(Block &block, list<RefPtr<Node> >::iterator &i)
683 {
684         block.body.insert(i, insert_nodes.begin(), insert_nodes.end());
685         insert_nodes.clear();
686
687         if(remove_node)
688                 block.body.erase(i++);
689         else
690                 ++i;
691         remove_node = false;
692 }
693
694 void ProgramCompiler::BlockModifier::visit(Block &block)
695 {
696         for(list<RefPtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
697         {
698                 (*i)->visit(*this);
699                 apply_and_increment(block, i);
700         }
701 }
702
703
704 ProgramCompiler::InterfaceGenerator::InterfaceGenerator():
705         scope_level(0)
706 { }
707
708 string ProgramCompiler::InterfaceGenerator::get_out_prefix(StageType type)
709 {
710         if(type==VERTEX)
711                 return "_vs_out_";
712         else if(type==GEOMETRY)
713                 return "_gs_out_";
714         else
715                 return string();
716 }
717
718 void ProgramCompiler::InterfaceGenerator::apply(Stage &s)
719 {
720         SetForScope<Stage *> set(stage, &s);
721         if(stage->previous)
722                 in_prefix = get_out_prefix(stage->previous->type);
723         out_prefix = get_out_prefix(stage->type);
724         stage->content.visit(*this);
725 }
726
727 void ProgramCompiler::InterfaceGenerator::visit(Block &block)
728 {
729         SetForScope<unsigned> set(scope_level, scope_level+1);
730         for(list<RefPtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
731         {
732                 (*i)->visit(*this);
733
734                 if(scope_level==1)
735                 {
736                         for(map<string, RefPtr<VariableDeclaration> >::iterator j=iface_declarations.begin(); j!=iface_declarations.end(); ++j)
737                         {
738                                 list<RefPtr<Node> >::iterator k = block.body.insert(i, j->second);
739                                 (*k)->visit(*this);
740                         }
741                         iface_declarations.clear();
742                 }
743
744                 apply_and_increment(block, i);
745         }
746 }
747
748 string ProgramCompiler::InterfaceGenerator::change_prefix(const string &name, const string &prefix) const
749 {
750         unsigned offset = (name.compare(0, in_prefix.size(), in_prefix) ? 0 : in_prefix.size());
751         return prefix+name.substr(offset);
752 }
753
754 bool ProgramCompiler::InterfaceGenerator::generate_interface(VariableDeclaration &var, const string &iface, const string &name)
755 {
756         const map<string, VariableDeclaration *> &stage_vars = (iface=="in" ? stage->in_variables : stage->out_variables);
757         if(stage_vars.count(name) || iface_declarations.count(name))
758                 return false;
759
760         VariableDeclaration* iface_var = new VariableDeclaration;
761         iface_var->sampling = var.sampling;
762         iface_var->interface = iface;
763         iface_var->type = var.type;
764         iface_var->type_declaration = var.type_declaration;
765         iface_var->name = name;
766         if(stage->type==GEOMETRY)
767                 iface_var->array = ((var.array && var.interface!="in") || iface=="in");
768         else
769                 iface_var->array = var.array;
770         if(iface_var->array)
771                 iface_var->array_size = var.array_size;
772         if(iface=="in")
773                 iface_var->linked_declaration = &var;
774         iface_declarations[name] = iface_var;
775
776         return true;
777 }
778
779 void ProgramCompiler::InterfaceGenerator::insert_assignment(const string &left, ProgramSyntax::Expression *right)
780 {
781         Assignment *assign = new Assignment;
782         VariableReference *ref = new VariableReference;
783         ref->name = left;
784         assign->left = ref;
785         assign->oper = "=";
786         assign->right = right;
787
788         ExpressionStatement *stmt = new ExpressionStatement;
789         stmt->expression = assign;
790         insert_nodes.push_back(stmt);
791 }
792
793 void ProgramCompiler::InterfaceGenerator::visit(VariableReference &var)
794 {
795         if(var.declaration || !stage->previous)
796                 return;
797         if(iface_declarations.count(var.name))
798                 return;
799
800         const map<string, VariableDeclaration *> &prev_out = stage->previous->out_variables;
801         map<string, VariableDeclaration *>::const_iterator i = prev_out.find(var.name);
802         if(i==prev_out.end())
803                 i = prev_out.find(in_prefix+var.name);
804         if(i!=prev_out.end())
805                 generate_interface(*i->second, "in", var.name);
806 }
807
808 void ProgramCompiler::InterfaceGenerator::visit(VariableDeclaration &var)
809 {
810         if(var.interface=="out")
811         {
812                 if(scope_level==1)
813                         stage->out_variables[var.name] = &var;
814                 else if(generate_interface(var, "out", change_prefix(var.name, string())))
815                 {
816                         remove_node = true;
817                         if(var.init_expression)
818                                 insert_assignment(var.name, var.init_expression->clone());
819                 }
820         }
821         else if(var.interface=="in")
822         {
823                 stage->in_variables[var.name] = &var;
824                 if(var.linked_declaration)
825                         var.linked_declaration->linked_declaration = &var;
826                 else if(stage->previous)
827                 {
828                         const map<string, VariableDeclaration *> &prev_out = stage->previous->out_variables;
829                         map<string, VariableDeclaration *>::const_iterator i = prev_out.find(var.name);
830                         if(i!=prev_out.end())
831                         {
832                                 var.linked_declaration = i->second;
833                                 i->second->linked_declaration = &var;
834                         }
835                 }
836         }
837
838         TraversingVisitor::visit(var);
839 }
840
841 void ProgramCompiler::InterfaceGenerator::visit(Passthrough &pass)
842 {
843         vector<VariableDeclaration *> pass_vars;
844
845         for(map<string, VariableDeclaration *>::const_iterator i=stage->in_variables.begin(); i!=stage->in_variables.end(); ++i)
846                 pass_vars.push_back(i->second);
847         for(map<string, RefPtr<VariableDeclaration> >::const_iterator i=iface_declarations.begin(); i!=iface_declarations.end(); ++i)
848                 if(i->second->interface=="in")
849                         pass_vars.push_back(i->second.get());
850
851         if(stage->previous)
852         {
853                 const map<string, VariableDeclaration *> &prev_out = stage->previous->out_variables;
854                 for(map<string, VariableDeclaration *>::const_iterator i=prev_out.begin(); i!=prev_out.end(); ++i)
855                 {
856                         bool linked = false;
857                         for(vector<VariableDeclaration *>::const_iterator j=pass_vars.begin(); (!linked && j!=pass_vars.end()); ++j)
858                                 linked = ((*j)->linked_declaration==i->second);
859
860                         if(!linked && generate_interface(*i->second, "in", i->second->name))
861                                 pass_vars.push_back(i->second);
862                 }
863         }
864
865         if(stage->type==GEOMETRY)
866         {
867                 VariableReference *ref = new VariableReference;
868                 ref->name = "gl_in";
869
870                 BinaryExpression *subscript = new BinaryExpression;
871                 subscript->left = ref;
872                 subscript->oper = "[";
873                 subscript->right = pass.subscript;
874                 subscript->after = "]";
875
876                 MemberAccess *memacc = new MemberAccess;
877                 memacc->left = subscript;
878                 memacc->member = "gl_Position";
879
880                 insert_assignment("gl_Position", memacc);
881         }
882
883         for(vector<VariableDeclaration *>::const_iterator i=pass_vars.begin(); i!=pass_vars.end(); ++i)
884         {
885                 string out_name = change_prefix((*i)->name, out_prefix);
886                 generate_interface(**i, "out", out_name);
887
888                 VariableReference *ref = new VariableReference;
889                 ref->name = (*i)->name;
890                 if(pass.subscript)
891                 {
892                         BinaryExpression *subscript = new BinaryExpression;
893                         subscript->left = ref;
894                         subscript->oper = "[";
895                         subscript->right = pass.subscript;
896                         subscript->after = "]";
897                         insert_assignment(out_name, subscript);
898                 }
899                 else
900                         insert_assignment(out_name, ref);
901         }
902
903         remove_node = true;
904 }
905
906
907 void ProgramCompiler::VariableRenamer::visit(VariableReference &var)
908 {
909         if(var.declaration)
910                 var.name = var.declaration->name;
911 }
912
913 void ProgramCompiler::VariableRenamer::visit(VariableDeclaration &var)
914 {
915         if(var.linked_declaration)
916                 var.name = var.linked_declaration->name;
917         TraversingVisitor::visit(var);
918 }
919
920
921 ProgramCompiler::DeclarationReorderer::DeclarationReorderer():
922         kind(NO_DECLARATION)
923 { }
924
925 void ProgramCompiler::DeclarationReorderer::visit(Block &block)
926 {
927         list<RefPtr<Node> >::iterator struct_insert_point = block.body.end();
928         list<RefPtr<Node> >::iterator variable_insert_point = block.body.end();
929
930         for(list<RefPtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
931         {
932                 kind = NO_DECLARATION;
933                 (*i)->visit(*this);
934
935                 bool moved = false;
936                 if(kind==STRUCT && struct_insert_point!=block.body.end())
937                 {
938                         block.body.insert(struct_insert_point, *i);
939                         moved = true;
940                 }
941                 else if(kind>STRUCT && struct_insert_point==block.body.end())
942                         struct_insert_point = i;
943
944                 if(kind==VARIABLE && variable_insert_point!=block.body.end())
945                 {
946                         block.body.insert(variable_insert_point, *i);
947                         moved = true;
948                 }
949                 else if(kind>VARIABLE && variable_insert_point==block.body.end())
950                         variable_insert_point = i;
951
952                 if(moved)
953                         block.body.erase(i++);
954                 else
955                         ++i;
956         }
957 }
958
959
960 ProgramCompiler::ExpressionEvaluator::ExpressionEvaluator():
961         variable_values(0),
962         result(0.0f),
963         result_valid(false)
964 { }
965
966 ProgramCompiler::ExpressionEvaluator::ExpressionEvaluator(const ValueMap &v):
967         variable_values(&v),
968         result(0.0f),
969         result_valid(false)
970 { }
971
972 void ProgramCompiler::ExpressionEvaluator::visit(Literal &literal)
973 {
974         if(literal.token=="true")
975                 result = 1.0f;
976         else if(literal.token=="false")
977                 result = 0.0f;
978         else
979                 result = lexical_cast<float>(literal.token);
980         result_valid = true;
981 }
982
983 void ProgramCompiler::ExpressionEvaluator::visit(ParenthesizedExpression &parexp)
984 {
985         parexp.expression->visit(*this);
986 }
987
988 void ProgramCompiler::ExpressionEvaluator::visit(VariableReference &var)
989 {
990         if(!var.declaration)
991                 return;
992
993         if(variable_values)
994         {
995                 ValueMap::const_iterator i = variable_values->find(var.declaration);
996                 if(i!=variable_values->end())
997                         i->second->visit(*this);
998         }
999         else if(var.declaration->init_expression)
1000                 var.declaration->init_expression->visit(*this);
1001 }
1002
1003 void ProgramCompiler::ExpressionEvaluator::visit(UnaryExpression &unary)
1004 {
1005         result_valid = false;
1006         unary.expression->visit(*this);
1007         if(!result_valid)
1008                 return;
1009
1010         if(unary.oper=="!")
1011                 result = !result;
1012         else
1013                 result_valid = false;
1014 }
1015
1016 void ProgramCompiler::ExpressionEvaluator::visit(BinaryExpression &binary)
1017 {
1018         result_valid = false;
1019         binary.left->visit(*this);
1020         if(!result_valid)
1021                 return;
1022
1023         float left_result = result;
1024         result_valid = false;
1025         binary.right->visit(*this);
1026         if(!result_valid)
1027                 return;
1028
1029         if(binary.oper=="<")
1030                 result = (left_result<result);
1031         else if(binary.oper=="<=")
1032                 result = (left_result<=result);
1033         else if(binary.oper==">")
1034                 result = (left_result>result);
1035         else if(binary.oper==">=")
1036                 result = (left_result>=result);
1037         else if(binary.oper=="==")
1038                 result = (left_result==result);
1039         else if(binary.oper=="!=")
1040                 result = (left_result!=result);
1041         else if(binary.oper=="&&")
1042                 result = (left_result && result);
1043         else if(binary.oper=="||")
1044                 result = (left_result || result);
1045         else
1046                 result_valid = false;
1047 }
1048
1049
1050 ProgramCompiler::ConstantConditionEliminator::ConstantConditionEliminator():
1051         scope_level(0)
1052 { }
1053
1054 void ProgramCompiler::ConstantConditionEliminator::visit(Block &block)
1055 {
1056         SetForScope<unsigned> set(scope_level, scope_level+1);
1057         BlockModifier::visit(block);
1058
1059         for(map<string, VariableDeclaration *>::const_iterator i=block.variables.begin(); i!=block.variables.end(); ++i)
1060                 variable_values.erase(i->second);
1061 }
1062
1063 void ProgramCompiler::ConstantConditionEliminator::visit(Assignment &assign)
1064 {
1065         variable_values.erase(assign.target_declaration);
1066 }
1067
1068 void ProgramCompiler::ConstantConditionEliminator::visit(VariableDeclaration &var)
1069 {
1070         if(var.constant || scope_level>1)
1071                 variable_values[&var] = var.init_expression.get();
1072 }
1073
1074 void ProgramCompiler::ConstantConditionEliminator::visit(Conditional &cond)
1075 {
1076         ExpressionEvaluator eval(variable_values);
1077         cond.condition->visit(eval);
1078         if(eval.result_valid)
1079                 flatten_block(eval.result ? cond.body : cond.else_body);
1080         else
1081                 TraversingVisitor::visit(cond);
1082 }
1083
1084 void ProgramCompiler::ConstantConditionEliminator::visit(Iteration &iter)
1085 {
1086         if(iter.condition)
1087         {
1088                 ExpressionEvaluator eval;
1089                 iter.condition->visit(eval);
1090                 if(eval.result_valid && !eval.result)
1091                 {
1092                         remove_node = true;
1093                         return;
1094                 }
1095         }
1096
1097         TraversingVisitor::visit(iter);
1098 }
1099
1100
1101 ProgramCompiler::UnusedVariableLocator::UnusedVariableLocator():
1102         aggregate(0),
1103         assignment(0),
1104         assignment_target(false)
1105 { }
1106
1107 void ProgramCompiler::UnusedVariableLocator::apply(Stage &s)
1108 {
1109         assignments.push_back(BlockAssignmentMap());
1110         Visitor::apply(s);
1111         assignments.pop_back();
1112 }
1113
1114 void ProgramCompiler::UnusedVariableLocator::visit(VariableReference &var)
1115 {
1116         unused_nodes.erase(var.declaration);
1117
1118         map<VariableDeclaration *, Node *>::iterator i = aggregates.find(var.declaration);
1119         if(i!=aggregates.end())
1120                 unused_nodes.erase(i->second);
1121
1122         if(assignment_target)
1123                 return;
1124
1125         for(vector<BlockAssignmentMap>::iterator j=assignments.end(); j!=assignments.begin(); )
1126         {
1127                 --j;
1128                 BlockAssignmentMap::iterator k = j->find(var.declaration);
1129                 if(k!=j->end())
1130                 {
1131                         for(vector<Node *>::iterator l=k->second.nodes.begin(); l!=k->second.nodes.end(); ++l)
1132                                 unused_nodes.erase(*l);
1133                         j->erase(k);
1134                         break;
1135                 }
1136         }
1137 }
1138
1139 void ProgramCompiler::UnusedVariableLocator::visit(MemberAccess &memacc)
1140 {
1141         TraversingVisitor::visit(memacc);
1142         unused_nodes.erase(memacc.declaration);
1143 }
1144
1145 void ProgramCompiler::UnusedVariableLocator::visit(BinaryExpression &binary)
1146 {
1147         if(binary.oper=="[")
1148         {
1149                 binary.left->visit(*this);
1150                 SetForScope<bool> set(assignment_target, false);
1151                 binary.right->visit(*this);
1152         }
1153         else
1154                 TraversingVisitor::visit(binary);
1155 }
1156
1157 void ProgramCompiler::UnusedVariableLocator::visit(Assignment &assign)
1158 {
1159         {
1160                 SetForScope<bool> set(assignment_target, !assign.self_referencing);
1161                 assign.left->visit(*this);
1162         }
1163         assign.right->visit(*this);
1164         assignment = &assign;
1165 }
1166
1167 void ProgramCompiler::UnusedVariableLocator::record_assignment(VariableDeclaration &var, Node &node, bool self_ref)
1168 {
1169         unused_nodes.insert(&node);
1170         BlockAssignmentMap &block_assignments = assignments.back();
1171         AssignmentList &var_assignments = block_assignments[&var];
1172         if(!self_ref)
1173                 var_assignments.nodes.clear();
1174         var_assignments.nodes.push_back(&node);
1175         var_assignments.conditional = false;
1176         var_assignments.self_referencing = self_ref;
1177 }
1178
1179 void ProgramCompiler::UnusedVariableLocator::visit(ExpressionStatement &expr)
1180 {
1181         assignment = 0;
1182         TraversingVisitor::visit(expr);
1183         if(assignment && assignment->target_declaration)
1184                 record_assignment(*assignment->target_declaration, expr, assignment->self_referencing);
1185 }
1186
1187 void ProgramCompiler::UnusedVariableLocator::visit(StructDeclaration &strct)
1188 {
1189         SetForScope<Node *> set(aggregate, &strct);
1190         unused_nodes.insert(&strct);
1191         TraversingVisitor::visit(strct);
1192 }
1193
1194 void ProgramCompiler::UnusedVariableLocator::visit(VariableDeclaration &var)
1195 {
1196         if(aggregate)
1197                 aggregates[&var] = aggregate;
1198         else
1199         {
1200                 unused_nodes.insert(&var);
1201                 if(var.init_expression)
1202                         record_assignment(var, *var.init_expression, false);
1203         }
1204         unused_nodes.erase(var.type_declaration);
1205         TraversingVisitor::visit(var);
1206 }
1207
1208 void ProgramCompiler::UnusedVariableLocator::visit(InterfaceBlock &iface)
1209 {
1210         SetForScope<Node *> set(aggregate, &iface);
1211         unused_nodes.insert(&iface);
1212         TraversingVisitor::visit(iface);
1213 }
1214
1215 void ProgramCompiler::UnusedVariableLocator::visit(FunctionDeclaration &func)
1216 {
1217         assignments.push_back(BlockAssignmentMap());
1218
1219         for(vector<RefPtr<VariableDeclaration> >::iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
1220                 (*i)->visit(*this);
1221         func.body.visit(*this);
1222
1223         BlockAssignmentMap &block_assignments = assignments.back();
1224         for(map<string, VariableDeclaration *>::iterator i=func.body.variables.begin(); i!=func.body.variables.end(); ++i)
1225                 block_assignments.erase(i->second);
1226         for(BlockAssignmentMap::iterator i=block_assignments.begin(); i!=block_assignments.end(); ++i)
1227         {
1228                 if(i->first->interface=="out" && stage->type!=FRAGMENT && !i->first->linked_declaration)
1229                         continue;
1230
1231                 for(vector<Node *>::iterator j=i->second.nodes.begin(); j!=i->second.nodes.end(); ++j)
1232                         unused_nodes.erase(*j);
1233         }
1234
1235         assignments.pop_back();
1236 }
1237
1238 void ProgramCompiler::UnusedVariableLocator::merge_down_assignments()
1239 {
1240         BlockAssignmentMap &parent_assignments = assignments[assignments.size()-2];
1241         BlockAssignmentMap &block_assignments = assignments.back();
1242         for(BlockAssignmentMap::iterator i=block_assignments.begin(); i!=block_assignments.end(); ++i)
1243         {
1244                 BlockAssignmentMap::iterator j = parent_assignments.find(i->first);
1245                 if(j==parent_assignments.end())
1246                         parent_assignments.insert(*i);
1247                 else if(i->second.self_referencing || i->second.conditional)
1248                 {
1249                         j->second.nodes.insert(j->second.nodes.end(), i->second.nodes.begin(), i->second.nodes.end());
1250                         j->second.conditional |= i->second.conditional;
1251                         j->second.self_referencing |= i->second.self_referencing;
1252                 }
1253                 else
1254                         j->second = i->second;
1255         }
1256         assignments.pop_back();
1257 }
1258
1259 void ProgramCompiler::UnusedVariableLocator::visit(Conditional &cond)
1260 {
1261         cond.condition->visit(*this);
1262         assignments.push_back(BlockAssignmentMap());
1263         cond.body.visit(*this);
1264
1265         BlockAssignmentMap if_assignments;
1266         swap(assignments.back(), if_assignments);
1267         cond.else_body.visit(*this);
1268
1269         BlockAssignmentMap &else_assignments = assignments.back();
1270         for(BlockAssignmentMap::iterator i=else_assignments.begin(); i!=else_assignments.end(); ++i)
1271         {
1272                 BlockAssignmentMap::iterator j = if_assignments.find(i->first);
1273                 if(j!=if_assignments.end())
1274                 {
1275                         i->second.nodes.insert(i->second.nodes.end(), j->second.nodes.begin(), j->second.nodes.end());
1276                         i->second.conditional |= j->second.conditional;
1277                         i->second.self_referencing |= j->second.self_referencing;
1278                         if_assignments.erase(j);
1279                 }
1280                 else
1281                         i->second.conditional = true;
1282         }
1283
1284         for(BlockAssignmentMap::iterator i=if_assignments.begin(); i!=if_assignments.end(); ++i)
1285         {
1286                 i->second.conditional = true;
1287                 else_assignments.insert(*i);
1288         }
1289
1290         merge_down_assignments();
1291 }
1292
1293 void ProgramCompiler::UnusedVariableLocator::visit(Iteration &iter)
1294 {
1295         assignments.push_back(BlockAssignmentMap());
1296         TraversingVisitor::visit(iter);
1297         merge_down_assignments();
1298 }
1299
1300
1301 void ProgramCompiler::UnusedFunctionLocator::visit(FunctionCall &call)
1302 {
1303         TraversingVisitor::visit(call);
1304
1305         unused_nodes.erase(call.declaration);
1306         if(call.declaration && call.declaration->definition!=call.declaration)
1307                 used_definitions.insert(call.declaration->definition);
1308 }
1309
1310 void ProgramCompiler::UnusedFunctionLocator::visit(FunctionDeclaration &func)
1311 {
1312         TraversingVisitor::visit(func);
1313
1314         if(func.name!="main" && !used_definitions.count(&func))
1315                 unused_nodes.insert(&func);
1316 }
1317
1318
1319 ProgramCompiler::NodeRemover::NodeRemover(const set<Node *> &r):
1320         to_remove(r)
1321 { }
1322
1323 void ProgramCompiler::NodeRemover::visit(Block &block)
1324 {
1325         for(list<RefPtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
1326         {
1327                 (*i)->visit(*this);
1328                 if(to_remove.count(i->get()))
1329                         block.body.erase(i++);
1330                 else
1331                         ++i;
1332         }
1333 }
1334
1335 void ProgramCompiler::NodeRemover::visit(VariableDeclaration &var)
1336 {
1337         if(to_remove.count(&var))
1338         {
1339                 stage->in_variables.erase(var.name);
1340                 stage->out_variables.erase(var.name);
1341                 stage->locations.erase(var.name);
1342                 if(var.linked_declaration)
1343                         var.linked_declaration->linked_declaration = 0;
1344         }
1345         else if(var.init_expression && to_remove.count(var.init_expression.get()))
1346                 var.init_expression = 0;
1347 }
1348
1349
1350 ProgramCompiler::LegacyConverter::LegacyConverter():
1351         target_version(get_glsl_version())
1352 { }
1353
1354 ProgramCompiler::LegacyConverter::LegacyConverter(const Version &v):
1355         target_version(v)
1356 { }
1357
1358 bool ProgramCompiler::LegacyConverter::check_version(const Version &feature_version)
1359 {
1360         if(target_version<feature_version)
1361                 return false;
1362         else if(stage->required_version<feature_version)
1363                 stage->required_version = feature_version;
1364
1365         return true;
1366 }
1367
1368 void ProgramCompiler::LegacyConverter::visit(VariableReference &var)
1369 {
1370         if(var.name==frag_out_name && !check_version(Version(1, 30)))
1371         {
1372                 var.name = "gl_FragColor";
1373                 var.declaration = 0;
1374                 type = "vec4";
1375         }
1376         else if(var.declaration)
1377                 type = var.declaration->type;
1378         else
1379                 type = string();
1380 }
1381
1382 void ProgramCompiler::LegacyConverter::visit(FunctionCall &call)
1383 {
1384         if(call.name=="texture" && !call.declaration && !check_version(Version(1, 30)))
1385         {
1386                 vector<RefPtr<Expression> >::iterator i = call.arguments.begin();
1387                 if(i!=call.arguments.end())
1388                 {
1389                         (*i)->visit(*this);
1390                         if(type=="sampler1D")
1391                                 call.name = "texture1D";
1392                         else if(type=="sampler2D")
1393                                 call.name = "texture2D";
1394                         else if(type=="sampler3D")
1395                                 call.name = "texture3D";
1396                         else if(type=="sampler1DShadow")
1397                                 call.name = "shadow1D";
1398                         else if(type=="sampler2DShadow")
1399                                 call.name = "shadow2D";
1400
1401                         for(; i!=call.arguments.end(); ++i)
1402                                 (*i)->visit(*this);
1403                 }
1404         }
1405         else
1406                 TraversingVisitor::visit(call);
1407 }
1408
1409 void ProgramCompiler::LegacyConverter::visit(VariableDeclaration &var)
1410 {
1411         if(var.layout && !check_version(Version(3, 30)))
1412         {
1413                 vector<Layout::Qualifier>::iterator i;
1414                 for(i=var.layout->qualifiers.begin(); (i!=var.layout->qualifiers.end() && i->identifier!="location"); ++i) ;
1415                 if(i!=var.layout->qualifiers.end())
1416                 {
1417                         unsigned location = lexical_cast<unsigned>(i->value);
1418                         if(stage->type==VERTEX && var.interface=="in")
1419                         {
1420                                 stage->locations[var.name] = location;
1421                                 var.layout->qualifiers.erase(i);
1422                         }
1423                         else if(stage->type==FRAGMENT && var.interface=="out")
1424                         {
1425                                 stage->locations[var.name] = location;
1426                                 var.layout->qualifiers.erase(i);
1427                         }
1428
1429                         if(var.layout->qualifiers.empty())
1430                                 var.layout = 0;
1431                 }
1432         }
1433
1434         if((var.interface=="in" || var.interface=="out") && !check_version(Version(1, 30)))
1435         {
1436                 if(stage->type==VERTEX && var.interface=="in")
1437                         var.interface = "attribute";
1438                 else if((stage->type==VERTEX && var.interface=="out") || (stage->type==FRAGMENT && var.interface=="in"))
1439                         var.interface = "varying";
1440                 else if(stage->type==FRAGMENT && var.interface=="out")
1441                 {
1442                         frag_out_name = var.name;
1443                         remove_node = true;
1444                 }
1445         }
1446
1447         TraversingVisitor::visit(var);
1448 }
1449
1450 void ProgramCompiler::LegacyConverter::visit(InterfaceBlock &iface)
1451 {
1452         if(!check_version(Version(1, 50)))
1453                 flatten_block(iface.members);
1454 }
1455
1456 } // namespace GL
1457 } // namespace Msp