]> git.tdb.fi Git - libs/gl.git/blobdiff - source/glsl/optimize.cpp
Handle all constructs when inlining GLSL functions
[libs/gl.git] / source / glsl / optimize.cpp
index d26c6e2d519179ac531dd2a635c85e605d98e9e8..b7236feb6fafb4f7e0495b3b6e8260329f91b29b 100644 (file)
@@ -1,4 +1,5 @@
 #include <msp/core/raii.h>
+#include <msp/strings/format.h>
 #include "optimize.h"
 
 using namespace std;
@@ -40,18 +41,21 @@ void InlineableFunctionLocator::visit(FunctionDeclaration &func)
        TraversingVisitor::visit(func);
 }
 
-void InlineableFunctionLocator::visit(Conditional &)
+void InlineableFunctionLocator::visit(Conditional &cond)
 {
+       TraversingVisitor::visit(cond);
        inlineable.erase(current_function);
 }
 
-void InlineableFunctionLocator::visit(Iteration &)
+void InlineableFunctionLocator::visit(Iteration &iter)
 {
+       TraversingVisitor::visit(iter);
        inlineable.erase(current_function);
 }
 
-void InlineableFunctionLocator::visit(Return &)
+void InlineableFunctionLocator::visit(Return &ret)
 {
+       TraversingVisitor::visit(ret);
        if(return_count)
                inlineable.erase(current_function);
        ++return_count;
@@ -198,6 +202,7 @@ void FunctionInliner::visit_and_inline(RefPtr<Expression> &ptr)
 
 void FunctionInliner::visit(Block &block)
 {
+       SetForScope<Block *> set_block(current_block, &block);
        SetForScope<NodeList<Statement>::iterator> save_insert_point(insert_point, block.body.begin());
        for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
        {
@@ -272,6 +277,22 @@ void FunctionInliner::visit(FunctionDeclaration &func)
        TraversingVisitor::visit(func);
 }
 
+void FunctionInliner::visit(Conditional &cond)
+{
+       visit_and_inline(cond.condition);
+       cond.body.visit(*this);
+}
+
+void FunctionInliner::visit(Iteration &iter)
+{
+       SetForScope<Block *> set_block(current_block, &iter.body);
+       if(iter.init_statement)
+               iter.init_statement->visit(*this);
+       /* Skip the condition and loop expression parts because they're executed on
+       every iteration of the loop */
+       iter.body.visit(*this);
+}
+
 void FunctionInliner::visit(Return &ret)
 {
        if(ret.expression)
@@ -305,7 +326,7 @@ void ConstantConditionEliminator::visit(Block &block)
 void ConstantConditionEliminator::visit(UnaryExpression &unary)
 {
        if(VariableReference *var = dynamic_cast<VariableReference *>(unary.expression.get()))
-               if(unary.oper=="++" || unary.oper=="--")
+               if(unary.oper->token[1]=='+' || unary.oper->token[1]=='-')
                        variable_values.erase(var->declaration);
 }
 
@@ -441,13 +462,13 @@ void UnusedVariableRemover::visit(MemberAccess &memacc)
 void UnusedVariableRemover::visit(UnaryExpression &unary)
 {
        TraversingVisitor::visit(unary);
-       if(unary.oper=="++" || unary.oper=="--")
+       if(unary.oper->token[1]=='+' || unary.oper->token[1]=='-')
                side_effects = true;
 }
 
 void UnusedVariableRemover::visit(BinaryExpression &binary)
 {
-       if(binary.oper=="[")
+       if(binary.oper->token[0]=='[')
        {
                if(assignment_target)
                        assign_to_subscript = true;