1 #include <msp/core/algorithm.h>
2 #include <msp/core/raii.h>
11 bool is_scalar(const BasicTypeDeclaration &type)
13 return (type.kind==BasicTypeDeclaration::INT || type.kind==BasicTypeDeclaration::FLOAT);
16 bool is_vector_or_matrix(const BasicTypeDeclaration &type)
18 return (type.kind==BasicTypeDeclaration::VECTOR || type.kind==BasicTypeDeclaration::MATRIX);
21 BasicTypeDeclaration *get_element_type(BasicTypeDeclaration &type)
23 if(is_vector_or_matrix(type) || type.kind==BasicTypeDeclaration::ARRAY)
25 BasicTypeDeclaration *basic_base = dynamic_cast<BasicTypeDeclaration *>(type.base_type);
26 return (basic_base ? get_element_type(*basic_base) : 0);
32 bool can_convert(const BasicTypeDeclaration &from, const BasicTypeDeclaration &to)
34 if(from.kind==BasicTypeDeclaration::INT && to.kind==BasicTypeDeclaration::FLOAT)
35 return from.size<=to.size;
36 else if(from.kind!=to.kind)
38 else if(from.kind==BasicTypeDeclaration::INT && from.sign!=to.sign)
39 return from.sign && from.size<=to.size;
40 else if(is_vector_or_matrix(from) && from.size==to.size)
42 BasicTypeDeclaration *from_base = dynamic_cast<BasicTypeDeclaration *>(from.base_type);
43 BasicTypeDeclaration *to_base = dynamic_cast<BasicTypeDeclaration *>(to.base_type);
44 return (from_base && to_base && can_convert(*from_base, *to_base));
51 unsigned TypeComparer::next_tag = 1;
53 void TypeComparer::compare(Node &node1, Node &node2)
65 T *TypeComparer::multi_visit(T &node)
67 static unsigned tag = next_tag++;
77 else if(!first || tag!=first_tag)
81 T *f = static_cast<T *>(first);
89 void TypeComparer::visit(Literal &literal)
91 if(Literal *lit1 = multi_visit(literal))
93 if(!lit1->type || !literal.type)
97 compare(*lit1->type, *literal.type);
99 r_result = (literal.value.check_type<int>() && lit1->value.value<int>()==literal.value.value<int>());
104 void TypeComparer::visit(VariableReference &var)
106 if(VariableReference *var1 = multi_visit(var))
108 if(!var1->declaration || !var.declaration)
110 else if(!var1->declaration->constant || !var.declaration->constant)
112 else if(!var1->declaration->init_expression || !var.declaration->init_expression)
115 compare(*var1->declaration->init_expression, *var.declaration->init_expression);
119 void TypeComparer::visit(UnaryExpression &unary)
121 if(UnaryExpression *unary1 = multi_visit(unary))
123 if(unary1->oper!=unary.oper)
126 compare(*unary1->expression, *unary.expression);
130 void TypeComparer::visit(BinaryExpression &binary)
132 if(BinaryExpression *binary1 = multi_visit(binary))
134 if(binary1->oper!=binary.oper)
138 compare(*binary1->left, *binary.left);
140 compare(*binary1->right, *binary.right);
145 void TypeComparer::visit(TernaryExpression &ternary)
147 if(TernaryExpression *ternary1 = multi_visit(ternary))
149 if(ternary1->oper!=ternary.oper)
153 compare(*ternary1->condition, *ternary.condition);
155 compare(*ternary1->true_expr, *ternary.true_expr);
157 compare(*ternary1->false_expr, *ternary.false_expr);
162 void TypeComparer::visit(FunctionCall &call)
164 if(FunctionCall *call1 = multi_visit(call))
166 if(!call1->constructor || !call.constructor)
168 else if(call1->name!=call.name)
170 else if(call1->arguments.size()!=call.arguments.size())
175 for(unsigned i=0; (r_result && i<call.arguments.size()); ++i)
176 compare(*call1->arguments[i], *call.arguments[i]);
181 void TypeComparer::visit(BasicTypeDeclaration &basic)
183 if(BasicTypeDeclaration *basic1 = multi_visit(basic))
185 if(basic1->kind!=basic.kind || basic1->size!=basic.size || basic1->sign!=basic.sign)
187 else if(basic1->base_type && basic.base_type)
188 compare(*basic1->base_type, *basic.base_type);
190 r_result = (!basic1->base_type && !basic.base_type);
194 void TypeComparer::visit(ImageTypeDeclaration &image)
196 if(ImageTypeDeclaration *image1 = multi_visit(image))
198 if(image1->dimensions!=image.dimensions || image1->array!=image.array)
200 else if(image1->sampled!=image.sampled || image1->shadow!=image.shadow || image1->multisample!=image.multisample)
202 else if(image1->format!=image.format)
204 else if(image1->base_type && image.base_type)
205 compare(*image1->base_type, *image.base_type);
207 r_result = (!image1->base_type && !image.base_type);
211 void TypeComparer::visit(StructDeclaration &strct)
213 if(StructDeclaration *strct1 = multi_visit(strct))
215 if(strct1->members.body.size()!=strct.members.body.size())
220 auto i = strct1->members.body.begin();
221 auto j = strct.members.body.begin();
222 for(; (r_result && i!=strct1->members.body.end()); ++i, ++j)
228 void TypeComparer::visit(VariableDeclaration &var)
230 if(VariableDeclaration *var1 = multi_visit(var))
232 if(var1->name!=var.name || var1->array!=var.array)
234 else if(!var1->type_declaration || !var.type_declaration)
241 if(var1->array_size && var.array_size)
242 compare(*var1->array_size, *var.array_size);
243 else if(!var1->array_size && !var.array_size)
246 if(r_result && var1->type_declaration!=var.type_declaration)
247 compare(*var1->type_declaration, *var.type_declaration);
248 // TODO Compare layout qualifiers for interface block members
254 void LocationCounter::visit(BasicTypeDeclaration &basic)
256 r_count = basic.kind==BasicTypeDeclaration::MATRIX ? basic.size>>16 : 1;
259 void LocationCounter::visit(ImageTypeDeclaration &)
264 void LocationCounter::visit(StructDeclaration &strct)
267 for(const RefPtr<Statement> &s: strct.members.body)
276 void LocationCounter::visit(VariableDeclaration &var)
279 if(var.type_declaration)
280 var.type_declaration->visit(*this);
282 if(const Literal *literal = dynamic_cast<const Literal *>(var.array_size.get()))
283 if(literal->value.check_type<int>())
284 r_count *= literal->value.value<int>();
288 void MemoryRequirementsCalculator::visit(BasicTypeDeclaration &basic)
290 if(basic.kind==BasicTypeDeclaration::BOOL)
295 else if(basic.kind==BasicTypeDeclaration::INT || basic.kind==BasicTypeDeclaration::FLOAT)
297 r_size = basic.size/8;
298 r_alignment = r_size;
300 else if(basic.kind==BasicTypeDeclaration::VECTOR || basic.kind==BasicTypeDeclaration::MATRIX)
302 basic.base_type->visit(*this);
303 unsigned n_elem = basic.size&0xFFFF;
305 if(basic.kind==BasicTypeDeclaration::VECTOR)
306 r_alignment *= (n_elem==3 ? 4 : n_elem);
308 else if(basic.kind==BasicTypeDeclaration::ARRAY)
309 basic.base_type->visit(*this);
311 if(basic.extended_alignment)
312 r_alignment = (r_alignment+15)&~15U;
315 void MemoryRequirementsCalculator::visit(StructDeclaration &strct)
318 unsigned max_align = 1;
319 for(const RefPtr<Statement> &s: strct.members.body)
327 total += r_alignment-1;
328 total -= total%r_alignment;
330 max_align = max(max_align, r_alignment);
333 r_alignment = max_align;
334 if(strct.extended_alignment)
335 r_alignment = (r_alignment+15)&~15U;
336 r_size += r_alignment-1;
337 r_size -= r_size%r_alignment;
340 void MemoryRequirementsCalculator::visit(VariableDeclaration &var)
342 r_offset = get_layout_value(var.layout.get(), "offset");
344 if(var.type_declaration)
345 var.type_declaration->visit(*this);
347 if(const Literal *literal = dynamic_cast<const Literal *>(var.array_size.get()))
348 if(literal->value.check_type<int>())
350 unsigned aligned_size = r_size+r_alignment-1;
351 aligned_size -= aligned_size%r_alignment;
352 r_size = aligned_size*literal->value.value<int>();
357 set<Node *> DependencyCollector::apply(FunctionDeclaration &func)
363 void DependencyCollector::visit(VariableReference &var)
365 if(var.declaration && !locals.count(var.declaration))
367 dependencies.insert(var.declaration);
368 var.declaration->visit(*this);
372 void DependencyCollector::visit(FunctionCall &call)
376 dependencies.insert(call.declaration);
377 if(call.declaration->definition)
378 call.declaration->definition->visit(*this);
380 TraversingVisitor::visit(call);
383 void DependencyCollector::visit(VariableDeclaration &var)
386 if(var.type_declaration)
388 dependencies.insert(var.type_declaration);
389 var.type_declaration->visit(*this);
392 TraversingVisitor::visit(var);
395 void DependencyCollector::visit(FunctionDeclaration &func)
397 if(!visited_functions.count(&func))
399 visited_functions.insert(&func);
400 TraversingVisitor::visit(func);
405 set<Node *> AssignmentCollector::apply(Node &node)
408 return assigned_variables;
411 void AssignmentCollector::visit(VariableReference &var)
413 if(assignment_target)
414 assigned_variables.insert(var.declaration);
417 void AssignmentCollector::visit(UnaryExpression &unary)
419 SetFlag set_assignment(assignment_target, (unary.oper->token[1]=='+' || unary.oper->token[1]=='-'));
420 TraversingVisitor::visit(unary);
423 void AssignmentCollector::visit(BinaryExpression &binary)
425 binary.left->visit(*this);
426 SetFlag clear_assignment(assignment_target, false);
427 binary.right->visit(*this);
430 void AssignmentCollector::visit(Assignment &assign)
433 SetFlag set_assignment(assignment_target);
434 assign.left->visit(*this);
436 assign.right->visit(*this);