]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/reflect.cpp
Visit function definitions while collection dependencies
[libs/gl.git] / source / glsl / reflect.cpp
1 #include "reflect.h"
2
3 using namespace std;
4
5 namespace Msp {
6 namespace GL {
7 namespace SL {
8
9 bool is_scalar(const BasicTypeDeclaration &type)
10 {
11         return (type.kind==BasicTypeDeclaration::INT || type.kind==BasicTypeDeclaration::FLOAT);
12 }
13
14 bool is_vector_or_matrix(const BasicTypeDeclaration &type)
15 {
16         return (type.kind==BasicTypeDeclaration::VECTOR || type.kind==BasicTypeDeclaration::MATRIX);
17 }
18
19 BasicTypeDeclaration *get_element_type(BasicTypeDeclaration &type)
20 {
21         if(is_vector_or_matrix(type) || type.kind==BasicTypeDeclaration::ARRAY)
22         {
23                 BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type);
24                 return (basic_base ? get_element_type(*basic_base) : 0);
25         }
26         else
27                 return &type;
28 }
29
30 bool can_convert(const BasicTypeDeclaration &from, const BasicTypeDeclaration &to)
31 {
32         if(from.kind==BasicTypeDeclaration::INT && to.kind==BasicTypeDeclaration::FLOAT)
33                 return from.size<=to.size;
34         else if(from.kind!=to.kind)
35                 return false;
36         else if(from.kind==BasicTypeDeclaration::INT && from.sign!=to.sign)
37                 return from.sign && from.size<=to.size;
38         else if(is_vector_or_matrix(from) && from.size==to.size)
39         {
40                 BasicTypeDeclaration *from_base = dynamic_cast<BasicTypeDeclaration *>(from.base_type);
41                 BasicTypeDeclaration *to_base = dynamic_cast<BasicTypeDeclaration *>(to.base_type);
42                 return (from_base && to_base && can_convert(*from_base, *to_base));
43         }
44         else
45                 return false;
46 }
47
48
49 unsigned TypeComparer::next_tag = 1;
50
51 TypeComparer::TypeComparer():
52         first(0),
53         second(0),
54         first_tag(0),
55         r_result(false)
56 { }
57
58 void TypeComparer::compare(Node &node1, Node &node2)
59 {
60         if(&node1==&node2)
61                 r_result = true;
62         else
63         {
64                 second = &node2;
65                 node1.visit(*this);
66         }
67 }
68
69 template<typename T>
70 T *TypeComparer::multi_visit(T &node)
71 {
72         static unsigned tag = next_tag++;
73
74         if(second)
75         {
76                 Node *s = second;
77                 first = &node;
78                 first_tag = tag;
79                 second = 0;
80                 s->visit(*this);
81         }
82         else if(!first || tag!=first_tag)
83                 r_result = false;
84         else
85         {
86                 T *f = static_cast<T *>(first);
87                 first = 0;
88                 return f;
89         }
90
91         return 0;
92 }
93
94 void TypeComparer::visit(Literal &literal)
95 {
96         if(Literal *lit1 = multi_visit(literal))
97         {
98                 if(!lit1->type || !literal.type)
99                         r_result = false;
100                 else
101                 {
102                         compare(*lit1->type, *literal.type);
103                         if(r_result)
104                                 r_result = (literal.value.check_type<int>() && lit1->value.value<int>()==literal.value.value<int>());
105                 }
106         }
107 }
108
109 void TypeComparer::visit(VariableReference &var)
110 {
111         if(VariableReference *var1 = multi_visit(var))
112         {
113                 if(!var1->declaration || !var.declaration)
114                         r_result = false;
115                 else if(!var1->declaration->constant || !var.declaration->constant)
116                         r_result = false;
117                 else if(!var1->declaration->init_expression || !var.declaration->init_expression)
118                         r_result = false;
119                 else
120                         compare(*var1->declaration->init_expression, *var.declaration->init_expression);
121         }
122 }
123
124 void TypeComparer::visit(BasicTypeDeclaration &basic)
125 {
126         if(BasicTypeDeclaration *basic1 = multi_visit(basic))
127         {
128                 if(basic1->kind!=basic.kind || basic1->size!=basic.size || basic1->sign!=basic.sign)
129                         r_result = false;
130                 else if(basic1->base_type && basic.base_type)
131                         compare(*basic1->base_type, *basic.base_type);
132                 else
133                         r_result = (!basic1->base_type && !basic.base_type);
134         }
135 }
136
137 void TypeComparer::visit(ImageTypeDeclaration &image)
138 {
139         if(ImageTypeDeclaration *image1 = multi_visit(image))
140         {
141                 if(image1->dimensions!=image.dimensions || image1->array!=image.array)
142                         r_result = false;
143                 else if(image1->sampled!=image.sampled || image1->shadow!=image.shadow)
144                         r_result = false;
145                 else if(image1->base_type && image.base_type)
146                         compare(*image1->base_type, *image.base_type);
147                 else
148                         r_result = (!image1->base_type && !image.base_type);
149         }
150 }
151
152 void TypeComparer::visit(StructDeclaration &strct)
153 {
154         if(StructDeclaration *strct1 = multi_visit(strct))
155         {
156                 if(strct1->members.body.size()!=strct.members.body.size())
157                         r_result = false;
158                 else
159                 {
160                         r_result = true;
161                         NodeList<Statement>::const_iterator i = strct1->members.body.begin();
162                         NodeList<Statement>::const_iterator j = strct.members.body.begin();
163                         for(; (r_result && i!=strct1->members.body.end()); ++i, ++j)
164                                 compare(**i, **j);
165                 }
166         }
167 }
168
169 void TypeComparer::visit(VariableDeclaration &var)
170 {
171         if(VariableDeclaration *var1 = multi_visit(var))
172         {
173                 if(var1->name!=var.name || var1->array!=var.array)
174                         r_result = false;
175                 else if(!var1->type_declaration || !var.type_declaration)
176                         r_result = false;
177                 else
178                 {
179                         if(var1->array)
180                         {
181                                 r_result = false;
182                                 if(var1->array_size && var.array_size)
183                                         compare(*var1->array_size, *var.array_size);
184                         }
185                         if(r_result && var1->type_declaration!=var.type_declaration)
186                                 compare(*var1->type_declaration, *var.type_declaration);
187                         // TODO Compare layout qualifiers for interface block members
188                 }
189         }
190 }
191
192
193 LocationCounter::LocationCounter():
194         r_count(0)
195 { }
196
197 void LocationCounter::visit(BasicTypeDeclaration &basic)
198 {
199         r_count = basic.kind==BasicTypeDeclaration::MATRIX ? basic.size>>16 : 1;
200 }
201
202 void LocationCounter::visit(ImageTypeDeclaration &)
203 {
204         r_count = 1;
205 }
206
207 void LocationCounter::visit(StructDeclaration &strct)
208 {
209         unsigned total = 0;
210         for(NodeList<Statement>::const_iterator i=strct.members.body.begin(); i!=strct.members.body.end(); ++i)
211         {
212                 r_count = 1;
213                 (*i)->visit(*this);
214                 total += r_count;
215         }
216         r_count = total;
217 }
218
219 void LocationCounter::visit(VariableDeclaration &var)
220 {
221         r_count = 1;
222         if(var.type_declaration)
223                 var.type_declaration->visit(*this);
224         if(var.array)
225                 if(const Literal *literal = dynamic_cast<const Literal *>(var.array_size.get()))
226                         if(literal->value.check_type<int>())
227                                 r_count *= literal->value.value<int>();
228 }
229
230
231 void MemoryRequirementsCalculator::visit(BasicTypeDeclaration &basic)
232 {
233         if(basic.kind==BasicTypeDeclaration::BOOL)
234         {
235                 r_size = 1;
236                 r_alignment = 1;
237         }
238         else if(basic.kind==BasicTypeDeclaration::INT || basic.kind==BasicTypeDeclaration::FLOAT)
239         {
240                 r_size = basic.size/8;
241                 r_alignment = r_size;
242         }
243         else if(basic.kind==BasicTypeDeclaration::VECTOR || basic.kind==BasicTypeDeclaration::MATRIX)
244         {
245                 basic.base_type->visit(*this);
246                 unsigned n_elem = basic.size&0xFFFF;
247                 r_size *= n_elem;
248                 if(basic.kind==BasicTypeDeclaration::VECTOR)
249                         r_alignment *= (n_elem==3 ? 4 : n_elem);
250         }
251         else if(basic.kind==BasicTypeDeclaration::ARRAY)
252                 basic.base_type->visit(*this);
253 }
254
255 void MemoryRequirementsCalculator::visit(StructDeclaration &strct)
256 {
257         unsigned total = 0;
258         unsigned max_align = 1;
259         for(NodeList<Statement>::iterator i=strct.members.body.begin(); i!=strct.members.body.end(); ++i)
260         {
261                 r_size = 0;
262                 r_alignment = 1;
263                 r_offset = -1;
264                 (*i)->visit(*this);
265                 if(r_offset)
266                         total = r_offset;
267                 total += r_alignment-1;
268                 total -= total%r_alignment;
269                 total += r_size;
270                 max_align = max(max_align, r_alignment);
271         }
272         r_size = total;
273         r_alignment = max_align;
274 }
275
276 void MemoryRequirementsCalculator::visit(VariableDeclaration &var)
277 {
278         if(var.layout)
279         {
280                 const vector<Layout::Qualifier> qualifiers = var.layout->qualifiers;
281                 for(vector<Layout::Qualifier>::const_iterator i=qualifiers.begin(); (r_offset<0 && i!=qualifiers.end()); ++i)
282                         if(i->name=="offset")
283                                 r_offset = i->value;
284         }
285
286         if(var.type_declaration)
287                 var.type_declaration->visit(*this);
288         if(var.array)
289                 if(const Literal *literal = dynamic_cast<const Literal *>(var.array_size.get()))
290                         if(literal->value.check_type<int>())
291                                 r_size += r_alignment*(literal->value.value<int>()-1);
292 }
293
294
295 set<Node *> DependencyCollector::apply(FunctionDeclaration &func)
296 {
297         func.visit(*this);
298         return dependencies;
299 }
300
301 void DependencyCollector::visit(VariableReference &var)
302 {
303         if(var.declaration && !locals.count(var.declaration))
304         {
305                 dependencies.insert(var.declaration);
306                 var.declaration->visit(*this);
307         }
308 }
309
310 void DependencyCollector::visit(InterfaceBlockReference &iface)
311 {
312         if(iface.declaration)
313         {
314                 dependencies.insert(iface.declaration);
315                 iface.declaration->visit(*this);
316         }
317 }
318
319 void DependencyCollector::visit(FunctionCall &call)
320 {
321         if(call.declaration)
322         {
323                 dependencies.insert(call.declaration);
324                 if(call.declaration->definition)
325                         call.declaration->definition->visit(*this);
326         }
327         TraversingVisitor::visit(call);
328 }
329
330 void DependencyCollector::visit(VariableDeclaration &var)
331 {
332         locals.insert(&var);
333         if(var.type_declaration)
334         {
335                 dependencies.insert(var.type_declaration);
336                 var.type_declaration->visit(*this);
337         }
338
339         TraversingVisitor::visit(var);
340 }
341
342 void DependencyCollector::visit(FunctionDeclaration &func)
343 {
344         if(!visited_functions.count(&func))
345         {
346                 visited_functions.insert(&func);
347                 TraversingVisitor::visit(func);
348         }
349 }
350
351 } // namespace SL
352 } // namespace GL
353 } // namespace Msp