]> git.tdb.fi Git - libs/gl.git/commitdiff
Resolve and validate the parameters of constructors in GLSL
authorMikko Rasa <tdb@tdb.fi>
Sun, 14 Mar 2021 14:38:29 +0000 (16:38 +0200)
committerMikko Rasa <tdb@tdb.fi>
Sun, 14 Mar 2021 16:58:57 +0000 (18:58 +0200)
source/glsl/generate.cpp
source/glsl/generate.h
source/glsl/validate.cpp
tests/glsl/constructors.glsl [new file with mode: 0644]
tests/glsl/expression_inline_iteration.glsl
tests/glsl/invalid_constructors.glsl [new file with mode: 0644]
tests/glsl/parentheses.glsl

index 11116b11ecb164cb704fdcf277de380ac882ba82..6430f9d8a0e53d630e284f64ef4ab869a2f85925 100644 (file)
@@ -602,6 +602,30 @@ bool ExpressionResolver::convert_to_element(RefPtr<Expression> &expr, BasicTypeD
        return false;
 }
 
+bool ExpressionResolver::truncate_vector(RefPtr<Expression> &expr, unsigned size)
+{
+       if(BasicTypeDeclaration *expr_basic = dynamic_cast<BasicTypeDeclaration *>(expr->type))
+               if(BasicTypeDeclaration *expr_elem = get_element_type(*expr_basic))
+               {
+                       RefPtr<Swizzle> swizzle = new Swizzle;
+                       swizzle->left = expr;
+                       swizzle->oper = &Operator::get_operator(".", Operator::POSTFIX);
+                       swizzle->component_group = string("xyzw", size);
+                       swizzle->count = size;
+                       for(unsigned i=0; i<size; ++i)
+                               swizzle->components[i] = i;
+                       if(size==1)
+                               swizzle->type = expr_elem;
+                       else
+                               swizzle->type = find_type(*expr_elem, BasicTypeDeclaration::VECTOR, size);
+                       expr = swizzle;
+
+                       return true;
+               }
+
+       return false;
+}
+
 void ExpressionResolver::resolve(Expression &expr, TypeDeclaration *type, bool lvalue)
 {
        r_any_resolved |= (type!=expr.type || lvalue!=expr.lvalue);
@@ -609,6 +633,16 @@ void ExpressionResolver::resolve(Expression &expr, TypeDeclaration *type, bool l
        expr.lvalue = lvalue;
 }
 
+void ExpressionResolver::visit(Block &block)
+{
+       SetForScope<Block *> set_block(current_block, &block);
+       for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
+       {
+               insert_point = i;
+               (*i)->visit(*this);
+       }
+}
+
 void ExpressionResolver::visit(Literal &literal)
 {
        if(literal.value.check_type<bool>())
@@ -889,19 +923,223 @@ void ExpressionResolver::visit(TernaryExpression &ternary)
        resolve(ternary, type, false);
 }
 
+void ExpressionResolver::visit_constructor(FunctionCall &call)
+{
+       if(call.arguments.empty())
+               return;
+
+       map<string, TypeDeclaration *>::const_iterator i = stage->types.find(call.name);
+       if(i==stage->types.end())
+               return;
+       else if(BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(i->second))
+       {
+               BasicTypeDeclaration *elem = get_element_type(*basic);
+               if(!elem)
+                       return;
+
+               vector<ArgumentInfo> args;
+               args.reserve(call.arguments.size());
+               unsigned arg_component_total = 0;
+               bool has_matrices = false;
+               for(NodeArray<Expression>::const_iterator j=call.arguments.begin(); j!=call.arguments.end(); ++j)
+               {
+                       ArgumentInfo info;
+                       if(!(info.type=dynamic_cast<BasicTypeDeclaration *>((*j)->type)))
+                               return;
+                       if(is_scalar(*info.type) || info.type->kind==BasicTypeDeclaration::BOOL)
+                               info.component_count = 1;
+                       else if(info.type->kind==BasicTypeDeclaration::VECTOR)
+                               info.component_count = info.type->size;
+                       else if(info.type->kind==BasicTypeDeclaration::MATRIX)
+                       {
+                               info.component_count = (info.type->size>>16)*(info.type->size&0xFFFF);
+                               has_matrices = true;
+                       }
+                       else
+                               return;
+                       arg_component_total += info.component_count;
+                       args.push_back(info);
+               }
+
+               bool convert_args = false;
+               if((is_scalar(*basic) || basic->kind==BasicTypeDeclaration::BOOL) && call.arguments.size()==1 && !has_matrices)
+               {
+                       if(arg_component_total>1)
+                               truncate_vector(call.arguments.front(), 1);
+
+                       /* Single-element type constructors never need to convert their
+                       arguments because the constructor *is* the conversion. */
+               }
+               else if(basic->kind==BasicTypeDeclaration::VECTOR && !has_matrices)
+               {
+                       /* Vector constructors need either a single scalar argument or
+                       enough components to fill out the vector. */
+                       if(arg_component_total!=1 && arg_component_total<basic->size)
+                               return;
+
+                       /* A vector of same size can be converted directly.  For other
+                       combinations the individual arguments need to be converted. */
+                       if(call.arguments.size()==1)
+                       {
+                               if(arg_component_total==1)
+                                       convert_args = true;
+                               else if(arg_component_total>basic->size)
+                                       truncate_vector(call.arguments.front(), basic->size);
+                       }
+                       else if(arg_component_total==basic->size)
+                               convert_args = true;
+                       else
+                               return;
+               }
+               else if(basic->kind==BasicTypeDeclaration::MATRIX)
+               {
+                       unsigned column_count = basic->size&0xFFFF;
+                       unsigned row_count = basic->size>>16;
+                       if(call.arguments.size()==1)
+                       {
+                               /* A matrix can be constructed from a single element or another
+                               matrix of sufficient size. */
+                               if(arg_component_total==1)
+                                       convert_args = true;
+                               else if(args.front().type->kind==BasicTypeDeclaration::MATRIX)
+                               {
+                                       unsigned arg_columns = args.front().type->size&0xFFFF;
+                                       unsigned arg_rows = args.front().type->size>>16;
+                                       if(arg_columns<column_count || arg_rows<row_count)
+                                               return;
+
+                                       /* Always generate a temporary here and let the optimization
+                                       stage inline it if that's reasonable. */
+                                       RefPtr<VariableDeclaration> temporary = new VariableDeclaration;
+                                       temporary->type = args.front().type->name;
+                                       temporary->name = get_unused_variable_name(*current_block, "_temp", string());
+                                       temporary->init_expression = call.arguments.front();
+                                       current_block->body.insert(insert_point, temporary);
+
+                                       // Create expressions to build each column.
+                                       vector<RefPtr<Expression> > columns;
+                                       columns.reserve(column_count);
+                                       for(unsigned j=0; j<column_count; ++j)
+                                       {
+                                               RefPtr<VariableReference> ref = new VariableReference;
+                                               ref->name = temporary->name;
+
+                                               RefPtr<Literal> index = new Literal;
+                                               index->token = lexical_cast<string>(j);
+                                               index->value = static_cast<int>(j);
+
+                                               RefPtr<BinaryExpression> subscript = new BinaryExpression;
+                                               subscript->left = ref;
+                                               subscript->oper = &Operator::get_operator("[", Operator::BINARY);
+                                               subscript->right = index;
+                                               subscript->type = args.front().type->base_type;
+
+                                               columns.push_back(subscript);
+                                               if(arg_rows>row_count)
+                                                       truncate_vector(columns.back(), row_count);
+                                       }
+
+                                       call.arguments.resize(column_count);
+                                       copy(columns.begin(), columns.end(), call.arguments.begin());
+
+                                       /* Let VariableResolver process the new nodes and finish
+                                       resolving the constructor on the next pass. */
+                                       r_any_resolved = true;
+                                       return;
+                               }
+                               else
+                                       return;
+                       }
+                       else if(arg_component_total==column_count*row_count && !has_matrices)
+                       {
+                               /* Construct a matrix from individual components in column-major
+                               order.  Arguments must align at column boundaries. */
+                               vector<RefPtr<Expression> > columns;
+                               columns.reserve(column_count);
+
+                               vector<RefPtr<Expression> > column_args;
+                               column_args.reserve(row_count);
+                               unsigned column_component_count = 0;
+
+                               for(unsigned j=0; j<call.arguments.size(); ++j)
+                               {
+                                       const ArgumentInfo &info = args[j];
+                                       if(!column_component_count && info.type->kind==BasicTypeDeclaration::VECTOR && info.component_count==row_count)
+                                               // A vector filling the entire column can be used as is.
+                                               columns.push_back(call.arguments[j]);
+                                       else
+                                       {
+                                               column_args.push_back(call.arguments[j]);
+                                               column_component_count += info.component_count;
+                                               if(column_component_count==row_count)
+                                               {
+                                                       /* The column has filled up.  Create a vector constructor
+                                                       for it.*/
+                                                       RefPtr<FunctionCall> column_call = new FunctionCall;
+                                                       column_call->name = basic->base_type->name;
+                                                       column_call->constructor = true;
+                                                       column_call->arguments.resize(column_args.size());
+                                                       copy(column_args.begin(), column_args.end(), column_call->arguments.begin());
+                                                       column_call->type = basic->base_type;
+                                                       visit_constructor(*column_call);
+                                                       columns.push_back(column_call);
+
+                                                       column_args.clear();
+                                                       column_component_count = 0;
+                                               }
+                                               else if(column_component_count>row_count)
+                                                       // Argument alignment mismatch.
+                                                       return;
+                                       }
+                               }
+                       }
+                       else
+                               return;
+               }
+               else
+                       return;
+
+               if(convert_args)
+               {
+                       // The argument list may have changed so can't rely on args.
+                       for(NodeArray<Expression>::iterator j=call.arguments.begin(); j!=call.arguments.end(); ++j)
+                               if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>((*j)->type))
+                               {
+                                       BasicTypeDeclaration *elem_arg = get_element_type(*basic_arg);
+                                       if(elem_arg!=elem)
+                                               convert_to_element(*j, *elem);
+                               }
+               }
+       }
+       else if(StructDeclaration *strct = dynamic_cast<StructDeclaration *>(i->second))
+       {
+               if(call.arguments.size()!=strct->members.body.size())
+                       return;
+
+               unsigned k = 0;
+               for(NodeList<Statement>::const_iterator j=strct->members.body.begin(); j!=strct->members.body.end(); ++j, ++k)
+               {
+                       if(VariableDeclaration *var = dynamic_cast<VariableDeclaration *>(j->get()))
+                       {
+                               if(!call.arguments[k]->type || call.arguments[k]->type!=var->type_declaration)
+                                       return;
+                       }
+                       else
+                               return;
+               }
+       }
+
+       resolve(call, i->second, false);
+}
+
 void ExpressionResolver::visit(FunctionCall &call)
 {
        TraversingVisitor::visit(call);
 
-       TypeDeclaration *type = 0;
        if(call.declaration)
-               type = call.declaration->return_type_declaration;
+               resolve(call, call.declaration->return_type_declaration, false);
        else if(call.constructor)
-       {
-               map<string, TypeDeclaration *>::const_iterator i=stage->types.find(call.name);
-               type = (i!=stage->types.end() ? i->second : 0);
-       }
-       resolve(call, type, false);
+               visit_constructor(call);
 }
 
 void ExpressionResolver::visit(BasicTypeDeclaration &type)
index 51d86a536854b129dc44fcbc9b06ff5927e0422f..58bc4ca202e6bdbf60f5d625c1f5b5ba2004ef14 100644 (file)
@@ -118,8 +118,15 @@ private:
                SAME_TYPE
        };
 
+       struct ArgumentInfo
+       {
+               BasicTypeDeclaration *type;
+               unsigned component_count;
+       };
+
        Stage *stage;
        std::vector<BasicTypeDeclaration *> basic_types;
+       NodeList<Statement>::iterator insert_point;
        bool r_any_resolved;
 
 public:
@@ -137,8 +144,10 @@ private:
        BasicTypeDeclaration *find_type(BasicTypeDeclaration &, BasicTypeDeclaration::Kind, unsigned);
        void convert_to(RefPtr<Expression> &, BasicTypeDeclaration &);
        bool convert_to_element(RefPtr<Expression> &, BasicTypeDeclaration &);
+       bool truncate_vector(RefPtr<Expression> &, unsigned);
        void resolve(Expression &, TypeDeclaration *, bool);
 
+       virtual void visit(Block &);
        virtual void visit(Literal &);
        virtual void visit(VariableReference &);
        virtual void visit(InterfaceBlockReference &);
@@ -149,6 +158,7 @@ private:
        virtual void visit(BinaryExpression &);
        virtual void visit(Assignment &);
        virtual void visit(TernaryExpression &);
+       void visit_constructor(FunctionCall &);
        virtual void visit(FunctionCall &);
        virtual void visit(BasicTypeDeclaration &);
        virtual void visit(VariableDeclaration &);
index 4b38c5b755e935f0e3f477c0ecd092202e411c7f..b8710509e8d32fe424062f6c6558976ceab6da35 100644 (file)
@@ -239,10 +239,16 @@ void ReferenceValidator::visit(InterfaceBlockReference &iface)
 
 void ReferenceValidator::visit(FunctionCall &call)
 {
-       if(!call.declaration && !call.constructor)
+       if((!call.constructor && !call.declaration) || (call.constructor && !call.type))
        {
-               map<string, FunctionDeclaration *>::iterator i = stage->functions.lower_bound(call.name);
-               if(i!=stage->functions.end() && i->second->name==call.name)
+               bool have_declaration = call.constructor;
+               if(!call.constructor)
+               {
+                       map<string, FunctionDeclaration *>::iterator i = stage->functions.lower_bound(call.name);
+                       have_declaration = (i!=stage->functions.end() && i->second->name==call.name);
+               }
+
+               if(have_declaration)
                {
                        bool valid_types = true;
                        string signature;
@@ -255,7 +261,7 @@ void ReferenceValidator::visit(FunctionCall &call)
                        }
 
                        if(valid_types)
-                               error(call, format("No matching overload found for call to '%s(%s)'", call.name, signature));
+                               error(call, format("No matching %s found for '%s(%s)'", (call.constructor ? "constructor" : "overload"), call.name, signature));
                }
                else
                        error(call, format("Call to undeclared function '%s'", call.name));
diff --git a/tests/glsl/constructors.glsl b/tests/glsl/constructors.glsl
new file mode 100644 (file)
index 0000000..fe7ca48
--- /dev/null
@@ -0,0 +1,58 @@
+uniform mat4 model;
+uniform mat4 view_projection;
+uniform vec3 light_dir;
+uniform sampler2D normalmap;
+
+#pragma MSP stage(vertex)
+layout(location=0) in vec3 position;
+layout(location=1) in vec3 normal;
+layout(location=2) in vec3 tangent;
+layout(location=3) in vec3 binormal;
+layout(location=4) in vec2 texcoord;
+void main()
+{
+       mat3 normal_matrix = mat3(model);
+       mat3 tbn_matrix = mat3(normal_matrix*tangent, normal_matrix*binormal, normal_matrix*normal);
+       out vec3 tbn_light_dir = tbn_matrix*light_dir;
+       gl_Position = view_projection*model*vec4(position, 1);
+       passthrough;
+}
+
+#pragma MSP stage(fragment)
+layout(location=0) out vec4 frag_color;
+void main()
+{
+       vec3 normal = vec3(texture(normalmap, texcoord))*2.0-1.0;
+       frag_color = vec4(vec3(dot(normal, normalize(tbn_light_dir))), 1);
+}
+
+/* Expected output: vertex
+uniform mat4 model;
+uniform mat4 view_projection;
+uniform vec3 light_dir;
+layout(location=0) in vec3 position;
+layout(location=1) in vec3 normal;
+layout(location=2) in vec3 tangent;
+layout(location=3) in vec3 binormal;
+layout(location=4) in vec2 texcoord;
+out vec3 tbn_light_dir;
+out vec2 _vs_out_texcoord;
+void main()
+{
+  mat3 normal_matrix = mat3(model[0].xyz, model[1].xyz, model[2].xyz);
+  tbn_light_dir = mat3(normal_matrix*tangent, normal_matrix*binormal, normal_matrix*normal)*light_dir;
+  gl_Position = view_projection*model*vec4(position, float(1));
+  _vs_out_texcoord = texcoord;
+}
+*/
+
+/* Expected output: fragment
+uniform sampler2D normalmap;
+layout(location=0) out vec4 frag_color;
+in vec2 _vs_out_texcoord;
+in vec3 tbn_light_dir;
+void main()
+{
+  frag_color = vec4(vec3(dot(vec3(texture(normalmap, _vs_out_texcoord).xyz)*2.0-1.0, normalize(tbn_light_dir))), float(1));
+}
+*/
index 4a125099ddc78eeb0389f6897795285d54259e7a..fa17ae9565f2cf5ed33bf4eeb2f37385425a0b2e 100644 (file)
@@ -30,6 +30,6 @@ void main()
                        break;
                step = i;
        }
-       gl_Position = position+vec4(step, 0.0, 0.0, 0.0);
+       gl_Position = position+vec4(float(step), 0.0, 0.0, 0.0);
 }
 */
diff --git a/tests/glsl/invalid_constructors.glsl b/tests/glsl/invalid_constructors.glsl
new file mode 100644 (file)
index 0000000..f9d7a04
--- /dev/null
@@ -0,0 +1,21 @@
+#pragma MSP stage(vertex)
+void main()
+{
+       float f = float();
+       vec3 v = vec3(f, f);
+       vec3 u = vec3(v, 1.0);
+       vec3 w = vec3(v.xy, vec2(3.0, 4.0));
+       mat3 m = mat3(v, u);
+       mat3 m2 = mat3(v, v, u, u);
+       mat2 m3 = mat2(v, 1.0);
+}
+
+/* Expected error:
+<test>:4: No matching constructor found for 'float()'
+<test>:5: No matching constructor found for 'vec3(float, float)'
+<test>:6: No matching constructor found for 'vec3(vec3, float)'
+<test>:7: No matching constructor found for 'vec3(vec2, vec2)'
+<test>:8: No matching constructor found for 'mat3(vec3, vec3)'
+<test>:9: No matching constructor found for 'mat3(vec3, vec3, vec3, vec3)'
+<test>:10: No matching constructor found for 'mat2(vec3, float)'
+*/
index 9899c49ed72aecfcc3f5de72c39f817a2c077880..32651ac97fd60bcd989754431e980aa5c9a6ee11 100644 (file)
@@ -64,6 +64,6 @@ in vec2 _vs_out_texcoord;
 in vec3 _vs_out_normal;
 void main()
 {
-       frag_color = vec4(material_color*(light.ambient+light.color*light.intensity*max(dot(light.dir, normalize(_vs_out_normal)), 0.0))*float(texture(occlusion_map, _vs_out_texcoord)), 1.0);
+       frag_color = vec4(material_color*(light.ambient+light.color*light.intensity*max(dot(light.dir, normalize(_vs_out_normal)), 0.0))*float(texture(occlusion_map, _vs_out_texcoord).x), 1.0);
 }
 */