]> git.tdb.fi Git - libs/gl.git/blobdiff - source/glsl/generate.cpp
Resolve and validate the parameters of constructors in GLSL
[libs/gl.git] / source / glsl / generate.cpp
index 11116b11ecb164cb704fdcf277de380ac882ba82..6430f9d8a0e53d630e284f64ef4ab869a2f85925 100644 (file)
@@ -602,6 +602,30 @@ bool ExpressionResolver::convert_to_element(RefPtr<Expression> &expr, BasicTypeD
        return false;
 }
 
+bool ExpressionResolver::truncate_vector(RefPtr<Expression> &expr, unsigned size)
+{
+       if(BasicTypeDeclaration *expr_basic = dynamic_cast<BasicTypeDeclaration *>(expr->type))
+               if(BasicTypeDeclaration *expr_elem = get_element_type(*expr_basic))
+               {
+                       RefPtr<Swizzle> swizzle = new Swizzle;
+                       swizzle->left = expr;
+                       swizzle->oper = &Operator::get_operator(".", Operator::POSTFIX);
+                       swizzle->component_group = string("xyzw", size);
+                       swizzle->count = size;
+                       for(unsigned i=0; i<size; ++i)
+                               swizzle->components[i] = i;
+                       if(size==1)
+                               swizzle->type = expr_elem;
+                       else
+                               swizzle->type = find_type(*expr_elem, BasicTypeDeclaration::VECTOR, size);
+                       expr = swizzle;
+
+                       return true;
+               }
+
+       return false;
+}
+
 void ExpressionResolver::resolve(Expression &expr, TypeDeclaration *type, bool lvalue)
 {
        r_any_resolved |= (type!=expr.type || lvalue!=expr.lvalue);
@@ -609,6 +633,16 @@ void ExpressionResolver::resolve(Expression &expr, TypeDeclaration *type, bool l
        expr.lvalue = lvalue;
 }
 
+void ExpressionResolver::visit(Block &block)
+{
+       SetForScope<Block *> set_block(current_block, &block);
+       for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
+       {
+               insert_point = i;
+               (*i)->visit(*this);
+       }
+}
+
 void ExpressionResolver::visit(Literal &literal)
 {
        if(literal.value.check_type<bool>())
@@ -889,19 +923,223 @@ void ExpressionResolver::visit(TernaryExpression &ternary)
        resolve(ternary, type, false);
 }
 
+void ExpressionResolver::visit_constructor(FunctionCall &call)
+{
+       if(call.arguments.empty())
+               return;
+
+       map<string, TypeDeclaration *>::const_iterator i = stage->types.find(call.name);
+       if(i==stage->types.end())
+               return;
+       else if(BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(i->second))
+       {
+               BasicTypeDeclaration *elem = get_element_type(*basic);
+               if(!elem)
+                       return;
+
+               vector<ArgumentInfo> args;
+               args.reserve(call.arguments.size());
+               unsigned arg_component_total = 0;
+               bool has_matrices = false;
+               for(NodeArray<Expression>::const_iterator j=call.arguments.begin(); j!=call.arguments.end(); ++j)
+               {
+                       ArgumentInfo info;
+                       if(!(info.type=dynamic_cast<BasicTypeDeclaration *>((*j)->type)))
+                               return;
+                       if(is_scalar(*info.type) || info.type->kind==BasicTypeDeclaration::BOOL)
+                               info.component_count = 1;
+                       else if(info.type->kind==BasicTypeDeclaration::VECTOR)
+                               info.component_count = info.type->size;
+                       else if(info.type->kind==BasicTypeDeclaration::MATRIX)
+                       {
+                               info.component_count = (info.type->size>>16)*(info.type->size&0xFFFF);
+                               has_matrices = true;
+                       }
+                       else
+                               return;
+                       arg_component_total += info.component_count;
+                       args.push_back(info);
+               }
+
+               bool convert_args = false;
+               if((is_scalar(*basic) || basic->kind==BasicTypeDeclaration::BOOL) && call.arguments.size()==1 && !has_matrices)
+               {
+                       if(arg_component_total>1)
+                               truncate_vector(call.arguments.front(), 1);
+
+                       /* Single-element type constructors never need to convert their
+                       arguments because the constructor *is* the conversion. */
+               }
+               else if(basic->kind==BasicTypeDeclaration::VECTOR && !has_matrices)
+               {
+                       /* Vector constructors need either a single scalar argument or
+                       enough components to fill out the vector. */
+                       if(arg_component_total!=1 && arg_component_total<basic->size)
+                               return;
+
+                       /* A vector of same size can be converted directly.  For other
+                       combinations the individual arguments need to be converted. */
+                       if(call.arguments.size()==1)
+                       {
+                               if(arg_component_total==1)
+                                       convert_args = true;
+                               else if(arg_component_total>basic->size)
+                                       truncate_vector(call.arguments.front(), basic->size);
+                       }
+                       else if(arg_component_total==basic->size)
+                               convert_args = true;
+                       else
+                               return;
+               }
+               else if(basic->kind==BasicTypeDeclaration::MATRIX)
+               {
+                       unsigned column_count = basic->size&0xFFFF;
+                       unsigned row_count = basic->size>>16;
+                       if(call.arguments.size()==1)
+                       {
+                               /* A matrix can be constructed from a single element or another
+                               matrix of sufficient size. */
+                               if(arg_component_total==1)
+                                       convert_args = true;
+                               else if(args.front().type->kind==BasicTypeDeclaration::MATRIX)
+                               {
+                                       unsigned arg_columns = args.front().type->size&0xFFFF;
+                                       unsigned arg_rows = args.front().type->size>>16;
+                                       if(arg_columns<column_count || arg_rows<row_count)
+                                               return;
+
+                                       /* Always generate a temporary here and let the optimization
+                                       stage inline it if that's reasonable. */
+                                       RefPtr<VariableDeclaration> temporary = new VariableDeclaration;
+                                       temporary->type = args.front().type->name;
+                                       temporary->name = get_unused_variable_name(*current_block, "_temp", string());
+                                       temporary->init_expression = call.arguments.front();
+                                       current_block->body.insert(insert_point, temporary);
+
+                                       // Create expressions to build each column.
+                                       vector<RefPtr<Expression> > columns;
+                                       columns.reserve(column_count);
+                                       for(unsigned j=0; j<column_count; ++j)
+                                       {
+                                               RefPtr<VariableReference> ref = new VariableReference;
+                                               ref->name = temporary->name;
+
+                                               RefPtr<Literal> index = new Literal;
+                                               index->token = lexical_cast<string>(j);
+                                               index->value = static_cast<int>(j);
+
+                                               RefPtr<BinaryExpression> subscript = new BinaryExpression;
+                                               subscript->left = ref;
+                                               subscript->oper = &Operator::get_operator("[", Operator::BINARY);
+                                               subscript->right = index;
+                                               subscript->type = args.front().type->base_type;
+
+                                               columns.push_back(subscript);
+                                               if(arg_rows>row_count)
+                                                       truncate_vector(columns.back(), row_count);
+                                       }
+
+                                       call.arguments.resize(column_count);
+                                       copy(columns.begin(), columns.end(), call.arguments.begin());
+
+                                       /* Let VariableResolver process the new nodes and finish
+                                       resolving the constructor on the next pass. */
+                                       r_any_resolved = true;
+                                       return;
+                               }
+                               else
+                                       return;
+                       }
+                       else if(arg_component_total==column_count*row_count && !has_matrices)
+                       {
+                               /* Construct a matrix from individual components in column-major
+                               order.  Arguments must align at column boundaries. */
+                               vector<RefPtr<Expression> > columns;
+                               columns.reserve(column_count);
+
+                               vector<RefPtr<Expression> > column_args;
+                               column_args.reserve(row_count);
+                               unsigned column_component_count = 0;
+
+                               for(unsigned j=0; j<call.arguments.size(); ++j)
+                               {
+                                       const ArgumentInfo &info = args[j];
+                                       if(!column_component_count && info.type->kind==BasicTypeDeclaration::VECTOR && info.component_count==row_count)
+                                               // A vector filling the entire column can be used as is.
+                                               columns.push_back(call.arguments[j]);
+                                       else
+                                       {
+                                               column_args.push_back(call.arguments[j]);
+                                               column_component_count += info.component_count;
+                                               if(column_component_count==row_count)
+                                               {
+                                                       /* The column has filled up.  Create a vector constructor
+                                                       for it.*/
+                                                       RefPtr<FunctionCall> column_call = new FunctionCall;
+                                                       column_call->name = basic->base_type->name;
+                                                       column_call->constructor = true;
+                                                       column_call->arguments.resize(column_args.size());
+                                                       copy(column_args.begin(), column_args.end(), column_call->arguments.begin());
+                                                       column_call->type = basic->base_type;
+                                                       visit_constructor(*column_call);
+                                                       columns.push_back(column_call);
+
+                                                       column_args.clear();
+                                                       column_component_count = 0;
+                                               }
+                                               else if(column_component_count>row_count)
+                                                       // Argument alignment mismatch.
+                                                       return;
+                                       }
+                               }
+                       }
+                       else
+                               return;
+               }
+               else
+                       return;
+
+               if(convert_args)
+               {
+                       // The argument list may have changed so can't rely on args.
+                       for(NodeArray<Expression>::iterator j=call.arguments.begin(); j!=call.arguments.end(); ++j)
+                               if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>((*j)->type))
+                               {
+                                       BasicTypeDeclaration *elem_arg = get_element_type(*basic_arg);
+                                       if(elem_arg!=elem)
+                                               convert_to_element(*j, *elem);
+                               }
+               }
+       }
+       else if(StructDeclaration *strct = dynamic_cast<StructDeclaration *>(i->second))
+       {
+               if(call.arguments.size()!=strct->members.body.size())
+                       return;
+
+               unsigned k = 0;
+               for(NodeList<Statement>::const_iterator j=strct->members.body.begin(); j!=strct->members.body.end(); ++j, ++k)
+               {
+                       if(VariableDeclaration *var = dynamic_cast<VariableDeclaration *>(j->get()))
+                       {
+                               if(!call.arguments[k]->type || call.arguments[k]->type!=var->type_declaration)
+                                       return;
+                       }
+                       else
+                               return;
+               }
+       }
+
+       resolve(call, i->second, false);
+}
+
 void ExpressionResolver::visit(FunctionCall &call)
 {
        TraversingVisitor::visit(call);
 
-       TypeDeclaration *type = 0;
        if(call.declaration)
-               type = call.declaration->return_type_declaration;
+               resolve(call, call.declaration->return_type_declaration, false);
        else if(call.constructor)
-       {
-               map<string, TypeDeclaration *>::const_iterator i=stage->types.find(call.name);
-               type = (i!=stage->types.end() ? i->second : 0);
-       }
-       resolve(call, type, false);
+               visit_constructor(call);
 }
 
 void ExpressionResolver::visit(BasicTypeDeclaration &type)