]> git.tdb.fi Git - libs/gl.git/blob - source/programcompiler.cpp
Streamline interface declarations
[libs/gl.git] / source / programcompiler.cpp
1 #include <msp/core/raii.h>
2 #include <msp/strings/format.h>
3 #include <msp/strings/utils.h>
4 #include "error.h"
5 #include "program.h"
6 #include "programcompiler.h"
7 #include "shader.h"
8
9 using namespace std;
10
11 namespace Msp {
12 namespace GL {
13
14 using namespace ProgramSyntax;
15
16 ProgramCompiler::ProgramCompiler():
17         module(0)
18 { }
19
20 void ProgramCompiler::compile(const string &source)
21 {
22         module = &parser.parse(source);
23         process();
24 }
25
26 void ProgramCompiler::compile(IO::Base &io)
27 {
28         module = &parser.parse(io);
29         process();
30 }
31
32 void ProgramCompiler::add_shaders(Program &program)
33 {
34         if(!module)
35                 throw invalid_operation("ProgramCompiler::add_shaders");
36
37         string head = "#version 150\n";
38         if(module->vertex_context.present)
39                 program.attach_shader_owned(new VertexShader(head+format_context(module->vertex_context)));
40         if(module->geometry_context.present)
41                 program.attach_shader_owned(new GeometryShader(head+format_context(module->geometry_context)));
42         if(module->fragment_context.present)
43                 program.attach_shader_owned(new FragmentShader(head+format_context(module->fragment_context)));
44
45         program.bind_attribute(VERTEX4, "vertex");
46         program.bind_attribute(NORMAL3, "normal");
47         program.bind_attribute(COLOR4_FLOAT, "color");
48         program.bind_attribute(TEXCOORD4, "texcoord");
49 }
50
51 void ProgramCompiler::process()
52 {
53         if(module->vertex_context.present)
54                 process(module->vertex_context);
55         if(module->geometry_context.present)
56                 process(module->geometry_context);
57         if(module->fragment_context.present)
58                 process(module->fragment_context);
59 }
60
61 void ProgramCompiler::process(Context &context)
62 {
63         inject_block(context.content, module->global_context.content);
64
65         resolve_variables(context);
66
67         InterfaceGenerator generator;
68         generator.visit(context);
69
70         resolve_variables(context);
71
72         VariableRenamer renamer;
73         context.content.visit(renamer);
74
75         while(1)
76         {
77                 UnusedVariableLocator unused_locator;
78                 context.content.visit(unused_locator);
79
80                 NodeRemover remover;
81                 remover.to_remove.insert(unused_locator.unused_variables.begin(), unused_locator.unused_variables.end());
82                 context.content.visit(remover);
83
84                 if(!remover.n_removed)
85                         break;
86         }
87 }
88
89 void ProgramCompiler::inject_block(Block &target, const Block &source)
90 {
91         list<NodePtr<Node> >::iterator insert_point = target.body.begin();
92         for(list<NodePtr<Node> >::const_iterator i=source.body.begin(); i!=source.body.end(); ++i)
93                 target.body.insert(insert_point, (*i)->clone());
94 }
95
96 void ProgramCompiler::resolve_variables(Context &context)
97 {
98         VariableResolver resolver;
99         context.content.visit(resolver);
100 }
101
102 string ProgramCompiler::format_context(Context &context)
103 {
104         Formatter formatter;
105         context.content.visit(formatter);
106         return formatter.formatted;
107 }
108
109
110 ProgramCompiler::Formatter::Formatter():
111         indent(0),
112         parameter_list(false),
113         else_if(false)
114 { }
115
116 void ProgramCompiler::Formatter::visit(Literal &literal)
117 {
118         formatted += literal.token;
119 }
120
121 void ProgramCompiler::Formatter::visit(ParenthesizedExpression &parexpr)
122 {
123         formatted += '(';
124         parexpr.expression->visit(*this);
125         formatted += ')';
126 }
127
128 void ProgramCompiler::Formatter::visit(VariableReference &var)
129 {
130         formatted += var.name;
131 }
132
133 void ProgramCompiler::Formatter::visit(MemberAccess &memacc)
134 {
135         memacc.left->visit(*this);
136         formatted += format(".%s", memacc.member);
137 }
138
139 void ProgramCompiler::Formatter::visit(UnaryExpression &unary)
140 {
141         if(unary.prefix)
142                 formatted += unary.oper;
143         unary.expression->visit(*this);
144         if(!unary.prefix)
145                 formatted += unary.oper;
146 }
147
148 void ProgramCompiler::Formatter::visit(BinaryExpression &binary)
149 {
150         binary.left->visit(*this);
151         if(binary.assignment)
152                 formatted += format(" %s ", binary.oper);
153         else
154                 formatted += binary.oper;
155         binary.right->visit(*this);
156         formatted += binary.after;
157 }
158
159 void ProgramCompiler::Formatter::visit(FunctionCall &call)
160 {
161         formatted += format("%s(", call.name);
162         for(vector<NodePtr<Expression> >::iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
163         {
164                 if(i!=call.arguments.begin())
165                         formatted += ", ";
166                 (*i)->visit(*this);
167         }
168         formatted += ')';
169 }
170
171 void ProgramCompiler::Formatter::visit(ExpressionStatement &expr)
172 {
173         expr.expression->visit(*this);
174         formatted += ';';
175 }
176
177 void ProgramCompiler::Formatter::visit(Block &block)
178 {
179         if(block.use_braces)
180         {
181                 if(else_if)
182                 {
183                         formatted += '\n';
184                         else_if = false;
185                 }
186                 formatted += format("%s{\n", string(indent*2, ' '));
187         }
188
189         bool change_indent = (!formatted.empty() && !else_if);
190         indent += change_indent;
191         string spaces(indent*2, ' ');
192         for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); ++i)
193         {
194                 if(i!=block.body.begin())
195                         formatted += '\n';
196                 if(!else_if)
197                         formatted += spaces;
198                 (*i)->visit(*this);
199         }
200         indent -= change_indent;
201
202         if(block.use_braces)
203                 formatted += format("\n%s}", string(indent*2, ' '));
204 }
205
206 void ProgramCompiler::Formatter::visit(Layout &layout)
207 {
208         formatted += "layout(";
209         for(vector<Layout::Qualifier>::const_iterator i=layout.qualifiers.begin(); i!=layout.qualifiers.end(); ++i)
210         {
211                 if(i!=layout.qualifiers.begin())
212                         formatted += ", ";
213                 formatted += i->identifier;
214                 if(!i->value.empty())
215                         formatted += format("=%s", i->value);
216         }
217         formatted += format(") %s;", layout.interface);
218 }
219
220 void ProgramCompiler::Formatter::visit(StructDeclaration &strct)
221 {
222         formatted += format("struct %s\n", strct.name);
223         strct.members.visit(*this);
224         formatted += ';';
225 }
226
227 void ProgramCompiler::Formatter::visit(VariableDeclaration &var)
228 {
229         if(var.constant)
230                 formatted += "const ";
231         if(!var.sampling.empty())
232                 formatted += format("%s ", var.sampling);
233         if(!var.interface.empty())
234                 formatted += format("%s ", var.interface);
235         formatted += format("%s %s", var.type, var.name);
236         if(var.array)
237         {
238                 formatted += '[';
239                 if(var.array_size)
240                         var.array_size->visit(*this);
241                 formatted += ']';
242         }
243         if(var.init_expression)
244         {
245                 formatted += " = ";
246                 var.init_expression->visit(*this);
247         }
248         if(!parameter_list)
249                 formatted += ';';
250 }
251
252 void ProgramCompiler::Formatter::visit(InterfaceBlock &iface)
253 {
254         formatted += format("%s %s\n", iface.interface, iface.name);
255         iface.members.visit(*this);
256         formatted += ';';
257 }
258
259 void ProgramCompiler::Formatter::visit(FunctionDeclaration &func)
260 {
261         formatted += format("%s %s(", func.return_type, func.name);
262         for(vector<NodePtr<VariableDeclaration> >::iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
263         {
264                 if(i!=func.parameters.begin())
265                         formatted += ", ";
266                 SetFlag set(parameter_list);
267                 (*i)->visit(*this);
268         }
269         formatted += ')';
270         if(func.definition)
271         {
272                 formatted += '\n';
273                 func.body.visit(*this);
274         }
275         else
276                 formatted += ';';
277 }
278
279 void ProgramCompiler::Formatter::visit(Conditional &cond)
280 {
281         if(else_if)
282         {
283                 formatted += ' ';
284                 else_if = false;
285         }
286
287         formatted += "if(";
288         cond.condition->visit(*this);
289         formatted += ")\n";
290
291         cond.body.visit(*this);
292         if(!cond.else_body.body.empty())
293         {
294                 formatted += format("\n%selse", string(indent*2, ' '));
295                 SetFlag set(else_if);
296                 cond.else_body.visit(*this);
297         }
298 }
299
300 void ProgramCompiler::Formatter::visit(Iteration &iter)
301 {
302         formatted += "for(";
303         iter.init_statement->visit(*this);
304         formatted += ' ';
305         iter.condition->visit(*this);
306         formatted += "; ";
307         iter.loop_expression->visit(*this);
308         formatted += ")\n";
309         iter.body.visit(*this);
310 }
311
312 void ProgramCompiler::Formatter::visit(Return &ret)
313 {
314         formatted += "return ";
315         ret.expression->visit(*this);
316         formatted += ';';
317 }
318
319
320 ProgramCompiler::VariableResolver::VariableResolver():
321         anonymous(false)
322 { }
323
324 void ProgramCompiler::VariableResolver::visit(Block &block)
325 {
326         blocks.push_back(&block);
327         block.variables.clear();
328         TraversingVisitor::visit(block);
329         blocks.pop_back();
330 }
331
332 void ProgramCompiler::VariableResolver::visit(VariableReference &var)
333 {
334         var.declaration = 0;
335         type = 0;
336         for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
337         {
338                 --i;
339                 map<string, VariableDeclaration *>::iterator j = (*i)->variables.find(var.name);
340                 if(j!=(*i)->variables.end())
341                 {
342                         var.declaration = j->second;
343                         type = j->second->type_declaration;
344                         break;
345                 }
346         }
347 }
348
349 void ProgramCompiler::VariableResolver::visit(MemberAccess &memacc)
350 {
351         type = 0;
352         TraversingVisitor::visit(memacc);
353         memacc.declaration = 0;
354         if(type)
355         {
356                 map<string, VariableDeclaration *>::iterator i = type->members.variables.find(memacc.member);
357                 if(i!=type->members.variables.end())
358                 {
359                         memacc.declaration = i->second;
360                         type = i->second->type_declaration;
361                 }
362                 else
363                         type = 0;
364         }
365 }
366
367 void ProgramCompiler::VariableResolver::visit(BinaryExpression &binary)
368 {
369         if(binary.oper=="[")
370         {
371                 binary.right->visit(*this);
372                 type = 0;
373                 binary.left->visit(*this);
374         }
375         else
376         {
377                 TraversingVisitor::visit(binary);
378                 type = 0;
379         }
380 }
381
382 void ProgramCompiler::VariableResolver::visit(StructDeclaration &strct)
383 {
384         TraversingVisitor::visit(strct);
385         blocks.back()->types[strct.name] = &strct;
386 }
387
388 void ProgramCompiler::VariableResolver::visit(VariableDeclaration &var)
389 {
390         for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
391         {
392                 --i;
393                 map<string, StructDeclaration *>::iterator j = (*i)->types.find(var.type);
394                 if(j!=(*i)->types.end())
395                         var.type_declaration = j->second;
396         }
397
398         TraversingVisitor::visit(var);
399         blocks.back()->variables[var.name] = &var;
400         if(anonymous && blocks.size()>1)
401                 blocks[blocks.size()-2]->variables[var.name] = &var;
402 }
403
404 void ProgramCompiler::VariableResolver::visit(InterfaceBlock &iface)
405 {
406         SetFlag set(anonymous);
407         TraversingVisitor::visit(iface);
408 }
409
410
411 ProgramCompiler::InterfaceGenerator::InterfaceGenerator():
412         context(0),
413         scope_level(0),
414         remove_node(false)
415 { }
416
417 string ProgramCompiler::InterfaceGenerator::get_out_prefix(ContextType type)
418 {
419         if(type==VERTEX)
420                 return "_vs_out_";
421         else if(type==GEOMETRY)
422                 return "_gs_out_";
423         else
424                 return string();
425 }
426
427 void ProgramCompiler::InterfaceGenerator::visit(Context &ctx)
428 {
429         SetForScope<Context *> set(context, &ctx);
430         if(context->previous)
431                 in_prefix = get_out_prefix(context->previous->type);
432         out_prefix = get_out_prefix(context->type);
433         ctx.content.visit(*this);
434 }
435
436 void ProgramCompiler::InterfaceGenerator::visit(Block &block)
437 {
438         SetForScope<unsigned> set(scope_level, scope_level+1);
439         for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
440         {
441                 (*i)->visit(*this);
442
443                 if(scope_level==1)
444                 {
445                         for(map<string, NodePtr<Node> >::iterator j=iface_declarations.begin(); j!=iface_declarations.end(); ++j)
446                         {
447                                 list<NodePtr<Node> >::iterator k = block.body.insert(i, j->second);
448                                 (*k)->visit(*this);
449                         }
450                         iface_declarations.clear();
451                 }
452
453                 for(list<NodePtr<Node> >::iterator j=insert_nodes.begin(); j!=insert_nodes.end(); ++j)
454                         block.body.insert(i, *j);
455                 insert_nodes.clear();
456
457                 if(remove_node)
458                         block.body.erase(i++);
459                 else
460                         ++i;
461                 remove_node = false;
462         }
463 }
464
465 string ProgramCompiler::InterfaceGenerator::change_prefix(const string &name, const string &prefix) const
466 {
467         unsigned offset = (name.compare(0, in_prefix.size(), in_prefix) ? 0 : in_prefix.size());
468         return prefix+name.substr(offset);
469 }
470
471 bool ProgramCompiler::InterfaceGenerator::generate_interface(VariableDeclaration &out, const string &iface, const string &name)
472 {
473         const map<string, VariableDeclaration *> &context_vars = (iface=="in" ? context->in_variables : context->out_variables);
474         if(context_vars.count(name) || iface_declarations.count(name))
475                 return false;
476
477         VariableDeclaration* iface_var = new VariableDeclaration;
478         iface_var->sampling = out.sampling;
479         iface_var->interface = iface;
480         iface_var->type = out.type;
481         iface_var->type_declaration = out.type_declaration;
482         iface_var->name = name;
483         iface_var->array = (out.array || (context->type==GEOMETRY && iface=="in"));
484         iface_var->array_size = out.array_size;
485         if(iface=="in")
486                 iface_var->linked_declaration = &out;
487         iface_declarations[iface_var->name] = iface_var;
488
489         return true;
490 }
491
492 void ProgramCompiler::InterfaceGenerator::insert_assignment(const string &left, ProgramSyntax::Expression *right)
493 {
494         BinaryExpression *assign = new BinaryExpression;
495         VariableReference *ref = new VariableReference;
496         ref->name = left;
497         assign->left = ref;
498         assign->oper = "=";
499         assign->right = right;
500         assign->assignment = true;
501
502         ExpressionStatement *stmt = new ExpressionStatement;
503         stmt->expression = assign;
504         insert_nodes.push_back(stmt);
505 }
506
507 void ProgramCompiler::InterfaceGenerator::visit(VariableReference &var)
508 {
509         if(var.declaration || !context->previous)
510                 return;
511         if(iface_declarations.count(var.name))
512                 return;
513
514         const map<string, VariableDeclaration *> &prev_out = context->previous->out_variables;
515         map<string, VariableDeclaration *>::const_iterator i = prev_out.find(var.name);
516         if(i==prev_out.end())
517                 i = prev_out.find(in_prefix+var.name);
518         if(i!=prev_out.end())
519                 generate_interface(*i->second, "in", var.name);
520 }
521
522 void ProgramCompiler::InterfaceGenerator::visit(VariableDeclaration &var)
523 {
524         if(var.interface=="out")
525         {
526                 if(scope_level==1)
527                         context->out_variables[var.name] = &var;
528                 else if(generate_interface(var, "out", change_prefix(var.name, string())))
529                 {
530                         remove_node = true;
531                         if(var.init_expression)
532                                 insert_assignment(var.name, var.init_expression->clone());
533                 }
534         }
535         else if(var.interface=="in")
536         {
537                 context->in_variables[var.name] = &var;
538                 if(context->previous)
539                 {
540                         const map<string, VariableDeclaration *> &prev_out = context->previous->out_variables;
541                         map<string, VariableDeclaration *>::const_iterator i = prev_out.find(var.name);
542                         if(i!=prev_out.end())
543                         {
544                                 var.linked_declaration = i->second;
545                                 i->second->linked_declaration = &var;
546                         }
547                 }
548         }
549
550         TraversingVisitor::visit(var);
551 }
552
553 void ProgramCompiler::InterfaceGenerator::visit(Passthrough &pass)
554 {
555         if(context->previous)
556         {
557                 const map<string, VariableDeclaration *> &prev_out = context->previous->out_variables;
558                 for(map<string, VariableDeclaration *>::const_iterator i=prev_out.begin(); i!=prev_out.end(); ++i)
559                 {
560                         string out_name = change_prefix(i->second->name, out_prefix);
561                         generate_interface(*i->second, "in", i->second->name);
562                         generate_interface(*i->second, "out", out_name);
563
564                         VariableReference *ref = new VariableReference;
565                         ref->name = i->first;
566                         if(pass.subscript)
567                         {
568                                 BinaryExpression *subscript = new BinaryExpression;
569                                 subscript->left = ref;
570                                 subscript->oper = "[";
571                                 subscript->right = pass.subscript;
572                                 subscript->after = "]";
573                                 insert_assignment(out_name, subscript);
574                         }
575                         else
576                                 insert_assignment(out_name, ref);
577                 }
578         }
579
580         remove_node = true;
581 }
582
583
584 void ProgramCompiler::VariableRenamer::visit(VariableReference &var)
585 {
586         if(var.declaration)
587                 var.name = var.declaration->name;
588 }
589
590 void ProgramCompiler::VariableRenamer::visit(VariableDeclaration &var)
591 {
592         if(var.linked_declaration)
593                 var.name = var.linked_declaration->name;
594         TraversingVisitor::visit(var);
595 }
596
597
598 void ProgramCompiler::UnusedVariableLocator::visit(VariableReference &var)
599 {
600         unused_variables.erase(var.declaration);
601 }
602
603 void ProgramCompiler::UnusedVariableLocator::visit(MemberAccess &memacc)
604 {
605         TraversingVisitor::visit(memacc);
606         unused_variables.erase(memacc.declaration);
607 }
608
609 void ProgramCompiler::UnusedVariableLocator::visit(VariableDeclaration &var)
610 {
611         unused_variables.insert(&var);
612         TraversingVisitor::visit(var);
613 }
614
615
616 ProgramCompiler::NodeRemover::NodeRemover():
617         n_removed(0),
618         immutable_block(false),
619         remove_block(false)
620 { }
621
622 void ProgramCompiler::NodeRemover::visit(Block &block)
623 {
624         remove_block = immutable_block;
625         for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
626         {
627                 bool remove = false;
628                 if(to_remove.count(&**i))
629                         remove = !immutable_block;
630                 else
631                 {
632                         remove_block = false;
633                         (*i)->visit(*this);
634                         remove = remove_block;
635                 }
636
637                 if(remove)
638                         block.body.erase(i++);
639                 else
640                         ++i;
641
642                 n_removed += remove;
643         }
644 }
645
646 void ProgramCompiler::NodeRemover::visit(StructDeclaration &strct)
647 {
648         SetFlag set(immutable_block);
649         TraversingVisitor::visit(strct);
650 }
651
652 void ProgramCompiler::NodeRemover::visit(InterfaceBlock &iface)
653 {
654         SetFlag set(immutable_block);
655         TraversingVisitor::visit(iface);
656 }
657
658 } // namespace GL
659 } // namespace Msp