]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/validate.cpp
7e5f295f165ac3eda0993a7767b2b29fa9b6bbc5
[libs/gl.git] / source / glsl / validate.cpp
1 #include <algorithm>
2 #include <cstring>
3 #include <msp/core/raii.h>
4 #include <msp/strings/format.h>
5 #include <msp/strings/utils.h>
6 #include "reflect.h"
7 #include "validate.h"
8
9 using namespace std;
10
11 namespace Msp {
12 namespace GL {
13 namespace SL {
14
15 Validator::Validator():
16         stage(0),
17         last_provoker(0)
18 { }
19
20 void Validator::diagnose(Node &node, Node &provoking_node, Diagnostic::Severity severity, const string &message)
21 {
22         Diagnostic diag;
23         diag.severity = severity;
24         diag.source = node.source;
25         diag.line = node.line;
26         diag.provoking_source = provoking_node.source;
27         diag.provoking_line = provoking_node.line;
28         diag.message = message;
29         stage->diagnostics.push_back(diag);
30
31         last_provoker = &provoking_node;
32 }
33
34 void Validator::add_info(Node &node, const string &message)
35 {
36         if(!last_provoker)
37                 throw logic_error("Tried to add info without a previous provoker");
38         diagnose(node, *last_provoker, Diagnostic::INFO, message);
39 }
40
41
42 DeclarationValidator::DeclarationValidator():
43         scope(GLOBAL),
44         iface_layout(0),
45         iface_block(0),
46         variable(0)
47 { }
48
49 const char *DeclarationValidator::describe_variable(ScopeType scope)
50 {
51         switch(scope)
52         {
53         case GLOBAL: return "global variable";
54         case STRUCT: return "struct member";
55         case INTERFACE_BLOCK: return "interface block member";
56         case FUNCTION_PARAM: return "function parameter";
57         case FUNCTION: return "local variable";
58         default: return "variable";
59         }
60 }
61
62 void DeclarationValidator::visit(Layout &layout)
63 {
64         for(vector<Layout::Qualifier>::const_iterator i=layout.qualifiers.begin(); i!=layout.qualifiers.end(); ++i)
65         {
66                 bool allowed = false;
67                 string err_descr;
68                 bool value = true;
69                 if(i->name=="location")
70                         allowed = (variable && scope==GLOBAL);
71                 else if(i->name=="binding" || i->name=="set")
72                 {
73                         if(i->name=="set")
74                         {
75                                 error(layout, "Layout qualifier 'set' not allowed when targeting OpenGL");
76                                 continue;
77                         }
78
79                         if(variable)
80                         {
81                                 TypeDeclaration *type = variable->type_declaration;
82                                 while(BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(type))
83                                         type = basic->base_type;
84                                 bool uniform = (variable->interface=="uniform");
85                                 allowed = (scope==GLOBAL && uniform && dynamic_cast<ImageTypeDeclaration *>(type));
86                                 err_descr = (uniform ? "variable of non-opaque type" : "non-uniform variable");
87                         }
88                         else if(iface_block)
89                         {
90                                 allowed = (iface_block->interface=="uniform");
91                                 err_descr = "non-uniform interface block";
92                         }
93                 }
94                 else if(i->name=="constant_id")
95                 {
96                         allowed = (variable && scope==GLOBAL);
97                         if(allowed)
98                         {
99                                 if(!variable->constant)
100                                 {
101                                         allowed = false;
102                                         err_descr = "non-constant variable";
103                                 }
104                                 else
105                                 {
106                                         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(variable->type_declaration);
107                                         if(!basic || basic->kind<BasicTypeDeclaration::BOOL || basic->kind>BasicTypeDeclaration::INT)
108                                         {
109                                                 allowed = false;
110                                                 err_descr = format("variable of type '%s'",
111                                                         (variable->type_declaration ? variable->type_declaration->name : variable->type));
112                                         }
113                                 }
114                         }
115                 }
116                 else if(i->name=="offset")
117                         allowed = (variable && scope==INTERFACE_BLOCK && iface_block->interface=="uniform");
118                 else if(i->name=="align")
119                         allowed = (scope==INTERFACE_BLOCK && iface_block->interface=="uniform");
120                 else if(i->name=="points")
121                 {
122                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && (iface_layout->interface=="in" || iface_layout->interface=="out"));
123                         value = false;
124                 }
125                 else if(i->name=="lines" || i->name=="lines_adjacency" || i->name=="triangles" || i->name=="triangles_adjacency")
126                 {
127                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && iface_layout->interface=="in");
128                         value = false;
129                 }
130                 else if(i->name=="line_strip" || i->name=="triangle_strip")
131                 {
132                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && iface_layout->interface=="out");
133                         value = false;
134                 }
135                 else if(i->name=="invocations")
136                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && iface_layout->interface=="in");
137                 else if(i->name=="max_vertices")
138                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && iface_layout->interface=="out");
139                 else if(i->name=="std140" || i->name=="std430")
140                 {
141                         allowed = (iface_block && !variable && iface_block->interface=="uniform");
142                         value = false;
143                 }
144
145                 if(!allowed)
146                 {
147                         if(err_descr.empty())
148                         {
149                                 if(variable)
150                                         err_descr = describe_variable(scope);
151                                 else if(iface_block)
152                                         err_descr = "interface block";
153                                 else if(iface_layout)
154                                         err_descr = format("interface '%s'", iface_layout->interface);
155                                 else
156                                         err_descr = "unknown declaration";
157                         }
158                         error(layout, format("Layout qualifier '%s' not allowed on %s", i->name, err_descr));
159                 }
160                 else if(value && !i->has_value)
161                         error(layout, format("Layout qualifier '%s' requires a value", i->name));
162                 else if(!value && i->has_value)
163                         error(layout, format("Layout qualifier '%s' does not allow a value", i->name));
164         }
165 }
166
167 void DeclarationValidator::visit(InterfaceLayout &layout)
168 {
169         SetForScope<InterfaceLayout *> set_layout(iface_layout, &layout);
170         TraversingVisitor::visit(layout);
171 }
172
173 void DeclarationValidator::visit(BasicTypeDeclaration &type)
174 {
175         BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type);
176         BasicTypeDeclaration::Kind base_kind = (basic_base ? basic_base->kind : BasicTypeDeclaration::VOID);
177
178         if(type.kind==BasicTypeDeclaration::VECTOR && (base_kind<BasicTypeDeclaration::BOOL || base_kind>BasicTypeDeclaration::FLOAT))
179                 error(type, format("Invalid base type '%s' for vector type '%s'", type.base, type.name));
180         else if(type.kind==BasicTypeDeclaration::MATRIX && base_kind!=BasicTypeDeclaration::VECTOR)
181                 error(type, format("Invalid base type '%s' for matrix type '%s'", type.base, type.name));
182         else if(type.kind==BasicTypeDeclaration::ARRAY && basic_base && base_kind==BasicTypeDeclaration::VOID)
183                 error(type, format("Invalid base type '%s' for array type '%s'", type.base, type.name));
184 }
185
186 void DeclarationValidator::visit(ImageTypeDeclaration &type)
187 {
188         BasicTypeDeclaration::Kind base_kind = BasicTypeDeclaration::VOID;
189         if(BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type))
190                 base_kind = basic_base->kind;
191         if(base_kind!=BasicTypeDeclaration::INT && base_kind!=BasicTypeDeclaration::FLOAT)
192                 error(type, format("Invalid base type '%s' for image type '%s'", type.base, type.name));
193 }
194
195 void DeclarationValidator::visit(StructDeclaration &strct)
196 {
197         SetForScope<ScopeType> set_scope(scope, (scope!=INTERFACE_BLOCK ? STRUCT : scope));
198         TraversingVisitor::visit(strct);
199 }
200
201 void DeclarationValidator::visit(VariableDeclaration &var)
202 {
203         SetForScope<VariableDeclaration *> set_var(variable, &var);
204
205         const char *descr = describe_variable(scope);
206
207         if(var.layout)
208         {
209                 if(scope!=GLOBAL && scope!=INTERFACE_BLOCK)
210                         error(var, format("Layout qualifier not allowed on %s", descr));
211                 else
212                         var.layout->visit(*this);
213         }
214
215         if(var.constant)
216         {
217                 if(scope==STRUCT || scope==INTERFACE_BLOCK)
218                         error(var, format("Constant qualifier not allowed on %s", descr));
219         }
220
221         if(!var.interpolation.empty() || !var.sampling.empty())
222         {
223                 if(var.interface!="in" && stage->type==Stage::VERTEX)
224                         error(var, "Interpolation qualifier not allowed on vertex input");
225                 else if(var.interface!="out" && stage->type==Stage::FRAGMENT)
226                         error(var, "Interpolation qualifier not allowed on fragment output");
227                 else if((var.interface!="in" && var.interface!="out") || (scope==FUNCTION_PARAM || scope==FUNCTION))
228                         error(var, "Interpolation qualifier not allowed on non-interpolated variable");
229         }
230
231         if(!var.interface.empty())
232         {
233                 if(iface_block && var.interface!=iface_block->interface)
234                         error(var, format("Mismatched interface qualifier '%s' inside '%s' block", var.interface, iface_block->interface));
235                 else if(scope==STRUCT || scope==FUNCTION)
236                         error(var, format("Interface qualifier not allowed on %s", descr));
237         }
238
239         TypeDeclaration *type = var.type_declaration;
240         while(BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(type))
241                 type = basic->base_type;
242         if(dynamic_cast<ImageTypeDeclaration *>(type))
243         {
244                 if(scope!=GLOBAL && scope!=FUNCTION_PARAM)
245                         error(var, format("Type '%s' not allowed on %s", type->name, descr));
246                 else if(scope==GLOBAL && var.interface!="uniform")
247                         error(var, format("Type '%s' only allowed with uniform interface", type->name));
248         }
249
250         if(var.init_expression)
251         {
252                 if(scope==GLOBAL && !var.constant)
253                         error(var, format("Initializer not allowed on non-constant %s", descr));
254                 else if(scope!=GLOBAL && scope!=FUNCTION)
255                         error(var, format("Initializer not allowed on %s", descr));
256                 else
257                         var.init_expression->visit(*this);
258         }
259 }
260
261 void DeclarationValidator::visit(InterfaceBlock &iface)
262 {
263         SetForScope<ScopeType> set_scope(scope, INTERFACE_BLOCK);
264         SetForScope<InterfaceBlock *> set_iface(iface_block, &iface);
265
266         if(stage->type==Stage::VERTEX && iface.interface=="in")
267                 error(iface, "Interface block not allowed on vertex shader input");
268         else if(stage->type==Stage::FRAGMENT && iface.interface=="out")
269                 error(iface, "Interface block not allowed on fragment shader output");
270
271         TraversingVisitor::visit(iface);
272         if(iface.struct_declaration)
273                 iface.struct_declaration->visit(*this);
274 }
275
276 void DeclarationValidator::visit(FunctionDeclaration &func)
277 {
278         SetForScope<ScopeType> set_scope(scope, FUNCTION_PARAM);
279         for(NodeArray<VariableDeclaration>::const_iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
280                 (*i)->visit(*this);
281         scope = FUNCTION;
282         func.body.visit(*this);
283 }
284
285
286 IdentifierValidator::IdentifierValidator():
287         anonymous_block(false)
288 { }
289
290 void IdentifierValidator::multiple_definition(const string &name, Statement &statement, Statement &previous)
291 {
292         error(statement, format("Multiple definition of %s", name));
293         add_info(previous, "Previous definition is here");
294 }
295
296 Statement *IdentifierValidator::find_definition(const string &name)
297 {
298         BlockDeclarationMap *decls = &declarations[current_block];
299         BlockDeclarationMap::const_iterator i = decls->find(name);
300         if(i==decls->end() && anonymous_block)
301         {
302                 decls = &declarations[current_block->parent];
303                 i = decls->find(name);
304         }
305         return (i!=decls->end() ? i->second : 0);
306 }
307
308 void IdentifierValidator::check_definition(const string &name, Statement &statement)
309 {
310         if(Statement *previous = find_definition(name))
311                 multiple_definition(format("'%s'", name), statement, *previous);
312         else
313                 record_definition(name, statement);
314 }
315
316 void IdentifierValidator::record_definition(const string &name, Statement &statement)
317 {
318         declarations[current_block].insert(make_pair(name, &statement));
319         if(anonymous_block)
320                 declarations[current_block->parent].insert(make_pair(name, &statement));
321 }
322
323 void IdentifierValidator::visit(TypeDeclaration &type)
324 {
325         check_definition(type.name, type);
326 }
327
328 void IdentifierValidator::visit(StructDeclaration &strct)
329 {
330         check_definition(strct.name, strct);
331         TraversingVisitor::visit(strct);
332 }
333
334 void IdentifierValidator::visit(VariableDeclaration &var)
335 {
336         check_definition(var.name, var);
337         TraversingVisitor::visit(var);
338 }
339
340 void IdentifierValidator::visit(InterfaceBlock &iface)
341 {
342         string key = iface.interface+iface.block_name;
343         map<string, InterfaceBlock *>::const_iterator i = interface_blocks.find(key);
344         if(i!=interface_blocks.end())
345                 multiple_definition(format("interface block '%s %s'", iface.interface, iface.block_name), iface, *i->second);
346         else
347                 interface_blocks.insert(make_pair(key, &iface));
348
349         if(Statement *previous = find_definition(iface.block_name))
350         {
351                 if(!dynamic_cast<InterfaceBlock *>(previous))
352                         multiple_definition(format("'%s'", iface.block_name), iface, *previous);
353         }
354         else
355                 record_definition(iface.block_name, iface);
356
357         if(!iface.instance_name.empty())
358                 check_definition(iface.instance_name, iface);
359
360         if(iface.instance_name.empty() && iface.struct_declaration)
361         {
362                 // Inject anonymous interface block members into the global scope
363                 const map<string, VariableDeclaration *> &iface_vars = iface.struct_declaration->members.variables;
364                 for(map<string, VariableDeclaration *>::const_iterator j=iface_vars.begin(); j!=iface_vars.end(); ++j)
365                         check_definition(j->first, *j->second);
366         }
367 }
368
369 void IdentifierValidator::visit(FunctionDeclaration &func)
370 {
371         string key = func.name+func.signature;
372         map<string, FunctionDeclaration *>::const_iterator i = overloaded_functions.find(key);
373         if(i==overloaded_functions.end())
374                 overloaded_functions.insert(make_pair(key, &func));
375         else if(func.return_type_declaration && i->second->return_type_declaration!=func.return_type_declaration)
376         {
377                 error(func, format("Conflicting return type '%s' for function '%s'", func.return_type_declaration->name, func.name));
378                 if(i->second->return_type_declaration)
379                         add_info(*i->second, format("Previously declared as returning '%s'", i->second->return_type_declaration->name));
380         }
381
382         if(Statement *previous = find_definition(func.name))
383         {
384                 if(!dynamic_cast<FunctionDeclaration *>(previous))
385                         multiple_definition(format("'%s'", func.name), func, *previous);
386         }
387         else
388                 record_definition(func.name, func);
389
390         if(func.definition==&func)
391                 check_definition(func.name+func.signature, func);
392
393         TraversingVisitor::visit(func);
394 }
395
396
397 void ReferenceValidator::visit(BasicTypeDeclaration &type)
398 {
399         if(!type.base.empty() && !type.base_type)
400                 error(type, format("Use of undeclared type '%s'", type.base));
401 }
402
403 void ReferenceValidator::visit(ImageTypeDeclaration &type)
404 {
405         if(!type.base.empty() && !type.base_type)
406                 error(type, format("Use of undeclared type '%s'", type.base));
407 }
408
409 void ReferenceValidator::visit(VariableReference &var)
410 {
411         if(!var.declaration)
412                 error(var, format("Use of undeclared variable '%s'", var.name));
413         else if(stage->type!=Stage::VERTEX && var.declaration->interface=="in" && var.name.compare(0, 3, "gl_") && !var.declaration->linked_declaration)
414                 error(var, format("Use of unlinked input variable '%s'", var.name));
415 }
416
417 void ReferenceValidator::visit(MemberAccess &memacc)
418 {
419         if(memacc.left->type && !memacc.declaration)
420                 error(memacc, format("Use of undeclared member '%s'", memacc.member));
421         TraversingVisitor::visit(memacc);
422 }
423
424 void ReferenceValidator::visit(InterfaceBlockReference &iface)
425 {
426         /* An interface block reference without a declaration should be impossible
427         since references are generated based on existing declarations. */
428         if(!iface.declaration)
429                 error(iface, format("Use of undeclared interface block '%s'", iface.name));
430         else if(stage->type!=Stage::VERTEX && iface.declaration->interface=="in" && !iface.declaration->linked_block)
431                 error(iface, format("Use of unlinked input block '%s'", iface.name));
432 }
433
434 void ReferenceValidator::visit(FunctionCall &call)
435 {
436         if((!call.constructor && !call.declaration) || (call.constructor && !call.type))
437         {
438                 bool have_declaration = call.constructor;
439                 if(!call.constructor)
440                 {
441                         map<string, FunctionDeclaration *>::iterator i = stage->functions.lower_bound(call.name);
442                         have_declaration = (i!=stage->functions.end() && i->second->name==call.name);
443                 }
444
445                 if(have_declaration)
446                 {
447                         bool valid_types = true;
448                         string signature;
449                         for(NodeArray<Expression>::const_iterator j=call.arguments.begin(); (valid_types && j!=call.arguments.end()); ++j)
450                         {
451                                 if((*j)->type)
452                                         append(signature, ", ", (*j)->type->name);
453                                 else
454                                         valid_types = false;
455                         }
456
457                         if(valid_types)
458                                 error(call, format("No matching %s found for '%s(%s)'", (call.constructor ? "constructor" : "overload"), call.name, signature));
459                 }
460                 else
461                         error(call, format("Call to undeclared function '%s'", call.name));
462         }
463         TraversingVisitor::visit(call);
464 }
465
466 void ReferenceValidator::visit(VariableDeclaration &var)
467 {
468         if(!var.type_declaration)
469                 error(var, format("Use of undeclared type '%s'", var.type));
470         TraversingVisitor::visit(var);
471 }
472
473 void ReferenceValidator::visit(InterfaceBlock &iface)
474 {
475         if(!iface.struct_declaration)
476                 error(iface, format("Interface block '%s %s' lacks a struct declaration", iface.interface, iface.block_name));
477         TraversingVisitor::visit(iface);
478 }
479
480 void ReferenceValidator::visit(FunctionDeclaration &func)
481 {
482         if(!func.return_type_declaration)
483                 error(func, format("Use of undeclared type '%s'", func.return_type));
484         TraversingVisitor::visit(func);
485 }
486
487
488 ExpressionValidator::ExpressionValidator():
489         current_function(0)
490 { }
491
492 void ExpressionValidator::visit(Swizzle &swizzle)
493 {
494         unsigned size = 0;
495         if(BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(swizzle.left->type))
496         {
497                 if(basic->kind==BasicTypeDeclaration::INT || basic->kind==BasicTypeDeclaration::FLOAT)
498                         size = 1;
499                 else if(basic->kind==BasicTypeDeclaration::VECTOR)
500                         size = basic->size;
501         }
502
503         if(size)
504         {
505                 static const char component_names[] = { 'x', 'y', 'z', 'w', 'r', 'g', 'b', 'a', 's', 't', 'p', 'q' };
506                 int flavour = -1;
507                 for(unsigned i=0; i<swizzle.count; ++i)
508                 {
509                         unsigned component_flavour = (find(component_names, component_names+12, swizzle.component_group[i])-component_names)/4;
510                         if(flavour==-1)
511                                 flavour = component_flavour;
512                         else if(flavour>=0 && component_flavour!=static_cast<unsigned>(flavour))
513                         {
514                                 error(swizzle, format("Flavour of swizzle component '%c' is inconsistent with '%c'",
515                                         swizzle.component_group[i], swizzle.component_group[0]));
516                                 flavour = -2;
517                         }
518
519                         if(swizzle.components[i]>=size)
520                                 error(swizzle, format("Access to component '%c' which is not present in '%s'",
521                                         swizzle.component_group[i], swizzle.left->type->name));
522                 }
523         }
524         else if(swizzle.left->type)
525                 error(swizzle, format("Swizzle applied to '%s' which is neither a scalar nor a vector", swizzle.left->type->name));
526
527         TraversingVisitor::visit(swizzle);
528 }
529
530 void ExpressionValidator::visit(UnaryExpression &unary)
531 {
532         if(unary.expression->type)
533         {
534                 if(!unary.type)
535                         error(unary, format("No matching operator '%s' found for '%s'", unary.oper->token, unary.expression->type->name));
536                 else if((unary.oper->token[1]=='+' || unary.oper->token[1]=='-') && !unary.expression->lvalue)
537                         error(unary, format("Operand of '%s' is not an lvalue", unary.oper->token));
538         }
539         TraversingVisitor::visit(unary);
540 }
541
542 void ExpressionValidator::visit(BinaryExpression &binary)
543 {
544         if(!binary.type && binary.left->type && binary.right->type)
545         {
546                 if(binary.oper->token[0]=='[')
547                         error(binary, format("Can't index element of '%s' with '%s'",
548                                 binary.left->type->name, binary.right->type->name));
549                 else
550                         error(binary, format("No matching operator '%s' found for '%s' and '%s'",
551                                 binary.oper->token, binary.left->type->name, binary.right->type->name));
552         }
553         TraversingVisitor::visit(binary);
554 }
555
556 void ExpressionValidator::visit(Assignment &assign)
557 {
558         if(assign.left->type)
559         {
560                 if(!assign.left->lvalue)
561                         error(assign, "Target of assignment is not an lvalue");
562                 if(assign.right->type)
563                 {
564                         if(assign.oper->token[0]!='=')
565                         {
566                                 if(!assign.type)
567                                         error(assign, format("No matching operator '%s' found for '%s' and '%s'",
568                                                 string(assign.oper->token, strlen(assign.oper->token)-1), assign.left->type->name, assign.right->type->name));
569                         }
570                         else if(assign.left->type!=assign.right->type)
571                                 error(assign, format("Assignment to variable of type '%s' from expression of incompatible type '%s'",
572                                         assign.left->type->name, assign.right->type->name));
573                 }
574         }
575         TraversingVisitor::visit(assign);
576 }
577
578 void ExpressionValidator::visit(TernaryExpression &ternary)
579 {
580         if(ternary.condition->type)
581         {
582                 BasicTypeDeclaration *basic_cond = dynamic_cast<BasicTypeDeclaration *>(ternary.condition->type);
583                 if(!basic_cond || basic_cond->kind!=BasicTypeDeclaration::BOOL)
584                         error(ternary, "Ternary operator condition is not a boolean");
585                 else if(!ternary.type && ternary.true_expr->type && ternary.false_expr->type)
586                         error(ternary, format("Ternary operator has incompatible types '%s' and '%s'",
587                                 ternary.true_expr->type->name, ternary.false_expr->type->name));
588         }
589         TraversingVisitor::visit(ternary);
590 }
591
592 void ExpressionValidator::visit(VariableDeclaration &var)
593 {
594         if(var.init_expression && var.init_expression->type && var.type_declaration && var.init_expression->type!=var.type_declaration)
595                 error(var, format("Initializing a variable of type '%s' with an expression of incompatible type '%s'",
596                         var.type_declaration->name, var.init_expression->type->name));
597         TraversingVisitor::visit(var);
598 }
599
600 void ExpressionValidator::visit(FunctionDeclaration &func)
601 {
602         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
603         TraversingVisitor::visit(func);
604 }
605
606 void ExpressionValidator::visit(Return &ret)
607 {
608         if(current_function && current_function->return_type_declaration)
609         {
610                 TypeDeclaration *return_type = current_function->return_type_declaration;
611                 BasicTypeDeclaration *basic_return = dynamic_cast<BasicTypeDeclaration *>(return_type);
612                 if(ret.expression)
613                 {
614                         if(ret.expression->type && ret.expression->type!=return_type)
615                                 error(ret, format("Return expression type '%s' is incompatible with declared return type '%s'",
616                                         ret.expression->type->name, return_type->name));
617                 }
618                 else if(!basic_return || basic_return->kind!=BasicTypeDeclaration::VOID)
619                         error(ret, "Return statement without an expression in a function not returning 'void'");
620         }
621
622         TraversingVisitor::visit(ret);
623 }
624
625
626 int StageInterfaceValidator::get_location(const Layout &layout)
627 {
628         for(vector<Layout::Qualifier>::const_iterator i=layout.qualifiers.begin(); i!=layout.qualifiers.end(); ++i)
629                 if(i->name=="location")
630                         return i->value;
631         return -1;
632 }
633
634 void StageInterfaceValidator::visit(VariableDeclaration &var)
635 {
636         int location = (var.layout ? get_location(*var.layout) : -1);
637         if(var.interface=="in" && var.linked_declaration)
638         {
639                 const Layout *linked_layout = var.linked_declaration->layout.get();
640                 int linked_location = (linked_layout ? get_location(*linked_layout) : -1);
641                 if(linked_location!=location)
642                 {
643                         error(var, format("Mismatched location %d for 'in %s'", location, var.name));
644                         add_info(*var.linked_declaration, format("Linked to 'out %s' with location %d",
645                                 var.linked_declaration->name, linked_location));
646                 }
647                 if(var.type_declaration && var.linked_declaration->type_declaration)
648                 {
649                         const TypeDeclaration *type = var.type_declaration;
650                         if(stage->type==Stage::GEOMETRY)
651                         {
652                                 if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(type))
653                                         if(basic->kind==BasicTypeDeclaration::ARRAY && basic->base_type)
654                                                 type = basic->base_type;
655                         }
656                         if(!is_same_type(*type, *var.linked_declaration->type_declaration))
657                         {
658                                 error(var, format("Mismatched type '%s' for 'in %s'", type->name, var.name));
659                                 add_info(*var.linked_declaration, format("Linked to 'out %s' with type '%s'",
660                                         var.linked_declaration->name, var.linked_declaration->type_declaration->name));
661                         }
662                 }
663         }
664
665         if(location>=0 && !var.interface.empty())
666         {
667                 map<unsigned, VariableDeclaration *> &used = used_locations[var.interface];
668
669                 unsigned loc_count = LocationCounter().apply(var);
670                 for(unsigned i=0; i<loc_count; ++i)
671                 {
672                         map<unsigned, VariableDeclaration *>::const_iterator j = used.find(location+i);
673                         if(j!=used.end())
674                         {
675                                 error(var, format("Overlapping location %d for '%s %s'", location+i, var.interface, var.name));
676                                 add_info(*j->second, format("Previously used here for '%s %s'", j->second->interface, j->second->name));
677                         }
678                         else
679                                 used[location+i] = &var;
680                 }
681         }
682 }
683
684
685 void GlobalInterfaceValidator::apply(Module &module)
686 {
687         for(list<Stage>::iterator i=module.stages.begin(); i!=module.stages.end(); ++i)
688         {
689                 stage = &*i;
690                 i->content.visit(*this);
691         }
692 }
693
694 void GlobalInterfaceValidator::check_uniform(const Uniform &uni)
695 {
696         map<std::string, const Uniform *>::const_iterator i = used_names.find(uni.name);
697         if(i!=used_names.end())
698         {
699                 if(uni.location>=0 && i->second->location>=0 && i->second->location!=uni.location)
700                 {
701                         error(*uni.node, format("Mismatched location %d for uniform '%s'", uni.location, uni.name));
702                         add_info(*i->second->node, format("Previously declared here with location %d", i->second->location));
703                 }
704                 if(uni.bind_point>=0 && i->second->bind_point>=0 && i->second->bind_point!=uni.bind_point)
705                 {
706                         error(*uni.node, format("Mismatched binding %d for uniform '%s'", uni.bind_point, uni.name));
707                         add_info(*i->second->node, format("Previously declared here with binding %d", i->second->bind_point));
708                 }
709                 if(uni.type && i->second->type && !is_same_type(*uni.type, *i->second->type))
710                 {
711                         string type_name = (dynamic_cast<const StructDeclaration *>(uni.type) ?
712                                 "structure" : format("type '%s'", uni.type->name));
713                         error(*uni.node, format("Mismatched %s for uniform '%s'", type_name, uni.name));
714
715                         string message = "Previously declared here";
716                         if(!dynamic_cast<const StructDeclaration *>(i->second->type))
717                                 message += format(" with type '%s'", i->second->type->name);
718                         add_info(*i->second->node, message);
719                 }
720         }
721         else
722                 used_names.insert(make_pair(uni.name, &uni));
723
724         if(uni.location>=0)
725         {
726                 map<unsigned, const Uniform *>::const_iterator j = used_locations.find(uni.location);
727                 if(j!=used_locations.end())
728                 {
729                         if(j->second->name!=uni.name)
730                         {
731                                 error(*uni.node, format("Overlapping location %d for '%s'", uni.location, uni.name));
732                                 add_info(*j->second->node, format("Previously used here for '%s'", j->second->name));
733                         }
734                 }
735                 else
736                 {
737                         for(unsigned k=0; k<uni.loc_count; ++k)
738                                 used_locations.insert(make_pair(uni.location+k, &uni));
739                 }
740         }
741
742         if(uni.bind_point>=0)
743         {
744                 map<unsigned, const Uniform *> &used = used_bindings[uni.desc_set];
745                 map<unsigned, const Uniform *>::const_iterator j = used.find(uni.bind_point);
746                 if(j!=used.end())
747                 {
748                         if(j->second->name!=uni.name)
749                         {
750                                 error(*uni.node, format("Overlapping binding %d for '%s'", uni.bind_point, uni.name));
751                                 add_info(*j->second->node, format("Previously used here for '%s'", j->second->name));
752                         }
753                 }
754                 else
755                         used.insert(make_pair(uni.bind_point, &uni));
756         }
757 }
758
759 void GlobalInterfaceValidator::visit(VariableDeclaration &var)
760 {
761         if(var.interface=="uniform")
762         {
763                 Uniform uni;
764                 uni.node = &var;
765                 uni.type = var.type_declaration;
766                 uni.name = var.name;
767                 if(var.layout)
768                 {
769                         uni.location = get_layout_value(*var.layout, "location");
770                         uni.loc_count = LocationCounter().apply(var);
771                         uni.desc_set = get_layout_value(*var.layout, "set", 0);
772                         uni.bind_point = get_layout_value(*var.layout, "binding");
773                 }
774
775                 uniforms.push_back(uni);
776                 check_uniform(uniforms.back());
777         }
778 }
779
780 void GlobalInterfaceValidator::visit(InterfaceBlock &iface)
781 {
782         if(iface.interface=="uniform")
783         {
784                 Uniform uni;
785                 uni.node = &iface;
786                 uni.type = iface.struct_declaration;
787                 uni.name = iface.block_name;
788                 if(iface.layout)
789                 {
790                         uni.desc_set = get_layout_value(*iface.layout, "set", 0);
791                         uni.bind_point = get_layout_value(*iface.layout, "binding");
792                 }
793
794                 uniforms.push_back(uni);
795                 check_uniform(uniforms.back());
796         }
797 }
798
799 } // namespace SL
800 } // namespace GL
801 } // namespace Msp