From 3a1fe833ea04df75449706f1d773f6e65521a392 Mon Sep 17 00:00:00 2001 From: Mikko Rasa Date: Fri, 12 Mar 2021 20:27:42 +0200 Subject: [PATCH] Implement the ternary operator in GLSL --- source/glsl/debug.cpp | 11 ++++ source/glsl/debug.h | 1 + source/glsl/generate.cpp | 30 ++++++++++ source/glsl/generate.h | 1 + source/glsl/optimize.cpp | 15 +++++ source/glsl/optimize.h | 2 + source/glsl/output.cpp | 10 ++++ source/glsl/output.h | 1 + source/glsl/parser.cpp | 16 +++++- source/glsl/parser.h | 1 + source/glsl/syntax.cpp | 10 +++- source/glsl/syntax.h | 13 ++++- source/glsl/validate.cpp | 14 +++++ source/glsl/validate.h | 1 + source/glsl/visitor.cpp | 7 +++ source/glsl/visitor.h | 2 + tests/glsl/ternary_operand_type_mismatch.glsl | 30 ++++++++++ tests/glsl/ternary_operator.glsl | 56 +++++++++++++++++++ 18 files changed, 217 insertions(+), 4 deletions(-) create mode 100644 tests/glsl/ternary_operand_type_mismatch.glsl create mode 100644 tests/glsl/ternary_operator.glsl diff --git a/source/glsl/debug.cpp b/source/glsl/debug.cpp index 68615ed6..700acbf9 100644 --- a/source/glsl/debug.cpp +++ b/source/glsl/debug.cpp @@ -240,6 +240,17 @@ void DumpTree::visit(Assignment &assign) end_sub(); } +void DumpTree::visit(TernaryExpression &ternary) +{ + append(format("Ternary: %s -> %s", (ternary.oper->token[0]=='?' ? "?:" : ternary.oper->token), format_type(ternary.type))); + begin_sub(); + ternary.condition->visit(*this); + ternary.true_expr->visit(*this); + last_branch(); + ternary.false_expr->visit(*this); + end_sub(); +} + void DumpTree::visit(FunctionCall &call) { string head = "Function call: "; diff --git a/source/glsl/debug.h b/source/glsl/debug.h index 3837c569..a5dae090 100644 --- a/source/glsl/debug.h +++ b/source/glsl/debug.h @@ -62,6 +62,7 @@ private: virtual void visit(UnaryExpression &); virtual void visit(BinaryExpression &); virtual void visit(Assignment &); + virtual void visit(TernaryExpression &); virtual void visit(FunctionCall &); virtual void visit(ExpressionStatement &); virtual void visit(Import &); diff --git a/source/glsl/generate.cpp b/source/glsl/generate.cpp index ab487a44..e5855e94 100644 --- a/source/glsl/generate.cpp +++ b/source/glsl/generate.cpp @@ -872,6 +872,36 @@ void ExpressionResolver::visit(Assignment &assign) resolve(assign, assign.left->type, true); } +void ExpressionResolver::visit(TernaryExpression &ternary) +{ + TraversingVisitor::visit(ternary); + + BasicTypeDeclaration *basic_cond = dynamic_cast(ternary.condition->type); + if(!basic_cond || basic_cond->kind!=BasicTypeDeclaration::BOOL) + return; + + TypeDeclaration *type = 0; + if(ternary.true_expr->type==ternary.false_expr->type) + type = ternary.true_expr->type; + else + { + BasicTypeDeclaration *basic_true = dynamic_cast(ternary.true_expr->type); + BasicTypeDeclaration *basic_false = dynamic_cast(ternary.false_expr->type); + Compatibility compat = get_compatibility(*basic_true, *basic_false); + if(compat==NOT_COMPATIBLE) + return; + + type = (compat==LEFT_CONVERTIBLE ? basic_true : basic_false); + + if(compat==LEFT_CONVERTIBLE) + convert_to(ternary.true_expr, *basic_false); + else if(compat==RIGHT_CONVERTIBLE) + convert_to(ternary.false_expr, *basic_true); + } + + resolve(ternary, type, false); +} + void ExpressionResolver::visit(FunctionCall &call) { TraversingVisitor::visit(call); diff --git a/source/glsl/generate.h b/source/glsl/generate.h index c7a162e5..0a940cea 100644 --- a/source/glsl/generate.h +++ b/source/glsl/generate.h @@ -163,6 +163,7 @@ private: void visit(BinaryExpression &, bool); virtual void visit(BinaryExpression &); virtual void visit(Assignment &); + virtual void visit(TernaryExpression &); virtual void visit(FunctionCall &); virtual void visit(BasicTypeDeclaration &); virtual void visit(VariableDeclaration &); diff --git a/source/glsl/optimize.cpp b/source/glsl/optimize.cpp index 9ab8ef70..8f6b2745 100644 --- a/source/glsl/optimize.cpp +++ b/source/glsl/optimize.cpp @@ -452,6 +452,15 @@ void ExpressionInliner::visit(Assignment &assign) r_trivial = false; } +void ExpressionInliner::visit(TernaryExpression &ternary) +{ + visit_and_record(ternary.condition, ternary.oper, false); + visit_and_record(ternary.true_expr, ternary.oper, false); + visit_and_record(ternary.false_expr, ternary.oper, true); + r_oper = ternary.oper; + r_trivial = false; +} + void ExpressionInliner::visit(FunctionCall &call) { TraversingVisitor::visit(call); @@ -581,6 +590,12 @@ void UnusedTypeRemover::visit(BinaryExpression &binary) TraversingVisitor::visit(binary); } +void UnusedTypeRemover::visit(TernaryExpression &ternary) +{ + unused_nodes.erase(ternary.type); + TraversingVisitor::visit(ternary); +} + void UnusedTypeRemover::visit(FunctionCall &call) { unused_nodes.erase(call.type); diff --git a/source/glsl/optimize.h b/source/glsl/optimize.h index 34c0b245..0604a39c 100644 --- a/source/glsl/optimize.h +++ b/source/glsl/optimize.h @@ -135,6 +135,7 @@ private: virtual void visit(UnaryExpression &); virtual void visit(BinaryExpression &); virtual void visit(Assignment &); + virtual void visit(TernaryExpression &); virtual void visit(FunctionCall &); virtual void visit(VariableDeclaration &); virtual void visit(Iteration &); @@ -170,6 +171,7 @@ private: virtual void visit(Literal &); virtual void visit(UnaryExpression &); virtual void visit(BinaryExpression &); + virtual void visit(TernaryExpression &); virtual void visit(FunctionCall &); virtual void visit(BasicTypeDeclaration &); virtual void visit(ImageTypeDeclaration &); diff --git a/source/glsl/output.cpp b/source/glsl/output.cpp index 718b1a8b..364e5e4b 100644 --- a/source/glsl/output.cpp +++ b/source/glsl/output.cpp @@ -167,6 +167,16 @@ void Formatter::visit(Assignment &assign) assign.right->visit(*this); } +void Formatter::visit(TernaryExpression &ternary) +{ + ternary.condition->visit(*this); + append(ternary.oper->token); + ternary.true_expr->visit(*this); + if(ternary.oper->token[0]=='?') + append(':'); + ternary.false_expr->visit(*this); +} + void Formatter::visit(FunctionCall &call) { append(format("%s(", call.name)); diff --git a/source/glsl/output.h b/source/glsl/output.h index 0a506a5d..d872dabe 100644 --- a/source/glsl/output.h +++ b/source/glsl/output.h @@ -42,6 +42,7 @@ private: virtual void visit(UnaryExpression &); virtual void visit(BinaryExpression &); virtual void visit(Assignment &); + virtual void visit(TernaryExpression &); virtual void visit(FunctionCall &); virtual void visit(ExpressionStatement &); virtual void visit(Import &); diff --git a/source/glsl/parser.cpp b/source/glsl/parser.cpp index b4631daf..0c3a5e35 100644 --- a/source/glsl/parser.cpp +++ b/source/glsl/parser.cpp @@ -449,7 +449,7 @@ RefPtr Parser::parse_expression(const Operator *outer_oper) oper = i; bool lower_precedence = (oper && oper->type!=Operator::PREFIX && oper->precedence>=outer_precedence); - if(token==";" || token==")" || token=="]" || token=="," || lower_precedence) + if(token==";" || token==")" || token=="]" || token=="," || token==":" || lower_precedence) { if(left) return left; @@ -483,6 +483,8 @@ RefPtr Parser::parse_expression(const Operator *outer_oper) } else if(oper && oper->type==Operator::BINARY) left = parse_binary(left, *oper); + else if(oper && oper->type==Operator::TERNARY) + left = parse_ternary(left, *oper); else throw parse_error(tokenizer.get_location(), token, "an operator"); left_var = 0; @@ -557,6 +559,18 @@ RefPtr Parser::parse_binary(const RefPtr &left, co return binary; } +RefPtr Parser::parse_ternary(const RefPtr &cond, const Operator &oper) +{ + RefPtr ternary = create_node(); + ternary->condition = cond; + ternary->oper = &oper; + tokenizer.expect("?"); + ternary->true_expr = parse_expression(&oper); + tokenizer.expect(":"); + ternary->false_expr = parse_expression(&oper); + return ternary; +} + RefPtr Parser::parse_function_call(const VariableReference &var) { RefPtr call = create_node(); diff --git a/source/glsl/parser.h b/source/glsl/parser.h index bd564139..c98867a3 100644 --- a/source/glsl/parser.h +++ b/source/glsl/parser.h @@ -73,6 +73,7 @@ private: RefPtr parse_expression(const Operator * = 0); RefPtr parse_literal(); RefPtr parse_binary(const RefPtr &, const Operator &); + RefPtr parse_ternary(const RefPtr &, const Operator &); RefPtr parse_function_call(const VariableReference &); RefPtr parse_type_declaration(); RefPtr parse_basic_type_declaration(); diff --git a/source/glsl/syntax.cpp b/source/glsl/syntax.cpp index cad7c9f4..11a099b5 100644 --- a/source/glsl/syntax.cpp +++ b/source/glsl/syntax.cpp @@ -40,8 +40,8 @@ const Operator Operator::operators[] = { "&&", 12, BINARY, ASSOCIATIVE }, { "^^", 13, BINARY, ASSOCIATIVE }, { "||", 14, BINARY, ASSOCIATIVE }, - { "?", 15, BINARY, RIGHT_TO_LEFT }, - { ":", 15, BINARY, RIGHT_TO_LEFT }, + { "?", 15, TERNARY, RIGHT_TO_LEFT }, + { ":", 15, TERNARY, RIGHT_TO_LEFT }, { "=", 16, BINARY, RIGHT_TO_LEFT }, { "+=", 16, BINARY, RIGHT_TO_LEFT }, { "-=", 16, BINARY, RIGHT_TO_LEFT }, @@ -218,6 +218,12 @@ bool Assignment::Target::operator<(const Target &other) const } +void TernaryExpression::visit(NodeVisitor &visitor) +{ + visitor.visit(*this); +} + + FunctionCall::FunctionCall(): constructor(false), declaration(0) diff --git a/source/glsl/syntax.h b/source/glsl/syntax.h index 39bc8264..86972869 100644 --- a/source/glsl/syntax.h +++ b/source/glsl/syntax.h @@ -27,7 +27,8 @@ struct Operator NO_OPERATOR, BINARY, PREFIX, - POSTFIX + POSTFIX, + TERNARY }; enum Associativity @@ -259,6 +260,16 @@ struct Assignment: BinaryExpression virtual void visit(NodeVisitor &); }; +struct TernaryExpression: Expression +{ + NodePtr condition; + NodePtr true_expr; + NodePtr false_expr; + + virtual TernaryExpression *clone() const { return new TernaryExpression(*this); } + virtual void visit(NodeVisitor &); +}; + struct FunctionCall: Expression { std::string name; diff --git a/source/glsl/validate.cpp b/source/glsl/validate.cpp index 8805d41f..6e90698c 100644 --- a/source/glsl/validate.cpp +++ b/source/glsl/validate.cpp @@ -365,6 +365,20 @@ void ExpressionValidator::visit(Assignment &assign) TraversingVisitor::visit(assign); } +void ExpressionValidator::visit(TernaryExpression &ternary) +{ + if(ternary.condition->type) + { + BasicTypeDeclaration *basic_cond = dynamic_cast(ternary.condition->type); + if(!basic_cond || basic_cond->kind!=BasicTypeDeclaration::BOOL) + error(ternary, "Ternary operator condition is not a boolean"); + else if(!ternary.type && ternary.true_expr->type && ternary.false_expr->type) + error(ternary, format("Ternary operator has incompatible types '%s' and '%s'", + ternary.true_expr->type->name, ternary.false_expr->type->name)); + } + TraversingVisitor::visit(ternary); +} + void ExpressionValidator::visit(VariableDeclaration &var) { if(var.init_expression && var.init_expression->type && var.type_declaration && var.init_expression->type!=var.type_declaration) diff --git a/source/glsl/validate.h b/source/glsl/validate.h index 35cc049c..bafa550b 100644 --- a/source/glsl/validate.h +++ b/source/glsl/validate.h @@ -98,6 +98,7 @@ private: virtual void visit(UnaryExpression &); virtual void visit(BinaryExpression &); virtual void visit(Assignment &); + virtual void visit(TernaryExpression &); virtual void visit(VariableDeclaration &); }; diff --git a/source/glsl/visitor.cpp b/source/glsl/visitor.cpp index b1767d3f..5da89f54 100644 --- a/source/glsl/visitor.cpp +++ b/source/glsl/visitor.cpp @@ -53,6 +53,13 @@ void TraversingVisitor::visit(Assignment &assign) visit(assign.right); } +void TraversingVisitor::visit(TernaryExpression &ternary) +{ + visit(ternary.condition); + visit(ternary.true_expr); + visit(ternary.false_expr); +} + void TraversingVisitor::visit(FunctionCall &call) { for(NodeArray::iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i) diff --git a/source/glsl/visitor.h b/source/glsl/visitor.h index 85f9e848..3b558bd1 100644 --- a/source/glsl/visitor.h +++ b/source/glsl/visitor.h @@ -27,6 +27,7 @@ public: virtual void visit(UnaryExpression &) { } virtual void visit(BinaryExpression &) { } virtual void visit(Assignment &) { } + virtual void visit(TernaryExpression &) { } virtual void visit(FunctionCall &) { } virtual void visit(ExpressionStatement &) { } virtual void visit(Import &) { } @@ -64,6 +65,7 @@ public: virtual void visit(UnaryExpression &); virtual void visit(BinaryExpression &); virtual void visit(Assignment &); + virtual void visit(TernaryExpression &); virtual void visit(FunctionCall &); virtual void visit(ExpressionStatement &); virtual void visit(InterfaceLayout &); diff --git a/tests/glsl/ternary_operand_type_mismatch.glsl b/tests/glsl/ternary_operand_type_mismatch.glsl new file mode 100644 index 00000000..5eae47d8 --- /dev/null +++ b/tests/glsl/ternary_operand_type_mismatch.glsl @@ -0,0 +1,30 @@ +uniform Colors +{ + vec4 color; + float gray; +}; +uniform sampler2D mask; +uniform Transform +{ + mat4 mvp; +} transform; + +#pragma MSP stage(vertex) +layout(location=0) in vec4 position; +layout(location=1) in vec2 texcoord; +void main() +{ + passthrough; + gl_Position = transform.mvp*position; +} + +#pragma MSP stage(fragment) +layout(location=0) out vec4 frag_color; +void main() +{ + frag_color = texture(mask, texcoord).r > 0.5 ? color : gray; +} + +/* Expected error: +:25: Ternary operator has incompatible types 'vec4' and 'float' +*/ diff --git a/tests/glsl/ternary_operator.glsl b/tests/glsl/ternary_operator.glsl new file mode 100644 index 00000000..1a1a0989 --- /dev/null +++ b/tests/glsl/ternary_operator.glsl @@ -0,0 +1,56 @@ +uniform Colors +{ + vec4 color1; + vec4 color2; +}; +uniform sampler2D mask; +uniform Transform +{ + mat4 mvp; +} transform; + +#pragma MSP stage(vertex) +layout(location=0) in vec4 position; +layout(location=1) in vec2 texcoord; +void main() +{ + passthrough; + gl_Position = transform.mvp*position; +} + +#pragma MSP stage(fragment) +layout(location=0) out vec4 frag_color; +void main() +{ + frag_color = texture(mask, texcoord).r > 0.5 ? color1 : color2; +} + +/* Expected output: vertex +uniform Transform +{ + mat4 mvp; +} transform; +layout(location=0) in vec4 position; +layout(location=1) in vec2 texcoord; +out vec2 _vs_out_texcoord; +void main() +{ + _vs_out_texcoord = texcoord; + gl_Position = transform.mvp*position; +} +*/ + +/* Expected output: fragment +uniform Colors +{ + vec4 color1; + vec4 color2; +}; +uniform sampler2D mask; +layout(location=0) out vec4 frag_color; +in vec2 _vs_out_texcoord; +void main() +{ + frag_color = texture(mask, _vs_out_texcoord).r>0.5?color1:color2; +} +*/ -- 2.45.2