From: Mikko Rasa Date: Mon, 15 Mar 2021 09:05:12 +0000 (+0200) Subject: Implement constant folding in the GLSL compiler X-Git-Url: http://git.tdb.fi/?a=commitdiff_plain;h=5e4204e;p=libs%2Fgl.git Implement constant folding in the GLSL compiler This replaces the old expression evaluator with a more comprehensive solution. Folding constant expressions may open up further possibilities for inlining. --- diff --git a/source/glsl/compiler.cpp b/source/glsl/compiler.cpp index b4b63e27..e9933e2f 100644 --- a/source/glsl/compiler.cpp +++ b/source/glsl/compiler.cpp @@ -318,6 +318,8 @@ bool Compiler::diagnostic_line_order(const Diagnostic &diag1, const Diagnostic & Compiler::OptimizeResult Compiler::optimize(Stage &stage) { + if(ConstantFolder().apply(stage)) + resolve(stage, RESOLVE_EXPRESSIONS); ConstantConditionEliminator().apply(stage); bool any_inlined = false; diff --git a/source/glsl/evaluate.cpp b/source/glsl/evaluate.cpp deleted file mode 100644 index 2b3f50a6..00000000 --- a/source/glsl/evaluate.cpp +++ /dev/null @@ -1,95 +0,0 @@ -#include -#include "evaluate.h" - -namespace Msp { -namespace GL { -namespace SL { - -ExpressionEvaluator::ExpressionEvaluator(): - variable_values(0), - r_result(0.0f), - r_result_valid(false) -{ } - -ExpressionEvaluator::ExpressionEvaluator(const ValueMap &v): - variable_values(&v), - r_result(0.0f), - r_result_valid(false) -{ } - -void ExpressionEvaluator::visit(Literal &literal) -{ - if(literal.token=="true") - r_result = 1.0f; - else if(literal.token=="false") - r_result = 0.0f; - else - r_result = lexical_cast(literal.token); - r_result_valid = true; -} - -void ExpressionEvaluator::visit(VariableReference &var) -{ - if(!var.declaration) - return; - - if(variable_values) - { - ValueMap::const_iterator i = variable_values->find(var.declaration); - if(i!=variable_values->end()) - i->second->visit(*this); - } - else if(var.declaration->init_expression) - var.declaration->init_expression->visit(*this); -} - -void ExpressionEvaluator::visit(UnaryExpression &unary) -{ - r_result_valid = false; - unary.expression->visit(*this); - if(!r_result_valid) - return; - - if(unary.oper->token[0]=='!') - r_result = !r_result; - else - r_result_valid = false; -} - -void ExpressionEvaluator::visit(BinaryExpression &binary) -{ - r_result_valid = false; - binary.left->visit(*this); - if(!r_result_valid) - return; - - float left_result = r_result; - r_result_valid = false; - binary.right->visit(*this); - if(!r_result_valid) - return; - - std::string oper = binary.oper->token; - if(oper=="<") - r_result = (left_result") - r_result = (left_result>r_result); - else if(oper==">=") - r_result = (left_result>=r_result); - else if(oper=="==") - r_result = (left_result==r_result); - else if(oper=="!=") - r_result = (left_result!=r_result); - else if(oper=="&&") - r_result = (left_result && r_result); - else if(oper=="||") - r_result = (left_result || r_result); - else - r_result_valid = false; -} - -} // namespace SL -} // namespace GL -} // namespace Msp diff --git a/source/glsl/evaluate.h b/source/glsl/evaluate.h deleted file mode 100644 index 4ad26779..00000000 --- a/source/glsl/evaluate.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef MSP_GL_SL_EVALUATE_H_ -#define MSP_GL_SL_EVALUATE_H_ - -#include "visitor.h" - -namespace Msp { -namespace GL { -namespace SL { - -/** Evaluates an expression. Only expressions consisting entirely of compile- -time constants can be evaluated. */ -class ExpressionEvaluator: public NodeVisitor -{ -public: - typedef std::map ValueMap; - -private: - const ValueMap *variable_values; - float r_result; - bool r_result_valid; - -public: - ExpressionEvaluator(); - ExpressionEvaluator(const ValueMap &); - - float get_result() const { return r_result; } - bool is_result_valid() const { return r_result_valid; } - - using NodeVisitor::visit; - virtual void visit(Literal &); - virtual void visit(VariableReference &); - virtual void visit(UnaryExpression &); - virtual void visit(BinaryExpression &); -}; - -} // namespace SL -} // namespace GL -} // namespace Msp - -#endif diff --git a/source/glsl/optimize.cpp b/source/glsl/optimize.cpp index 2a9bfb17..9cfad86a 100644 --- a/source/glsl/optimize.cpp +++ b/source/glsl/optimize.cpp @@ -487,12 +487,301 @@ void ExpressionInliner::visit(Iteration &iter) } +BasicTypeDeclaration::Kind ConstantFolder::get_value_kind(const Variant &value) +{ + if(value.check_type()) + return BasicTypeDeclaration::BOOL; + else if(value.check_type()) + return BasicTypeDeclaration::INT; + else if(value.check_type()) + return BasicTypeDeclaration::FLOAT; + else + return BasicTypeDeclaration::VOID; +} + +template +T ConstantFolder::evaluate_logical(char oper, T left, T right) +{ + switch(oper) + { + case '&': return left&right; + case '|': return left|right; + case '^': return left^right; + default: return T(); + } +} + +template +bool ConstantFolder::evaluate_relation(const char *oper, T left, T right) +{ + switch(oper[0]|oper[1]) + { + case '<': return left': return left>right; + case '>'|'=': return left>=right; + default: return false; + } +} + +template +T ConstantFolder::evaluate_arithmetic(char oper, T left, T right) +{ + switch(oper) + { + case '+': return left+right; + case '-': return left-right; + case '*': return left*right; + case '/': return left/right; + default: return T(); + } +} + +void ConstantFolder::set_result(const Variant &value, bool literal) +{ + r_constant_value = value; + r_constant = true; + r_literal = literal; +} + +void ConstantFolder::visit(RefPtr &expr) +{ + r_constant_value = Variant(); + r_constant = false; + r_literal = false; + r_uses_iter_var = false; + expr->visit(*this); + /* Don't replace literals since they'd only be replaced with an identical + literal. Also skip anything that uses an iteration variable, but pass on + the result so the Iteration visiting function can handle it. */ + if(!r_constant || r_literal || r_uses_iter_var) + return; + + BasicTypeDeclaration::Kind kind = get_value_kind(r_constant_value); + if(kind==BasicTypeDeclaration::VOID) + { + r_constant = false; + return; + } + + RefPtr literal = new Literal; + if(kind==BasicTypeDeclaration::BOOL) + literal->token = (r_constant_value.value() ? "true" : "false"); + else if(kind==BasicTypeDeclaration::INT) + literal->token = lexical_cast(r_constant_value.value()); + else if(kind==BasicTypeDeclaration::FLOAT) + literal->token = lexical_cast(r_constant_value.value()); + literal->value = r_constant_value; + expr = literal; +} + +void ConstantFolder::visit(Literal &literal) +{ + set_result(literal.value, true); +} + +void ConstantFolder::visit(VariableReference &var) +{ + /* If an iteration variable is initialized with a constant value, return + that value here for the purpose of evaluating the loop condition for the + first iteration. */ + if(var.declaration==iteration_var) + { + set_result(iter_init_value); + r_uses_iter_var = true; + } +} + +void ConstantFolder::visit(MemberAccess &memacc) +{ + TraversingVisitor::visit(memacc); + r_constant = false; +} + +void ConstantFolder::visit(Swizzle &swizzle) +{ + TraversingVisitor::visit(swizzle); + r_constant = false; +} + +void ConstantFolder::visit(UnaryExpression &unary) +{ + TraversingVisitor::visit(unary); + bool can_fold = r_constant; + r_constant = false; + if(!can_fold) + return; + + BasicTypeDeclaration::Kind kind = get_value_kind(r_constant_value); + + char oper = unary.oper->token[0]; + char oper2 = unary.oper->token[1]; + if(oper=='!') + { + if(kind==BasicTypeDeclaration::BOOL) + set_result(!r_constant_value.value()); + } + else if(oper=='~') + { + if(kind==BasicTypeDeclaration::INT) + set_result(~r_constant_value.value()); + } + else if(oper=='-' && !oper2) + { + if(kind==BasicTypeDeclaration::INT) + set_result(-r_constant_value.value()); + else if(kind==BasicTypeDeclaration::FLOAT) + set_result(-r_constant_value.value()); + } +} + +void ConstantFolder::visit(BinaryExpression &binary) +{ + visit(binary.left); + bool left_constant = r_constant; + bool left_iter_var = r_uses_iter_var; + Variant left_value = r_constant_value; + visit(binary.right); + if(left_iter_var) + r_uses_iter_var = true; + + bool can_fold = (left_constant && r_constant); + r_constant = false; + if(!can_fold) + return; + + BasicTypeDeclaration::Kind left_kind = get_value_kind(left_value); + BasicTypeDeclaration::Kind right_kind = get_value_kind(r_constant_value); + // Currently only expressions with both sides of equal types are handled. + if(left_kind!=right_kind) + return; + + char oper = binary.oper->token[0]; + char oper2 = binary.oper->token[1]; + if(oper=='&' || oper=='|' || oper=='^') + { + if(oper2==oper && left_kind==BasicTypeDeclaration::BOOL) + set_result(evaluate_logical(oper, left_value.value(), r_constant_value.value())); + else if(!oper2 && left_kind==BasicTypeDeclaration::INT) + set_result(evaluate_logical(oper, left_value.value(), r_constant_value.value())); + } + else if((oper=='<' || oper=='>') && oper2!=oper) + { + if(left_kind==BasicTypeDeclaration::INT) + set_result(evaluate_relation(binary.oper->token, left_value.value(), r_constant_value.value())); + else if(left_kind==BasicTypeDeclaration::FLOAT) + set_result(evaluate_relation(binary.oper->token, left_value.value(), r_constant_value.value())); + } + else if((oper=='=' || oper=='!') && oper2=='=') + { + if(left_kind==BasicTypeDeclaration::INT) + set_result((left_value.value()==r_constant_value.value()) == (oper=='=')); + if(left_kind==BasicTypeDeclaration::FLOAT) + set_result((left_value.value()==r_constant_value.value()) == (oper=='=')); + } + else if(oper=='+' || oper=='-' || oper=='*' || oper=='/') + { + if(left_kind==BasicTypeDeclaration::INT) + set_result(evaluate_arithmetic(oper, left_value.value(), r_constant_value.value())); + else if(left_kind==BasicTypeDeclaration::FLOAT) + set_result(evaluate_arithmetic(oper, left_value.value(), r_constant_value.value())); + } + else if(oper=='%' || ((oper=='<' || oper=='>') && oper2==oper)) + { + if(left_kind!=BasicTypeDeclaration::INT) + return; + + if(oper=='%') + set_result(left_value.value()%r_constant_value.value()); + else if(oper=='<') + set_result(left_value.value()<()); + else if(oper=='>') + set_result(left_value.value()>>r_constant_value.value()); + } +} + +void ConstantFolder::visit(Assignment &assign) +{ + TraversingVisitor::visit(assign); + r_constant = false; +} + +void ConstantFolder::visit(TernaryExpression &ternary) +{ + TraversingVisitor::visit(ternary); + r_constant = false; +} + +void ConstantFolder::visit(FunctionCall &call) +{ + TraversingVisitor::visit(call); + r_constant = false; +} + +void ConstantFolder::visit(VariableDeclaration &var) +{ + if(iteration_init && var.init_expression) + { + visit(var.init_expression); + if(r_constant) + { + /* Record the value of a constant initialization expression of an + iteration, so it can be used to evaluate the loop condition. */ + iteration_var = &var; + iter_init_value = r_constant_value; + } + } + else + TraversingVisitor::visit(var); +} + +void ConstantFolder::visit(Iteration &iter) +{ + SetForScope set_block(current_block, &iter.body); + + /* The iteration variable is not normally inlined into expressions, so we + process it specially here. If the initial value causes the loop condition + to evaluate to false, then the expression can be folded. */ + iteration_var = 0; + if(iter.init_statement) + { + SetFlag set_init(iteration_init); + iter.init_statement->visit(*this); + } + + if(iter.condition) + { + visit(iter.condition); + if(r_constant && r_constant_value.check_type() && !r_constant_value.value()) + { + RefPtr literal = new Literal; + literal->token = "false"; + literal->value = r_constant_value; + iter.condition = literal; + } + } + iteration_var = 0; + + iter.body.visit(*this); + if(iter.loop_expression) + visit(iter.loop_expression); +} + + void ConstantConditionEliminator::apply(Stage &stage) { stage.content.visit(*this); NodeRemover().apply(stage, nodes_to_remove); } +ConstantConditionEliminator::ConstantStatus ConstantConditionEliminator::check_constant_condition(const Expression &expr) +{ + if(const Literal *literal = dynamic_cast(&expr)) + if(literal->value.check_type()) + return (literal->value.value() ? CONSTANT_TRUE : CONSTANT_FALSE); + return NOT_CONSTANT; +} + void ConstantConditionEliminator::visit(Block &block) { SetForScope set_block(current_block, &block); @@ -505,14 +794,15 @@ void ConstantConditionEliminator::visit(Block &block) void ConstantConditionEliminator::visit(Conditional &cond) { - if(Literal *literal = dynamic_cast(cond.condition.get())) - if(literal->value.check_type()) - { - Block &block = (literal->value.value() ? cond.body : cond.else_body); - current_block->body.splice(insert_point, block.body); - nodes_to_remove.insert(&cond); - return; - } + ConstantStatus result = check_constant_condition(*cond.condition); + if(result!=NOT_CONSTANT) + { + Block &block = (result==CONSTANT_TRUE ? cond.body : cond.else_body); + // TODO should check variable names for conflicts. Potentially reuse InlineContentInjector? + current_block->body.splice(insert_point, block.body); + nodes_to_remove.insert(&cond); + return; + } TraversingVisitor::visit(cond); } @@ -521,14 +811,8 @@ void ConstantConditionEliminator::visit(Iteration &iter) { if(iter.condition) { - /* If the loop condition is always false on the first iteration, the - entire loop can be removed */ - ExpressionEvaluator::ValueMap values; - if(VariableDeclaration *var = dynamic_cast(iter.init_statement.get())) - values[var] = var->init_expression.get(); - ExpressionEvaluator eval(values); - iter.condition->visit(eval); - if(eval.is_result_valid() && !eval.get_result()) + ConstantStatus result = check_constant_condition(*iter.condition); + if(result==CONSTANT_FALSE) { nodes_to_remove.insert(&iter); return; diff --git a/source/glsl/optimize.h b/source/glsl/optimize.h index 8a888e15..1262823f 100644 --- a/source/glsl/optimize.h +++ b/source/glsl/optimize.h @@ -3,7 +3,6 @@ #include #include -#include "evaluate.h" #include "visitor.h" namespace Msp { @@ -137,11 +136,59 @@ private: virtual void visit(Iteration &); }; +/** Replaces expressions consisting entirely of literals with the results of +evaluating the expression.*/ +class ConstantFolder: private TraversingVisitor +{ +private: + VariableDeclaration *iteration_var; + Variant iter_init_value; + Variant r_constant_value; + bool iteration_init; + bool r_constant; + bool r_literal; + bool r_uses_iter_var; + bool r_any_folded; + +public: + bool apply(Stage &s) { s.content.visit(*this); return r_any_folded; } + +private: + static BasicTypeDeclaration::Kind get_value_kind(const Variant &); + template + static T evaluate_logical(char, T, T); + template + static bool evaluate_relation(const char *, T, T); + template + static T evaluate_arithmetic(char, T, T); + void set_result(const Variant &, bool = false); + + virtual void visit(RefPtr &); + virtual void visit(Literal &); + virtual void visit(VariableReference &); + virtual void visit(MemberAccess &); + virtual void visit(Swizzle &); + virtual void visit(UnaryExpression &); + virtual void visit(BinaryExpression &); + virtual void visit(Assignment &); + virtual void visit(TernaryExpression &); + virtual void visit(FunctionCall &); + virtual void visit(VariableDeclaration &); + virtual void visit(Iteration &); +}; + /** Removes conditional statements and loops where the condition can be determined as constant at compile time. */ class ConstantConditionEliminator: private TraversingVisitor { private: + enum ConstantStatus + { + CONSTANT_FALSE, + CONSTANT_TRUE, + NOT_CONSTANT + }; + NodeList::iterator insert_point; std::set nodes_to_remove; @@ -149,6 +196,8 @@ public: void apply(Stage &); private: + ConstantStatus check_constant_condition(const Expression &); + virtual void visit(Block &); virtual void visit(Conditional &); virtual void visit(Iteration &); diff --git a/tests/glsl/complex_constant_condition_removal.glsl b/tests/glsl/complex_constant_condition_removal.glsl new file mode 100644 index 00000000..8a9d96ac --- /dev/null +++ b/tests/glsl/complex_constant_condition_removal.glsl @@ -0,0 +1,42 @@ +const int lod = 1; +const int bias = 1; +const int threshold = 3; + +#pragma MSP stage(vertex) +layout(location=0) in vec4 position; +layout(location=1) in vec4 color; +void main() +{ + gl_Position = position; + passthrough; +} + +#pragma MSP stage(fragment) +layout(location=0) out vec4 frag_color; +void main() +{ + if(lod+bias