]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/reflect.cpp
Use forward references for entry point interfaces in SPIR-V
[libs/gl.git] / source / glsl / reflect.cpp
1 #include <msp/core/algorithm.h>
2 #include <msp/core/raii.h>
3 #include "reflect.h"
4
5 using namespace std;
6
7 namespace Msp {
8 namespace GL {
9 namespace SL {
10
11 bool is_scalar(const BasicTypeDeclaration &type)
12 {
13         return (type.kind==BasicTypeDeclaration::INT || type.kind==BasicTypeDeclaration::FLOAT);
14 }
15
16 bool is_vector_or_matrix(const BasicTypeDeclaration &type)
17 {
18         return (type.kind==BasicTypeDeclaration::VECTOR || type.kind==BasicTypeDeclaration::MATRIX);
19 }
20
21 BasicTypeDeclaration *get_element_type(BasicTypeDeclaration &type)
22 {
23         if(is_vector_or_matrix(type) || type.kind==BasicTypeDeclaration::ARRAY)
24         {
25                 BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type);
26                 return (basic_base ? get_element_type(*basic_base) : 0);
27         }
28         else
29                 return &type;
30 }
31
32 bool can_convert(const BasicTypeDeclaration &from, const BasicTypeDeclaration &to)
33 {
34         if(from.kind==BasicTypeDeclaration::INT && to.kind==BasicTypeDeclaration::FLOAT)
35                 return from.size<=to.size;
36         else if(from.kind!=to.kind)
37                 return false;
38         else if(from.kind==BasicTypeDeclaration::INT && from.sign!=to.sign)
39                 return from.sign && from.size<=to.size;
40         else if(is_vector_or_matrix(from) && from.size==to.size)
41         {
42                 BasicTypeDeclaration *from_base = dynamic_cast<BasicTypeDeclaration *>(from.base_type);
43                 BasicTypeDeclaration *to_base = dynamic_cast<BasicTypeDeclaration *>(to.base_type);
44                 return (from_base && to_base && can_convert(*from_base, *to_base));
45         }
46         else
47                 return false;
48 }
49
50
51 unsigned TypeComparer::next_tag = 1;
52
53 void TypeComparer::compare(Node &node1, Node &node2)
54 {
55         if(&node1==&node2)
56                 r_result = true;
57         else
58         {
59                 second = &node2;
60                 node1.visit(*this);
61         }
62 }
63
64 template<typename T>
65 T *TypeComparer::multi_visit(T &node)
66 {
67         static unsigned tag = next_tag++;
68
69         if(second)
70         {
71                 Node *s = second;
72                 first = &node;
73                 first_tag = tag;
74                 second = 0;
75                 s->visit(*this);
76         }
77         else if(!first || tag!=first_tag)
78                 r_result = false;
79         else
80         {
81                 T *f = static_cast<T *>(first);
82                 first = 0;
83                 return f;
84         }
85
86         return 0;
87 }
88
89 void TypeComparer::visit(Literal &literal)
90 {
91         if(Literal *lit1 = multi_visit(literal))
92         {
93                 if(!lit1->type || !literal.type)
94                         r_result = false;
95                 else
96                 {
97                         compare(*lit1->type, *literal.type);
98                         if(r_result)
99                                 r_result = (literal.value.check_type<int>() && lit1->value.value<int>()==literal.value.value<int>());
100                 }
101         }
102 }
103
104 void TypeComparer::visit(VariableReference &var)
105 {
106         if(VariableReference *var1 = multi_visit(var))
107         {
108                 if(!var1->declaration || !var.declaration)
109                         r_result = false;
110                 else if(!var1->declaration->constant || !var.declaration->constant)
111                         r_result = false;
112                 else if(!var1->declaration->init_expression || !var.declaration->init_expression)
113                         r_result = false;
114                 else
115                         compare(*var1->declaration->init_expression, *var.declaration->init_expression);
116         }
117 }
118
119 void TypeComparer::visit(UnaryExpression &unary)
120 {
121         if(UnaryExpression *unary1 = multi_visit(unary))
122         {
123                 if(unary1->oper!=unary.oper)
124                         r_result = false;
125                 else
126                         compare(*unary1->expression, *unary.expression);
127         }
128 }
129
130 void TypeComparer::visit(BinaryExpression &binary)
131 {
132         if(BinaryExpression *binary1 = multi_visit(binary))
133         {
134                 if(binary1->oper!=binary.oper)
135                         r_result = false;
136                 else
137                 {
138                         compare(*binary1->left, *binary.left);
139                         if(r_result)
140                                 compare(*binary1->right, *binary.right);
141                 }
142         }
143 }
144
145 void TypeComparer::visit(TernaryExpression &ternary)
146 {
147         if(TernaryExpression *ternary1 = multi_visit(ternary))
148         {
149                 if(ternary1->oper!=ternary.oper)
150                         r_result = false;
151                 else
152                 {
153                         compare(*ternary1->condition, *ternary.condition);
154                         if(r_result)
155                                 compare(*ternary1->true_expr, *ternary.true_expr);
156                         if(r_result)
157                                 compare(*ternary1->false_expr, *ternary.false_expr);
158                 }
159         }
160 }
161
162 void TypeComparer::visit(FunctionCall &call)
163 {
164         if(FunctionCall *call1 = multi_visit(call))
165         {
166                 if(!call1->constructor || !call.constructor)
167                         r_result = false;
168                 else if(call1->name!=call.name)
169                         r_result = false;
170                 else if(call1->arguments.size()!=call.arguments.size())
171                         r_result = false;
172                 else
173                 {
174                         r_result = true;
175                         for(unsigned i=0; (r_result && i<call.arguments.size()); ++i)
176                                 compare(*call1->arguments[i], *call.arguments[i]);
177                 }
178         }
179 }
180
181 void TypeComparer::visit(BasicTypeDeclaration &basic)
182 {
183         if(BasicTypeDeclaration *basic1 = multi_visit(basic))
184         {
185                 if(basic1->kind!=basic.kind || basic1->size!=basic.size || basic1->sign!=basic.sign)
186                         r_result = false;
187                 else if(basic1->base_type && basic.base_type)
188                         compare(*basic1->base_type, *basic.base_type);
189                 else
190                         r_result = (!basic1->base_type && !basic.base_type);
191         }
192 }
193
194 void TypeComparer::visit(ImageTypeDeclaration &image)
195 {
196         if(ImageTypeDeclaration *image1 = multi_visit(image))
197         {
198                 if(image1->dimensions!=image.dimensions || image1->array!=image.array)
199                         r_result = false;
200                 else if(image1->sampled!=image.sampled || image1->shadow!=image.shadow || image1->multisample!=image.multisample)
201                         r_result = false;
202                 else if(image1->format!=image.format)
203                         r_result = false;
204                 else if(image1->base_type && image.base_type)
205                         compare(*image1->base_type, *image.base_type);
206                 else
207                         r_result = (!image1->base_type && !image.base_type);
208         }
209 }
210
211 void TypeComparer::visit(StructDeclaration &strct)
212 {
213         if(StructDeclaration *strct1 = multi_visit(strct))
214         {
215                 if(strct1->members.body.size()!=strct.members.body.size())
216                         r_result = false;
217                 else
218                 {
219                         r_result = true;
220                         auto i = strct1->members.body.begin();
221                         auto j = strct.members.body.begin();
222                         for(; (r_result && i!=strct1->members.body.end()); ++i, ++j)
223                                 compare(**i, **j);
224                 }
225         }
226 }
227
228 void TypeComparer::visit(VariableDeclaration &var)
229 {
230         if(VariableDeclaration *var1 = multi_visit(var))
231         {
232                 if(var1->name!=var.name || var1->array!=var.array)
233                         r_result = false;
234                 else if(!var1->type_declaration || !var.type_declaration)
235                         r_result = false;
236                 else
237                 {
238                         if(var1->array)
239                         {
240                                 r_result = false;
241                                 if(var1->array_size && var.array_size)
242                                         compare(*var1->array_size, *var.array_size);
243                                 else if(!var1->array_size && !var.array_size)
244                                         r_result = true;
245                         }
246                         if(r_result && var1->type_declaration!=var.type_declaration)
247                                 compare(*var1->type_declaration, *var.type_declaration);
248                         // TODO Compare layout qualifiers for interface block members
249                 }
250         }
251 }
252
253
254 void LocationCounter::visit(BasicTypeDeclaration &basic)
255 {
256         r_count = basic.kind==BasicTypeDeclaration::MATRIX ? basic.size>>16 : 1;
257 }
258
259 void LocationCounter::visit(ImageTypeDeclaration &)
260 {
261         r_count = 1;
262 }
263
264 void LocationCounter::visit(StructDeclaration &strct)
265 {
266         unsigned total = 0;
267         for(const RefPtr<Statement> &s: strct.members.body)
268         {
269                 r_count = 1;
270                 s->visit(*this);
271                 total += r_count;
272         }
273         r_count = total;
274 }
275
276 void LocationCounter::visit(VariableDeclaration &var)
277 {
278         r_count = 1;
279         if(var.type_declaration)
280                 var.type_declaration->visit(*this);
281         if(var.array)
282                 if(const Literal *literal = dynamic_cast<const Literal *>(var.array_size.get()))
283                         if(literal->value.check_type<int>())
284                                 r_count *= literal->value.value<int>();
285 }
286
287
288 void MemoryRequirementsCalculator::visit(BasicTypeDeclaration &basic)
289 {
290         if(basic.kind==BasicTypeDeclaration::BOOL)
291         {
292                 r_size = 1;
293                 r_alignment = 1;
294         }
295         else if(basic.kind==BasicTypeDeclaration::INT || basic.kind==BasicTypeDeclaration::FLOAT)
296         {
297                 r_size = basic.size/8;
298                 r_alignment = r_size;
299         }
300         else if(basic.kind==BasicTypeDeclaration::VECTOR || basic.kind==BasicTypeDeclaration::MATRIX)
301         {
302                 basic.base_type->visit(*this);
303                 unsigned n_elem = basic.size&0xFFFF;
304                 r_size *= n_elem;
305                 if(basic.kind==BasicTypeDeclaration::VECTOR)
306                         r_alignment *= (n_elem==3 ? 4 : n_elem);
307         }
308         else if(basic.kind==BasicTypeDeclaration::ARRAY)
309                 basic.base_type->visit(*this);
310
311         if(basic.extended_alignment)
312                 r_alignment = (r_alignment+15)&~15U;
313 }
314
315 void MemoryRequirementsCalculator::visit(StructDeclaration &strct)
316 {
317         unsigned total = 0;
318         unsigned max_align = 1;
319         for(const RefPtr<Statement> &s: strct.members.body)
320         {
321                 r_size = 0;
322                 r_alignment = 1;
323                 r_offset = -1;
324                 s->visit(*this);
325                 if(r_offset>=0)
326                         total = r_offset;
327                 total += r_alignment-1;
328                 total -= total%r_alignment;
329                 total += r_size;
330                 max_align = max(max_align, r_alignment);
331         }
332         r_size = total;
333         r_alignment = max_align;
334         if(strct.extended_alignment)
335                 r_alignment = (r_alignment+15)&~15U;
336         r_size += r_alignment-1;
337         r_size -= r_size%r_alignment;
338 }
339
340 void MemoryRequirementsCalculator::visit(VariableDeclaration &var)
341 {
342         r_offset = get_layout_value(var.layout.get(), "offset");
343
344         if(var.type_declaration)
345                 var.type_declaration->visit(*this);
346         if(var.array)
347                 if(const Literal *literal = dynamic_cast<const Literal *>(var.array_size.get()))
348                         if(literal->value.check_type<int>())
349                         {
350                                 unsigned aligned_size = r_size+r_alignment-1;
351                                 aligned_size -= aligned_size%r_alignment;
352                                 r_size = aligned_size*literal->value.value<int>();
353                         }
354 }
355
356
357 set<Node *> DependencyCollector::apply(FunctionDeclaration &func)
358 {
359         func.visit(*this);
360         return dependencies;
361 }
362
363 void DependencyCollector::visit(VariableReference &var)
364 {
365         if(var.declaration && !locals.count(var.declaration))
366         {
367                 dependencies.insert(var.declaration);
368                 var.declaration->visit(*this);
369         }
370 }
371
372 void DependencyCollector::visit(FunctionCall &call)
373 {
374         if(call.declaration)
375         {
376                 dependencies.insert(call.declaration);
377                 if(call.declaration->definition)
378                         call.declaration->definition->visit(*this);
379         }
380         TraversingVisitor::visit(call);
381 }
382
383 void DependencyCollector::visit(VariableDeclaration &var)
384 {
385         locals.insert(&var);
386         if(var.type_declaration)
387         {
388                 dependencies.insert(var.type_declaration);
389                 var.type_declaration->visit(*this);
390         }
391
392         TraversingVisitor::visit(var);
393 }
394
395 void DependencyCollector::visit(FunctionDeclaration &func)
396 {
397         if(!visited_functions.count(&func))
398         {
399                 visited_functions.insert(&func);
400                 TraversingVisitor::visit(func);
401         }
402 }
403
404
405 set<Node *> AssignmentCollector::apply(Node &node)
406 {
407         node.visit(*this);
408         return assigned_variables;
409 }
410
411 void AssignmentCollector::visit(VariableReference &var)
412 {
413         if(assignment_target)
414                 assigned_variables.insert(var.declaration);
415 }
416
417 void AssignmentCollector::visit(UnaryExpression &unary)
418 {
419         SetFlag set_assignment(assignment_target, (unary.oper->token[1]=='+' || unary.oper->token[1]=='-'));
420         TraversingVisitor::visit(unary);
421 }
422
423 void AssignmentCollector::visit(BinaryExpression &binary)
424 {
425         binary.left->visit(*this);
426         SetFlag clear_assignment(assignment_target, false);
427         binary.right->visit(*this);
428 }
429
430 void AssignmentCollector::visit(Assignment &assign)
431 {
432         {
433                 SetFlag set_assignment(assignment_target);
434                 assign.left->visit(*this);
435         }
436         assign.right->visit(*this);
437 }
438
439 } // namespace SL
440 } // namespace GL
441 } // namespace Msp