]> git.tdb.fi Git - libs/gl.git/blob - source/programcompiler.cpp
Remove unused variable and struct declarations from the syntax tree
[libs/gl.git] / source / programcompiler.cpp
1 #include <msp/core/raii.h>
2 #include <msp/strings/format.h>
3 #include <msp/strings/utils.h>
4 #include "error.h"
5 #include "program.h"
6 #include "programcompiler.h"
7 #include "shader.h"
8
9 using namespace std;
10
11 namespace Msp {
12 namespace GL {
13
14 using namespace ProgramSyntax;
15
16 ProgramCompiler::ProgramCompiler():
17         module(0)
18 { }
19
20 void ProgramCompiler::compile(const string &source)
21 {
22         module = &parser.parse(source);
23         process();
24 }
25
26 void ProgramCompiler::compile(IO::Base &io)
27 {
28         module = &parser.parse(io);
29         process();
30 }
31
32 void ProgramCompiler::add_shaders(Program &program)
33 {
34         if(!module)
35                 throw invalid_operation("ProgramCompiler::add_shaders");
36
37         string head = "#version 150\n";
38         if(module->vertex_context.present)
39                 program.attach_shader_owned(new VertexShader(head+format_context(module->vertex_context)));
40         if(module->geometry_context.present)
41                 program.attach_shader_owned(new GeometryShader(head+format_context(module->geometry_context)));
42         if(module->fragment_context.present)
43                 program.attach_shader_owned(new FragmentShader(head+format_context(module->fragment_context)));
44
45         program.bind_attribute(VERTEX4, "vertex");
46         program.bind_attribute(NORMAL3, "normal");
47         program.bind_attribute(COLOR4_FLOAT, "color");
48         program.bind_attribute(TEXCOORD4, "texcoord");
49 }
50
51 void ProgramCompiler::process()
52 {
53         if(module->vertex_context.present)
54                 process(module->vertex_context);
55         if(module->geometry_context.present)
56                 process(module->geometry_context);
57         if(module->fragment_context.present)
58                 process(module->fragment_context);
59 }
60
61 void ProgramCompiler::process(Context &context)
62 {
63         inject_block(context.content, module->global_context.content);
64
65         VariableResolver resolver;
66         context.content.visit(resolver);
67
68         while(1)
69         {
70                 UnusedVariableLocator unused_locator;
71                 context.content.visit(unused_locator);
72
73                 NodeRemover remover;
74                 remover.to_remove.insert(unused_locator.unused_variables.begin(), unused_locator.unused_variables.end());
75                 context.content.visit(remover);
76
77                 if(!remover.n_removed)
78                         break;
79         }
80 }
81
82 void ProgramCompiler::inject_block(Block &target, const Block &source)
83 {
84         list<NodePtr<Node> >::iterator insert_point = target.body.begin();
85         for(list<NodePtr<Node> >::const_iterator i=source.body.begin(); i!=source.body.end(); ++i)
86                 target.body.insert(insert_point, (*i)->clone());
87 }
88
89 string ProgramCompiler::format_context(Context &context)
90 {
91         Formatter formatter;
92         context.content.visit(formatter);
93         return formatter.formatted;
94 }
95
96
97 ProgramCompiler::Formatter::Formatter():
98         indent(0),
99         parameter_list(false),
100         else_if(false)
101 { }
102
103 void ProgramCompiler::Formatter::visit(Literal &literal)
104 {
105         formatted += literal.token;
106 }
107
108 void ProgramCompiler::Formatter::visit(ParenthesizedExpression &parexpr)
109 {
110         formatted += '(';
111         parexpr.expression->visit(*this);
112         formatted += ')';
113 }
114
115 void ProgramCompiler::Formatter::visit(VariableReference &var)
116 {
117         formatted += var.name;
118 }
119
120 void ProgramCompiler::Formatter::visit(MemberAccess &memacc)
121 {
122         memacc.left->visit(*this);
123         formatted += format(".%s", memacc.member);
124 }
125
126 void ProgramCompiler::Formatter::visit(UnaryExpression &unary)
127 {
128         if(unary.prefix)
129                 formatted += unary.oper;
130         unary.expression->visit(*this);
131         if(!unary.prefix)
132                 formatted += unary.oper;
133 }
134
135 void ProgramCompiler::Formatter::visit(BinaryExpression &binary)
136 {
137         binary.left->visit(*this);
138         if(binary.assignment)
139                 formatted += format(" %s ", binary.oper);
140         else
141                 formatted += binary.oper;
142         binary.right->visit(*this);
143         formatted += binary.after;
144 }
145
146 void ProgramCompiler::Formatter::visit(FunctionCall &call)
147 {
148         formatted += format("%s(", call.name);
149         for(vector<NodePtr<Expression> >::iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
150         {
151                 if(i!=call.arguments.begin())
152                         formatted += ", ";
153                 (*i)->visit(*this);
154         }
155         formatted += ')';
156 }
157
158 void ProgramCompiler::Formatter::visit(ExpressionStatement &expr)
159 {
160         expr.expression->visit(*this);
161         formatted += ';';
162 }
163
164 void ProgramCompiler::Formatter::visit(Block &block)
165 {
166         if(block.use_braces)
167         {
168                 if(else_if)
169                 {
170                         formatted += '\n';
171                         else_if = false;
172                 }
173                 formatted += format("%s{\n", string(indent*2, ' '));
174         }
175
176         bool change_indent = (!formatted.empty() && !else_if);
177         indent += change_indent;
178         string spaces(indent*2, ' ');
179         for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); ++i)
180         {
181                 if(i!=block.body.begin())
182                         formatted += '\n';
183                 if(!else_if)
184                         formatted += spaces;
185                 (*i)->visit(*this);
186         }
187         indent -= change_indent;
188
189         if(block.use_braces)
190                 formatted += format("\n%s}", string(indent*2, ' '));
191 }
192
193 void ProgramCompiler::Formatter::visit(Layout &layout)
194 {
195         formatted += "layout(";
196         for(vector<Layout::Qualifier>::const_iterator i=layout.qualifiers.begin(); i!=layout.qualifiers.end(); ++i)
197         {
198                 if(i!=layout.qualifiers.begin())
199                         formatted += ", ";
200                 formatted += i->identifier;
201                 if(!i->value.empty())
202                         formatted += format("=%s", i->value);
203         }
204         formatted += format(") %s;", layout.interface);
205 }
206
207 void ProgramCompiler::Formatter::visit(StructDeclaration &strct)
208 {
209         formatted += format("struct %s\n", strct.name);
210         strct.members.visit(*this);
211         formatted += ';';
212 }
213
214 void ProgramCompiler::Formatter::visit(VariableDeclaration &var)
215 {
216         if(var.constant)
217                 formatted += "const ";
218         if(!var.sampling.empty())
219                 formatted += format("%s ", var.sampling);
220         if(!var.interface.empty())
221                 formatted += format("%s ", var.interface);
222         formatted += format("%s %s", var.type, var.name);
223         if(var.array)
224         {
225                 formatted += '[';
226                 if(var.array_size)
227                         var.array_size->visit(*this);
228                 formatted += ']';
229         }
230         if(var.init_expression)
231         {
232                 formatted += " = ";
233                 var.init_expression->visit(*this);
234         }
235         if(!parameter_list)
236                 formatted += ';';
237 }
238
239 void ProgramCompiler::Formatter::visit(InterfaceBlock &iface)
240 {
241         formatted += format("%s %s\n", iface.interface, iface.name);
242         iface.members.visit(*this);
243         formatted += ';';
244 }
245
246 void ProgramCompiler::Formatter::visit(FunctionDeclaration &func)
247 {
248         formatted += format("%s %s(", func.return_type, func.name);
249         for(vector<NodePtr<VariableDeclaration> >::iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
250         {
251                 if(i!=func.parameters.begin())
252                         formatted += ", ";
253                 SetFlag set(parameter_list);
254                 (*i)->visit(*this);
255         }
256         formatted += ')';
257         if(func.definition)
258         {
259                 formatted += '\n';
260                 func.body.visit(*this);
261         }
262         else
263                 formatted += ';';
264 }
265
266 void ProgramCompiler::Formatter::visit(Conditional &cond)
267 {
268         if(else_if)
269         {
270                 formatted += ' ';
271                 else_if = false;
272         }
273
274         formatted += "if(";
275         cond.condition->visit(*this);
276         formatted += ")\n";
277
278         cond.body.visit(*this);
279         if(!cond.else_body.body.empty())
280         {
281                 formatted += format("\n%selse", string(indent*2, ' '));
282                 SetFlag set(else_if);
283                 cond.else_body.visit(*this);
284         }
285 }
286
287 void ProgramCompiler::Formatter::visit(Iteration &iter)
288 {
289         formatted += "for(";
290         iter.init_statement->visit(*this);
291         formatted += ' ';
292         iter.condition->visit(*this);
293         formatted += "; ";
294         iter.loop_expression->visit(*this);
295         formatted += ")\n";
296         iter.body.visit(*this);
297 }
298
299 void ProgramCompiler::Formatter::visit(Return &ret)
300 {
301         formatted += "return ";
302         ret.expression->visit(*this);
303         formatted += ';';
304 }
305
306
307 ProgramCompiler::VariableResolver::VariableResolver():
308         anonymous(false)
309 { }
310
311 void ProgramCompiler::VariableResolver::visit(Block &block)
312 {
313         blocks.push_back(&block);
314         block.variables.clear();
315         TraversingVisitor::visit(block);
316         blocks.pop_back();
317 }
318
319 void ProgramCompiler::VariableResolver::visit(VariableReference &var)
320 {
321         var.declaration = 0;
322         type = 0;
323         for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
324         {
325                 --i;
326                 map<string, VariableDeclaration *>::iterator j = (*i)->variables.find(var.name);
327                 if(j!=(*i)->variables.end())
328                 {
329                         var.declaration = j->second;
330                         type = j->second->type_declaration;
331                         break;
332                 }
333         }
334 }
335
336 void ProgramCompiler::VariableResolver::visit(MemberAccess &memacc)
337 {
338         type = 0;
339         TraversingVisitor::visit(memacc);
340         memacc.declaration = 0;
341         if(type)
342         {
343                 map<string, VariableDeclaration *>::iterator i = type->members.variables.find(memacc.member);
344                 if(i!=type->members.variables.end())
345                 {
346                         memacc.declaration = i->second;
347                         type = i->second->type_declaration;
348                 }
349                 else
350                         type = 0;
351         }
352 }
353
354 void ProgramCompiler::VariableResolver::visit(BinaryExpression &binary)
355 {
356         if(binary.oper=="[")
357         {
358                 binary.right->visit(*this);
359                 type = 0;
360                 binary.left->visit(*this);
361         }
362         else
363         {
364                 TraversingVisitor::visit(binary);
365                 type = 0;
366         }
367 }
368
369 void ProgramCompiler::VariableResolver::visit(StructDeclaration &strct)
370 {
371         TraversingVisitor::visit(strct);
372         blocks.back()->types[strct.name] = &strct;
373 }
374
375 void ProgramCompiler::VariableResolver::visit(VariableDeclaration &var)
376 {
377         for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
378         {
379                 --i;
380                 map<string, StructDeclaration *>::iterator j = (*i)->types.find(var.type);
381                 if(j!=(*i)->types.end())
382                         var.type_declaration = j->second;
383         }
384
385         TraversingVisitor::visit(var);
386         blocks.back()->variables[var.name] = &var;
387         if(anonymous && blocks.size()>1)
388                 blocks[blocks.size()-2]->variables[var.name] = &var;
389 }
390
391 void ProgramCompiler::VariableResolver::visit(InterfaceBlock &iface)
392 {
393         SetFlag set(anonymous);
394         TraversingVisitor::visit(iface);
395 }
396
397
398 void ProgramCompiler::UnusedVariableLocator::visit(VariableReference &var)
399 {
400         unused_variables.erase(var.declaration);
401 }
402
403 void ProgramCompiler::UnusedVariableLocator::visit(MemberAccess &memacc)
404 {
405         TraversingVisitor::visit(memacc);
406         unused_variables.erase(memacc.declaration);
407 }
408
409 void ProgramCompiler::UnusedVariableLocator::visit(VariableDeclaration &var)
410 {
411         unused_variables.insert(&var);
412         TraversingVisitor::visit(var);
413 }
414
415
416 ProgramCompiler::NodeRemover::NodeRemover():
417         n_removed(0),
418         immutable_block(false),
419         remove_block(false)
420 { }
421
422 void ProgramCompiler::NodeRemover::visit(Block &block)
423 {
424         remove_block = immutable_block;
425         for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
426         {
427                 bool remove = false;
428                 if(to_remove.count(&**i))
429                         remove = !immutable_block;
430                 else
431                 {
432                         remove_block = false;
433                         (*i)->visit(*this);
434                         remove = remove_block;
435                 }
436
437                 if(remove)
438                         block.body.erase(i++);
439                 else
440                         ++i;
441
442                 n_removed += remove;
443         }
444 }
445
446 void ProgramCompiler::NodeRemover::visit(StructDeclaration &strct)
447 {
448         SetFlag set(immutable_block);
449         TraversingVisitor::visit(strct);
450 }
451
452 void ProgramCompiler::NodeRemover::visit(InterfaceBlock &iface)
453 {
454         SetFlag set(immutable_block);
455         TraversingVisitor::visit(iface);
456 }
457
458 } // namespace GL
459 } // namespace Msp