]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/optimize.h
Fix a name conflict in certain inlining scenarios
[libs/gl.git] / source / glsl / optimize.h
1 #ifndef MSP_GL_SL_OPTIMIZE_H_
2 #define MSP_GL_SL_OPTIMIZE_H_
3
4 #include <map>
5 #include <set>
6 #include "visitor.h"
7
8 namespace Msp {
9 namespace GL {
10 namespace SL {
11
12 /** Assigns values to specialization constants, turning them into normal
13 constants. */
14 class ConstantSpecializer: private TraversingVisitor
15 {
16 private:
17         const std::map<std::string, int> *values = 0;
18
19 public:
20         void apply(Stage &, const std::map<std::string, int> &);
21
22 private:
23         virtual void visit(VariableDeclaration &);
24 };
25
26 /** Finds functions which are candidates for inlining.  Currently this means
27 functions which have no flow control statements, no more than one return
28 statement, and are either builtins or only called once. */
29 class InlineableFunctionLocator: private TraversingVisitor
30 {
31 private:
32         std::map<FunctionDeclaration *, unsigned> refcounts;
33         std::set<FunctionDeclaration *> inlineable;
34         FunctionDeclaration *current_function = 0;
35         unsigned return_count = 0;
36
37 public:
38         std::set<FunctionDeclaration *> apply(Stage &s) { s.content.visit(*this); return inlineable; }
39
40 private:
41         virtual void visit(FunctionCall &);
42         virtual void visit(FunctionDeclaration &);
43         virtual void visit(Conditional &);
44         virtual void visit(Iteration &);
45         virtual void visit(Return &);
46 };
47
48 /** Injects statements from one function into another.  Local variables are
49 renamed to avoid conflicts.  After inlining, uses NodeReorderer to cause
50 dependencies of the inlined statements to appear before the target function. */
51 class InlineContentInjector: private TraversingVisitor
52 {
53 private:
54         enum Pass
55         {
56                 REFERENCED,
57                 INLINE,
58                 RENAME
59         };
60
61         FunctionDeclaration *source_func = 0;
62         Block staging_block;
63         Pass pass = REFERENCED;
64         RefPtr<Statement> r_inlined_statement;
65         std::set<Node *> dependencies;
66         std::set<std::string> referenced_names;
67         std::string r_result_name;
68
69 public:
70         std::string apply(Stage &, FunctionDeclaration &, Block &, const NodeList<Statement>::iterator &, FunctionCall &);
71
72 private:
73         virtual void visit(VariableReference &);
74         virtual void visit(InterfaceBlockReference &);
75         virtual void visit(FunctionCall &);
76         virtual void visit(VariableDeclaration &);
77         virtual void visit(Return &);
78 };
79
80 /** Inlines functions.  Internally uses InlineableFunctionLocator to find
81 candidate functions.  Only functions which consist of a single return statement
82 are inlined. */
83 class FunctionInliner: private TraversingVisitor
84 {
85 private:
86         Stage *stage = 0;
87         std::set<FunctionDeclaration *> inlineable;
88         FunctionDeclaration *current_function = 0;
89         NodeList<Statement>::iterator insert_point;
90         RefPtr<Expression> r_inline_result;
91         bool r_any_inlined = false;
92         bool r_inlined_here = false;
93
94 public:
95         bool apply(Stage &);
96
97 private:
98         virtual void visit(RefPtr<Expression> &);
99         virtual void visit(Block &);
100         virtual void visit(FunctionCall &);
101         virtual void visit(FunctionDeclaration &);
102         virtual void visit(Iteration &);
103 };
104
105 /** Inlines variables into expressions.  Variables with trivial values (those
106 consisting of a single literal or variable reference) are always inlined.
107 Variables which are only referenced once are also inlined. */
108 class ExpressionInliner: private TraversingVisitor
109 {
110 private:
111         struct ExpressionUse
112         {
113                 RefPtr<Expression> *reference = 0;
114                 Block *ref_scope = 0;
115                 bool blocked = false;
116         };
117
118         struct ExpressionInfo
119         {
120                 Assignment::Target target;
121                 RefPtr<Expression> expression;
122                 Block *assign_scope = 0;
123                 std::vector<ExpressionUse> uses;
124                 bool trivial = false;
125                 bool blocked = false;
126         };
127
128         std::list<ExpressionInfo> expressions;
129         std::map<Assignment::Target, ExpressionInfo *> assignments;
130         ExpressionInfo *r_ref_info = 0;
131         bool r_trivial = false;
132         bool access_read = true;
133         bool access_write = false;
134         bool iteration_init = false;
135         Block *iteration_body = 0;
136         const Operator *r_oper = 0;
137
138 public:
139         bool apply(Stage &);
140
141 private:
142         virtual void visit(RefPtr<Expression> &);
143         virtual void visit(VariableReference &);
144         virtual void visit(MemberAccess &);
145         virtual void visit(Swizzle &);
146         virtual void visit(UnaryExpression &);
147         virtual void visit(BinaryExpression &);
148         virtual void visit(Assignment &);
149         virtual void visit(TernaryExpression &);
150         virtual void visit(FunctionCall &);
151         virtual void visit(VariableDeclaration &);
152         virtual void visit(Iteration &);
153 };
154
155 /**
156 Breaks aggregates up into separate variables if only the individual fields are
157 accessed and not the aggregate as a whole.
158 */
159 class AggregateDismantler: public TraversingVisitor
160 {
161 private:
162         struct AggregateMember
163         {
164                 const VariableDeclaration *declaration = 0;
165                 unsigned index = 0;
166                 RefPtr<Expression> initializer;
167                 std::vector<RefPtr<Expression> *> references;
168         };
169
170         struct Aggregate
171         {
172                 VariableDeclaration *declaration = 0;
173                 Block *decl_scope = 0;
174                 NodeList<Statement>::iterator insert_point;
175                 std::vector<AggregateMember> members;
176                 bool referenced = false;
177                 bool members_referenced = false;
178         };
179
180         NodeList<Statement>::iterator insert_point;
181         std::map<Statement *, Aggregate> aggregates;
182         bool composite_reference = false;
183         Assignment::Target r_reference;
184         Aggregate *r_aggregate_ref = 0;
185
186 public:
187         bool apply(Stage &);
188
189 private:
190         virtual void visit(Block &);
191         virtual void visit(RefPtr<Expression> &);
192         virtual void visit(VariableReference &);
193         void visit_composite(RefPtr<Expression> &);
194         virtual void visit(MemberAccess &);
195         virtual void visit(BinaryExpression &);
196         virtual void visit(StructDeclaration &) { }
197         virtual void visit(VariableDeclaration &);
198         virtual void visit(InterfaceBlock &) { }
199         virtual void visit(FunctionDeclaration &);
200 };
201
202 /** Replaces expressions consisting entirely of literals with the results of
203 evaluating the expression.*/
204 class ConstantFolder: private TraversingVisitor
205 {
206 private:
207         VariableDeclaration *iteration_var = 0;
208         Variant iter_init_value;
209         Variant r_constant_value;
210         bool iteration_init = false;
211         bool r_constant = false;
212         bool r_literal = false;
213         bool r_uses_iter_var = false;
214         bool r_any_folded = false;
215
216 public:
217         bool apply(Stage &s) { s.content.visit(*this); return r_any_folded; }
218
219 private:
220         template<typename T>
221         static T evaluate_logical(char, T, T);
222         template<typename T>
223         static bool evaluate_relation(const char *, T, T);
224         template<typename T>
225         static T evaluate_arithmetic(char, T, T);
226         template<typename T>
227         static T evaluate_int_special_op(char, T, T);
228         template<typename T>
229         void convert_to_result(const Variant &);
230         void set_result(const Variant &, bool = false);
231
232         virtual void visit(RefPtr<Expression> &);
233         virtual void visit(Literal &);
234         virtual void visit(VariableReference &);
235         virtual void visit(MemberAccess &);
236         virtual void visit(Swizzle &);
237         virtual void visit(UnaryExpression &);
238         virtual void visit(BinaryExpression &);
239         virtual void visit(Assignment &);
240         virtual void visit(TernaryExpression &);
241         virtual void visit(FunctionCall &);
242         virtual void visit(VariableDeclaration &);
243         virtual void visit(Iteration &);
244 };
245
246 /** Removes conditional statements and loops where the condition can be
247 determined as constant at compile time.  Also removes such statements where
248 the body is empty and the condition has no side effects. */
249 class ConstantConditionEliminator: private TraversingVisitor
250 {
251 private:
252         enum ConstantStatus
253         {
254                 CONSTANT_FALSE,
255                 CONSTANT_TRUE,
256                 NOT_CONSTANT
257         };
258
259         NodeList<Statement>::iterator insert_point;
260         std::set<Node *> nodes_to_remove;
261         RefPtr<Expression> r_ternary_result;
262         bool r_external_side_effects = false;
263
264 public:
265         void apply(Stage &);
266
267 private:
268         ConstantStatus check_constant_condition(const Expression &);
269
270         virtual void visit(Block &);
271         virtual void visit(RefPtr<Expression> &);
272         virtual void visit(UnaryExpression &);
273         virtual void visit(Assignment &);
274         virtual void visit(TernaryExpression &);
275         virtual void visit(FunctionCall &);
276         virtual void visit(Conditional &);
277         virtual void visit(Iteration &);
278 };
279
280 /** Removes code which is never executed due to flow control statements. */
281 class UnreachableCodeRemover: private TraversingVisitor
282 {
283 private:
284         bool reachable = true;
285         std::set<Node *> unreachable_nodes;
286
287 public:
288         virtual bool apply(Stage &);
289
290 private:
291         virtual void visit(Block &);
292         virtual void visit(FunctionDeclaration &);
293         virtual void visit(Conditional &);
294         virtual void visit(Iteration &);
295         virtual void visit(Return &) { reachable = false; }
296         virtual void visit(Jump &) { reachable = false; }
297 };
298
299 /** Removes types which are not used anywhere. */
300 class UnusedTypeRemover: private TraversingVisitor
301 {
302 private:
303         std::set<Node *> unused_nodes;
304
305 public:
306         bool apply(Stage &);
307
308 private:
309         virtual void visit(RefPtr<Expression> &);
310         virtual void visit(BasicTypeDeclaration &);
311         virtual void visit(ImageTypeDeclaration &);
312         virtual void visit(StructDeclaration &);
313         virtual void visit(VariableDeclaration &);
314         virtual void visit(InterfaceBlock &);
315         virtual void visit(FunctionDeclaration &);
316 };
317
318 /** Removes variable declarations with no references to them.  Assignment
319 statements where the result is not used are also removed. */
320 class UnusedVariableRemover: private TraversingVisitor
321 {
322 private:
323         struct AssignmentInfo
324         {
325                 Node *node = 0;
326                 Assignment::Target target;
327                 std::vector<Node *> used_by;
328                 unsigned in_loop = 0;
329         };
330
331         struct VariableInfo
332         {
333                 std::vector<AssignmentInfo *> assignments;
334                 bool initialized = false;
335                 bool output = false;
336                 bool referenced = false;
337         };
338
339         typedef std::map<Statement *, VariableInfo> BlockVariableMap;
340
341         Stage *stage = 0;
342         BlockVariableMap variables;
343         std::list<AssignmentInfo> assignments;
344         Assignment *r_assignment = 0;
345         bool assignment_target = false;
346         bool r_side_effects = false;
347         bool in_struct = false;
348         bool composite_reference = false;
349         unsigned in_loop = 0;
350         std::vector<Node *> loop_ext_refs;
351         Assignment::Target r_reference;
352         std::set<Node *> unused_nodes;
353
354 public:
355         bool apply(Stage &);
356
357 private:
358         void referenced(const Assignment::Target &, Node &);
359         virtual void visit(VariableReference &);
360         virtual void visit(InterfaceBlockReference &);
361         void visit_composite(Expression &);
362         virtual void visit(MemberAccess &);
363         virtual void visit(Swizzle &);
364         virtual void visit(UnaryExpression &);
365         virtual void visit(BinaryExpression &);
366         virtual void visit(Assignment &);
367         virtual void visit(TernaryExpression &);
368         virtual void visit(FunctionCall &);
369         void record_assignment(const Assignment::Target &, Node &);
370         virtual void visit(ExpressionStatement &);
371         virtual void visit(StructDeclaration &);
372         virtual void visit(VariableDeclaration &);
373         virtual void visit(InterfaceBlock &);
374         void merge_variables(const BlockVariableMap &);
375         virtual void visit(FunctionDeclaration &);
376         virtual void visit(Conditional &);
377         virtual void visit(Iteration &);
378 };
379
380 /** Removes function declarations with no references to them. */
381 class UnusedFunctionRemover: private TraversingVisitor
382 {
383 private:
384         std::set<Node *> unused_nodes;
385         std::set<FunctionDeclaration *> used_definitions;
386
387 public:
388         bool apply(Stage &s);
389
390 private:
391         virtual void visit(FunctionCall &);
392         virtual void visit(FunctionDeclaration &);
393 };
394
395 } // namespace SL
396 } // namespace GL
397 } // namespace Msp
398
399 #endif