]> git.tdb.fi Git - libs/gl.git/blobdiff - source/glsl/spirv.cpp
Recognize composite constants when generating SPIR-V
[libs/gl.git] / source / glsl / spirv.cpp
index 1cce97dfbf34888580ab6137140127e6fc67b094..2220dcc1f012c00655fb207e841787b76be09238 100644 (file)
@@ -394,7 +394,10 @@ SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, u
                writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0));
        }
        else if(opcode==OP_COMPOSITE_CONSTRUCT)
-               writer.begin_op(content.globals, OP_SPEC_CONSTANT_COMPOSITE, (n_args ? 1+has_result*2+n_args : 0));
+               writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE),
+                       (n_args ? 1+has_result*2+n_args : 0));
+       else if(!spec_constant)
+               throw internal_error("invalid non-specialization constant expression");
        else
                writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0));
 
@@ -404,7 +407,7 @@ SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, u
                writer.write(type_id);
                writer.write(result_id);
        }
-       if(constant_expression && opcode!=OP_COMPOSITE_CONSTRUCT)
+       if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT)
                writer.write(opcode);
 
        return result_id;
@@ -413,7 +416,7 @@ SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, u
 void SpirVGenerator::end_expression(Opcode opcode)
 {
        if(constant_expression)
-               opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? OP_SPEC_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
+               opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
        writer.end_op(opcode);
 }
 
@@ -468,6 +471,7 @@ void SpirVGenerator::visit(Literal &literal)
                r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true);
        else
                r_expression_result_id = get_constant_id(type_id, literal.value);
+       r_constant_result = true;
 }
 
 void SpirVGenerator::visit(VariableReference &var)
@@ -478,11 +482,13 @@ void SpirVGenerator::visit(VariableReference &var)
                        throw internal_error("reference to non-constant variable in constant context");
 
                r_expression_result_id = get_id(*var.declaration);
+               r_constant_result = true;
                return;
        }
        else if(!current_function)
                throw internal_error("non-constant context outside a function");
 
+       r_constant_result = false;
        if(composite_access)
        {
                r_composite_base = var.declaration;
@@ -505,6 +511,7 @@ void SpirVGenerator::visit(InterfaceBlockReference &iface)
 
        r_composite_base = iface.declaration;
        r_expression_result_id = 0;
+       r_constant_result = false;
 }
 
 void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
@@ -550,6 +557,7 @@ void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
                writer.write(*i);
        end_expression(opcode);
 
+       r_constant_result = false;
        if(r_composite_base)
        {
                if(assignment_source_id)
@@ -643,6 +651,7 @@ void SpirVGenerator::visit(Swizzle &swizzle)
                        writer.write(swizzle.components[i]);
                end_expression(OP_VECTOR_SHUFFLE);
        }
+       r_constant_result = false;
 }
 
 void SpirVGenerator::visit(UnaryExpression &unary)
@@ -665,6 +674,7 @@ void SpirVGenerator::visit(UnaryExpression &unary)
        Id result_type_id = get_id(*unary.type);
        Opcode opcode = OP_NOP;
 
+       r_constant_result = false;
        if(oper=='!')
                opcode = OP_LOGICAL_NOT;
        else if(oper=='~')
@@ -753,6 +763,8 @@ void SpirVGenerator::visit(BinaryExpression &binary)
        Opcode opcode = OP_NOP;
        bool swap_operands = false;
 
+       r_constant_result = false;
+
        char oper2 = binary.oper->token[1];
        if((oper=='<' || oper=='>') && oper2!=oper)
        {
@@ -936,6 +948,7 @@ void SpirVGenerator::visit(Assignment &assign)
 
        SetForScope<Id> set_assign(assignment_source_id, r_expression_result_id);
        assign.left->visit(*this);
+       r_constant_result = false;
 }
 
 void SpirVGenerator::visit(TernaryExpression &ternary)
@@ -983,13 +996,13 @@ void SpirVGenerator::visit(TernaryExpression &ternary)
        writer.write(false_result_id);
        writer.write(false_label_id);
        end_expression(OP_PHI);
+
+       r_constant_result = false;
 }
 
 void SpirVGenerator::visit(FunctionCall &call)
 {
-       if(constant_expression)
-               throw internal_error("function call in constant expression");
-       else if(assignment_source_id)
+       if(assignment_source_id)
                throw internal_error("assignment to function call");
        else if(composite_access)
                return visit_isolated(call);
@@ -998,16 +1011,21 @@ void SpirVGenerator::visit(FunctionCall &call)
 
        vector<Id> argument_ids;
        argument_ids.reserve(call.arguments.size());
+       bool all_args_const = true;
        for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
        {
                (*i)->visit(*this);
                argument_ids.push_back(r_expression_result_id);
+               all_args_const &= r_constant_result;
        }
 
+       if(constant_expression && (!call.constructor || !all_args_const))
+               throw internal_error("function call in constant expression");
+
        Id result_type_id = get_id(*call.type);
 
        if(call.constructor)
-               visit_constructor(call, argument_ids);
+               visit_constructor(call, argument_ids, all_args_const);
        else if(call.declaration->source==BUILTIN_SOURCE)
        {
                string arg_types;
@@ -1077,7 +1095,7 @@ void SpirVGenerator::visit(FunctionCall &call)
        }
 }
 
-void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids)
+void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids, bool all_args_const)
 {
        Id result_type_id = get_id(*call.type);
 
@@ -1091,6 +1109,8 @@ void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &arg
                return;
        }
 
+       SetFlag set_const(constant_expression, constant_expression || all_args_const);
+
        BasicTypeDeclaration &elem = *get_element_type(*basic);
        BasicTypeDeclaration &basic_arg0 = dynamic_cast<BasicTypeDeclaration &>(*call.arguments[0]->type);
        BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0);
@@ -1136,6 +1156,9 @@ void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &arg
        }
        else if(elem.kind==BasicTypeDeclaration::BOOL)
        {
+               if(constant_expression)
+                       throw internal_error("unconverted constant");
+
                // Conversion to boolean is implemented as comparing against zero.
                Id number_type_id = get_id(elem_arg0);
                Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ?
@@ -1148,6 +1171,9 @@ void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &arg
        }
        else if(elem_arg0.kind==BasicTypeDeclaration::BOOL)
        {
+               if(constant_expression)
+                       throw internal_error("unconverted constant");
+
                /* Conversion from boolean is implemented as selecting from zero
                or one. */
                Id number_type_id = get_id(elem);
@@ -1169,6 +1195,9 @@ void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &arg
        }
        else
        {
+               if(constant_expression)
+                       throw internal_error("unconverted constant");
+
                // Scalar or vector conversion between types of equal size.
                Opcode opcode;
                if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT)