From 49080e2c9359a3929e77817707ad7e8bf3f8a36d Mon Sep 17 00:00:00 2001 From: Mikko Rasa Date: Tue, 20 Apr 2021 18:32:22 +0300 Subject: [PATCH] Recognize composite constants when generating SPIR-V --- source/glsl/spirv.cpp | 45 +++++++++++++++++++++++++++++++++++-------- source/glsl/spirv.h | 3 ++- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/source/glsl/spirv.cpp b/source/glsl/spirv.cpp index 1cce97df..2220dcc1 100644 --- a/source/glsl/spirv.cpp +++ b/source/glsl/spirv.cpp @@ -394,7 +394,10 @@ SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, u writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0)); } else if(opcode==OP_COMPOSITE_CONSTRUCT) - writer.begin_op(content.globals, OP_SPEC_CONSTANT_COMPOSITE, (n_args ? 1+has_result*2+n_args : 0)); + writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE), + (n_args ? 1+has_result*2+n_args : 0)); + else if(!spec_constant) + throw internal_error("invalid non-specialization constant expression"); else writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0)); @@ -404,7 +407,7 @@ SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, u writer.write(type_id); writer.write(result_id); } - if(constant_expression && opcode!=OP_COMPOSITE_CONSTRUCT) + if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT) writer.write(opcode); return result_id; @@ -413,7 +416,7 @@ SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, u void SpirVGenerator::end_expression(Opcode opcode) { if(constant_expression) - opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? OP_SPEC_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP); + opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP); writer.end_op(opcode); } @@ -468,6 +471,7 @@ void SpirVGenerator::visit(Literal &literal) r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true); else r_expression_result_id = get_constant_id(type_id, literal.value); + r_constant_result = true; } void SpirVGenerator::visit(VariableReference &var) @@ -478,11 +482,13 @@ void SpirVGenerator::visit(VariableReference &var) throw internal_error("reference to non-constant variable in constant context"); r_expression_result_id = get_id(*var.declaration); + r_constant_result = true; return; } else if(!current_function) throw internal_error("non-constant context outside a function"); + r_constant_result = false; if(composite_access) { r_composite_base = var.declaration; @@ -505,6 +511,7 @@ void SpirVGenerator::visit(InterfaceBlockReference &iface) r_composite_base = iface.declaration; r_expression_result_id = 0; + r_constant_result = false; } void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type) @@ -550,6 +557,7 @@ void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type) writer.write(*i); end_expression(opcode); + r_constant_result = false; if(r_composite_base) { if(assignment_source_id) @@ -643,6 +651,7 @@ void SpirVGenerator::visit(Swizzle &swizzle) writer.write(swizzle.components[i]); end_expression(OP_VECTOR_SHUFFLE); } + r_constant_result = false; } void SpirVGenerator::visit(UnaryExpression &unary) @@ -665,6 +674,7 @@ void SpirVGenerator::visit(UnaryExpression &unary) Id result_type_id = get_id(*unary.type); Opcode opcode = OP_NOP; + r_constant_result = false; if(oper=='!') opcode = OP_LOGICAL_NOT; else if(oper=='~') @@ -753,6 +763,8 @@ void SpirVGenerator::visit(BinaryExpression &binary) Opcode opcode = OP_NOP; bool swap_operands = false; + r_constant_result = false; + char oper2 = binary.oper->token[1]; if((oper=='<' || oper=='>') && oper2!=oper) { @@ -936,6 +948,7 @@ void SpirVGenerator::visit(Assignment &assign) SetForScope set_assign(assignment_source_id, r_expression_result_id); assign.left->visit(*this); + r_constant_result = false; } void SpirVGenerator::visit(TernaryExpression &ternary) @@ -983,13 +996,13 @@ void SpirVGenerator::visit(TernaryExpression &ternary) writer.write(false_result_id); writer.write(false_label_id); end_expression(OP_PHI); + + r_constant_result = false; } void SpirVGenerator::visit(FunctionCall &call) { - if(constant_expression) - throw internal_error("function call in constant expression"); - else if(assignment_source_id) + if(assignment_source_id) throw internal_error("assignment to function call"); else if(composite_access) return visit_isolated(call); @@ -998,16 +1011,21 @@ void SpirVGenerator::visit(FunctionCall &call) vector argument_ids; argument_ids.reserve(call.arguments.size()); + bool all_args_const = true; for(NodeArray::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i) { (*i)->visit(*this); argument_ids.push_back(r_expression_result_id); + all_args_const &= r_constant_result; } + if(constant_expression && (!call.constructor || !all_args_const)) + throw internal_error("function call in constant expression"); + Id result_type_id = get_id(*call.type); if(call.constructor) - visit_constructor(call, argument_ids); + visit_constructor(call, argument_ids, all_args_const); else if(call.declaration->source==BUILTIN_SOURCE) { string arg_types; @@ -1077,7 +1095,7 @@ void SpirVGenerator::visit(FunctionCall &call) } } -void SpirVGenerator::visit_constructor(FunctionCall &call, const vector &argument_ids) +void SpirVGenerator::visit_constructor(FunctionCall &call, const vector &argument_ids, bool all_args_const) { Id result_type_id = get_id(*call.type); @@ -1091,6 +1109,8 @@ void SpirVGenerator::visit_constructor(FunctionCall &call, const vector &arg return; } + SetFlag set_const(constant_expression, constant_expression || all_args_const); + BasicTypeDeclaration &elem = *get_element_type(*basic); BasicTypeDeclaration &basic_arg0 = dynamic_cast(*call.arguments[0]->type); BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0); @@ -1136,6 +1156,9 @@ void SpirVGenerator::visit_constructor(FunctionCall &call, const vector &arg } else if(elem.kind==BasicTypeDeclaration::BOOL) { + if(constant_expression) + throw internal_error("unconverted constant"); + // Conversion to boolean is implemented as comparing against zero. Id number_type_id = get_id(elem_arg0); Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ? @@ -1148,6 +1171,9 @@ void SpirVGenerator::visit_constructor(FunctionCall &call, const vector &arg } else if(elem_arg0.kind==BasicTypeDeclaration::BOOL) { + if(constant_expression) + throw internal_error("unconverted constant"); + /* Conversion from boolean is implemented as selecting from zero or one. */ Id number_type_id = get_id(elem); @@ -1169,6 +1195,9 @@ void SpirVGenerator::visit_constructor(FunctionCall &call, const vector &arg } else { + if(constant_expression) + throw internal_error("unconverted constant"); + // Scalar or vector conversion between types of equal size. Opcode opcode; if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT) diff --git a/source/glsl/spirv.h b/source/glsl/spirv.h index 1e7ec370..98aa673e 100644 --- a/source/glsl/spirv.h +++ b/source/glsl/spirv.h @@ -84,6 +84,7 @@ private: std::map variable_load_ids; Id next_id; Id r_expression_result_id; + bool r_constant_result; bool constant_expression; bool spec_constant; bool reachable; @@ -142,7 +143,7 @@ private: virtual void visit(Assignment &); virtual void visit(TernaryExpression &); virtual void visit(FunctionCall &); - void visit_constructor(FunctionCall &, const std::vector &); + void visit_constructor(FunctionCall &, const std::vector &, bool); void visit_builtin_matrix_comp_mult(FunctionCall &, const std::vector &); void visit_builtin_texture(FunctionCall &, const std::vector &); void visit_builtin_texel_fetch(FunctionCall &, const std::vector &); -- 2.43.0