1 #include <msp/core/algorithm.h>
10 bool is_scalar(const BasicTypeDeclaration &type)
12 return (type.kind==BasicTypeDeclaration::INT || type.kind==BasicTypeDeclaration::FLOAT);
15 bool is_vector_or_matrix(const BasicTypeDeclaration &type)
17 return (type.kind==BasicTypeDeclaration::VECTOR || type.kind==BasicTypeDeclaration::MATRIX);
20 BasicTypeDeclaration *get_element_type(BasicTypeDeclaration &type)
22 if(is_vector_or_matrix(type) || type.kind==BasicTypeDeclaration::ARRAY)
24 BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type);
25 return (basic_base ? get_element_type(*basic_base) : 0);
31 bool can_convert(const BasicTypeDeclaration &from, const BasicTypeDeclaration &to)
33 if(from.kind==BasicTypeDeclaration::INT && to.kind==BasicTypeDeclaration::FLOAT)
34 return from.size<=to.size;
35 else if(from.kind!=to.kind)
37 else if(from.kind==BasicTypeDeclaration::INT && from.sign!=to.sign)
38 return from.sign && from.size<=to.size;
39 else if(is_vector_or_matrix(from) && from.size==to.size)
41 BasicTypeDeclaration *from_base = dynamic_cast<BasicTypeDeclaration *>(from.base_type);
42 BasicTypeDeclaration *to_base = dynamic_cast<BasicTypeDeclaration *>(to.base_type);
43 return (from_base && to_base && can_convert(*from_base, *to_base));
50 unsigned TypeComparer::next_tag = 1;
52 TypeComparer::TypeComparer():
59 void TypeComparer::compare(Node &node1, Node &node2)
71 T *TypeComparer::multi_visit(T &node)
73 static unsigned tag = next_tag++;
83 else if(!first || tag!=first_tag)
87 T *f = static_cast<T *>(first);
95 void TypeComparer::visit(Literal &literal)
97 if(Literal *lit1 = multi_visit(literal))
99 if(!lit1->type || !literal.type)
103 compare(*lit1->type, *literal.type);
105 r_result = (literal.value.check_type<int>() && lit1->value.value<int>()==literal.value.value<int>());
110 void TypeComparer::visit(VariableReference &var)
112 if(VariableReference *var1 = multi_visit(var))
114 if(!var1->declaration || !var.declaration)
116 else if(!var1->declaration->constant || !var.declaration->constant)
118 else if(!var1->declaration->init_expression || !var.declaration->init_expression)
121 compare(*var1->declaration->init_expression, *var.declaration->init_expression);
125 void TypeComparer::visit(UnaryExpression &unary)
127 if(UnaryExpression *unary1 = multi_visit(unary))
129 if(unary1->oper!=unary.oper)
132 compare(*unary1->expression, *unary.expression);
136 void TypeComparer::visit(BinaryExpression &binary)
138 if(BinaryExpression *binary1 = multi_visit(binary))
140 if(binary1->oper!=binary.oper)
144 compare(*binary1->left, *binary.left);
146 compare(*binary1->right, *binary.right);
151 void TypeComparer::visit(TernaryExpression &ternary)
153 if(TernaryExpression *ternary1 = multi_visit(ternary))
155 if(ternary1->oper!=ternary.oper)
159 compare(*ternary1->condition, *ternary.condition);
161 compare(*ternary1->true_expr, *ternary.true_expr);
163 compare(*ternary1->false_expr, *ternary.false_expr);
168 void TypeComparer::visit(BasicTypeDeclaration &basic)
170 if(BasicTypeDeclaration *basic1 = multi_visit(basic))
172 if(basic1->kind!=basic.kind || basic1->size!=basic.size || basic1->sign!=basic.sign)
174 else if(basic1->base_type && basic.base_type)
175 compare(*basic1->base_type, *basic.base_type);
177 r_result = (!basic1->base_type && !basic.base_type);
181 void TypeComparer::visit(ImageTypeDeclaration &image)
183 if(ImageTypeDeclaration *image1 = multi_visit(image))
185 if(image1->dimensions!=image.dimensions || image1->array!=image.array)
187 else if(image1->sampled!=image.sampled || image1->shadow!=image.shadow)
189 else if(image1->base_type && image.base_type)
190 compare(*image1->base_type, *image.base_type);
192 r_result = (!image1->base_type && !image.base_type);
196 void TypeComparer::visit(StructDeclaration &strct)
198 if(StructDeclaration *strct1 = multi_visit(strct))
200 if(strct1->members.body.size()!=strct.members.body.size())
205 auto i = strct1->members.body.begin();
206 auto j = strct.members.body.begin();
207 for(; (r_result && i!=strct1->members.body.end()); ++i, ++j)
213 void TypeComparer::visit(VariableDeclaration &var)
215 if(VariableDeclaration *var1 = multi_visit(var))
217 if(var1->name!=var.name || var1->array!=var.array)
219 else if(!var1->type_declaration || !var.type_declaration)
226 if(var1->array_size && var.array_size)
227 compare(*var1->array_size, *var.array_size);
229 if(r_result && var1->type_declaration!=var.type_declaration)
230 compare(*var1->type_declaration, *var.type_declaration);
231 // TODO Compare layout qualifiers for interface block members
237 LocationCounter::LocationCounter():
241 void LocationCounter::visit(BasicTypeDeclaration &basic)
243 r_count = basic.kind==BasicTypeDeclaration::MATRIX ? basic.size>>16 : 1;
246 void LocationCounter::visit(ImageTypeDeclaration &)
251 void LocationCounter::visit(StructDeclaration &strct)
254 for(const RefPtr<Statement> &s: strct.members.body)
263 void LocationCounter::visit(VariableDeclaration &var)
266 if(var.type_declaration)
267 var.type_declaration->visit(*this);
269 if(const Literal *literal = dynamic_cast<const Literal *>(var.array_size.get()))
270 if(literal->value.check_type<int>())
271 r_count *= literal->value.value<int>();
275 void MemoryRequirementsCalculator::visit(BasicTypeDeclaration &basic)
277 if(basic.kind==BasicTypeDeclaration::BOOL)
282 else if(basic.kind==BasicTypeDeclaration::INT || basic.kind==BasicTypeDeclaration::FLOAT)
284 r_size = basic.size/8;
285 r_alignment = r_size;
287 else if(basic.kind==BasicTypeDeclaration::VECTOR || basic.kind==BasicTypeDeclaration::MATRIX)
289 basic.base_type->visit(*this);
290 unsigned n_elem = basic.size&0xFFFF;
292 if(basic.kind==BasicTypeDeclaration::VECTOR)
293 r_alignment *= (n_elem==3 ? 4 : n_elem);
295 else if(basic.kind==BasicTypeDeclaration::ARRAY)
296 basic.base_type->visit(*this);
299 void MemoryRequirementsCalculator::visit(StructDeclaration &strct)
302 unsigned max_align = 1;
303 for(const RefPtr<Statement> &s: strct.members.body)
311 total += r_alignment-1;
312 total -= total%r_alignment;
314 max_align = max(max_align, r_alignment);
317 r_alignment = max_align;
320 void MemoryRequirementsCalculator::visit(VariableDeclaration &var)
324 auto i = find_member(var.layout->qualifiers, string("offset"), &Layout::Qualifier::name);
325 if(i!=var.layout->qualifiers.end())
329 if(var.type_declaration)
330 var.type_declaration->visit(*this);
332 if(const Literal *literal = dynamic_cast<const Literal *>(var.array_size.get()))
333 if(literal->value.check_type<int>())
334 r_size += r_alignment*(literal->value.value<int>()-1);
338 set<Node *> DependencyCollector::apply(FunctionDeclaration &func)
344 void DependencyCollector::visit(VariableReference &var)
346 if(var.declaration && !locals.count(var.declaration))
348 dependencies.insert(var.declaration);
349 var.declaration->visit(*this);
353 void DependencyCollector::visit(InterfaceBlockReference &iface)
355 if(iface.declaration)
357 dependencies.insert(iface.declaration);
358 iface.declaration->visit(*this);
362 void DependencyCollector::visit(FunctionCall &call)
366 dependencies.insert(call.declaration);
367 if(call.declaration->definition)
368 call.declaration->definition->visit(*this);
370 TraversingVisitor::visit(call);
373 void DependencyCollector::visit(VariableDeclaration &var)
376 if(var.type_declaration)
378 dependencies.insert(var.type_declaration);
379 var.type_declaration->visit(*this);
382 TraversingVisitor::visit(var);
385 void DependencyCollector::visit(FunctionDeclaration &func)
387 if(!visited_functions.count(&func))
389 visited_functions.insert(&func);
390 TraversingVisitor::visit(func);