]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/visitor.cpp
Take care of SPIR-V load IDs in ternary expressions
[libs/gl.git] / source / glsl / visitor.cpp
1 #include <msp/core/raii.h>
2 #include "visitor.h"
3
4 using namespace std;
5
6 namespace Msp {
7 namespace GL {
8 namespace SL {
9
10 void TraversingVisitor::visit(Block &block)
11 {
12         if(&block!=current_block)
13                 enter(block);
14         SetForScope<Block *> set_block(current_block, &block);
15         for(const RefPtr<Statement> &s: block.body)
16                 s->visit(*this);
17 }
18
19 void TraversingVisitor::visit(RefPtr<Expression> &expr)
20 {
21         expr->visit(*this);
22 }
23
24 void TraversingVisitor::visit(MemberAccess &memacc)
25 {
26         visit(memacc.left);
27 }
28
29 void TraversingVisitor::visit(Swizzle &swizzle)
30 {
31         visit(swizzle.left);
32 }
33
34 void TraversingVisitor::visit(UnaryExpression &unary)
35 {
36         visit(unary.expression);
37 }
38
39 void TraversingVisitor::visit(BinaryExpression &binary)
40 {
41         visit(binary.left);
42         visit(binary.right);
43 }
44
45 void TraversingVisitor::visit(Assignment &assign)
46 {
47         visit(assign.left);
48         visit(assign.right);
49 }
50
51 void TraversingVisitor::visit(TernaryExpression &ternary)
52 {
53         visit(ternary.condition);
54         visit(ternary.true_expr);
55         visit(ternary.false_expr);
56 }
57
58 void TraversingVisitor::visit(FunctionCall &call)
59 {
60         for(RefPtr<Expression> &a: call.arguments)
61                 visit(a);
62 }
63
64 void TraversingVisitor::visit(ExpressionStatement &expr)
65 {
66         visit(expr.expression);
67 }
68
69 void TraversingVisitor::visit(InterfaceLayout &layout)
70 {
71         layout.layout.visit(*this);
72 }
73
74 void TraversingVisitor::visit(StructDeclaration &strct)
75 {
76         strct.members.visit(*this);
77 }
78
79 void TraversingVisitor::visit(VariableDeclaration &var)
80 {
81         if(var.layout)
82                 var.layout->visit(*this);
83         if(var.init_expression)
84                 visit(var.init_expression);
85         if(var.array_size)
86                 visit(var.array_size);
87 }
88
89 void TraversingVisitor::visit(FunctionDeclaration &func)
90 {
91         enter(func.body);
92         SetForScope<Block *> set_block(current_block, &func.body);
93         for(const RefPtr<VariableDeclaration> &p: func.parameters)
94                 p->visit(*this);
95         func.body.visit(*this);
96 }
97
98 void TraversingVisitor::visit(Conditional &cond)
99 {
100         visit(cond.condition);
101         cond.body.visit(*this);
102         cond.else_body.visit(*this);
103 }
104
105 void TraversingVisitor::visit(Iteration &iter)
106 {
107         enter(iter.body);
108         SetForScope<Block *> set_block(current_block, &iter.body);
109         if(iter.init_statement)
110                 iter.init_statement->visit(*this);
111         if(iter.condition)
112                 visit(iter.condition);
113         iter.body.visit(*this);
114         if(iter.loop_expression)
115                 visit(iter.loop_expression);
116 }
117
118 void TraversingVisitor::visit(Passthrough &pass)
119 {
120         if(pass.subscript)
121                 visit(pass.subscript);
122 }
123
124 void TraversingVisitor::visit(Return &ret)
125 {
126         if(ret.expression)
127                 visit(ret.expression);
128 }
129
130
131 void NodeRemover::apply(Stage &s, const set<Node *> &tr)
132 {
133         stage = &s;
134         to_remove = &tr;
135         s.content.visit(*this);
136 }
137
138 template<typename T>
139 void NodeRemover::remove_from_map(map<string, T *> &vars, const string &key, T &node)
140 {
141         auto i = vars.find(key);
142         if(i!=vars.end() && i->second==&node)
143                 vars.erase(i);
144 }
145
146 void NodeRemover::visit(Block &block)
147 {
148         SetForScope<Block *> set_block(current_block, &block);
149         for(auto i=block.body.begin(); i!=block.body.end(); )
150         {
151                 (*i)->visit(*this);
152                 if(to_remove->count(i->get()))
153                         block.body.erase(i++);
154                 else
155                         ++i;
156         }
157 }
158
159 void NodeRemover::visit(TypeDeclaration &type)
160 {
161         if(to_remove->count(&type))
162                 remove_from_map(stage->types, type.name, type);
163 }
164
165 void NodeRemover::visit(StructDeclaration &strct)
166 {
167         if(to_remove->count(&strct))
168         {
169                 remove_from_map<TypeDeclaration>(stage->types, strct.name, strct);
170                 if(strct.block_declaration)
171                 {
172                         string key = format("%s %s", strct.block_declaration->interface, strct.block_name);
173                         remove_from_map(stage->interface_blocks, key, *strct.block_declaration);
174                         remove_from_map(stage->interface_blocks, strct.block_declaration->name, *strct.block_declaration);
175                         strct.block_declaration->block_declaration = 0;
176                 }
177         }
178 }
179
180 void NodeRemover::visit(VariableDeclaration &var)
181 {
182         if(to_remove->count(&var))
183         {
184                 remove_from_map(current_block->variables, var.name, var);
185                 if(var.block_declaration)
186                 {
187                         remove_from_map(stage->interface_blocks, format("%s %s", var.interface, var.block_declaration->block_name), var);
188                         remove_from_map(stage->interface_blocks, var.name, var);
189                         var.block_declaration->block_declaration = 0;
190                 }
191                 stage->locations.erase(var.name);
192                 if(var.linked_declaration)
193                         var.linked_declaration->linked_declaration = 0;
194         }
195         else if(var.init_expression && to_remove->count(var.init_expression.get()))
196                 var.init_expression = 0;
197 }
198
199 void NodeRemover::visit(FunctionDeclaration &func)
200 {
201         if(to_remove->count(&func))
202         {
203                 remove_from_map(stage->functions, func.name, func);
204                 if(!func.signature.empty())
205                         remove_from_map(stage->functions, func.name+func.signature, func);
206         }
207         TraversingVisitor::visit(func);
208 }
209
210 void NodeRemover::visit(Iteration &iter)
211 {
212         if(to_remove->count(iter.init_statement.get()))
213                 iter.init_statement = 0;
214         TraversingVisitor::visit(iter);
215 }
216
217
218 void NodeReorderer::apply(Stage &stage, Node &before, const set<Node *> &tr)
219 {
220         reorder_before = &before;
221         to_reorder = &tr;
222         stage.content.visit(*this);
223 }
224
225 void NodeReorderer::visit(Block &block)
226 {
227         auto insert_point = block.body.end();
228         for(auto i=block.body.begin(); i!=block.body.end(); )
229         {
230                 (*i)->visit(*this);
231                 if(insert_point!=block.body.end() && to_reorder->count(i->get()))
232                         block.body.splice(insert_point, block.body, i++);
233                 else
234                 {
235                         if(i->get()==reorder_before)
236                                 insert_point = i;
237                         ++i;
238                 }
239         }
240 }
241
242 } // namespace SL
243 } // namespace GL
244 } // namespace Msp