]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/generate.cpp
Refactor interface management
[libs/gl.git] / source / glsl / generate.cpp
1 #include <msp/core/raii.h>
2 #include "builtin.h"
3 #include "generate.h"
4
5 using namespace std;
6
7 namespace Msp {
8 namespace GL {
9 namespace SL {
10
11 void DeclarationCombiner::apply(Stage &stage)
12 {
13         stage.content.visit(*this);
14         NodeRemover().apply(stage, nodes_to_remove);
15 }
16
17 void DeclarationCombiner::visit(Block &block)
18 {
19         if(current_block)
20                 return;
21
22         TraversingVisitor::visit(block);
23 }
24
25 void DeclarationCombiner::visit(FunctionDeclaration &func)
26 {
27         vector<FunctionDeclaration *> &decls = functions[func.name];
28         if(func.definition)
29         {
30                 for(vector<FunctionDeclaration *>::iterator i=decls.begin(); i!=decls.end(); ++i)
31                 {
32                         (*i)->definition = func.definition;
33                         (*i)->body.body.clear();
34                 }
35         }
36         decls.push_back(&func);
37 }
38
39 void DeclarationCombiner::visit(VariableDeclaration &var)
40 {
41         VariableDeclaration *&ptr = variables[var.name];
42         if(ptr)
43         {
44                 ptr->type = var.type;
45                 if(var.init_expression)
46                         ptr->init_expression = var.init_expression;
47                 if(var.layout)
48                 {
49                         if(ptr->layout)
50                         {
51                                 for(vector<Layout::Qualifier>::iterator i=var.layout->qualifiers.begin(); i!=var.layout->qualifiers.end(); ++i)
52                                 {
53                                         bool found = false;
54                                         for(vector<Layout::Qualifier>::iterator j=ptr->layout->qualifiers.begin(); (!found && j!=ptr->layout->qualifiers.end()); ++j)
55                                                 if(j->name==i->name)
56                                                 {
57                                                         j->has_value = i->value;
58                                                         j->value = i->value;
59                                                         found = true;
60                                                 }
61
62                                         if(!found)
63                                                 ptr->layout->qualifiers.push_back(*i);
64                                 }
65                         }
66                         else
67                                 ptr->layout = var.layout;
68                 }
69                 nodes_to_remove.insert(&var);
70         }
71         else
72                 ptr = &var;
73 }
74
75
76 void BlockResolver::enter(Block &block)
77 {
78         block.parent = current_block;
79 }
80
81 void BlockResolver::visit(InterfaceBlock &iface)
82 {
83         current_block->interfaces.insert(&iface);
84         TraversingVisitor::visit(iface);
85 }
86
87
88 VariableResolver::VariableResolver():
89         record_target(false),
90         assignment_target(0),
91         self_referencing(false)
92 { }
93
94 void VariableResolver::apply(Stage &stage)
95 {
96         Stage *builtin_stage = get_builtins(stage.type);
97         builtins = (builtin_stage ? &builtin_stage->content : 0);
98         stage.content.visit(*this);
99 }
100
101 Block *VariableResolver::next_block(Block &block)
102 {
103         return block.parent ? block.parent : &block!=builtins ? builtins : 0;
104 }
105
106 void VariableResolver::enter(Block &block)
107 {
108         block.variables.clear();
109 }
110
111 void VariableResolver::visit(VariableReference &var)
112 {
113         var.declaration = 0;
114         type = 0;
115         for(Block *block=current_block; block; block=next_block(*block))
116         {
117                 map<string, VariableDeclaration *>::iterator i = block->variables.find(var.name);
118                 if(i!=block->variables.end())
119                         var.declaration = i->second;
120                 else
121                 {
122                         const set<InterfaceBlock *> &ifaces = block->interfaces;
123                         for(set<InterfaceBlock *>::const_iterator j=ifaces.begin(); (!var.declaration && j!=ifaces.end()); ++j)
124                         {
125                                 i = (*j)->members.variables.find(var.name);
126                                 if(i!=(*j)->members.variables.end())
127                                         var.declaration = i->second;
128                         }
129                 }
130
131                 if(var.declaration)
132                 {
133                         type = var.declaration->type_declaration;
134                         break;
135                 }
136         }
137
138         if(record_target)
139         {
140                 if(assignment_target)
141                 {
142                         record_target = false;
143                         assignment_target = 0;
144                 }
145                 else
146                         assignment_target = var.declaration;
147         }
148         else if(var.declaration && var.declaration==assignment_target)
149                 self_referencing = true;
150 }
151
152 void VariableResolver::visit(MemberAccess &memacc)
153 {
154         type = 0;
155         TraversingVisitor::visit(memacc);
156         memacc.declaration = 0;
157         if(type)
158         {
159                 map<string, VariableDeclaration *>::iterator i = type->members.variables.find(memacc.member);
160                 if(i!=type->members.variables.end())
161                 {
162                         memacc.declaration = i->second;
163                         type = i->second->type_declaration;
164                 }
165                 else
166                         type = 0;
167         }
168 }
169
170 void VariableResolver::visit(BinaryExpression &binary)
171 {
172         if(binary.oper=="[")
173         {
174                 {
175                         SetForScope<bool> set(record_target, false);
176                         binary.right->visit(*this);
177                 }
178                 type = 0;
179                 binary.left->visit(*this);
180         }
181         else
182         {
183                 TraversingVisitor::visit(binary);
184                 type = 0;
185         }
186 }
187
188 void VariableResolver::visit(Assignment &assign)
189 {
190         {
191                 SetFlag set(record_target);
192                 assignment_target = 0;
193                 assign.left->visit(*this);
194         }
195
196         self_referencing = false;
197         assign.right->visit(*this);
198
199         assign.self_referencing = (self_referencing || assign.oper!="=");
200         assign.target_declaration = assignment_target;
201 }
202
203 void VariableResolver::visit(StructDeclaration &strct)
204 {
205         TraversingVisitor::visit(strct);
206         current_block->types[strct.name] = &strct;
207 }
208
209 void VariableResolver::visit(VariableDeclaration &var)
210 {
211         for(Block *block=current_block; block; block=next_block(*block))
212         {
213                 map<string, StructDeclaration *>::iterator j = block->types.find(var.type);
214                 if(j!=block->types.end())
215                         var.type_declaration = j->second;
216         }
217
218         if(!block_interface.empty() && var.interface.empty())
219                 var.interface = block_interface;
220
221         TraversingVisitor::visit(var);
222         current_block->variables[var.name] = &var;
223 }
224
225 void VariableResolver::visit(InterfaceBlock &iface)
226 {
227         SetForScope<string> set_iface(block_interface, iface.interface);
228         TraversingVisitor::visit(iface);
229 }
230
231
232 void FunctionResolver::visit(FunctionCall &call)
233 {
234         map<string, vector<FunctionDeclaration *> >::iterator i = functions.find(call.name);
235         if(i!=functions.end())
236                 call.declaration = i->second.back();
237
238         TraversingVisitor::visit(call);
239 }
240
241 void FunctionResolver::visit(FunctionDeclaration &func)
242 {
243         vector<FunctionDeclaration *> &decls = functions[func.name];
244         if(func.definition)
245         {
246                 for(vector<FunctionDeclaration *>::iterator i=decls.begin(); i!=decls.end(); ++i)
247                         (*i)->definition = func.definition;
248                 decls.clear();
249                 decls.push_back(&func);
250         }
251         else if(!decls.empty() && decls.back()->definition)
252                 func.definition = decls.back()->definition;
253         else
254                 decls.push_back(&func);
255
256         TraversingVisitor::visit(func);
257 }
258
259
260 InterfaceGenerator::InterfaceGenerator():
261         stage(0)
262 { }
263
264 string InterfaceGenerator::get_out_prefix(Stage::Type type)
265 {
266         if(type==Stage::VERTEX)
267                 return "_vs_out_";
268         else if(type==Stage::GEOMETRY)
269                 return "_gs_out_";
270         else
271                 return string();
272 }
273
274 void InterfaceGenerator::apply(Stage &s)
275 {
276         stage = &s;
277         if(stage->previous)
278                 in_prefix = get_out_prefix(stage->previous->type);
279         out_prefix = get_out_prefix(stage->type);
280         s.content.visit(*this);
281         NodeRemover().apply(s, nodes_to_remove);
282 }
283
284 void InterfaceGenerator::visit(Block &block)
285 {
286         SetForScope<Block *> set_block(current_block, &block);
287         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
288         {
289                 assignment_insert_point = i;
290                 if(&block==&stage->content)
291                         iface_insert_point = i;
292
293                 (*i)->visit(*this);
294         }
295 }
296
297 string InterfaceGenerator::change_prefix(const string &name, const string &prefix) const
298 {
299         unsigned offset = (name.compare(0, in_prefix.size(), in_prefix) ? 0 : in_prefix.size());
300         return prefix+name.substr(offset);
301 }
302
303 bool InterfaceGenerator::generate_interface(VariableDeclaration &var, const string &iface, const string &name)
304 {
305         if(stage->content.variables.count(name))
306                 return false;
307
308         VariableDeclaration* iface_var = new VariableDeclaration;
309         iface_var->sampling = var.sampling;
310         iface_var->interface = iface;
311         iface_var->type = var.type;
312         iface_var->type_declaration = var.type_declaration;
313         iface_var->name = name;
314         if(stage->type==Stage::GEOMETRY)
315                 iface_var->array = ((var.array && var.interface!="in") || iface=="in");
316         else
317                 iface_var->array = var.array;
318         if(iface_var->array)
319                 iface_var->array_size = var.array_size;
320         if(iface=="in")
321         {
322                 iface_var->linked_declaration = &var;
323                 var.linked_declaration = iface_var;
324         }
325         stage->content.body.insert(iface_insert_point, iface_var);
326         stage->content.variables[name] = iface_var;
327
328         return true;
329 }
330
331 ExpressionStatement &InterfaceGenerator::insert_assignment(const string &left, Expression *right)
332 {
333         Assignment *assign = new Assignment;
334         VariableReference *ref = new VariableReference;
335         ref->name = left;
336         assign->left = ref;
337         assign->oper = "=";
338         assign->right = right;
339
340         ExpressionStatement *stmt = new ExpressionStatement;
341         stmt->expression = assign;
342         current_block->body.insert(assignment_insert_point, stmt);
343         stmt->visit(*this);
344
345         return *stmt;
346 }
347
348 void InterfaceGenerator::visit(VariableReference &var)
349 {
350         if(var.declaration || !stage->previous)
351                 return;
352         /* Don't pull a variable from previous stage if we just generated an out
353         interface in this stage */
354         if(stage->content.variables.count(var.name))
355                 return;
356
357         const map<string, VariableDeclaration *> &prev_vars = stage->previous->content.variables;
358         map<string, VariableDeclaration *>::const_iterator i = prev_vars.find(var.name);
359         if(i==prev_vars.end() || i->second->interface!="out")
360                 i = prev_vars.find(in_prefix+var.name);
361         if(i!=prev_vars.end() && i->second->interface=="out")
362         {
363                 generate_interface(*i->second, "in", i->second->name);
364                 var.name = i->second->name;
365         }
366 }
367
368 void InterfaceGenerator::visit(VariableDeclaration &var)
369 {
370         if(var.interface=="out")
371         {
372                 if(current_block!=&stage->content && generate_interface(var, "out", change_prefix(var.name, string())))
373                 {
374                         nodes_to_remove.insert(&var);
375                         if(var.init_expression)
376                         {
377                                 ExpressionStatement &stmt = insert_assignment(var.name, var.init_expression->clone());
378                                 stmt.source = var.source;
379                                 stmt.line = var.line;
380                                 return;
381                         }
382                 }
383         }
384         else if(var.interface=="in")
385         {
386                 if(!var.linked_declaration && stage->previous)
387                 {
388                         const map<string, VariableDeclaration *> &prev_vars = stage->previous->content.variables;
389                         map<string, VariableDeclaration *>::const_iterator i = prev_vars.find(var.name);
390                         if(i!=prev_vars.end() && i->second->interface=="out")
391                         {
392                                 var.linked_declaration = i->second;
393                                 i->second->linked_declaration = &var;
394                         }
395                 }
396         }
397
398         TraversingVisitor::visit(var);
399 }
400
401 void InterfaceGenerator::visit(FunctionDeclaration &func)
402 {
403         // Skip parameters because they're not useful here
404         func.body.visit(*this);
405 }
406
407 void InterfaceGenerator::visit(Passthrough &pass)
408 {
409         vector<VariableDeclaration *> pass_vars;
410
411         for(map<string, VariableDeclaration *>::const_iterator i=stage->content.variables.begin(); i!=stage->content.variables.end(); ++i)
412                 if(i->second->interface=="in")
413                         pass_vars.push_back(i->second);
414
415         if(stage->previous)
416         {
417                 const map<string, VariableDeclaration *> &prev_vars = stage->previous->content.variables;
418                 for(map<string, VariableDeclaration *>::const_iterator i=prev_vars.begin(); i!=prev_vars.end(); ++i)
419                 {
420                         bool linked = false;
421                         for(vector<VariableDeclaration *>::const_iterator j=pass_vars.begin(); (!linked && j!=pass_vars.end()); ++j)
422                                 linked = ((*j)->linked_declaration==i->second);
423
424                         if(!linked && generate_interface(*i->second, "in", i->second->name))
425                                 pass_vars.push_back(i->second);
426                 }
427         }
428
429         if(stage->type==Stage::GEOMETRY)
430         {
431                 VariableReference *ref = new VariableReference;
432                 ref->name = "gl_in";
433
434                 BinaryExpression *subscript = new BinaryExpression;
435                 subscript->left = ref;
436                 subscript->oper = "[";
437                 subscript->right = pass.subscript;
438                 subscript->after = "]";
439
440                 MemberAccess *memacc = new MemberAccess;
441                 memacc->left = subscript;
442                 memacc->member = "gl_Position";
443
444                 insert_assignment("gl_Position", memacc);
445         }
446
447         for(vector<VariableDeclaration *>::const_iterator i=pass_vars.begin(); i!=pass_vars.end(); ++i)
448         {
449                 string out_name = change_prefix((*i)->name, out_prefix);
450                 generate_interface(**i, "out", out_name);
451
452                 VariableReference *ref = new VariableReference;
453                 ref->name = (*i)->name;
454                 if(pass.subscript)
455                 {
456                         BinaryExpression *subscript = new BinaryExpression;
457                         subscript->left = ref;
458                         subscript->oper = "[";
459                         subscript->right = pass.subscript;
460                         subscript->after = "]";
461                         insert_assignment(out_name, subscript);
462                 }
463                 else
464                         insert_assignment(out_name, ref);
465         }
466
467         nodes_to_remove.insert(&pass);
468 }
469
470
471 DeclarationReorderer::DeclarationReorderer():
472         kind(NO_DECLARATION)
473 { }
474
475 void DeclarationReorderer::visit(FunctionCall &call)
476 {
477         FunctionDeclaration *def = call.declaration;
478         if(def)
479                 def = def->definition;
480         if(def && !ordered_funcs.count(def))
481                 needed_funcs.insert(def);
482 }
483
484 void DeclarationReorderer::visit(Block &block)
485 {
486         if(block.parent)
487                 return TraversingVisitor::visit(block);
488
489         NodeList<Statement>::iterator struct_insert_point = block.body.end();
490         NodeList<Statement>::iterator variable_insert_point = block.body.end();
491         NodeList<Statement>::iterator function_insert_point = block.body.end();
492         unsigned unordered_func_count = 0;
493         bool ordered_any_funcs = false;
494
495         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); )
496         {
497                 kind = NO_DECLARATION;
498                 (*i)->visit(*this);
499
500                 bool moved = false;
501                 if(kind==STRUCT && struct_insert_point!=block.body.end())
502                 {
503                         block.body.insert(struct_insert_point, *i);
504                         moved = true;
505                 }
506                 else if(kind>STRUCT && struct_insert_point==block.body.end())
507                         struct_insert_point = i;
508
509                 if(kind==VARIABLE && variable_insert_point!=block.body.end())
510                 {
511                         block.body.insert(variable_insert_point, *i);
512                         moved = true;
513                 }
514                 else if(kind>VARIABLE && variable_insert_point==block.body.end())
515                         variable_insert_point = i;
516
517                 if(kind==FUNCTION)
518                 {
519                         if(function_insert_point==block.body.end())
520                                 function_insert_point = i;
521
522                         if(needed_funcs.empty())
523                         {
524                                 ordered_funcs.insert(i->get());
525                                 if(i!=function_insert_point)
526                                 {
527                                         block.body.insert(function_insert_point, *i);
528                                         moved = true;
529                                 }
530                                 else
531                                         ++function_insert_point;
532                                 ordered_any_funcs = true;
533                         }
534                         else
535                                 ++unordered_func_count;
536                 }
537
538                 if(moved)
539                 {
540                         if(function_insert_point==i)
541                                 ++function_insert_point;
542                         block.body.erase(i++);
543                 }
544                 else
545                         ++i;
546
547                 if(i==block.body.end() && unordered_func_count)
548                 {
549                         if(!ordered_any_funcs)
550                                 // A subset of the remaining functions forms a recursive loop
551                                 /* TODO pick a function and move it up, adding any necessary
552                                 declarations */
553                                 break;
554
555                         i = function_insert_point;
556                         unordered_func_count = 0;
557                 }
558         }
559 }
560
561 void DeclarationReorderer::visit(VariableDeclaration &var)
562 {
563         TraversingVisitor::visit(var);
564         kind = VARIABLE;
565 }
566
567 void DeclarationReorderer::visit(FunctionDeclaration &func)
568 {
569         needed_funcs.clear();
570         func.body.visit(*this);
571         needed_funcs.erase(&func);
572         kind = FUNCTION;
573 }
574
575 } // namespace SL
576 } // namespace GL
577 } // namespace Msp