1 #include <msp/core/raii.h>
2 #include <msp/strings/format.h>
3 #include <msp/strings/utils.h>
6 #include "programcompiler.h"
14 using namespace ProgramSyntax;
16 ProgramCompiler::ProgramCompiler():
20 void ProgramCompiler::compile(const string &source)
22 module = &parser.parse(source);
26 void ProgramCompiler::compile(IO::Base &io)
28 module = &parser.parse(io);
32 void ProgramCompiler::add_shaders(Program &program)
35 throw invalid_operation("ProgramCompiler::add_shaders");
37 string head = "#version 150\n";
38 if(module->vertex_context.present)
39 program.attach_shader_owned(new VertexShader(head+format_context(module->vertex_context)));
40 if(module->geometry_context.present)
41 program.attach_shader_owned(new GeometryShader(head+format_context(module->geometry_context)));
42 if(module->fragment_context.present)
43 program.attach_shader_owned(new FragmentShader(head+format_context(module->fragment_context)));
45 program.bind_attribute(VERTEX4, "vertex");
46 program.bind_attribute(NORMAL3, "normal");
47 program.bind_attribute(COLOR4_FLOAT, "color");
48 program.bind_attribute(TEXCOORD4, "texcoord");
51 void ProgramCompiler::process()
53 if(module->vertex_context.present)
54 process(module->vertex_context);
55 if(module->geometry_context.present)
56 process(module->geometry_context);
57 if(module->fragment_context.present)
58 process(module->fragment_context);
61 void ProgramCompiler::process(Context &context)
63 inject_block(context.content, module->global_context.content);
65 VariableResolver resolver;
66 context.content.visit(resolver);
70 UnusedVariableLocator unused_locator;
71 context.content.visit(unused_locator);
74 remover.to_remove.insert(unused_locator.unused_variables.begin(), unused_locator.unused_variables.end());
75 context.content.visit(remover);
77 if(!remover.n_removed)
82 void ProgramCompiler::inject_block(Block &target, const Block &source)
84 list<NodePtr<Node> >::iterator insert_point = target.body.begin();
85 for(list<NodePtr<Node> >::const_iterator i=source.body.begin(); i!=source.body.end(); ++i)
86 target.body.insert(insert_point, (*i)->clone());
89 string ProgramCompiler::format_context(Context &context)
92 context.content.visit(formatter);
93 return formatter.formatted;
97 ProgramCompiler::Formatter::Formatter():
99 parameter_list(false),
103 void ProgramCompiler::Formatter::visit(Literal &literal)
105 formatted += literal.token;
108 void ProgramCompiler::Formatter::visit(ParenthesizedExpression &parexpr)
111 parexpr.expression->visit(*this);
115 void ProgramCompiler::Formatter::visit(VariableReference &var)
117 formatted += var.name;
120 void ProgramCompiler::Formatter::visit(MemberAccess &memacc)
122 memacc.left->visit(*this);
123 formatted += format(".%s", memacc.member);
126 void ProgramCompiler::Formatter::visit(UnaryExpression &unary)
129 formatted += unary.oper;
130 unary.expression->visit(*this);
132 formatted += unary.oper;
135 void ProgramCompiler::Formatter::visit(BinaryExpression &binary)
137 binary.left->visit(*this);
138 if(binary.assignment)
139 formatted += format(" %s ", binary.oper);
141 formatted += binary.oper;
142 binary.right->visit(*this);
143 formatted += binary.after;
146 void ProgramCompiler::Formatter::visit(FunctionCall &call)
148 formatted += format("%s(", call.name);
149 for(vector<NodePtr<Expression> >::iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
151 if(i!=call.arguments.begin())
158 void ProgramCompiler::Formatter::visit(ExpressionStatement &expr)
160 expr.expression->visit(*this);
164 void ProgramCompiler::Formatter::visit(Block &block)
173 formatted += format("%s{\n", string(indent*2, ' '));
176 bool change_indent = (!formatted.empty() && !else_if);
177 indent += change_indent;
178 string spaces(indent*2, ' ');
179 for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); ++i)
181 if(i!=block.body.begin())
187 indent -= change_indent;
190 formatted += format("\n%s}", string(indent*2, ' '));
193 void ProgramCompiler::Formatter::visit(Layout &layout)
195 formatted += "layout(";
196 for(vector<Layout::Qualifier>::const_iterator i=layout.qualifiers.begin(); i!=layout.qualifiers.end(); ++i)
198 if(i!=layout.qualifiers.begin())
200 formatted += i->identifier;
201 if(!i->value.empty())
202 formatted += format("=%s", i->value);
204 formatted += format(") %s;", layout.interface);
207 void ProgramCompiler::Formatter::visit(StructDeclaration &strct)
209 formatted += format("struct %s\n", strct.name);
210 strct.members.visit(*this);
214 void ProgramCompiler::Formatter::visit(VariableDeclaration &var)
217 formatted += "const ";
218 if(!var.sampling.empty())
219 formatted += format("%s ", var.sampling);
220 if(!var.interface.empty())
221 formatted += format("%s ", var.interface);
222 formatted += format("%s %s", var.type, var.name);
227 var.array_size->visit(*this);
230 if(var.init_expression)
233 var.init_expression->visit(*this);
239 void ProgramCompiler::Formatter::visit(InterfaceBlock &iface)
241 formatted += format("%s %s\n", iface.interface, iface.name);
242 iface.members.visit(*this);
246 void ProgramCompiler::Formatter::visit(FunctionDeclaration &func)
248 formatted += format("%s %s(", func.return_type, func.name);
249 for(vector<NodePtr<VariableDeclaration> >::iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
251 if(i!=func.parameters.begin())
253 SetFlag set(parameter_list);
260 func.body.visit(*this);
266 void ProgramCompiler::Formatter::visit(Conditional &cond)
275 cond.condition->visit(*this);
278 cond.body.visit(*this);
279 if(!cond.else_body.body.empty())
281 formatted += format("\n%selse", string(indent*2, ' '));
282 SetFlag set(else_if);
283 cond.else_body.visit(*this);
287 void ProgramCompiler::Formatter::visit(Iteration &iter)
290 iter.init_statement->visit(*this);
292 iter.condition->visit(*this);
294 iter.loop_expression->visit(*this);
296 iter.body.visit(*this);
299 void ProgramCompiler::Formatter::visit(Return &ret)
301 formatted += "return ";
302 ret.expression->visit(*this);
307 ProgramCompiler::VariableResolver::VariableResolver():
311 void ProgramCompiler::VariableResolver::visit(Block &block)
313 blocks.push_back(&block);
314 block.variables.clear();
315 TraversingVisitor::visit(block);
319 void ProgramCompiler::VariableResolver::visit(VariableReference &var)
323 for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
326 map<string, VariableDeclaration *>::iterator j = (*i)->variables.find(var.name);
327 if(j!=(*i)->variables.end())
329 var.declaration = j->second;
330 type = j->second->type_declaration;
336 void ProgramCompiler::VariableResolver::visit(MemberAccess &memacc)
339 TraversingVisitor::visit(memacc);
340 memacc.declaration = 0;
343 map<string, VariableDeclaration *>::iterator i = type->members.variables.find(memacc.member);
344 if(i!=type->members.variables.end())
346 memacc.declaration = i->second;
347 type = i->second->type_declaration;
354 void ProgramCompiler::VariableResolver::visit(BinaryExpression &binary)
358 binary.right->visit(*this);
360 binary.left->visit(*this);
364 TraversingVisitor::visit(binary);
369 void ProgramCompiler::VariableResolver::visit(StructDeclaration &strct)
371 TraversingVisitor::visit(strct);
372 blocks.back()->types[strct.name] = &strct;
375 void ProgramCompiler::VariableResolver::visit(VariableDeclaration &var)
377 for(vector<Block *>::iterator i=blocks.end(); i!=blocks.begin(); )
380 map<string, StructDeclaration *>::iterator j = (*i)->types.find(var.type);
381 if(j!=(*i)->types.end())
382 var.type_declaration = j->second;
385 TraversingVisitor::visit(var);
386 blocks.back()->variables[var.name] = &var;
387 if(anonymous && blocks.size()>1)
388 blocks[blocks.size()-2]->variables[var.name] = &var;
391 void ProgramCompiler::VariableResolver::visit(InterfaceBlock &iface)
393 SetFlag set(anonymous);
394 TraversingVisitor::visit(iface);
398 void ProgramCompiler::UnusedVariableLocator::visit(VariableReference &var)
400 unused_variables.erase(var.declaration);
403 void ProgramCompiler::UnusedVariableLocator::visit(MemberAccess &memacc)
405 TraversingVisitor::visit(memacc);
406 unused_variables.erase(memacc.declaration);
409 void ProgramCompiler::UnusedVariableLocator::visit(VariableDeclaration &var)
411 unused_variables.insert(&var);
412 TraversingVisitor::visit(var);
416 ProgramCompiler::NodeRemover::NodeRemover():
418 immutable_block(false),
422 void ProgramCompiler::NodeRemover::visit(Block &block)
424 remove_block = immutable_block;
425 for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
428 if(to_remove.count(&**i))
429 remove = !immutable_block;
432 remove_block = false;
434 remove = remove_block;
438 block.body.erase(i++);
446 void ProgramCompiler::NodeRemover::visit(StructDeclaration &strct)
448 SetFlag set(immutable_block);
449 TraversingVisitor::visit(strct);
452 void ProgramCompiler::NodeRemover::visit(InterfaceBlock &iface)
454 SetFlag set(immutable_block);
455 TraversingVisitor::visit(iface);