]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/validate.cpp
Disallow bool variables in shader interface blocks
[libs/gl.git] / source / glsl / validate.cpp
1 #include <algorithm>
2 #include <cstring>
3 #include <msp/core/raii.h>
4 #include <msp/strings/format.h>
5 #include <msp/strings/utils.h>
6 #include "reflect.h"
7 #include "validate.h"
8
9 using namespace std;
10
11 namespace Msp {
12 namespace GL {
13 namespace SL {
14
15 void Validator::diagnose(Node &node, Node &provoking_node, Diagnostic::Severity severity, const string &message)
16 {
17         Diagnostic diag;
18         diag.severity = severity;
19         diag.source = node.source;
20         diag.line = node.line;
21         diag.provoking_source = provoking_node.source;
22         diag.provoking_line = provoking_node.line;
23         diag.message = message;
24         stage->diagnostics.push_back(diag);
25
26         last_provoker = &provoking_node;
27 }
28
29 void Validator::add_info(Node &node, const string &message)
30 {
31         if(!last_provoker)
32                 throw logic_error("Tried to add info without a previous provoker");
33         diagnose(node, *last_provoker, Diagnostic::INFO, message);
34 }
35
36
37 const char *DeclarationValidator::describe_variable(ScopeType scope)
38 {
39         switch(scope)
40         {
41         case GLOBAL: return "global variable";
42         case STRUCT: return "struct member";
43         case INTERFACE_BLOCK: return "interface block member";
44         case FUNCTION_PARAM: return "function parameter";
45         case FUNCTION: return "local variable";
46         default: return "variable";
47         }
48 }
49
50 void DeclarationValidator::visit(Layout &layout)
51 {
52         for(const Layout::Qualifier &q: layout.qualifiers)
53         {
54                 bool allowed = false;
55                 string err_descr;
56                 bool value = true;
57                 if(q.name=="location")
58                         allowed = (variable && scope==GLOBAL);
59                 else if(q.name=="binding" || q.name=="set")
60                 {
61                         if(q.name=="set")
62                         {
63                                 error(layout, "Layout qualifier 'set' not allowed when targeting OpenGL");
64                                 continue;
65                         }
66
67                         if(variable)
68                         {
69                                 TypeDeclaration *type = variable->type_declaration;
70                                 while(BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(type))
71                                         type = basic->base_type;
72                                 bool uniform = (variable->interface=="uniform");
73                                 allowed = (scope==GLOBAL && uniform && dynamic_cast<ImageTypeDeclaration *>(type));
74                                 err_descr = (uniform ? "variable of non-opaque type" : "non-uniform variable");
75                         }
76                         else if(iface_block)
77                         {
78                                 allowed = (iface_block->interface=="uniform");
79                                 err_descr = "non-uniform interface block";
80                         }
81                 }
82                 else if(q.name=="constant_id")
83                 {
84                         allowed = (variable && scope==GLOBAL);
85                         if(allowed)
86                         {
87                                 if(!variable->constant)
88                                 {
89                                         allowed = false;
90                                         err_descr = "non-constant variable";
91                                 }
92                                 else
93                                 {
94                                         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(variable->type_declaration);
95                                         if(!basic || basic->kind<BasicTypeDeclaration::BOOL || basic->kind>BasicTypeDeclaration::INT)
96                                         {
97                                                 allowed = false;
98                                                 err_descr = format("variable of type '%s'",
99                                                         (variable->type_declaration ? variable->type_declaration->name : variable->type));
100                                         }
101                                 }
102                         }
103                 }
104                 else if(q.name=="offset")
105                         allowed = (variable && scope==INTERFACE_BLOCK && iface_block->interface=="uniform");
106                 else if(q.name=="align")
107                         allowed = (scope==INTERFACE_BLOCK && iface_block->interface=="uniform");
108                 else if(q.name=="points")
109                 {
110                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && (iface_layout->interface=="in" || iface_layout->interface=="out"));
111                         value = false;
112                 }
113                 else if(q.name=="lines" || q.name=="lines_adjacency" || q.name=="triangles" || q.name=="triangles_adjacency")
114                 {
115                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && iface_layout->interface=="in");
116                         value = false;
117                 }
118                 else if(q.name=="line_strip" || q.name=="triangle_strip")
119                 {
120                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && iface_layout->interface=="out");
121                         value = false;
122                 }
123                 else if(q.name=="invocations")
124                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && iface_layout->interface=="in");
125                 else if(q.name=="max_vertices")
126                         allowed = (stage->type==Stage::GEOMETRY && iface_layout && iface_layout->interface=="out");
127                 else if(q.name=="std140" || q.name=="std430")
128                 {
129                         allowed = (iface_block && !variable && iface_block->interface=="uniform");
130                         value = false;
131                 }
132                 else if(q.name=="column_major" || q.name=="row_major")
133                 {
134                         allowed = (variable && scope==INTERFACE_BLOCK);
135                         if(allowed)
136                         {
137                                 BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(variable->type_declaration);
138                                 while(basic && basic->kind==BasicTypeDeclaration::ARRAY)
139                                         basic = dynamic_cast<BasicTypeDeclaration *>(basic->base_type);
140                                 allowed = (basic && basic->kind==BasicTypeDeclaration::MATRIX);
141                                 err_descr = "non-matrix variable";
142                         }
143                 }
144
145                 if(!allowed)
146                 {
147                         if(err_descr.empty())
148                         {
149                                 if(variable)
150                                         err_descr = describe_variable(scope);
151                                 else if(iface_block)
152                                         err_descr = "interface block";
153                                 else if(iface_layout)
154                                         err_descr = format("interface '%s'", iface_layout->interface);
155                                 else
156                                         err_descr = "unknown declaration";
157                         }
158                         error(layout, format("Layout qualifier '%s' not allowed on %s", q.name, err_descr));
159                 }
160                 else if(value && !q.has_value)
161                         error(layout, format("Layout qualifier '%s' requires a value", q.name));
162                 else if(!value && q.has_value)
163                         error(layout, format("Layout qualifier '%s' does not allow a value", q.name));
164         }
165 }
166
167 void DeclarationValidator::visit(InterfaceLayout &layout)
168 {
169         SetForScope<InterfaceLayout *> set_layout(iface_layout, &layout);
170         TraversingVisitor::visit(layout);
171 }
172
173 void DeclarationValidator::visit(BasicTypeDeclaration &type)
174 {
175         BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type);
176         BasicTypeDeclaration::Kind base_kind = (basic_base ? basic_base->kind : BasicTypeDeclaration::VOID);
177
178         if(type.kind==BasicTypeDeclaration::VECTOR && (base_kind<BasicTypeDeclaration::BOOL || base_kind>BasicTypeDeclaration::FLOAT))
179                 error(type, format("Invalid base type '%s' for vector type '%s'", type.base, type.name));
180         else if(type.kind==BasicTypeDeclaration::MATRIX && base_kind!=BasicTypeDeclaration::VECTOR)
181                 error(type, format("Invalid base type '%s' for matrix type '%s'", type.base, type.name));
182         else if(type.kind==BasicTypeDeclaration::ARRAY && basic_base && base_kind==BasicTypeDeclaration::VOID)
183                 error(type, format("Invalid base type '%s' for array type '%s'", type.base, type.name));
184 }
185
186 void DeclarationValidator::visit(ImageTypeDeclaration &type)
187 {
188         BasicTypeDeclaration::Kind base_kind = BasicTypeDeclaration::VOID;
189         if(BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type))
190                 base_kind = basic_base->kind;
191         if(base_kind!=BasicTypeDeclaration::INT && base_kind!=BasicTypeDeclaration::FLOAT)
192                 error(type, format("Invalid base type '%s' for image type '%s'", type.base, type.name));
193 }
194
195 void DeclarationValidator::visit(StructDeclaration &strct)
196 {
197         SetForScope<ScopeType> set_scope(scope, (scope!=INTERFACE_BLOCK ? STRUCT : scope));
198         TraversingVisitor::visit(strct);
199 }
200
201 void DeclarationValidator::visit(VariableDeclaration &var)
202 {
203         SetForScope<VariableDeclaration *> set_var(variable, &var);
204
205         const char *descr = describe_variable(scope);
206
207         if(var.layout)
208         {
209                 if(scope!=GLOBAL && scope!=INTERFACE_BLOCK)
210                         error(var, format("Layout qualifier not allowed on %s", descr));
211                 else
212                         var.layout->visit(*this);
213         }
214
215         if(var.constant)
216         {
217                 if(scope==STRUCT || scope==INTERFACE_BLOCK)
218                         error(var, format("Constant qualifier not allowed on %s", descr));
219                 if(!var.init_expression)
220                         error(var, "Constant variable must have an initializer");
221         }
222
223         if(!var.interpolation.empty() || !var.sampling.empty())
224         {
225                 if(var.interface!="in" && stage->type==Stage::VERTEX)
226                         error(var, "Interpolation qualifier not allowed on vertex input");
227                 else if(var.interface!="out" && stage->type==Stage::FRAGMENT)
228                         error(var, "Interpolation qualifier not allowed on fragment output");
229                 else if((var.interface!="in" && var.interface!="out") || (scope==FUNCTION_PARAM || scope==FUNCTION))
230                         error(var, "Interpolation qualifier not allowed on non-interpolated variable");
231         }
232
233         if(!var.interface.empty())
234         {
235                 if(iface_block && var.interface!=iface_block->interface)
236                         error(var, format("Mismatched interface qualifier '%s' inside '%s' block", var.interface, iface_block->interface));
237                 else if(scope==STRUCT || scope==FUNCTION)
238                         error(var, format("Interface qualifier not allowed on %s", descr));
239         }
240
241         TypeDeclaration *type = var.type_declaration;
242         BasicTypeDeclaration::Kind kind = BasicTypeDeclaration::ALIAS;
243         while(BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(type))
244         {
245                 kind = basic->kind;
246                 type = basic->base_type;
247         }
248         if(dynamic_cast<ImageTypeDeclaration *>(type))
249         {
250                 if(scope!=GLOBAL && scope!=FUNCTION_PARAM)
251                         error(var, format("Type '%s' not allowed on %s", type->name, descr));
252                 else if(scope==GLOBAL && var.interface!="uniform")
253                         error(var, format("Type '%s' only allowed with uniform interface", type->name));
254         }
255         else if(kind==BasicTypeDeclaration::VOID)
256                 error(var, "Type 'void' not allowed on variable");
257         else if(kind==BasicTypeDeclaration::BOOL && var.source!=BUILTIN_SOURCE)
258         {
259                 if(scope==INTERFACE_BLOCK)
260                         error(var, "Type 'bool' not allowed in an interface block");
261                 else if(!var.interface.empty())
262                         error(var, "Type 'bool' not allowed on interface variable");
263         }
264
265         if(var.init_expression)
266         {
267                 if(scope==GLOBAL && !var.constant)
268                         error(var, format("Initializer not allowed on non-constant %s", descr));
269                 else if(scope!=GLOBAL && scope!=FUNCTION)
270                         error(var, format("Initializer not allowed on %s", descr));
271                 else
272                         var.init_expression->visit(*this);
273         }
274 }
275
276 void DeclarationValidator::visit(InterfaceBlock &iface)
277 {
278         SetForScope<ScopeType> set_scope(scope, INTERFACE_BLOCK);
279         SetForScope<InterfaceBlock *> set_iface(iface_block, &iface);
280
281         if(stage->type==Stage::VERTEX && iface.interface=="in")
282                 error(iface, "Interface block not allowed on vertex shader input");
283         else if(stage->type==Stage::FRAGMENT && iface.interface=="out")
284                 error(iface, "Interface block not allowed on fragment shader output");
285
286         TraversingVisitor::visit(iface);
287         if(iface.struct_declaration)
288                 iface.struct_declaration->visit(*this);
289 }
290
291 void DeclarationValidator::visit(FunctionDeclaration &func)
292 {
293         SetForScope<ScopeType> set_scope(scope, FUNCTION_PARAM);
294         for(const RefPtr<VariableDeclaration> &p: func.parameters)
295                 p->visit(*this);
296         scope = FUNCTION;
297         func.body.visit(*this);
298 }
299
300
301 void IdentifierValidator::multiple_definition(const string &name, Statement &statement, Statement &previous)
302 {
303         error(statement, format("Multiple definition of %s", name));
304         add_info(previous, "Previous definition is here");
305 }
306
307 Statement *IdentifierValidator::find_definition(const string &name)
308 {
309         BlockDeclarationMap *decls = &declarations[current_block];
310         auto i = decls->find(name);
311         if(i==decls->end() && anonymous_block)
312         {
313                 decls = &declarations[current_block->parent];
314                 i = decls->find(name);
315         }
316         return (i!=decls->end() ? i->second : 0);
317 }
318
319 void IdentifierValidator::check_definition(const string &name, Statement &statement)
320 {
321         if(Statement *previous = find_definition(name))
322                 multiple_definition(format("'%s'", name), statement, *previous);
323         else
324                 record_definition(name, statement);
325 }
326
327 void IdentifierValidator::record_definition(const string &name, Statement &statement)
328 {
329         declarations[current_block].insert(make_pair(name, &statement));
330         if(anonymous_block)
331                 declarations[current_block->parent].insert(make_pair(name, &statement));
332 }
333
334 void IdentifierValidator::visit(TypeDeclaration &type)
335 {
336         check_definition(type.name, type);
337 }
338
339 void IdentifierValidator::visit(StructDeclaration &strct)
340 {
341         check_definition(strct.name, strct);
342         TraversingVisitor::visit(strct);
343 }
344
345 void IdentifierValidator::visit(VariableDeclaration &var)
346 {
347         check_definition(var.name, var);
348         TraversingVisitor::visit(var);
349 }
350
351 void IdentifierValidator::visit(InterfaceBlock &iface)
352 {
353         string key = format("%s %s", iface.interface, iface.block_name);
354         auto i = interface_blocks.find(key);
355         if(i!=interface_blocks.end())
356                 multiple_definition(format("interface block '%s %s'", iface.interface, iface.block_name), iface, *i->second);
357         else
358                 interface_blocks.insert(make_pair(key, &iface));
359
360         if(Statement *previous = find_definition(iface.block_name))
361         {
362                 if(!dynamic_cast<InterfaceBlock *>(previous))
363                         multiple_definition(format("'%s'", iface.block_name), iface, *previous);
364         }
365         else
366                 record_definition(iface.block_name, iface);
367
368         if(!iface.instance_name.empty())
369                 check_definition(iface.instance_name, iface);
370
371         if(iface.instance_name.empty() && iface.struct_declaration)
372         {
373                 // Inject anonymous interface block members into the global scope
374                 for(const auto &kvp: iface.struct_declaration->members.variables)
375                         check_definition(kvp.first, *kvp.second);
376         }
377 }
378
379 void IdentifierValidator::visit(FunctionDeclaration &func)
380 {
381         string key = func.name+func.signature;
382         auto i = overloaded_functions.find(key);
383         if(i==overloaded_functions.end())
384                 overloaded_functions.insert(make_pair(key, &func));
385         else if(func.return_type_declaration && i->second->return_type_declaration!=func.return_type_declaration)
386         {
387                 error(func, format("Conflicting return type '%s' for function '%s'", func.return_type_declaration->name, func.name));
388                 if(i->second->return_type_declaration)
389                         add_info(*i->second, format("Previously declared as returning '%s'", i->second->return_type_declaration->name));
390         }
391
392         if(Statement *previous = find_definition(func.name))
393         {
394                 if(!dynamic_cast<FunctionDeclaration *>(previous))
395                         multiple_definition(format("'%s'", func.name), func, *previous);
396         }
397         else
398                 record_definition(func.name, func);
399
400         if(func.definition==&func)
401                 check_definition(func.name+func.signature, func);
402
403         TraversingVisitor::visit(func);
404 }
405
406
407 void ReferenceValidator::visit(BasicTypeDeclaration &type)
408 {
409         if(!type.base.empty() && !type.base_type)
410                 error(type, format("Use of undeclared type '%s'", type.base));
411 }
412
413 void ReferenceValidator::visit(ImageTypeDeclaration &type)
414 {
415         if(!type.base.empty() && !type.base_type)
416                 error(type, format("Use of undeclared type '%s'", type.base));
417 }
418
419 void ReferenceValidator::visit(VariableReference &var)
420 {
421         if(!var.declaration)
422                 error(var, format("Use of undeclared variable '%s'", var.name));
423         else if(stage->type!=Stage::VERTEX && var.declaration->interface=="in" && var.name.compare(0, 3, "gl_") && !var.declaration->linked_declaration)
424                 error(var, format("Use of unlinked input variable '%s'", var.name));
425 }
426
427 void ReferenceValidator::visit(MemberAccess &memacc)
428 {
429         if(memacc.left->type && !memacc.declaration)
430                 error(memacc, format("Use of undeclared member '%s'", memacc.member));
431         TraversingVisitor::visit(memacc);
432 }
433
434 void ReferenceValidator::visit(InterfaceBlockReference &iface)
435 {
436         /* An interface block reference without a declaration should be impossible
437         since references are generated based on existing declarations. */
438         if(!iface.declaration)
439                 error(iface, format("Use of undeclared interface block '%s'", iface.name));
440         else if(stage->type!=Stage::VERTEX && iface.declaration->interface=="in" && !iface.declaration->linked_block)
441                 error(iface, format("Use of unlinked input block '%s'", iface.name));
442 }
443
444 void ReferenceValidator::visit(FunctionCall &call)
445 {
446         if((!call.constructor && !call.declaration) || (call.constructor && !call.type))
447         {
448                 bool have_declaration = call.constructor;
449                 if(!call.constructor)
450                 {
451                         auto i = stage->functions.lower_bound(call.name);
452                         have_declaration = (i!=stage->functions.end() && i->second->name==call.name);
453                 }
454
455                 if(have_declaration)
456                 {
457                         bool valid_types = true;
458                         string signature;
459                         for(auto j=call.arguments.begin(); (valid_types && j!=call.arguments.end()); ++j)
460                         {
461                                 if((*j)->type)
462                                         append(signature, ", ", (*j)->type->name);
463                                 else
464                                         valid_types = false;
465                         }
466
467                         if(valid_types)
468                                 error(call, format("No matching %s found for '%s(%s)'", (call.constructor ? "constructor" : "overload"), call.name, signature));
469                 }
470                 else
471                         error(call, format("Call to undeclared function '%s'", call.name));
472         }
473         TraversingVisitor::visit(call);
474 }
475
476 void ReferenceValidator::visit(VariableDeclaration &var)
477 {
478         if(!var.type_declaration)
479                 error(var, format("Use of undeclared type '%s'", var.type));
480         TraversingVisitor::visit(var);
481 }
482
483 void ReferenceValidator::visit(InterfaceBlock &iface)
484 {
485         if(!iface.struct_declaration)
486                 error(iface, format("Interface block '%s %s' lacks a struct declaration", iface.interface, iface.block_name));
487         TraversingVisitor::visit(iface);
488 }
489
490 void ReferenceValidator::visit(FunctionDeclaration &func)
491 {
492         if(!func.return_type_declaration)
493                 error(func, format("Use of undeclared type '%s'", func.return_type));
494         TraversingVisitor::visit(func);
495 }
496
497
498 void ExpressionValidator::visit(VariableReference &var)
499 {
500         if(var.declaration && constant_expression && !var.declaration->constant)
501                 error(var, format("Reference to non-constant variable '%s' in a constant expression", var.name));
502 }
503
504 void ExpressionValidator::visit(InterfaceBlockReference &iface)
505 {
506         if(constant_expression)
507                 error(iface, format("Reference to interface block '%s' in a constant expression", iface.name));
508 }
509
510 void ExpressionValidator::visit(Swizzle &swizzle)
511 {
512         unsigned size = 0;
513         if(BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(swizzle.left->type))
514         {
515                 if(basic->kind==BasicTypeDeclaration::INT || basic->kind==BasicTypeDeclaration::FLOAT)
516                         size = 1;
517                 else if(basic->kind==BasicTypeDeclaration::VECTOR)
518                         size = basic->size;
519         }
520
521         if(size)
522         {
523                 static const char component_names[] = { 'x', 'y', 'z', 'w', 'r', 'g', 'b', 'a', 's', 't', 'p', 'q' };
524                 int flavour = -1;
525                 for(unsigned i=0; i<swizzle.count; ++i)
526                 {
527                         unsigned component_flavour = (find(component_names, component_names+12, swizzle.component_group[i])-component_names)/4;
528                         if(flavour==-1)
529                                 flavour = component_flavour;
530                         else if(flavour>=0 && component_flavour!=static_cast<unsigned>(flavour))
531                         {
532                                 error(swizzle, format("Flavour of swizzle component '%c' is inconsistent with '%c'",
533                                         swizzle.component_group[i], swizzle.component_group[0]));
534                                 flavour = -2;
535                         }
536
537                         if(swizzle.components[i]>=size)
538                                 error(swizzle, format("Access to component '%c' which is not present in '%s'",
539                                         swizzle.component_group[i], swizzle.left->type->name));
540                 }
541         }
542         else if(swizzle.left->type)
543                 error(swizzle, format("Swizzle applied to '%s' which is neither a scalar nor a vector", swizzle.left->type->name));
544
545         TraversingVisitor::visit(swizzle);
546 }
547
548 void ExpressionValidator::visit(UnaryExpression &unary)
549 {
550         if(unary.expression->type)
551         {
552                 if(!unary.type)
553                         error(unary, format("No matching operator '%s' found for '%s'", unary.oper->token, unary.expression->type->name));
554                 else if(unary.oper->token[1]=='+' || unary.oper->token[1]=='-')
555                 {
556                         if(constant_expression)
557                                 error(unary, format("Use of '%s' in a constant expression", unary.oper->token));
558                         else if(!unary.expression->lvalue)
559                                 error(unary, format("Operand of '%s' is not an lvalue", unary.oper->token));
560                 }
561         }
562         TraversingVisitor::visit(unary);
563 }
564
565 void ExpressionValidator::visit(BinaryExpression &binary)
566 {
567         if(!binary.type && binary.left->type && binary.right->type)
568         {
569                 if(binary.oper->token[0]=='[')
570                         error(binary, format("Can't index element of '%s' with '%s'",
571                                 binary.left->type->name, binary.right->type->name));
572                 else
573                         error(binary, format("No matching operator '%s' found for '%s' and '%s'",
574                                 binary.oper->token, binary.left->type->name, binary.right->type->name));
575         }
576         TraversingVisitor::visit(binary);
577 }
578
579 void ExpressionValidator::visit(Assignment &assign)
580 {
581         if(assign.left->type)
582         {
583                 if(constant_expression)
584                         error(assign, "Assignment in constant expression");
585                 else if(!assign.left->lvalue)
586                         error(assign, "Target of assignment is not an lvalue");
587                 if(assign.right->type)
588                 {
589                         if(assign.oper->token[0]!='=')
590                         {
591                                 if(!assign.type)
592                                         error(assign, format("No matching operator '%s' found for '%s' and '%s'",
593                                                 string(assign.oper->token, strlen(assign.oper->token)-1), assign.left->type->name, assign.right->type->name));
594                         }
595                         else if(assign.left->type!=assign.right->type)
596                                 error(assign, format("Assignment to variable of type '%s' from expression of incompatible type '%s'",
597                                         assign.left->type->name, assign.right->type->name));
598                 }
599         }
600         TraversingVisitor::visit(assign);
601 }
602
603 void ExpressionValidator::visit(TernaryExpression &ternary)
604 {
605         if(ternary.condition->type)
606         {
607                 BasicTypeDeclaration *basic_cond = dynamic_cast<BasicTypeDeclaration *>(ternary.condition->type);
608                 if(!basic_cond || basic_cond->kind!=BasicTypeDeclaration::BOOL)
609                         error(ternary, "Ternary operator condition is not a boolean");
610                 else if(!ternary.type && ternary.true_expr->type && ternary.false_expr->type)
611                         error(ternary, format("Ternary operator has incompatible types '%s' and '%s'",
612                                 ternary.true_expr->type->name, ternary.false_expr->type->name));
613         }
614         TraversingVisitor::visit(ternary);
615 }
616
617 void ExpressionValidator::visit(VariableDeclaration &var)
618 {
619         if(var.init_expression && var.init_expression->type && var.type_declaration && var.init_expression->type!=var.type_declaration)
620                 error(var, format("Initializing a variable of type '%s' with an expression of incompatible type '%s'",
621                         var.type_declaration->name, var.init_expression->type->name));
622
623         if(var.layout)
624                 var.layout->visit(*this);
625         if(var.init_expression)
626         {
627                 SetFlag set_const(constant_expression, var.constant);
628                 TraversingVisitor::visit(var.init_expression);
629         }
630         if(var.array_size)
631         {
632                 SetFlag set_const(constant_expression);
633                 TraversingVisitor::visit(var.array_size);
634         }
635 }
636
637 void ExpressionValidator::visit(FunctionDeclaration &func)
638 {
639         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
640         TraversingVisitor::visit(func);
641 }
642
643 void ExpressionValidator::visit(Conditional &cond)
644 {
645         if(cond.condition->type)
646         {
647                 BasicTypeDeclaration *basic_cond = dynamic_cast<BasicTypeDeclaration *>(cond.condition->type);
648                 if(!basic_cond || basic_cond->kind!=BasicTypeDeclaration::BOOL)
649                         error(cond, "Condition is not a boolean");
650         }
651         TraversingVisitor::visit(cond);
652 }
653
654 void ExpressionValidator::visit(Iteration &iter)
655 {
656         if(iter.condition->type)
657         {
658                 BasicTypeDeclaration *basic_cond = dynamic_cast<BasicTypeDeclaration *>(iter.condition->type);
659                 if(!basic_cond || basic_cond->kind!=BasicTypeDeclaration::BOOL)
660                         error(iter, "Loop condition is not a boolean");
661         }
662         TraversingVisitor::visit(iter);
663 }
664
665 void ExpressionValidator::visit(Return &ret)
666 {
667         if(current_function && current_function->return_type_declaration)
668         {
669                 TypeDeclaration *return_type = current_function->return_type_declaration;
670                 BasicTypeDeclaration *basic_return = dynamic_cast<BasicTypeDeclaration *>(return_type);
671                 if(ret.expression)
672                 {
673                         if(ret.expression->type && ret.expression->type!=return_type)
674                                 error(ret, format("Return expression type '%s' is incompatible with declared return type '%s'",
675                                         ret.expression->type->name, return_type->name));
676                 }
677                 else if(!basic_return || basic_return->kind!=BasicTypeDeclaration::VOID)
678                         error(ret, "Return statement without an expression in a function not returning 'void'");
679         }
680
681         TraversingVisitor::visit(ret);
682 }
683
684
685 void FlowControlValidator::visit(Block &block)
686 {
687         for(const RefPtr<Statement> &s: block.body)
688         {
689                 if(!reachable)
690                 {
691                         diagnose(*s, Diagnostic::WARN, "Unreachable code detected");
692                         break;
693                 }
694                 s->visit(*this);
695         }
696 }
697
698 void FlowControlValidator::visit(FunctionDeclaration &func)
699 {
700         func.body.visit(*this);
701
702         if(func.definition==&func && func.return_type_declaration)
703         {
704                 const BasicTypeDeclaration *basic_ret = dynamic_cast<const BasicTypeDeclaration *>(func.return_type_declaration);
705                 if(reachable && (!basic_ret || basic_ret->kind!=BasicTypeDeclaration::VOID))
706                         error(func, "Missing return statement at the end of a function not returning 'void'");
707         }
708         reachable = true;
709 }
710
711 void FlowControlValidator::visit(Conditional &cond)
712 {
713         cond.body.visit(*this);
714         bool reachable_if_true = reachable;
715         reachable = true;
716         cond.else_body.visit(*this);
717         reachable |= reachable_if_true;
718 }
719
720 void FlowControlValidator::visit(Iteration &iter)
721 {
722         iter.body.visit(*this);
723         reachable = true;
724 }
725
726
727 int StageInterfaceValidator::get_location(const Layout &layout)
728 {
729         return get_layout_value(layout, "location", -1);
730 }
731
732 void StageInterfaceValidator::visit(VariableDeclaration &var)
733 {
734         int location = (var.layout ? get_location(*var.layout) : -1);
735         if(var.interface=="in" && var.linked_declaration)
736         {
737                 const Layout *linked_layout = var.linked_declaration->layout.get();
738                 int linked_location = (linked_layout ? get_location(*linked_layout) : -1);
739                 if(linked_location!=location)
740                 {
741                         error(var, format("Mismatched location %d for 'in %s'", location, var.name));
742                         add_info(*var.linked_declaration, format("Linked to 'out %s' with location %d",
743                                 var.linked_declaration->name, linked_location));
744                 }
745                 if(var.type_declaration && var.linked_declaration->type_declaration)
746                 {
747                         TypeDeclaration *type = var.type_declaration;
748                         if(stage->type==Stage::GEOMETRY)
749                         {
750                                 if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(type))
751                                         if(basic->kind==BasicTypeDeclaration::ARRAY && basic->base_type)
752                                                 type = basic->base_type;
753                         }
754                         if(!TypeComparer().apply(*type, *var.linked_declaration->type_declaration))
755                         {
756                                 error(var, format("Mismatched type '%s' for 'in %s'", type->name, var.name));
757                                 add_info(*var.linked_declaration, format("Linked to 'out %s' with type '%s'",
758                                         var.linked_declaration->name, var.linked_declaration->type_declaration->name));
759                         }
760                 }
761         }
762
763         if(location>=0 && !var.interface.empty())
764         {
765                 map<unsigned, VariableDeclaration *> &used = used_locations[var.interface];
766
767                 unsigned loc_count = LocationCounter().apply(var);
768                 for(unsigned i=0; i<loc_count; ++i)
769                 {
770                         auto j = used.find(location+i);
771                         if(j!=used.end())
772                         {
773                                 error(var, format("Overlapping location %d for '%s %s'", location+i, var.interface, var.name));
774                                 add_info(*j->second, format("Previously used here for '%s %s'", j->second->interface, j->second->name));
775                         }
776                         else
777                                 used[location+i] = &var;
778                 }
779         }
780 }
781
782
783 void GlobalInterfaceValidator::apply(Module &module)
784 {
785         for(Stage &s: module.stages)
786         {
787                 stage = &s;
788                 s.content.visit(*this);
789         }
790 }
791
792 void GlobalInterfaceValidator::check_uniform(const Uniform &uni)
793 {
794         auto i = used_names.find(uni.name);
795         if(i!=used_names.end())
796         {
797                 if(uni.location>=0 && i->second->location>=0 && i->second->location!=uni.location)
798                 {
799                         error(*uni.node, format("Mismatched location %d for uniform '%s'", uni.location, uni.name));
800                         add_info(*i->second->node, format("Previously declared here with location %d", i->second->location));
801                 }
802                 if(uni.bind_point>=0 && i->second->bind_point>=0 && i->second->bind_point!=uni.bind_point)
803                 {
804                         error(*uni.node, format("Mismatched binding %d for uniform '%s'", uni.bind_point, uni.name));
805                         add_info(*i->second->node, format("Previously declared here with binding %d", i->second->bind_point));
806                 }
807                 if(uni.type && i->second->type && !TypeComparer().apply(*uni.type, *i->second->type))
808                 {
809                         string type_name = (dynamic_cast<const StructDeclaration *>(uni.type) ?
810                                 "structure" : format("type '%s'", uni.type->name));
811                         error(*uni.node, format("Mismatched %s for uniform '%s'", type_name, uni.name));
812
813                         string message = "Previously declared here";
814                         if(!dynamic_cast<const StructDeclaration *>(i->second->type))
815                                 message += format(" with type '%s'", i->second->type->name);
816                         add_info(*i->second->node, message);
817                 }
818         }
819         else
820                 used_names.insert(make_pair(uni.name, &uni));
821
822         if(uni.location>=0)
823         {
824                 auto j = used_locations.find(uni.location);
825                 if(j!=used_locations.end())
826                 {
827                         if(j->second->name!=uni.name)
828                         {
829                                 error(*uni.node, format("Overlapping location %d for '%s'", uni.location, uni.name));
830                                 add_info(*j->second->node, format("Previously used here for '%s'", j->second->name));
831                         }
832                 }
833                 else
834                 {
835                         for(unsigned k=0; k<uni.loc_count; ++k)
836                                 used_locations.insert(make_pair(uni.location+k, &uni));
837                 }
838         }
839
840         if(uni.bind_point>=0)
841         {
842                 map<unsigned, const Uniform *> &used = used_bindings[uni.desc_set];
843                 auto j = used.find(uni.bind_point);
844                 if(j!=used.end())
845                 {
846                         if(j->second->name!=uni.name)
847                         {
848                                 error(*uni.node, format("Overlapping binding %d for '%s'", uni.bind_point, uni.name));
849                                 add_info(*j->second->node, format("Previously used here for '%s'", j->second->name));
850                         }
851                 }
852                 else
853                         used.insert(make_pair(uni.bind_point, &uni));
854         }
855 }
856
857 void GlobalInterfaceValidator::visit(VariableDeclaration &var)
858 {
859         if(var.interface=="uniform")
860         {
861                 Uniform uni;
862                 uni.node = &var;
863                 uni.type = var.type_declaration;
864                 uni.name = var.name;
865                 if(var.layout)
866                 {
867                         uni.location = get_layout_value(*var.layout, "location");
868                         uni.loc_count = LocationCounter().apply(var);
869                         uni.desc_set = get_layout_value(*var.layout, "set", 0);
870                         uni.bind_point = get_layout_value(*var.layout, "binding");
871                 }
872
873                 uniforms.push_back(uni);
874                 check_uniform(uniforms.back());
875         }
876 }
877
878 void GlobalInterfaceValidator::visit(InterfaceBlock &iface)
879 {
880         if(iface.interface=="uniform")
881         {
882                 Uniform uni;
883                 uni.node = &iface;
884                 uni.type = iface.struct_declaration;
885                 uni.name = iface.block_name;
886                 if(iface.layout)
887                 {
888                         uni.desc_set = get_layout_value(*iface.layout, "set", 0);
889                         uni.bind_point = get_layout_value(*iface.layout, "binding");
890                 }
891
892                 uniforms.push_back(uni);
893                 check_uniform(uniforms.back());
894         }
895 }
896
897 } // namespace SL
898 } // namespace GL
899 } // namespace Msp