]> git.tdb.fi Git - libs/gl.git/commitdiff
Implement constant folding in the GLSL compiler
authorMikko Rasa <tdb@tdb.fi>
Mon, 15 Mar 2021 09:05:12 +0000 (11:05 +0200)
committerMikko Rasa <tdb@tdb.fi>
Mon, 15 Mar 2021 09:15:21 +0000 (11:15 +0200)
This replaces the old expression evaluator with a more comprehensive
solution.  Folding constant expressions may open up further possibilities
for inlining.

source/glsl/compiler.cpp
source/glsl/evaluate.cpp [deleted file]
source/glsl/evaluate.h [deleted file]
source/glsl/optimize.cpp
source/glsl/optimize.h
tests/glsl/complex_constant_condition_removal.glsl [new file with mode: 0644]
tests/glsl/constant_last_argument.glsl [new file with mode: 0644]
tests/glsl/dead_loop_removal.glsl [new file with mode: 0644]

index b4b63e27637c32231341a2631383c0b972be8019..e9933e2f9b67c4ec414999438a9e8c1c0ec60ba9 100644 (file)
@@ -318,6 +318,8 @@ bool Compiler::diagnostic_line_order(const Diagnostic &diag1, const Diagnostic &
 
 Compiler::OptimizeResult Compiler::optimize(Stage &stage)
 {
+       if(ConstantFolder().apply(stage))
+               resolve(stage, RESOLVE_EXPRESSIONS);
        ConstantConditionEliminator().apply(stage);
 
        bool any_inlined = false;
diff --git a/source/glsl/evaluate.cpp b/source/glsl/evaluate.cpp
deleted file mode 100644 (file)
index 2b3f50a..0000000
+++ /dev/null
@@ -1,95 +0,0 @@
-#include <msp/strings/lexicalcast.h>
-#include "evaluate.h"
-
-namespace Msp {
-namespace GL {
-namespace SL {
-
-ExpressionEvaluator::ExpressionEvaluator():
-       variable_values(0),
-       r_result(0.0f),
-       r_result_valid(false)
-{ }
-
-ExpressionEvaluator::ExpressionEvaluator(const ValueMap &v):
-       variable_values(&v),
-       r_result(0.0f),
-       r_result_valid(false)
-{ }
-
-void ExpressionEvaluator::visit(Literal &literal)
-{
-       if(literal.token=="true")
-               r_result = 1.0f;
-       else if(literal.token=="false")
-               r_result = 0.0f;
-       else
-               r_result = lexical_cast<float>(literal.token);
-       r_result_valid = true;
-}
-
-void ExpressionEvaluator::visit(VariableReference &var)
-{
-       if(!var.declaration)
-               return;
-
-       if(variable_values)
-       {
-               ValueMap::const_iterator i = variable_values->find(var.declaration);
-               if(i!=variable_values->end())
-                       i->second->visit(*this);
-       }
-       else if(var.declaration->init_expression)
-               var.declaration->init_expression->visit(*this);
-}
-
-void ExpressionEvaluator::visit(UnaryExpression &unary)
-{
-       r_result_valid = false;
-       unary.expression->visit(*this);
-       if(!r_result_valid)
-               return;
-
-       if(unary.oper->token[0]=='!')
-               r_result = !r_result;
-       else
-               r_result_valid = false;
-}
-
-void ExpressionEvaluator::visit(BinaryExpression &binary)
-{
-       r_result_valid = false;
-       binary.left->visit(*this);
-       if(!r_result_valid)
-               return;
-
-       float left_result = r_result;
-       r_result_valid = false;
-       binary.right->visit(*this);
-       if(!r_result_valid)
-               return;
-
-       std::string oper = binary.oper->token;
-       if(oper=="<")
-               r_result = (left_result<r_result);
-       else if(oper=="<=")
-               r_result = (left_result<=r_result);
-       else if(oper==">")
-               r_result = (left_result>r_result);
-       else if(oper==">=")
-               r_result = (left_result>=r_result);
-       else if(oper=="==")
-               r_result = (left_result==r_result);
-       else if(oper=="!=")
-               r_result = (left_result!=r_result);
-       else if(oper=="&&")
-               r_result = (left_result && r_result);
-       else if(oper=="||")
-               r_result = (left_result || r_result);
-       else
-               r_result_valid = false;
-}
-
-} // namespace SL
-} // namespace GL
-} // namespace Msp
diff --git a/source/glsl/evaluate.h b/source/glsl/evaluate.h
deleted file mode 100644 (file)
index 4ad2677..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-#ifndef MSP_GL_SL_EVALUATE_H_
-#define MSP_GL_SL_EVALUATE_H_
-
-#include "visitor.h"
-
-namespace Msp {
-namespace GL {
-namespace SL {
-
-/** Evaluates an expression.  Only expressions consisting entirely of compile-
-time constants can be evaluated. */
-class ExpressionEvaluator: public NodeVisitor
-{
-public:
-       typedef std::map<VariableDeclaration *, Expression *> ValueMap;
-
-private:
-       const ValueMap *variable_values;
-       float r_result;
-       bool r_result_valid;
-
-public:
-       ExpressionEvaluator();
-       ExpressionEvaluator(const ValueMap &);
-
-       float get_result() const { return r_result; }
-       bool is_result_valid() const { return r_result_valid; }
-
-       using NodeVisitor::visit;
-       virtual void visit(Literal &);
-       virtual void visit(VariableReference &);
-       virtual void visit(UnaryExpression &);
-       virtual void visit(BinaryExpression &);
-};
-
-} // namespace SL
-} // namespace GL
-} // namespace Msp
-
-#endif
index 2a9bfb17273aa743c7dd7af8608065f7b66058a0..9cfad86a69d70069fcf6abf2e5877bb99f167000 100644 (file)
@@ -487,12 +487,301 @@ void ExpressionInliner::visit(Iteration &iter)
 }
 
 
+BasicTypeDeclaration::Kind ConstantFolder::get_value_kind(const Variant &value)
+{
+       if(value.check_type<bool>())
+               return BasicTypeDeclaration::BOOL;
+       else if(value.check_type<int>())
+               return BasicTypeDeclaration::INT;
+       else if(value.check_type<float>())
+               return BasicTypeDeclaration::FLOAT;
+       else
+               return BasicTypeDeclaration::VOID;
+}
+
+template<typename T>
+T ConstantFolder::evaluate_logical(char oper, T left, T right)
+{
+       switch(oper)
+       {
+       case '&': return left&right;
+       case '|': return left|right;
+       case '^': return left^right;
+       default: return T();
+       }
+}
+
+template<typename T>
+bool ConstantFolder::evaluate_relation(const char *oper, T left, T right)
+{
+       switch(oper[0]|oper[1])
+       {
+       case '<': return left<right;
+       case '<'|'=': return left<=right;
+       case '>': return left>right;
+       case '>'|'=': return left>=right;
+       default: return false;
+       }
+}
+
+template<typename T>
+T ConstantFolder::evaluate_arithmetic(char oper, T left, T right)
+{
+       switch(oper)
+       {
+       case '+': return left+right;
+       case '-': return left-right;
+       case '*': return left*right;
+       case '/': return left/right;
+       default: return T();
+       }
+}
+
+void ConstantFolder::set_result(const Variant &value, bool literal)
+{
+       r_constant_value = value;
+       r_constant = true;
+       r_literal = literal;
+}
+
+void ConstantFolder::visit(RefPtr<Expression> &expr)
+{
+       r_constant_value = Variant();
+       r_constant = false;
+       r_literal = false;
+       r_uses_iter_var = false;
+       expr->visit(*this);
+       /* Don't replace literals since they'd only be replaced with an identical
+       literal.  Also skip anything that uses an iteration variable, but pass on
+       the result so the Iteration visiting function can handle it. */
+       if(!r_constant || r_literal || r_uses_iter_var)
+               return;
+
+       BasicTypeDeclaration::Kind kind = get_value_kind(r_constant_value);
+       if(kind==BasicTypeDeclaration::VOID)
+       {
+               r_constant = false;
+               return;
+       }
+
+       RefPtr<Literal> literal = new Literal;
+       if(kind==BasicTypeDeclaration::BOOL)
+               literal->token = (r_constant_value.value<bool>() ? "true" : "false");
+       else if(kind==BasicTypeDeclaration::INT)
+               literal->token = lexical_cast<string>(r_constant_value.value<int>());
+       else if(kind==BasicTypeDeclaration::FLOAT)
+               literal->token = lexical_cast<string>(r_constant_value.value<float>());
+       literal->value = r_constant_value;
+       expr = literal;
+}
+
+void ConstantFolder::visit(Literal &literal)
+{
+       set_result(literal.value, true);
+}
+
+void ConstantFolder::visit(VariableReference &var)
+{
+       /* If an iteration variable is initialized with a constant value, return
+       that value here for the purpose of evaluating the loop condition for the
+       first iteration. */
+       if(var.declaration==iteration_var)
+       {
+               set_result(iter_init_value);
+               r_uses_iter_var = true;
+       }
+}
+
+void ConstantFolder::visit(MemberAccess &memacc)
+{
+       TraversingVisitor::visit(memacc);
+       r_constant = false;
+}
+
+void ConstantFolder::visit(Swizzle &swizzle)
+{
+       TraversingVisitor::visit(swizzle);
+       r_constant = false;
+}
+
+void ConstantFolder::visit(UnaryExpression &unary)
+{
+       TraversingVisitor::visit(unary);
+       bool can_fold = r_constant;
+       r_constant = false;
+       if(!can_fold)
+               return;
+
+       BasicTypeDeclaration::Kind kind = get_value_kind(r_constant_value);
+
+       char oper = unary.oper->token[0];
+       char oper2 = unary.oper->token[1];
+       if(oper=='!')
+       {
+               if(kind==BasicTypeDeclaration::BOOL)
+                       set_result(!r_constant_value.value<bool>());
+       }
+       else if(oper=='~')
+       {
+               if(kind==BasicTypeDeclaration::INT)
+                       set_result(~r_constant_value.value<int>());
+       }
+       else if(oper=='-' && !oper2)
+       {
+               if(kind==BasicTypeDeclaration::INT)
+                       set_result(-r_constant_value.value<int>());
+               else if(kind==BasicTypeDeclaration::FLOAT)
+                       set_result(-r_constant_value.value<float>());
+       }
+}
+
+void ConstantFolder::visit(BinaryExpression &binary)
+{
+       visit(binary.left);
+       bool left_constant = r_constant;
+       bool left_iter_var = r_uses_iter_var;
+       Variant left_value = r_constant_value;
+       visit(binary.right);
+       if(left_iter_var)
+               r_uses_iter_var = true;
+
+       bool can_fold = (left_constant && r_constant);
+       r_constant = false;
+       if(!can_fold)
+               return;
+
+       BasicTypeDeclaration::Kind left_kind = get_value_kind(left_value);
+       BasicTypeDeclaration::Kind right_kind = get_value_kind(r_constant_value);
+       // Currently only expressions with both sides of equal types are handled.
+       if(left_kind!=right_kind)
+               return;
+
+       char oper = binary.oper->token[0];
+       char oper2 = binary.oper->token[1];
+       if(oper=='&' || oper=='|' || oper=='^')
+       {
+               if(oper2==oper && left_kind==BasicTypeDeclaration::BOOL)
+                       set_result(evaluate_logical(oper, left_value.value<bool>(), r_constant_value.value<bool>()));
+               else if(!oper2 && left_kind==BasicTypeDeclaration::INT)
+                       set_result(evaluate_logical(oper, left_value.value<int>(), r_constant_value.value<int>()));
+       }
+       else if((oper=='<' || oper=='>') && oper2!=oper)
+       {
+               if(left_kind==BasicTypeDeclaration::INT)
+                       set_result(evaluate_relation(binary.oper->token, left_value.value<int>(), r_constant_value.value<int>()));
+               else if(left_kind==BasicTypeDeclaration::FLOAT)
+                       set_result(evaluate_relation(binary.oper->token, left_value.value<float>(), r_constant_value.value<float>()));
+       }
+       else if((oper=='=' || oper=='!') && oper2=='=')
+       {
+               if(left_kind==BasicTypeDeclaration::INT)
+                       set_result((left_value.value<int>()==r_constant_value.value<int>()) == (oper=='='));
+               if(left_kind==BasicTypeDeclaration::FLOAT)
+                       set_result((left_value.value<float>()==r_constant_value.value<float>()) == (oper=='='));
+       }
+       else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
+       {
+               if(left_kind==BasicTypeDeclaration::INT)
+                       set_result(evaluate_arithmetic(oper, left_value.value<int>(), r_constant_value.value<int>()));
+               else if(left_kind==BasicTypeDeclaration::FLOAT)
+                       set_result(evaluate_arithmetic(oper, left_value.value<float>(), r_constant_value.value<float>()));
+       }
+       else if(oper=='%' || ((oper=='<' || oper=='>') && oper2==oper))
+       {
+               if(left_kind!=BasicTypeDeclaration::INT)
+                       return;
+
+               if(oper=='%')
+                       set_result(left_value.value<int>()%r_constant_value.value<int>());
+               else if(oper=='<')
+                       set_result(left_value.value<int>()<<r_constant_value.value<int>());
+               else if(oper=='>')
+                       set_result(left_value.value<int>()>>r_constant_value.value<int>());
+       }
+}
+
+void ConstantFolder::visit(Assignment &assign)
+{
+       TraversingVisitor::visit(assign);
+       r_constant = false;
+}
+
+void ConstantFolder::visit(TernaryExpression &ternary)
+{
+       TraversingVisitor::visit(ternary);
+       r_constant = false;
+}
+
+void ConstantFolder::visit(FunctionCall &call)
+{
+       TraversingVisitor::visit(call);
+       r_constant = false;
+}
+
+void ConstantFolder::visit(VariableDeclaration &var)
+{
+       if(iteration_init && var.init_expression)
+       {
+               visit(var.init_expression);
+               if(r_constant)
+               {
+                       /* Record the value of a constant initialization expression of an
+                       iteration, so it can be used to evaluate the loop condition. */
+                       iteration_var = &var;
+                       iter_init_value = r_constant_value;
+               }
+       }
+       else
+               TraversingVisitor::visit(var);
+}
+
+void ConstantFolder::visit(Iteration &iter)
+{
+       SetForScope<Block *> set_block(current_block, &iter.body);
+
+       /* The iteration variable is not normally inlined into expressions, so we
+       process it specially here.  If the initial value causes the loop condition
+       to evaluate to false, then the expression can be folded. */
+       iteration_var = 0;
+       if(iter.init_statement)
+       {
+               SetFlag set_init(iteration_init);
+               iter.init_statement->visit(*this);
+       }
+
+       if(iter.condition)
+       {
+               visit(iter.condition);
+               if(r_constant && r_constant_value.check_type<bool>() && !r_constant_value.value<bool>())
+               {
+                       RefPtr<Literal> literal = new Literal;
+                       literal->token = "false";
+                       literal->value = r_constant_value;
+                       iter.condition = literal;
+               }
+       }
+       iteration_var = 0;
+
+       iter.body.visit(*this);
+       if(iter.loop_expression)
+               visit(iter.loop_expression);
+}
+
+
 void ConstantConditionEliminator::apply(Stage &stage)
 {
        stage.content.visit(*this);
        NodeRemover().apply(stage, nodes_to_remove);
 }
 
+ConstantConditionEliminator::ConstantStatus ConstantConditionEliminator::check_constant_condition(const Expression &expr)
+{
+       if(const Literal *literal = dynamic_cast<const Literal *>(&expr))
+               if(literal->value.check_type<bool>())
+                       return (literal->value.value<bool>() ? CONSTANT_TRUE : CONSTANT_FALSE);
+       return NOT_CONSTANT;
+}
+
 void ConstantConditionEliminator::visit(Block &block)
 {
        SetForScope<Block *> set_block(current_block, &block);
@@ -505,14 +794,15 @@ void ConstantConditionEliminator::visit(Block &block)
 
 void ConstantConditionEliminator::visit(Conditional &cond)
 {
-       if(Literal *literal = dynamic_cast<Literal *>(cond.condition.get()))
-               if(literal->value.check_type<bool>())
-               {
-                       Block &block = (literal->value.value<bool>() ? cond.body : cond.else_body);
-                       current_block->body.splice(insert_point, block.body);
-                       nodes_to_remove.insert(&cond);
-                       return;
-               }
+       ConstantStatus result = check_constant_condition(*cond.condition);
+       if(result!=NOT_CONSTANT)
+       {
+               Block &block = (result==CONSTANT_TRUE ? cond.body : cond.else_body);
+               // TODO should check variable names for conflicts.  Potentially reuse InlineContentInjector?
+               current_block->body.splice(insert_point, block.body);
+               nodes_to_remove.insert(&cond);
+               return;
+       }
 
        TraversingVisitor::visit(cond);
 }
@@ -521,14 +811,8 @@ void ConstantConditionEliminator::visit(Iteration &iter)
 {
        if(iter.condition)
        {
-               /* If the loop condition is always false on the first iteration, the
-               entire loop can be removed */
-               ExpressionEvaluator::ValueMap values;
-               if(VariableDeclaration *var = dynamic_cast<VariableDeclaration *>(iter.init_statement.get()))
-                       values[var] = var->init_expression.get();
-               ExpressionEvaluator eval(values);
-               iter.condition->visit(eval);
-               if(eval.is_result_valid() && !eval.get_result())
+               ConstantStatus result = check_constant_condition(*iter.condition);
+               if(result==CONSTANT_FALSE)
                {
                        nodes_to_remove.insert(&iter);
                        return;
index 8a888e15bec7c11e019b172570e758a06f46963b..1262823f53ead86c058734fc0d62841b70f7dea3 100644 (file)
@@ -3,7 +3,6 @@
 
 #include <map>
 #include <set>
-#include "evaluate.h"
 #include "visitor.h"
 
 namespace Msp {
@@ -137,11 +136,59 @@ private:
        virtual void visit(Iteration &);
 };
 
+/** Replaces expressions consisting entirely of literals with the results of
+evaluating the expression.*/
+class ConstantFolder: private TraversingVisitor
+{
+private:
+       VariableDeclaration *iteration_var;
+       Variant iter_init_value;
+       Variant r_constant_value;
+       bool iteration_init;
+       bool r_constant;
+       bool r_literal;
+       bool r_uses_iter_var;
+       bool r_any_folded;
+
+public:
+       bool apply(Stage &s) { s.content.visit(*this); return r_any_folded; }
+
+private:
+       static BasicTypeDeclaration::Kind get_value_kind(const Variant &);
+       template<typename T>
+       static T evaluate_logical(char, T, T);
+       template<typename T>
+       static bool evaluate_relation(const char *, T, T);
+       template<typename T>
+       static T evaluate_arithmetic(char, T, T);
+       void set_result(const Variant &, bool = false);
+
+       virtual void visit(RefPtr<Expression> &);
+       virtual void visit(Literal &);
+       virtual void visit(VariableReference &);
+       virtual void visit(MemberAccess &);
+       virtual void visit(Swizzle &);
+       virtual void visit(UnaryExpression &);
+       virtual void visit(BinaryExpression &);
+       virtual void visit(Assignment &);
+       virtual void visit(TernaryExpression &);
+       virtual void visit(FunctionCall &);
+       virtual void visit(VariableDeclaration &);
+       virtual void visit(Iteration &);
+};
+
 /** Removes conditional statements and loops where the condition can be
 determined as constant at compile time. */
 class ConstantConditionEliminator: private TraversingVisitor
 {
 private:
+       enum ConstantStatus
+       {
+               CONSTANT_FALSE,
+               CONSTANT_TRUE,
+               NOT_CONSTANT
+       };
+
        NodeList<Statement>::iterator insert_point;
        std::set<Node *> nodes_to_remove;
 
@@ -149,6 +196,8 @@ public:
        void apply(Stage &);
 
 private:
+       ConstantStatus check_constant_condition(const Expression &);
+
        virtual void visit(Block &);
        virtual void visit(Conditional &);
        virtual void visit(Iteration &);
diff --git a/tests/glsl/complex_constant_condition_removal.glsl b/tests/glsl/complex_constant_condition_removal.glsl
new file mode 100644 (file)
index 0000000..8a9d96a
--- /dev/null
@@ -0,0 +1,42 @@
+const int lod = 1;
+const int bias = 1;
+const int threshold = 3;
+
+#pragma MSP stage(vertex)
+layout(location=0) in vec4 position;
+layout(location=1) in vec4 color;
+void main()
+{
+       gl_Position = position;
+       passthrough;
+}
+
+#pragma MSP stage(fragment)
+layout(location=0) out vec4 frag_color;
+void main()
+{
+       if(lod+bias<threshold)
+               frag_color = color;
+       else
+               frag_color = vec4(1.0);
+}
+
+/* Expected output: vertex
+layout(location=0) in vec4 position;
+layout(location=1) in vec4 color;
+out vec4 _vs_out_color;
+void main()
+{
+       gl_Position = position;
+       _vs_out_color = color;
+}
+*/
+
+/* Expected output: fragment
+layout(location=0) out vec4 frag_color;
+in vec4 _vs_out_color;
+void main()
+{
+       frag_color = _vs_out_color;
+}
+*/
diff --git a/tests/glsl/constant_last_argument.glsl b/tests/glsl/constant_last_argument.glsl
new file mode 100644 (file)
index 0000000..b73f182
--- /dev/null
@@ -0,0 +1,40 @@
+uniform sampler2D tex;
+uniform mat4 mvp;
+
+#pragma MSP stage(vertex)
+layout(location=0) in vec4 position;
+layout(location=1) in vec2 texcoord;
+void main()
+{
+       gl_Position = mvp*position;
+       passthrough;
+}
+
+#pragma MSP stage(fragment)
+layout(location=0) out vec4 frag_color;
+void main()
+{
+       frag_color = textureLod(tex, texcoord, 0.0)*0.8;
+}
+
+/* Expected output: vertex
+uniform mat4 mvp;
+layout(location=0) in vec4 position;
+layout(location=1) in vec2 texcoord;
+out vec2 _vs_out_texcoord;
+void main()
+{
+       gl_Position = mvp*position;
+       _vs_out_texcoord = texcoord;
+}
+*/
+
+/* Expected output: fragment
+uniform sampler2D tex;
+layout(location=0) out vec4 frag_color;
+in vec2 _vs_out_texcoord;
+void main()
+{
+       frag_color = textureLod(tex, _vs_out_texcoord, 0.0)*0.8;
+}
+*/
diff --git a/tests/glsl/dead_loop_removal.glsl b/tests/glsl/dead_loop_removal.glsl
new file mode 100644 (file)
index 0000000..7cfba08
--- /dev/null
@@ -0,0 +1,48 @@
+const int n_lights = 0;
+struct LightParams
+{
+       vec3 direction;
+       vec3 color;
+};
+uniform LightParams lights[n_lights];
+uniform vec3 ambient;
+uniform mat4 model_matrix;
+uniform mat4 vp_matrix;
+
+#pragma MSP stage(vertex)
+layout(location=0) in vec4 position;
+layout(location=1) in vec3 normal;
+void main()
+{
+       out vec3 world_normal = mat3(model_matrix)*normal;
+       gl_Position = vp_matrix*model_matrix*position;
+}
+
+#pragma MSP stage(fragment)
+layout(location=0) out vec4 frag_color;
+void main()
+{
+       vec3 color = ambient;
+       for(int i=0; i<n_lights; ++i)
+               color += max(dot(normalize(world_normal), lights[i].direction), 0.0)*lights[i].color;
+       frag_color = vec4(color, 1.0);
+}
+
+/* Expected output: vertex
+uniform mat4 model_matrix;
+uniform mat4 vp_matrix;
+layout(location=0) in vec4 position;
+void main()
+{
+  gl_Position = vp_matrix*model_matrix*position;
+}
+*/
+
+/* Expected output: fragment
+uniform vec3 ambient;
+layout(location=0) out vec4 frag_color;
+void main()
+{
+  frag_color = vec4(ambient, 1.0);
+}
+*/