]> git.tdb.fi Git - libs/gl.git/blobdiff - source/programcompiler.cpp
Implement an import system
[libs/gl.git] / source / programcompiler.cpp
index a730f74348f072fab4edd1398835b380b4295af4..82a324dec52d8e7c9e92b8d8c56cb15264e423f3 100644 (file)
@@ -4,6 +4,7 @@
 #include "error.h"
 #include "program.h"
 #include "programcompiler.h"
+#include "resources.h"
 #include "shader.h"
 
 using namespace std;
@@ -34,17 +35,20 @@ namespace GL {
 using namespace ProgramSyntax;
 
 ProgramCompiler::ProgramCompiler():
+       resources(0),
        module(0)
 { }
 
 void ProgramCompiler::compile(const string &source)
 {
+       resources = 0;
        module = &parser.parse(source);
        process();
 }
 
-void ProgramCompiler::compile(IO::Base &io)
+void ProgramCompiler::compile(IO::Base &io, Resources *res)
 {
+       resources = res;
        module = &parser.parse(io);
        process();
 }
@@ -58,11 +62,11 @@ void ProgramCompiler::add_shaders(Program &program)
        for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
        {
                if(i->type==VERTEX)
-                       program.attach_shader_owned(new VertexShader(head+create_source(*i)));
+                       program.attach_shader_owned(new VertexShader(head+apply<Formatter>(*i)));
                else if(i->type==GEOMETRY)
-                       program.attach_shader_owned(new GeometryShader(head+create_source(*i)));
+                       program.attach_shader_owned(new GeometryShader(head+apply<Formatter>(*i)));
                else if(i->type==FRAGMENT)
-                       program.attach_shader_owned(new FragmentShader(head+create_source(*i)));
+                       program.attach_shader_owned(new FragmentShader(head+apply<Formatter>(*i)));
        }
 
        program.bind_attribute(VERTEX4, "vertex");
@@ -102,6 +106,11 @@ Stage *ProgramCompiler::get_builtins(StageType type)
 
 void ProgramCompiler::process()
 {
+       list<Import *> imports = apply<NodeGatherer<Import> >(module->shared);
+       for(list<Import *>::iterator i=imports.end(); i!=imports.begin(); )
+               import((*--i)->module);
+       apply<NodeRemover>(module->shared, set<Node *>(imports.begin(), imports.end()));
+
        for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
                generate(*i);
        for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); )
@@ -113,6 +122,39 @@ void ProgramCompiler::process()
        }
 }
 
+void ProgramCompiler::import(const string &name)
+{
+       if(!resources)
+               throw runtime_error("no resources");
+       RefPtr<IO::Seekable> io = resources->open_raw(name+".glsl");
+       if(!io)
+               throw runtime_error(format("module %s not found", name));
+       ProgramParser import_parser;
+       Module &imported_module = import_parser.parse(*io);
+
+       inject_block(module->shared.content, imported_module.shared.content);
+       apply<DeclarationCombiner>(module->shared);
+       for(list<Stage>::iterator i=imported_module.stages.begin(); i!=imported_module.stages.end(); ++i)
+       {
+               list<Stage>::iterator j;
+               for(j=module->stages.begin(); (j!=module->stages.end() && j->type<i->type); ++j) ;
+               if(j==module->stages.end() || j->type>i->type)
+               {
+                       j = module->stages.insert(j, *i);
+                       list<Stage>::iterator k = j;
+                       if(++k!=module->stages.end())
+                               k->previous = &*j;
+                       if(j!=module->stages.begin())
+                               j->previous = &*--(k=j);
+               }
+               else
+               {
+                       inject_block(j->content, i->content);
+                       apply<DeclarationCombiner>(*j);
+               }
+       }
+}
+
 void ProgramCompiler::generate(Stage &stage)
 {
        inject_block(stage.content, module->shared.content);
@@ -125,14 +167,10 @@ void ProgramCompiler::generate(Stage &stage)
 
 bool ProgramCompiler::optimize(Stage &stage)
 {
-       UnusedVariableLocator unused_locator;
-       unused_locator.apply(stage);
-
-       NodeRemover remover;
-       remover.to_remove = unused_locator.unused_nodes;
-       remover.apply(stage);
+       set<Node *> unused = apply<UnusedVariableLocator>(stage);
+       apply<NodeRemover>(stage, unused);
 
-       return !unused_locator.unused_nodes.empty();
+       return !unused.empty();
 }
 
 void ProgramCompiler::inject_block(Block &target, const Block &source)
@@ -143,17 +181,19 @@ void ProgramCompiler::inject_block(Block &target, const Block &source)
 }
 
 template<typename T>
-void ProgramCompiler::apply(Stage &stage)
+typename T::ResultType ProgramCompiler::apply(Stage &stage)
 {
        T visitor;
        visitor.apply(stage);
+       return visitor.get_result();
 }
 
-string ProgramCompiler::create_source(Stage &stage)
+template<typename T, typename A>
+typename T::ResultType ProgramCompiler::apply(Stage &stage, const A &arg)
 {
-       Formatter formatter;
-       formatter.apply(stage);
-       return formatter.formatted;
+       T visitor(arg);
+       visitor.apply(stage);
+       return visitor.get_result();
 }
 
 
@@ -264,6 +304,11 @@ void ProgramCompiler::Formatter::visit(Block &block)
                formatted += format("\n%s}", string(brace_indent*2, ' '));
 }
 
+void ProgramCompiler::Formatter::visit(Import &import)
+{
+       formatted += format("import %s;", import.module);
+}
+
 void ProgramCompiler::Formatter::visit(Layout &layout)
 {
        formatted += "layout(";
@@ -329,7 +374,7 @@ void ProgramCompiler::Formatter::visit(FunctionDeclaration &func)
                (*i)->visit(*this);
        }
        formatted += ')';
-       if(func.definition)
+       if(func.definition==&func)
        {
                formatted += '\n';
                func.body.visit(*this);
@@ -378,6 +423,56 @@ void ProgramCompiler::Formatter::visit(Return &ret)
 }
 
 
+ProgramCompiler::DeclarationCombiner::DeclarationCombiner():
+       toplevel(true),
+       remove_node(false)
+{ }
+
+void ProgramCompiler::DeclarationCombiner::visit(Block &block)
+{
+       if(!toplevel)
+               return;
+
+       SetForScope<bool> set(toplevel, false);
+       for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )
+       {
+               remove_node = false;
+               (*i)->visit(*this);
+               if(remove_node)
+                       block.body.erase(i++);
+               else
+                       ++i;
+       }
+}
+
+void ProgramCompiler::DeclarationCombiner::visit(FunctionDeclaration &func)
+{
+       vector<FunctionDeclaration *> &decls = functions[func.name];
+       if(func.definition)
+       {
+               for(vector<FunctionDeclaration *>::iterator i=decls.begin(); i!=decls.end(); ++i)
+               {
+                       (*i)->definition = func.definition;
+                       (*i)->body.body.clear();
+               }
+       }
+       decls.push_back(&func);
+}
+
+void ProgramCompiler::DeclarationCombiner::visit(VariableDeclaration &var)
+{
+       VariableDeclaration *&ptr = variables[var.name];
+       if(ptr)
+       {
+               if(var.init_expression)
+                       ptr->init_expression = var.init_expression;
+               remove_node = true;
+       }
+       else
+               ptr = &var;
+}
+
+
 ProgramCompiler::VariableResolver::VariableResolver():
        anonymous(false),
        record_target(false),
@@ -944,6 +1039,10 @@ void ProgramCompiler::UnusedVariableLocator::visit(Iteration &iter)
 }
 
 
+ProgramCompiler::NodeRemover::NodeRemover(const set<Node *> &r):
+       to_remove(r)
+{ }
+
 void ProgramCompiler::NodeRemover::visit(Block &block)
 {
        for(list<NodePtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); )