]> git.tdb.fi Git - libs/gl.git/blob - source/programcompiler.cpp
Refactor module and stage management
[libs/gl.git] / source / programcompiler.cpp
1 #include <msp/core/raii.h>
2 #include <msp/strings/format.h>
3 #include <msp/strings/utils.h>
4 #include "error.h"
5 #include "program.h"
6 #include "programcompiler.h"
7 #include "shader.h"
8
9 using namespace std;
10
11 namespace Msp {
12 namespace GL {
13
14 using namespace ProgramSyntax;
15
16 ProgramCompiler::ProgramCompiler():
17         module(0)
18 { }
19
20 void ProgramCompiler::compile(const string &source)
21 {
22         module = &parser.parse(source);
23         process();
24 }
25
26 void ProgramCompiler::compile(IO::Base &io)
27 {
28         module = &parser.parse(io);
29         process();
30 }
31
32 void ProgramCompiler::add_shaders(Program &program)
33 {
34         if(!module)
35                 throw invalid_operation("ProgramCompiler::add_shaders");
36
37         string head = "#version 150\n";
38         for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
39         {
40                 if(i->type==VERTEX)
41                         program.attach_shader_owned(new VertexShader(head+create_source(*i)));
42                 else if(i->type==GEOMETRY)
43                         program.attach_shader_owned(new GeometryShader(head+create_source(*i)));
44                 else if(i->type==FRAGMENT)
45                         program.attach_shader_owned(new FragmentShader(head+create_source(*i)));
46         }
47
48         program.bind_attribute(VERTEX4, "vertex");
49         program.bind_attribute(NORMAL3, "normal");
50         program.bind_attribute(COLOR4_FLOAT, "color");
51         program.bind_attribute(TEXCOORD4, "texcoord");
52 }
53
54 void ProgramCompiler::process()
55 {
56         for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
57                 generate(*i);
58         for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
59                 optimize(*i);
60 }
61
62 void ProgramCompiler::generate(Stage &stage)
63 {
64         inject_block(stage.content, module->shared.content);
65
66         resolve_variables(stage);
67
68         InterfaceGenerator generator;
69         generator.visit(stage);
70
71         resolve_variables(stage);
72
73         VariableRenamer renamer;
74         stage.content.visit(renamer);
75 }
76
77 void ProgramCompiler::optimize(Stage &stage)
78 {
79         while(1)
80         {
81                 UnusedVariableLocator unused_locator;
82                 unused_locator.visit(stage);
83
84                 NodeRemover remover;
85                 remover.to_remove = unused_locator.unused_nodes;
86                 remover.visit(stage);
87
88                 if(!remover.n_removed)
89                         break;
90         }
91 }
92
93 void ProgramCompiler::inject_block(Block &target, const Block &source)
94 {
95         list<NodePtr<Node> >::iterator insert_point = target.body.begin();
96         for(list<NodePtr<Node> >::const_iterator i=source.body.begin(); i!=source.body.end(); ++i)
97                 target.body.insert(insert_point, (*i)->clone());
98 }
99
100 void ProgramCompiler::resolve_variables(Stage &stage)
101 {
102         VariableResolver resolver;
103         stage.content.visit(resolver);
104 }
105
106 string ProgramCompiler::create_source(Stage &stage)
107 {
108         Formatter formatter;
109         stage.content.visit(formatter);
110         return formatter.formatted;
111 }
112
113
114 ProgramCompiler::Formatter::Formatter():
115         indent(0),
116         parameter_list(false),
117         else_if(false)
118 { }
119
120 void ProgramCompiler::Formatter::visit(Literal &literal)
121 {
122         formatted += literal.token;
123 }
124
125 void ProgramCompiler::Formatter::visit(ParenthesizedExpression &parexpr)
126 {
127         formatted += '(';
128         parexpr.expression->visit(*this);
129         formatted += ')';
130 }
131
132 void ProgramCompiler::Formatter::visit(VariableReference &var)
133 {
134         formatted += var.name;
135 }
136
137 void ProgramCompiler::Formatter::visit(MemberAccess &memacc)
138 {
139         memacc.left->visit(*this);
140         formatted += format(".%s", memacc.member);
141 }
142
143 void ProgramCompiler::Formatter::visit(UnaryExpression &unary)
144 {
145         if(unary.prefix)
146                 formatted += unary.oper;
147         unary.expression->visit(*this);
148         if(!unary.prefix)
149                 formatted += unary.oper;
150 }
151
152 void ProgramCompiler::Formatter::visit(BinaryExpression &binary)
153 {
154         binary.left->visit(*this);
155         if(binary.assignment)
156                 formatted += format(" %s ", binary.oper);
157         else
158                 formatted += binary.oper;
159         binary.right->visit(*this);
160         formatted += binary.after;
161 }
162
163 void ProgramCompiler::Formatter::visit(FunctionCall &call)
164 {
165         formatted += format("%s(", call.name);
166         for(vector<NodePtr<Expression> >::iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
167         {
168                 if(i!=call.arguments.begin())
169                         formatted += ", ";
170                 (*i)->visit(*this);
171         }
172         formatted += ')';
173 }
174
175 void ProgramCompiler::Formatter::visit(ExpressionStatement &expr)
176 {
177         expr.expression->visit(*this);
178         formatted += ';';
179 }
180
181 void ProgramCompiler::Formatter::visit(Block &block)
182 {
183         if(block.use_braces)
184         {
185                 if(else_if)
186                 {
187                         formatted += '\n';
188                         else_if = false;
189                 }
190                 formatted += format("%s{\n", string(indent*2, ' '));
191         }
192
193         bool change_indent = (!formatted.empty() && !else_if);
194         indent += change_indent;
195         string spaces(indent*2, ' ');
196         for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); ++i)
197         {
198                 if(i!=block.body.begin())
199                         formatted += '\n';
200                 if(!else_if)
201                         formatted += spaces;
202                 (*i)->visit(*this);
203         }
204         indent -= change_indent;
205
206         if(block.use_braces)
207                 formatted += format("\n%s}", string(indent*2, ' '));
208 }
209
210 void ProgramCompiler::Formatter::visit(Layout &layout)
211 {
212         formatted += "layout(";
213         for(vector<Layout::Qualifier>::const_iterator i=layout.qualifiers.begin(); i!=layout.qualifiers.end(); ++i)
214         {
215                 if(i!=layout.qualifiers.begin())
216                         formatted += ", ";
217                 formatted += i->identifier;
218                 if(!i->value.empty())
219                         formatted += format("=%s", i->value);
220         }
221         formatted += format(") %s;", layout.interface);
222 }
223
224 void ProgramCompiler::Formatter::visit(StructDeclaration &strct)
225 {
226         formatted += format("struct %s\n", strct.name);
227         strct.members.visit(*this);
228         formatted += ';';
229 }
230
231 void ProgramCompiler::Formatter::visit(VariableDeclaration &var)
232 {
233         if(var.constant)
234                 formatted += "const ";
235         if(!var.sampling.empty())
236                 formatted += format("%s ", var.sampling);
237         if(!var.interface.empty())
238                 formatted += format("%s ", var.interface);
239         formatted += format("%s %s", var.type, var.name);
240         if(var.array)
241         {
242                 formatted += '[';
243                 if(var.array_size)
244                         var.array_size->visit(*this);
245                 formatted += ']';
246         }
247         if(var.init_expression)
248         {
249                 formatted += " = ";
250                 var.init_expression->visit(*this);
251         }
252         if(!parameter_list)
253                 formatted += ';';
254 }
255
256 void ProgramCompiler::Formatter::visit(InterfaceBlock &iface)
257 {
258         formatted += format("%s %s\n", iface.interface, iface.name);
259         iface.members.visit(*this);
260         formatted += ';';
261 }
262
263 void ProgramCompiler::Formatter::visit(FunctionDeclaration &func)
264 {
265         formatted += format("%s %s(", func.return_type, func.name);
266         for(vector<NodePtr<VariableDeclaration> >::iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
267         {
268                 if(i!=func.parameters.begin())
269                         formatted += ", ";
270                 SetFlag set(parameter_list);
271                 (*i)->visit(*this);
272         }
273         formatted += ')';
274         if(func.definition)
275         {
276                 formatted += '\n';
277                 func.body.visit(*this);
278         }
279         else
280                 formatted += ';';
281 }
282
283 void ProgramCompiler::Formatter::visit(Conditional &cond)
284 {
285         if(else_if)
286         {
287                 formatted += ' ';
288                 else_if = false;
289         }
290
291         formatted += "if(";
292         cond.condition->visit(*this);
293         formatted += ")\n";
294
295         cond.body.visit(*this);
296         if(!cond.else_body.body.empty())
297         {
298                 formatted += format("\n%selse", string(indent*2, ' '));
299                 SetFlag set(else_if);
300                 cond.else_body.visit(*this);
301         }
302 }
303
304 void ProgramCompiler::Formatter::visit(Iteration &iter)
305 {
306         formatted += "for(";
307         iter.init_statement->visit(*this);
308         formatted += ' ';
309         iter.condition->visit(*this);
310         formatted += "; ";
311         iter.loop_expression->visit(*this);
312         formatted += ")\n";
313         iter.body.visit(*this);
314 }
315
316 void ProgramCompiler::Formatter::visit(Return &ret)
317 {
318         formatted += "return ";
319         ret.expression->visit(*this);
320         formatted += ';';
321 }
322
323
324 ProgramCompiler::VariableResolver::VariableResolver():
325         anonymous(false)
326 { }
327
328 void ProgramCompiler::VariableResolver::visit(Block &block)
329 {
330         blocks.push_back(&block);
331         block.variables.clear();
332         TraversingVisitor::visit(block);
333         blocks.pop_back();
334 }
335
336 void ProgramCompiler::VariableResolver::visit(VariableReference &var)
337 {
338         var.declaration = 0;
339         type = 0;
340         for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
341         {
342                 --i;
343                 map<string, VariableDeclaration *>::iterator j = (*i)->variables.find(var.name);
344                 if(j!=(*i)->variables.end())
345                 {
346                         var.declaration = j->second;
347                         type = j->second->type_declaration;
348                         break;
349                 }
350         }
351 }
352
353 void ProgramCompiler::VariableResolver::visit(MemberAccess &memacc)
354 {
355         type = 0;
356         TraversingVisitor::visit(memacc);
357         memacc.declaration = 0;
358         if(type)
359         {
360                 map<string, VariableDeclaration *>::iterator i = type->members.variables.find(memacc.member);
361                 if(i!=type->members.variables.end())
362                 {
363                         memacc.declaration = i->second;
364                         type = i->second->type_declaration;
365                 }
366                 else
367                         type = 0;
368         }
369 }
370
371 void ProgramCompiler::VariableResolver::visit(BinaryExpression &binary)
372 {
373         if(binary.oper=="[")
374         {
375                 binary.right->visit(*this);
376                 type = 0;
377                 binary.left->visit(*this);
378         }
379         else
380         {
381                 TraversingVisitor::visit(binary);
382                 type = 0;
383         }
384 }
385
386 void ProgramCompiler::VariableResolver::visit(StructDeclaration &strct)
387 {
388         TraversingVisitor::visit(strct);
389         blocks.back()->types[strct.name] = &strct;
390 }
391
392 void ProgramCompiler::VariableResolver::visit(VariableDeclaration &var)
393 {
394         for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
395         {
396                 --i;
397                 map<string, StructDeclaration *>::iterator j = (*i)->types.find(var.type);
398                 if(j!=(*i)->types.end())
399                         var.type_declaration = j->second;
400         }
401
402         TraversingVisitor::visit(var);
403         blocks.back()->variables[var.name] = &var;
404         if(anonymous && blocks.size()>1)
405                 blocks[blocks.size()-2]->variables[var.name] = &var;
406 }
407
408 void ProgramCompiler::VariableResolver::visit(InterfaceBlock &iface)
409 {
410         SetFlag set(anonymous);
411         TraversingVisitor::visit(iface);
412 }
413
414
415 ProgramCompiler::InterfaceGenerator::InterfaceGenerator():
416         stage(0),
417         scope_level(0),
418         remove_node(false)
419 { }
420
421 string ProgramCompiler::InterfaceGenerator::get_out_prefix(StageType type)
422 {
423         if(type==VERTEX)
424                 return "_vs_out_";
425         else if(type==GEOMETRY)
426                 return "_gs_out_";
427         else
428                 return string();
429 }
430
431 void ProgramCompiler::InterfaceGenerator::visit(Stage &s)
432 {
433         SetForScope<Stage *> set(stage, &s);
434         if(stage->previous)
435                 in_prefix = get_out_prefix(stage->previous->type);
436         out_prefix = get_out_prefix(stage->type);
437         stage->content.visit(*this);
438 }
439
440 void ProgramCompiler::InterfaceGenerator::visit(Block &block)
441 {
442         SetForScope<unsigned> set(scope_level, scope_level+1);
443         for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
444         {
445                 (*i)->visit(*this);
446
447                 if(scope_level==1)
448                 {
449                         for(map<string, NodePtr<Node> >::iterator j=iface_declarations.begin(); j!=iface_declarations.end(); ++j)
450                         {
451                                 list<NodePtr<Node> >::iterator k = block.body.insert(i, j->second);
452                                 (*k)->visit(*this);
453                         }
454                         iface_declarations.clear();
455                 }
456
457                 for(list<NodePtr<Node> >::iterator j=insert_nodes.begin(); j!=insert_nodes.end(); ++j)
458                         block.body.insert(i, *j);
459                 insert_nodes.clear();
460
461                 if(remove_node)
462                         block.body.erase(i++);
463                 else
464                         ++i;
465                 remove_node = false;
466         }
467 }
468
469 string ProgramCompiler::InterfaceGenerator::change_prefix(const string &name, const string &prefix) const
470 {
471         unsigned offset = (name.compare(0, in_prefix.size(), in_prefix) ? 0 : in_prefix.size());
472         return prefix+name.substr(offset);
473 }
474
475 bool ProgramCompiler::InterfaceGenerator::generate_interface(VariableDeclaration &out, const string &iface, const string &name)
476 {
477         const map<string, VariableDeclaration *> &stage_vars = (iface=="in" ? stage->in_variables : stage->out_variables);
478         if(stage_vars.count(name) || iface_declarations.count(name))
479                 return false;
480
481         VariableDeclaration* iface_var = new VariableDeclaration;
482         iface_var->sampling = out.sampling;
483         iface_var->interface = iface;
484         iface_var->type = out.type;
485         iface_var->type_declaration = out.type_declaration;
486         iface_var->name = name;
487         iface_var->array = (out.array || (stage->type==GEOMETRY && iface=="in"));
488         iface_var->array_size = out.array_size;
489         if(iface=="in")
490                 iface_var->linked_declaration = &out;
491         iface_declarations[iface_var->name] = iface_var;
492
493         return true;
494 }
495
496 void ProgramCompiler::InterfaceGenerator::insert_assignment(const string &left, ProgramSyntax::Expression *right)
497 {
498         BinaryExpression *assign = new BinaryExpression;
499         VariableReference *ref = new VariableReference;
500         ref->name = left;
501         assign->left = ref;
502         assign->oper = "=";
503         assign->right = right;
504         assign->assignment = true;
505
506         ExpressionStatement *stmt = new ExpressionStatement;
507         stmt->expression = assign;
508         insert_nodes.push_back(stmt);
509 }
510
511 void ProgramCompiler::InterfaceGenerator::visit(VariableReference &var)
512 {
513         if(var.declaration || !stage->previous)
514                 return;
515         if(iface_declarations.count(var.name))
516                 return;
517
518         const map<string, VariableDeclaration *> &prev_out = stage->previous->out_variables;
519         map<string, VariableDeclaration *>::const_iterator i = prev_out.find(var.name);
520         if(i==prev_out.end())
521                 i = prev_out.find(in_prefix+var.name);
522         if(i!=prev_out.end())
523                 generate_interface(*i->second, "in", var.name);
524 }
525
526 void ProgramCompiler::InterfaceGenerator::visit(VariableDeclaration &var)
527 {
528         if(var.interface=="out")
529         {
530                 if(scope_level==1)
531                         stage->out_variables[var.name] = &var;
532                 else if(generate_interface(var, "out", change_prefix(var.name, string())))
533                 {
534                         remove_node = true;
535                         if(var.init_expression)
536                                 insert_assignment(var.name, var.init_expression->clone());
537                 }
538         }
539         else if(var.interface=="in")
540         {
541                 stage->in_variables[var.name] = &var;
542                 if(var.linked_declaration)
543                         var.linked_declaration->linked_declaration = &var;
544                 else if(stage->previous)
545                 {
546                         const map<string, VariableDeclaration *> &prev_out = stage->previous->out_variables;
547                         map<string, VariableDeclaration *>::const_iterator i = prev_out.find(var.name);
548                         if(i!=prev_out.end())
549                         {
550                                 var.linked_declaration = i->second;
551                                 i->second->linked_declaration = &var;
552                         }
553                 }
554         }
555
556         TraversingVisitor::visit(var);
557 }
558
559 void ProgramCompiler::InterfaceGenerator::visit(Passthrough &pass)
560 {
561         if(stage->previous)
562         {
563                 const map<string, VariableDeclaration *> &prev_out = stage->previous->out_variables;
564                 for(map<string, VariableDeclaration *>::const_iterator i=prev_out.begin(); i!=prev_out.end(); ++i)
565                 {
566                         string out_name = change_prefix(i->second->name, out_prefix);
567                         generate_interface(*i->second, "in", i->second->name);
568                         generate_interface(*i->second, "out", out_name);
569
570                         VariableReference *ref = new VariableReference;
571                         ref->name = i->first;
572                         if(pass.subscript)
573                         {
574                                 BinaryExpression *subscript = new BinaryExpression;
575                                 subscript->left = ref;
576                                 subscript->oper = "[";
577                                 subscript->right = pass.subscript;
578                                 subscript->after = "]";
579                                 insert_assignment(out_name, subscript);
580                         }
581                         else
582                                 insert_assignment(out_name, ref);
583                 }
584         }
585
586         remove_node = true;
587 }
588
589
590 void ProgramCompiler::VariableRenamer::visit(VariableReference &var)
591 {
592         if(var.declaration)
593                 var.name = var.declaration->name;
594 }
595
596 void ProgramCompiler::VariableRenamer::visit(VariableDeclaration &var)
597 {
598         if(var.linked_declaration)
599                 var.name = var.linked_declaration->name;
600         TraversingVisitor::visit(var);
601 }
602
603
604 ProgramCompiler::UnusedVariableLocator::UnusedVariableLocator():
605         stage(0),
606         assignment(false),
607         assignment_target(0)
608 { }
609
610 void ProgramCompiler::UnusedVariableLocator::visit(Stage &s)
611 {
612         stage = &s;
613         stage->content.visit(*this);
614 }
615
616 void ProgramCompiler::UnusedVariableLocator::visit(VariableReference &var)
617 {
618         if(assignment)
619                 assignment_target = var.declaration;
620         else
621         {
622                 unused_nodes.erase(var.declaration);
623                 map<VariableDeclaration *, Node *>::iterator i = assignments.find(var.declaration);
624                 if(i!=assignments.end())
625                         unused_nodes.erase(i->second);
626         }
627 }
628
629 void ProgramCompiler::UnusedVariableLocator::visit(MemberAccess &memacc)
630 {
631         TraversingVisitor::visit(memacc);
632         unused_nodes.erase(memacc.declaration);
633 }
634
635 void ProgramCompiler::UnusedVariableLocator::visit(BinaryExpression &binary)
636 {
637         if(binary.assignment)
638         {
639                 binary.right->visit(*this);
640                 assignment = true;
641                 binary.left->visit(*this);
642         }
643         else
644                 TraversingVisitor::visit(binary);
645 }
646
647 void ProgramCompiler::UnusedVariableLocator::visit(ExpressionStatement &expr)
648 {
649         assignment = false;
650         assignment_target = 0;
651         TraversingVisitor::visit(expr);
652         if(assignment && assignment_target)
653         {
654                 if(assignment_target->interface!="out" || (stage->type!=FRAGMENT && !assignment_target->linked_declaration))
655                 {
656                         unused_nodes.insert(&expr);
657                         assignments[assignment_target] = &expr;
658                 }
659                 else
660                         unused_nodes.erase(assignment_target);
661         }
662         assignment = false;
663 }
664
665 void ProgramCompiler::UnusedVariableLocator::visit(VariableDeclaration &var)
666 {
667         unused_nodes.insert(&var);
668         TraversingVisitor::visit(var);
669 }
670
671
672 ProgramCompiler::NodeRemover::NodeRemover():
673         stage(0),
674         n_removed(0),
675         immutable_block(false),
676         remove_block(false)
677 { }
678
679 void ProgramCompiler::NodeRemover::visit(Stage &s)
680 {
681         stage = &s;
682         stage->content.visit(*this);
683 }
684
685 void ProgramCompiler::NodeRemover::visit(Block &block)
686 {
687         remove_block = immutable_block;
688         for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
689         {
690                 bool remove = to_remove.count(&**i);
691                 if(!remove)
692                         remove_block = false;
693                 (*i)->visit(*this);
694
695                 if(remove ? !immutable_block : remove_block)
696                 {
697                         block.body.erase(i++);
698                         ++n_removed;
699                 }
700                 else
701                         ++i;
702         }
703 }
704
705 void ProgramCompiler::NodeRemover::visit(StructDeclaration &strct)
706 {
707         SetFlag set(immutable_block);
708         TraversingVisitor::visit(strct);
709 }
710
711 void ProgramCompiler::NodeRemover::visit(VariableDeclaration &var)
712 {
713         if(to_remove.count(&var))
714         {
715                 stage->in_variables.erase(var.name);
716                 stage->out_variables.erase(var.name);
717                 if(var.linked_declaration)
718                         var.linked_declaration->linked_declaration = 0;
719         }
720 }
721
722 void ProgramCompiler::NodeRemover::visit(InterfaceBlock &iface)
723 {
724         SetFlag set(immutable_block);
725         TraversingVisitor::visit(iface);
726 }
727
728 } // namespace GL
729 } // namespace Msp