]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/spirv.cpp
Recognize composite constants when generating SPIR-V
[libs/gl.git] / source / glsl / spirv.cpp
1 #include <msp/core/maputils.h>
2 #include <msp/core/raii.h>
3 #include "reflect.h"
4 #include "spirv.h"
5
6 using namespace std;
7
8 namespace Msp {
9 namespace GL {
10 namespace SL {
11
12 const SpirVGenerator::BuiltinFunctionInfo SpirVGenerator::builtin_functions[] =
13 {
14         { "radians", "f", "GLSL.std.450", GLSL450_RADIANS, { 1 }, 0 },
15         { "degrees", "f", "GLSL.std.450", GLSL450_DEGREES, { 1 }, 0 },
16         { "sin", "f", "GLSL.std.450", GLSL450_SIN, { 1 }, 0 },
17         { "cos", "f", "GLSL.std.450", GLSL450_COS, { 1 }, 0 },
18         { "tan", "f", "GLSL.std.450", GLSL450_TAN, { 1 }, 0 },
19         { "asin", "f", "GLSL.std.450", GLSL450_ASIN, { 1 }, 0 },
20         { "acos", "f", "GLSL.std.450", GLSL450_ACOS, { 1 }, 0 },
21         { "atan", "f", "GLSL.std.450", GLSL450_ATAN, { 1 }, 0 },
22         { "atan", "ff", "GLSL.std.450", GLSL450_ATAN2, { 1, 2 }, 0 },
23         { "sinh", "f", "GLSL.std.450", GLSL450_SINH, { 1 }, 0 },
24         { "cosh", "f", "GLSL.std.450", GLSL450_COSH, { 1 }, 0 },
25         { "tanh", "f", "GLSL.std.450", GLSL450_TANH, { 1 }, 0 },
26         { "asinh", "f", "GLSL.std.450", GLSL450_ASINH, { 1 }, 0 },
27         { "acosh", "f", "GLSL.std.450", GLSL450_ACOSH, { 1 }, 0 },
28         { "atanh", "f", "GLSL.std.450", GLSL450_ATANH, { 1 }, 0 },
29         { "pow", "ff", "GLSL.std.450", GLSL450_POW, { 1, 2 }, 0 },
30         { "exp", "f", "GLSL.std.450", GLSL450_EXP, { 1 }, 0 },
31         { "log", "f", "GLSL.std.450", GLSL450_LOG, { 1 }, 0 },
32         { "exp2", "f", "GLSL.std.450", GLSL450_EXP2, { 1 }, 0 },
33         { "log2", "f", "GLSL.std.450", GLSL450_LOG2, { 1 }, 0 },
34         { "sqrt", "f", "GLSL.std.450", GLSL450_SQRT, { 1 }, 0 },
35         { "inversesqrt", "f", "GLSL.std.450", GLSL450_INVERSE_SQRT, { 1 }, 0 },
36         { "abs", "f", "GLSL.std.450", GLSL450_F_ABS, { 1 }, 0 },
37         { "abs", "i", "GLSL.std.450", GLSL450_S_ABS, { 1 }, 0 },
38         { "sign", "f", "GLSL.std.450", GLSL450_F_SIGN, { 1 }, 0 },
39         { "sign", "i", "GLSL.std.450", GLSL450_S_SIGN, { 1 }, 0 },
40         { "floor", "f", "GLSL.std.450", GLSL450_FLOOR, { 1 }, 0 },
41         { "trunc", "f", "GLSL.std.450", GLSL450_TRUNC, { 1 }, 0 },
42         { "round", "f", "GLSL.std.450", GLSL450_ROUND, { 1 }, 0 },
43         { "roundEven", "f", "GLSL.std.450", GLSL450_ROUND_EVEN, { 1 }, 0 },
44         { "ceil", "f", "GLSL.std.450", GLSL450_CEIL, { 1 }, 0 },
45         { "fract", "f", "GLSL.std.450", GLSL450_FRACT, { 1 }, 0 },
46         { "mod", "f", "", OP_F_MOD, { 1, 2 }, 0 },
47         { "min", "ff", "GLSL.std.450", GLSL450_F_MIN, { 1, 2 }, 0 },
48         { "min", "ii", "GLSL.std.450", GLSL450_S_MIN, { 1, 2 }, 0 },
49         { "max", "ff", "GLSL.std.450", GLSL450_F_MAX, { 1, 2 }, 0 },
50         { "max", "ii", "GLSL.std.450", GLSL450_S_MAX, { 1, 2 }, 0 },
51         { "clamp", "fff", "GLSL.std.450", GLSL450_F_CLAMP, { 1, 2, 3 }, 0 },
52         { "clamp", "iii", "GLSL.std.450", GLSL450_S_CLAMP, { 1, 2, 3 }, 0 },
53         { "mix", "fff", "GLSL.std.450", GLSL450_F_MIX, { 1, 2, 3 }, 0 },
54         { "mix", "ffb", "", OP_SELECT, { 3, 2, 1 }, 0 },
55         { "mix", "iib", "", OP_SELECT, { 3, 2, 1 }, 0 },
56         { "step", "ff", "GLSL.std.450", GLSL450_F_STEP, { 1, 2 }, 0 },
57         { "smoothstep", "fff", "GLSL.std.450", GLSL450_F_SMOOTH_STEP, { 1, 2, 3 }, 0 },
58         { "isnan", "f", "", OP_IS_NAN, { 1 }, 0 },
59         { "isinf", "f", "", OP_IS_INF, { 1 }, 0 },
60         { "fma", "fff", "GLSL.std.450", GLSL450_F_FMA, { 1, 2, 3 }, 0 },
61         { "length", "f", "GLSL.std.450", GLSL450_LENGTH, { 1 }, 0 },
62         { "distance", "ff", "GLSL.std.450", GLSL450_DISTANCE, { 1, 2 }, 0 },
63         { "dot", "ff", "", OP_DOT, { 1, 2 }, 0 },
64         { "cross", "ff", "GLSL.std.450", GLSL450_CROSS, { 1, 2 }, 0 },
65         { "normalize", "f", "GLSL.std.450", GLSL450_NORMALIZE, { 1 }, 0 },
66         { "faceforward", "fff", "GLSL.std.450", GLSL450_FACE_FORWARD, { 1, 2, 3 }, 0 },
67         { "reflect", "ff", "GLSL.std.450", GLSL450_REFLECT, { 1, 2 }, 0 },
68         { "refract", "fff", "GLSL.std.450", GLSL450_REFRACT, { 1, 2, 3 }, 0 },
69         { "matrixCompMult", "ff", "", 0, { 0 }, &SpirVGenerator::visit_builtin_matrix_comp_mult },
70         { "outerProduct", "ff", "", OP_OUTER_PRODUCT, { 1, 2 }, 0 },
71         { "transpose", "f", "", OP_TRANSPOSE, { 1 }, 0 },
72         { "determinant", "f", "GLSL.std.450", GLSL450_DETERMINANT, { 1 }, 0 },
73         { "inverse", "f", "GLSL.std.450", GLSL450_MATRIX_INVERSE, { 1 }, 0 },
74         { "lessThan", "ff", "", OP_F_ORD_LESS_THAN, { 1, 2 }, 0 },
75         { "lessThan", "ii", "", OP_S_LESS_THAN, { 1, 2 }, 0 },
76         { "lessThanEqual", "ff", "", OP_F_ORD_LESS_THAN_EQUAL, { 1, 2 }, 0 },
77         { "lessThanEqual", "ii", "", OP_S_LESS_THAN_EQUAL, { 1, 2 }, 0 },
78         { "greaterThan", "ff", "", OP_F_ORD_GREATER_THAN, { 1, 2 }, 0 },
79         { "greaterThan", "ii", "", OP_S_GREATER_THAN, { 1, 2 }, 0 },
80         { "greaterThanEqual", "ff", "", OP_F_ORD_GREATER_THAN_EQUAL, { 1, 2 }, 0 },
81         { "greaterThanEqual", "ii", "", OP_S_GREATER_THAN_EQUAL, { 1, 2 }, 0 },
82         { "equal", "ff", "", OP_F_ORD_EQUAL, { 1, 2 }, 0 },
83         { "equal", "ii", "", OP_I_EQUAL, { 1, 2 }, 0 },
84         { "notEqual", "ff", "", OP_F_ORD_NOT_EQUAL, { 1, 2 }, 0 },
85         { "notEqual", "ii", "", OP_I_NOT_EQUAL, { 1, 2 }, 0 },
86         { "any", "b", "", OP_ANY, { 1 }, 0 },
87         { "all", "b", "", OP_ALL, { 1 }, 0 },
88         { "not", "b", "", OP_LOGICAL_NOT, { 1 }, 0 },
89         { "bitfieldExtract", "iii", "", OP_BIT_FIELD_S_EXTRACT, { 1, 2, 3 }, 0 },
90         { "bitfieldInsert", "iiii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0 },
91         { "bitfieldReverse", "i", "", OP_BIT_REVERSE, { 1 }, 0 },
92         { "bitCount", "i", "", OP_BIT_COUNT, { 1 }, 0 },
93         { "findLSB", "i", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0 },
94         { "findMSB", "i", "GLSL.std.450", GLSL450_FIND_S_MSB, { 1 }, 0 },
95         { "textureSize", "", "", OP_IMAGE_QUERY_SIZE_LOD, { 1, 2 }, 0 },
96         { "texture", "", "", 0, { }, &SpirVGenerator::visit_builtin_texture },
97         { "textureLod", "", "", 0, { }, &SpirVGenerator::visit_builtin_texture },
98         { "texelFetch", "", "", 0, { }, &SpirVGenerator::visit_builtin_texel_fetch },
99         { "EmitVertex", "", "", OP_EMIT_VERTEX, { }, 0 },
100         { "EndPrimitive", "", "", OP_END_PRIMITIVE, { }, 0 },
101         { "dFdx", "f", "", OP_DP_DX, { 1 }, 0 },
102         { "dFdy", "f", "", OP_DP_DY, { 1 }, 0 },
103         { "dFdxFine", "f", "", OP_DP_DX_FINE, { 1 }, 0 },
104         { "dFdyFine", "f", "", OP_DP_DY_FINE, { 1 }, 0 },
105         { "dFdxCoarse", "f", "", OP_DP_DX_COARSE, { 1 }, 0 },
106         { "dFdyCoarse", "f", "", OP_DP_DY_COARSE, { 1 }, 0 },
107         { "fwidth", "f", "", OP_FWIDTH, { 1 }, 0 },
108         { "fwidthFine", "f", "", OP_FWIDTH_FINE, { 1 }, 0 },
109         { "fwidthCoarse", "f", "", OP_FWIDTH_COARSE, { 1 }, 0 },
110         { "interpolateAtCentroid", "", "", 0, { }, &SpirVGenerator::visit_builtin_interpolate },
111         { "interpolateAtSample", "", "", 0, { }, &SpirVGenerator::visit_builtin_interpolate },
112         { "interpolateAtOffset", "", "", 0, { }, &SpirVGenerator::visit_builtin_interpolate },
113         { "", "", "", 0, { }, 0 }
114 };
115
116 SpirVGenerator::SpirVGenerator():
117         stage(0),
118         current_function(0),
119         writer(content),
120         next_id(1),
121         r_expression_result_id(0),
122         constant_expression(false),
123         spec_constant(false),
124         reachable(false),
125         composite_access(false),
126         r_composite_base_id(0),
127         r_composite_base(0),
128         assignment_source_id(0),
129         loop_merge_block_id(0),
130         loop_continue_target_id(0)
131 { }
132
133 void SpirVGenerator::apply(Module &module)
134 {
135         use_capability(CAP_SHADER);
136
137         for(list<Stage>::iterator i=module.stages.begin(); i!=module.stages.end(); ++i)
138         {
139                 stage = &*i;
140                 interface_layouts.clear();
141                 i->content.visit(*this);
142         }
143
144         writer.finalize(SPIRV_GENERATOR_MSP, next_id);
145 }
146
147 SpirVGenerator::StorageClass SpirVGenerator::get_interface_storage(const string &iface, bool block)
148 {
149         if(iface=="in")
150                 return STORAGE_INPUT;
151         else if(iface=="out")
152                 return STORAGE_OUTPUT;
153         else if(iface=="uniform")
154                 return (block ? STORAGE_UNIFORM : STORAGE_UNIFORM_CONSTANT);
155         else if(iface.empty())
156                 return STORAGE_PRIVATE;
157         else
158                 throw invalid_argument("SpirVGenerator::get_interface_storage");
159 }
160
161 SpirVGenerator::BuiltinSemantic SpirVGenerator::get_builtin_semantic(const string &name)
162 {
163         if(name=="gl_Position")
164                 return BUILTIN_POSITION;
165         else if(name=="gl_PointSize")
166                 return BUILTIN_POINT_SIZE;
167         else if(name=="gl_ClipDistance")
168                 return BUILTIN_CLIP_DISTANCE;
169         else if(name=="gl_VertexID")
170                 return BUILTIN_VERTEX_ID;
171         else if(name=="gl_InstanceID")
172                 return BUILTIN_INSTANCE_ID;
173         else if(name=="gl_PrimitiveID" || name=="gl_PrimitiveIDIn")
174                 return BUILTIN_PRIMITIVE_ID;
175         else if(name=="gl_InvocationID")
176                 return BUILTIN_INVOCATION_ID;
177         else if(name=="gl_Layer")
178                 return BUILTIN_LAYER;
179         else if(name=="gl_FragCoord")
180                 return BUILTIN_FRAG_COORD;
181         else if(name=="gl_PointCoord")
182                 return BUILTIN_POINT_COORD;
183         else if(name=="gl_FrontFacing")
184                 return BUILTIN_FRONT_FACING;
185         else if(name=="gl_SampleId")
186                 return BUILTIN_SAMPLE_ID;
187         else if(name=="gl_SamplePosition")
188                 return BUILTIN_SAMPLE_POSITION;
189         else if(name=="gl_FragDepth")
190                 return BUILTIN_FRAG_DEPTH;
191         else
192                 throw invalid_argument("SpirVGenerator::get_builtin_semantic");
193 }
194
195 void SpirVGenerator::use_capability(Capability cap)
196 {
197         if(used_capabilities.count(cap))
198                 return;
199
200         used_capabilities.insert(cap);
201         writer.write_op(content.capabilities, OP_CAPABILITY, cap);
202 }
203
204 SpirVGenerator::Id SpirVGenerator::import_extension(const string &name)
205 {
206         Id &ext_id = imported_extension_ids[name];
207         if(!ext_id)
208         {
209                 ext_id = next_id++;
210                 writer.begin_op(content.extensions, OP_EXT_INST_IMPORT);
211                 writer.write(ext_id);
212                 writer.write_string(name);
213                 writer.end_op(OP_EXT_INST_IMPORT);
214         }
215         return ext_id;
216 }
217
218 SpirVGenerator::Id SpirVGenerator::get_id(Node &node) const
219 {
220         return get_item(declared_ids, &node).id;
221 }
222
223 SpirVGenerator::Id SpirVGenerator::allocate_id(Node &node, Id type_id)
224 {
225         Id id = next_id++;
226         insert_unique(declared_ids, &node, Declaration(id, type_id));
227         return id;
228 }
229
230 SpirVGenerator::Id SpirVGenerator::write_constant(Id type_id, Word value, bool spec)
231 {
232         Id const_id = next_id++;
233         if(is_scalar_type(type_id, BasicTypeDeclaration::BOOL))
234         {
235                 Opcode opcode = (value ? (spec ? OP_SPEC_CONSTANT_TRUE : OP_CONSTANT_TRUE) :
236                         (spec ? OP_SPEC_CONSTANT_FALSE : OP_CONSTANT_FALSE));
237                 writer.write_op(content.globals, opcode, type_id, const_id);
238         }
239         else
240         {
241                 Opcode opcode = (spec ? OP_SPEC_CONSTANT : OP_CONSTANT);
242                 writer.write_op(content.globals, opcode, type_id, const_id, value);
243         }
244         return const_id;
245 }
246
247 SpirVGenerator::ConstantKey SpirVGenerator::get_constant_key(Id type_id, const Variant &value)
248 {
249         if(value.check_type<bool>())
250                 return ConstantKey(type_id, value.value<bool>());
251         else if(value.check_type<int>())
252                 return ConstantKey(type_id, value.value<int>());
253         else if(value.check_type<float>())
254                 return ConstantKey(type_id, value.value<float>());
255         else
256                 throw invalid_argument("SpirVGenerator::get_constant_key");
257 }
258
259 SpirVGenerator::Id SpirVGenerator::get_constant_id(Id type_id, const Variant &value)
260 {
261         ConstantKey key = get_constant_key(type_id, value);
262         Id &const_id = constant_ids[key];
263         if(!const_id)
264                 const_id = write_constant(type_id, key.int_value, false);
265         return const_id;
266 }
267
268 SpirVGenerator::Id SpirVGenerator::get_vector_constant_id(Id type_id, unsigned size, Id scalar_id)
269 {
270         Id &const_id = constant_ids[get_constant_key(type_id, static_cast<int>(scalar_id))];
271         if(!const_id)
272         {
273                 const_id = next_id++;
274                 writer.begin_op(content.globals, OP_CONSTANT_COMPOSITE, 3+size);
275                 writer.write(type_id);
276                 writer.write(const_id);
277                 for(unsigned i=0; i<size; ++i)
278                         writer.write(scalar_id);
279                 writer.end_op(OP_CONSTANT_COMPOSITE);
280         }
281         return const_id;
282 }
283
284 SpirVGenerator::Id SpirVGenerator::get_standard_type_id(BasicTypeDeclaration::Kind kind, unsigned size)
285 {
286         Id base_id = (size>1 ? get_standard_type_id(kind, 1) : 0);
287         Id &type_id = standard_type_ids[TypeKey(base_id, (size>1 ? size : static_cast<unsigned>(kind)))];
288         if(!type_id)
289         {
290                 type_id = next_id++;
291                 if(size>1)
292                         writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, base_id, size);
293                 else if(kind==BasicTypeDeclaration::VOID)
294                         writer.write_op(content.globals, OP_TYPE_VOID, type_id);
295                 else if(kind==BasicTypeDeclaration::BOOL)
296                         writer.write_op(content.globals, OP_TYPE_BOOL, type_id);
297                 else if(kind==BasicTypeDeclaration::INT)
298                         writer.write_op(content.globals, OP_TYPE_INT, type_id, 32, 1);
299                 else if(kind==BasicTypeDeclaration::FLOAT)
300                         writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, 32);
301                 else
302                         throw invalid_argument("SpirVGenerator::get_standard_type_id");
303         }
304         return type_id;
305 }
306
307 bool SpirVGenerator::is_scalar_type(Id type_id, BasicTypeDeclaration::Kind kind) const
308 {
309         map<TypeKey, Id>::const_iterator i = standard_type_ids.find(TypeKey(0, kind));
310         return (i!=standard_type_ids.end() && i->second==type_id);
311 }
312
313 SpirVGenerator::Id SpirVGenerator::get_array_type_id(TypeDeclaration &base_type, Id size_id)
314 {
315         Id base_type_id = get_id(base_type);
316         Id &array_type_id = array_type_ids[TypeKey(base_type_id, size_id)];
317         if(!array_type_id)
318         {
319                 array_type_id = next_id++;
320                 if(size_id)
321                         writer.write_op(content.globals, OP_TYPE_ARRAY, array_type_id, base_type_id, size_id);
322                 else
323                         writer.write_op(content.globals, OP_TYPE_RUNTIME_ARRAY, array_type_id, base_type_id);
324
325                 unsigned stride = MemoryRequirementsCalculator().apply(base_type).stride;
326                 writer.write_op_decorate(array_type_id, DECO_ARRAY_STRIDE, stride);
327         }
328
329         return array_type_id;
330 }
331
332 SpirVGenerator::Id SpirVGenerator::get_pointer_type_id(Id type_id, StorageClass storage)
333 {
334         Id &ptr_type_id = pointer_type_ids[TypeKey(type_id, storage)];
335         if(!ptr_type_id)
336         {
337                 ptr_type_id = next_id++;
338                 writer.write_op(content.globals, OP_TYPE_POINTER, ptr_type_id, storage, type_id);
339         }
340         return ptr_type_id;
341 }
342
343 SpirVGenerator::Id SpirVGenerator::get_variable_type_id(const VariableDeclaration &var)
344 {
345         if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var.type_declaration))
346                 if(basic->kind==BasicTypeDeclaration::ARRAY)
347                 {
348                         Id size_id = 0;
349                         if(var.array_size)
350                         {
351                                 SetFlag set_const(constant_expression);
352                                 r_expression_result_id = 0;
353                                 var.array_size->visit(*this);
354                                 size_id = r_expression_result_id;
355                         }
356                         else
357                                 size_id = get_constant_id(get_standard_type_id(BasicTypeDeclaration::INT, 1), 1);
358                         return get_array_type_id(*basic->base_type, size_id);
359                 }
360
361         return get_id(*var.type_declaration);
362 }
363
364 SpirVGenerator::Id SpirVGenerator::get_load_id(VariableDeclaration &var)
365 {
366         Id &load_result_id = variable_load_ids[&var];
367         if(!load_result_id)
368         {
369                 load_result_id = next_id++;
370                 writer.write_op(content.function_body, OP_LOAD, get_variable_type_id(var), load_result_id, get_id(var));
371         }
372         return load_result_id;
373 }
374
375 void SpirVGenerator::prune_loads(Id min_id)
376 {
377         for(map<const VariableDeclaration *, Id>::iterator i=variable_load_ids.begin(); i!=variable_load_ids.end(); )
378         {
379                 if(i->second>=min_id)
380                         variable_load_ids.erase(i++);
381                 else
382                         ++i;
383         }
384 }
385
386 SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, unsigned n_args)
387 {
388         bool has_result = (opcode==OP_FUNCTION_CALL || !is_scalar_type(type_id, BasicTypeDeclaration::VOID));
389         if(!constant_expression)
390         {
391                 if(!current_function)
392                         throw internal_error("non-constant expression outside a function");
393
394                 writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0));
395         }
396         else if(opcode==OP_COMPOSITE_CONSTRUCT)
397                 writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE),
398                         (n_args ? 1+has_result*2+n_args : 0));
399         else if(!spec_constant)
400                 throw internal_error("invalid non-specialization constant expression");
401         else
402                 writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0));
403
404         Id result_id = next_id++;
405         if(has_result)
406         {
407                 writer.write(type_id);
408                 writer.write(result_id);
409         }
410         if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT)
411                 writer.write(opcode);
412
413         return result_id;
414 }
415
416 void SpirVGenerator::end_expression(Opcode opcode)
417 {
418         if(constant_expression)
419                 opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
420         writer.end_op(opcode);
421 }
422
423 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id arg_id)
424 {
425         Id result_id = begin_expression(opcode, type_id, 1);
426         writer.write(arg_id);
427         end_expression(opcode);
428         return result_id;
429 }
430
431 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id left_id, Id right_id)
432 {
433         Id result_id = begin_expression(opcode, type_id, 2);
434         writer.write(left_id);
435         writer.write(right_id);
436         end_expression(opcode);
437         return result_id;
438 }
439
440 void SpirVGenerator::write_deconstruct(Id elem_type_id, Id composite_id, Id *elem_ids, unsigned n_elems)
441 {
442         for(unsigned i=0; i<n_elems; ++i)
443         {
444                 elem_ids[i] = begin_expression(OP_COMPOSITE_EXTRACT, elem_type_id, 2);
445                 writer.write(composite_id);
446                 writer.write(i);
447                 end_expression(OP_COMPOSITE_EXTRACT);
448         }
449 }
450
451 SpirVGenerator::Id SpirVGenerator::write_construct(Id type_id, const Id *elem_ids, unsigned n_elems)
452 {
453         Id result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, type_id, n_elems);
454         for(unsigned i=0; i<n_elems; ++i)
455                 writer.write(elem_ids[i]);
456         end_expression(OP_COMPOSITE_CONSTRUCT);
457
458         return result_id;
459 }
460
461 void SpirVGenerator::visit(Block &block)
462 {
463         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
464                 (*i)->visit(*this);
465 }
466
467 void SpirVGenerator::visit(Literal &literal)
468 {
469         Id type_id = get_id(*literal.type);
470         if(spec_constant)
471                 r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true);
472         else
473                 r_expression_result_id = get_constant_id(type_id, literal.value);
474         r_constant_result = true;
475 }
476
477 void SpirVGenerator::visit(VariableReference &var)
478 {
479         if(constant_expression || var.declaration->constant)
480         {
481                 if(!var.declaration->constant)
482                         throw internal_error("reference to non-constant variable in constant context");
483
484                 r_expression_result_id = get_id(*var.declaration);
485                 r_constant_result = true;
486                 return;
487         }
488         else if(!current_function)
489                 throw internal_error("non-constant context outside a function");
490
491         r_constant_result = false;
492         if(composite_access)
493         {
494                 r_composite_base = var.declaration;
495                 r_expression_result_id = 0;
496         }
497         else if(assignment_source_id)
498         {
499                 writer.write_op(content.function_body, OP_STORE, get_id(*var.declaration), assignment_source_id);
500                 variable_load_ids[var.declaration] = assignment_source_id;
501                 r_expression_result_id = assignment_source_id;
502         }
503         else
504                 r_expression_result_id = get_load_id(*var.declaration);
505 }
506
507 void SpirVGenerator::visit(InterfaceBlockReference &iface)
508 {
509         if(!composite_access || !current_function)
510                 throw internal_error("invalid interface block reference");
511
512         r_composite_base = iface.declaration;
513         r_expression_result_id = 0;
514         r_constant_result = false;
515 }
516
517 void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
518 {
519         Opcode opcode;
520         Id result_type_id = get_id(result_type);
521         Id access_type_id = result_type_id;
522         if(r_composite_base)
523         {
524                 if(constant_expression)
525                         throw internal_error("composite access through pointer in constant context");
526
527                 Id int32_type_id = get_standard_type_id(BasicTypeDeclaration::INT, 1);
528                 for(vector<unsigned>::iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
529                         *i = (*i<0x400000 ? get_constant_id(int32_type_id, static_cast<int>(*i)) : *i&0x3FFFFF);
530
531                 /* Find the storage class of the base and obtain appropriate pointer type
532                 for the result. */
533                 const Declaration &base_decl = get_item(declared_ids, r_composite_base);
534                 map<TypeKey, Id>::const_iterator i = pointer_type_ids.begin();
535                 for(; (i!=pointer_type_ids.end() && i->second!=base_decl.type_id); ++i) ;
536                 if(i==pointer_type_ids.end())
537                         throw internal_error("could not find storage class");
538                 access_type_id = get_pointer_type_id(result_type_id, static_cast<StorageClass>(i->first.detail));
539
540                 opcode = OP_ACCESS_CHAIN;
541         }
542         else if(assignment_source_id)
543                 throw internal_error("assignment to temporary composite");
544         else
545         {
546                 for(vector<unsigned>::iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
547                         for(map<ConstantKey, Id>::iterator j=constant_ids.begin(); (*i>=0x400000 && j!=constant_ids.end()); ++j)
548                                 if(j->second==(*i&0x3FFFFF))
549                                         *i = j->first.int_value;
550
551                 opcode = OP_COMPOSITE_EXTRACT;
552         }
553
554         Id access_id = begin_expression(opcode, access_type_id, 1+r_composite_chain.size());
555         writer.write(r_composite_base_id);
556         for(vector<unsigned>::const_iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
557                 writer.write(*i);
558         end_expression(opcode);
559
560         r_constant_result = false;
561         if(r_composite_base)
562         {
563                 if(assignment_source_id)
564                 {
565                         writer.write_op(content.function_body, OP_STORE, access_id, assignment_source_id);
566                         r_expression_result_id = assignment_source_id;
567                 }
568                 else
569                         r_expression_result_id = write_expression(OP_LOAD, result_type_id, access_id);
570         }
571         else
572                 r_expression_result_id = access_id;
573 }
574
575 void SpirVGenerator::visit_composite(Expression &base_expr, unsigned index, TypeDeclaration &type)
576 {
577         if(!composite_access)
578         {
579                 r_composite_base = 0;
580                 r_composite_base_id = 0;
581                 r_composite_chain.clear();
582         }
583
584         {
585                 SetFlag set_composite(composite_access);
586                 base_expr.visit(*this);
587         }
588
589         if(!r_composite_base_id)
590                 r_composite_base_id = (r_composite_base ? get_id(*r_composite_base) : r_expression_result_id);
591
592         r_composite_chain.push_back(index);
593         if(!composite_access)
594                 generate_composite_access(type);
595         else
596                 r_expression_result_id = 0;
597 }
598
599 void SpirVGenerator::visit_isolated(Expression &expr)
600 {
601         SetForScope<Id> clear_assign(assignment_source_id, 0);
602         SetFlag clear_composite(composite_access, false);
603         SetForScope<Node *> clear_base(r_composite_base, 0);
604         SetForScope<Id> clear_base_id(r_composite_base_id, 0);
605         vector<unsigned> saved_chain;
606         swap(saved_chain, r_composite_chain);
607         expr.visit(*this);
608         swap(saved_chain, r_composite_chain);
609 }
610
611 void SpirVGenerator::visit(MemberAccess &memacc)
612 {
613         visit_composite(*memacc.left, memacc.index, *memacc.type);
614 }
615
616 void SpirVGenerator::visit(Swizzle &swizzle)
617 {
618         if(swizzle.count==1)
619                 visit_composite(*swizzle.left, swizzle.components[0], *swizzle.type);
620         else if(assignment_source_id)
621         {
622                 const BasicTypeDeclaration &basic = dynamic_cast<const BasicTypeDeclaration &>(*swizzle.left->type);
623
624                 unsigned mask = 0;
625                 for(unsigned i=0; i<swizzle.count; ++i)
626                         mask |= 1<<swizzle.components[i];
627
628                 visit_isolated(*swizzle.left);
629
630                 Id combined_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.left->type), 2+basic.size);
631                 writer.write(r_expression_result_id);
632                 writer.write(assignment_source_id);
633                 for(unsigned i=0; i<basic.size; ++i)
634                         writer.write(i+((mask>>i)&1)*basic.size);
635                 end_expression(OP_VECTOR_SHUFFLE);
636
637                 SetForScope<Id> set_assign(assignment_source_id, combined_id);
638                 swizzle.left->visit(*this);
639
640                 r_expression_result_id = combined_id;
641         }
642         else
643         {
644                 swizzle.left->visit(*this);
645                 Id left_id = r_expression_result_id;
646
647                 r_expression_result_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.type), 2+swizzle.count);
648                 writer.write(left_id);
649                 writer.write(left_id);
650                 for(unsigned i=0; i<swizzle.count; ++i)
651                         writer.write(swizzle.components[i]);
652                 end_expression(OP_VECTOR_SHUFFLE);
653         }
654         r_constant_result = false;
655 }
656
657 void SpirVGenerator::visit(UnaryExpression &unary)
658 {
659         unary.expression->visit(*this);
660
661         char oper = unary.oper->token[0];
662         char oper2 = unary.oper->token[1];
663         if(oper=='+' && !oper2)
664                 return;
665
666         BasicTypeDeclaration &basic = dynamic_cast<BasicTypeDeclaration &>(*unary.expression->type);
667         BasicTypeDeclaration &elem = *get_element_type(basic);
668
669         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
670                 /* SPIR-V allows constant operations on floating-point values only for
671                 OpenGL kernels. */
672                 throw internal_error("invalid operands for constant unary expression");
673
674         Id result_type_id = get_id(*unary.type);
675         Opcode opcode = OP_NOP;
676
677         r_constant_result = false;
678         if(oper=='!')
679                 opcode = OP_LOGICAL_NOT;
680         else if(oper=='~')
681                 opcode = OP_NOT;
682         else if(oper=='-' && !oper2)
683         {
684                 opcode = (elem.kind==BasicTypeDeclaration::INT ? OP_S_NEGATE : OP_F_NEGATE);
685
686                 if(basic.kind==BasicTypeDeclaration::MATRIX)
687                 {
688                         Id column_type_id = get_id(*basic.base_type);
689                         unsigned n_columns = basic.size&0xFFFF;
690                         Id column_ids[4];
691                         write_deconstruct(column_type_id, r_expression_result_id, column_ids, n_columns);
692                         for(unsigned i=0; i<n_columns; ++i)
693                                 column_ids[i] = write_expression(opcode, column_type_id, column_ids[i]);
694                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
695                         return;
696                 }
697         }
698         else if((oper=='+' || oper=='-') && oper2==oper)
699         {
700                 if(constant_expression)
701                         throw internal_error("increment or decrement in constant expression");
702
703                 Id one_id = 0;
704                 if(elem.kind==BasicTypeDeclaration::INT)
705                 {
706                         opcode = (oper=='+' ? OP_I_ADD : OP_I_SUB);
707                         one_id = get_constant_id(get_id(elem), 1);
708                 }
709                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
710                 {
711                         opcode = (oper=='+' ? OP_F_ADD : OP_F_SUB);
712                         one_id = get_constant_id(get_id(elem), 1.0f);
713                 }
714                 else
715                         throw internal_error("invalid increment/decrement");
716
717                 if(basic.kind==BasicTypeDeclaration::VECTOR)
718                         one_id = get_vector_constant_id(result_type_id, basic.size, one_id);
719
720                 Id post_value_id = write_expression(opcode, result_type_id, r_expression_result_id, one_id);
721
722                 SetForScope<Id> set_assign(assignment_source_id, post_value_id);
723                 unary.expression->visit(*this);
724
725                 r_expression_result_id = (unary.oper->type==Operator::POSTFIX ? r_expression_result_id : post_value_id);
726                 return;
727         }
728
729         if(opcode==OP_NOP)
730                 throw internal_error("unknown unary operator");
731
732         r_expression_result_id = write_expression(opcode, result_type_id, r_expression_result_id);
733 }
734
735 void SpirVGenerator::visit(BinaryExpression &binary)
736 {
737         char oper = binary.oper->token[0];
738         if(oper=='[')
739         {
740                 visit_isolated(*binary.right);
741                 return visit_composite(*binary.left, 0x400000|r_expression_result_id, *binary.type);
742         }
743
744         if(assignment_source_id)
745                 throw internal_error("invalid binary expression in assignment target");
746
747         BasicTypeDeclaration &basic_left = dynamic_cast<BasicTypeDeclaration &>(*binary.left->type);
748         BasicTypeDeclaration &basic_right = dynamic_cast<BasicTypeDeclaration &>(*binary.right->type);
749         // Expression resolver ensures that element types are the same
750         BasicTypeDeclaration &elem = *get_element_type(basic_left);
751
752         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
753                 /* SPIR-V allows constant operations on floating-point values only for
754                 OpenGL kernels. */
755                 throw internal_error("invalid operands for constant binary expression");
756
757         binary.left->visit(*this);
758         Id left_id = r_expression_result_id;
759         binary.right->visit(*this);
760         Id right_id = r_expression_result_id;
761
762         Id result_type_id = get_id(*binary.type);
763         Opcode opcode = OP_NOP;
764         bool swap_operands = false;
765
766         r_constant_result = false;
767
768         char oper2 = binary.oper->token[1];
769         if((oper=='<' || oper=='>') && oper2!=oper)
770         {
771                 if(basic_left.kind==BasicTypeDeclaration::INT)
772                         opcode = (oper=='<' ? (oper2=='=' ? OP_S_LESS_THAN_EQUAL : OP_S_LESS_THAN) :
773                                 (oper2=='=' ? OP_S_GREATER_THAN_EQUAL : OP_S_GREATER_THAN));
774                 else if(basic_left.kind==BasicTypeDeclaration::FLOAT)
775                         opcode = (oper=='<' ? (oper2=='=' ? OP_F_ORD_LESS_THAN_EQUAL : OP_F_ORD_LESS_THAN) :
776                                 (oper2=='=' ? OP_F_ORD_GREATER_THAN_EQUAL : OP_F_ORD_GREATER_THAN));
777         }
778         else if((oper=='=' || oper=='!') && oper2=='=')
779         {
780                 if(elem.kind==BasicTypeDeclaration::BOOL)
781                         opcode = (oper=='=' ? OP_LOGICAL_EQUAL : OP_LOGICAL_NOT_EQUAL);
782                 else if(elem.kind==BasicTypeDeclaration::INT)
783                         opcode = (oper=='=' ? OP_I_EQUAL : OP_I_NOT_EQUAL);
784                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
785                         opcode = (oper=='=' ? OP_F_ORD_EQUAL : OP_F_ORD_NOT_EQUAL);
786
787                 if(opcode!=OP_NOP && basic_left.base_type)
788                 {
789                         /* The SPIR-V equality operations produce component-wise results, but
790                         GLSL operators return a single boolean.  Use the any/all operations to
791                         combine the results. */
792                         Opcode combine_op = (oper=='!' ? OP_ANY : OP_ALL);
793                         unsigned n_elems = basic_left.size&0xFFFF;
794                         Id bool_vec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, n_elems);
795
796                         Id compare_id = 0;
797                         if(basic_left.kind==BasicTypeDeclaration::VECTOR)
798                                 compare_id = write_expression(opcode, bool_vec_type_id, left_id, right_id);
799                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX)
800                         {
801                                 Id column_type_id = get_id(*basic_left.base_type);
802                                 Id column_ids[8];
803                                 write_deconstruct(column_type_id, left_id, column_ids, n_elems);
804                                 write_deconstruct(column_type_id, right_id, column_ids+4, n_elems);
805
806                                 Id column_bvec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, basic_left.size>>16);
807                                 for(unsigned i=0; i<n_elems; ++i)
808                                 {
809                                         compare_id = write_expression(opcode, column_bvec_type_id, column_ids[i], column_ids[4+i]);
810                                         column_ids[i] = write_expression(combine_op, result_type_id, compare_id);;
811                                 }
812
813                                 compare_id = write_construct(bool_vec_type_id, column_ids, n_elems);
814                         }
815
816                         if(compare_id)
817                                 r_expression_result_id = write_expression(combine_op, result_type_id, compare_id);
818                         return;
819                 }
820         }
821         else if(oper2=='&' && elem.kind==BasicTypeDeclaration::BOOL)
822                 opcode = OP_LOGICAL_AND;
823         else if(oper2=='|' && elem.kind==BasicTypeDeclaration::BOOL)
824                 opcode = OP_LOGICAL_OR;
825         else if(oper2=='^' && elem.kind==BasicTypeDeclaration::BOOL)
826                 opcode = OP_LOGICAL_NOT_EQUAL;
827         else if(oper=='&' && elem.kind==BasicTypeDeclaration::INT)
828                 opcode = OP_BITWISE_AND;
829         else if(oper=='|' && elem.kind==BasicTypeDeclaration::INT)
830                 opcode = OP_BITWISE_OR;
831         else if(oper=='^' && elem.kind==BasicTypeDeclaration::INT)
832                 opcode = OP_BITWISE_XOR;
833         else if(oper=='<' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
834                 opcode = OP_SHIFT_LEFT_LOGICAL;
835         else if(oper=='>' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
836                 opcode = OP_SHIFT_RIGHT_ARITHMETIC;
837         else if(oper=='%' && elem.kind==BasicTypeDeclaration::INT)
838                 opcode = OP_S_MOD;
839         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
840         {
841                 Opcode elem_op = OP_NOP;
842                 if(elem.kind==BasicTypeDeclaration::INT)
843                         elem_op = (oper=='+' ? OP_I_ADD : oper=='-' ? OP_I_SUB : oper=='*' ? OP_I_MUL : OP_S_DIV);
844                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
845                         elem_op = (oper=='+' ? OP_F_ADD : oper=='-' ? OP_F_SUB : oper=='*' ? OP_F_MUL : OP_F_DIV);
846
847                 if(oper=='*' && (basic_left.base_type || basic_right.base_type) && elem.kind==BasicTypeDeclaration::FLOAT)
848                 {
849                         /* Multiplication between floating-point vectors and matrices has
850                         dedicated operations. */
851                         if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
852                                 opcode = OP_MATRIX_TIMES_MATRIX;
853                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
854                         {
855                                 if(basic_left.kind==BasicTypeDeclaration::VECTOR)
856                                         opcode = OP_VECTOR_TIMES_MATRIX;
857                                 else if(basic_right.kind==BasicTypeDeclaration::VECTOR)
858                                         opcode = OP_MATRIX_TIMES_VECTOR;
859                                 else
860                                 {
861                                         opcode = OP_MATRIX_TIMES_SCALAR;
862                                         swap_operands = (basic_right.kind==BasicTypeDeclaration::MATRIX);
863                                 }
864                         }
865                         else if(basic_left.kind==BasicTypeDeclaration::VECTOR && basic_right.kind==BasicTypeDeclaration::VECTOR)
866                                 opcode = elem_op;
867                         else
868                         {
869                                 opcode = OP_VECTOR_TIMES_SCALAR;
870                                 swap_operands = (basic_right.kind==BasicTypeDeclaration::VECTOR);
871                         }
872                 }
873                 else if((basic_left.base_type!=0)!=(basic_right.base_type!=0))
874                 {
875                         /* One operand is scalar and the other is a vector or a matrix.
876                         Expand the scalar to a vector of appropriate size. */
877                         Id &scalar_id = (basic_left.base_type ? right_id : left_id);
878                         BasicTypeDeclaration *vector_type = (basic_left.base_type ? &basic_left : &basic_right);
879                         if(vector_type->kind==BasicTypeDeclaration::MATRIX)
880                                 vector_type = dynamic_cast<BasicTypeDeclaration *>(vector_type->base_type);
881                         Id vector_type_id = get_id(*vector_type);
882
883                         Id expanded_id = begin_expression(OP_COMPOSITE_CONSTRUCT, vector_type_id, vector_type->size);
884                         for(unsigned i=0; i<vector_type->size; ++i)
885                                 writer.write(scalar_id);
886                         end_expression(OP_COMPOSITE_CONSTRUCT);
887
888                         scalar_id = expanded_id;
889
890                         if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
891                         {
892                                 // Apply matrix operation column-wise.
893                                 Id matrix_id = (basic_left.base_type ? left_id : right_id);
894
895                                 Id column_ids[4];
896                                 unsigned n_columns = (basic_left.base_type ? basic_left.size : basic_right.size)&0xFFFF;
897                                 write_deconstruct(vector_type_id, matrix_id, column_ids, n_columns);
898
899                                 for(unsigned i=0; i<n_columns; ++i)
900                                         column_ids[i] = write_expression(elem_op, vector_type_id, column_ids[i], expanded_id);
901
902                                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
903                                 return;
904                         }
905                         else
906                                 opcode = elem_op;
907                 }
908                 else if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
909                 {
910                         if(oper=='*')
911                                 throw internal_error("non-float matrix multiplication");
912
913                         /* Other operations involving matrices need to be performed
914                         column-wise. */
915                         Id column_type_id = get_id(*basic_left.base_type);
916                         Id column_ids[8];
917
918                         unsigned n_columns = basic_left.size&0xFFFF;
919                         write_deconstruct(column_type_id, left_id, column_ids, n_columns);
920                         write_deconstruct(column_type_id, right_id, column_ids+4, n_columns);
921
922                         for(unsigned i=0; i<n_columns; ++i)
923                                 column_ids[i] = write_expression(elem_op, column_type_id, column_ids[i], column_ids[4+i]);
924
925                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
926                         return;
927                 }
928                 else if(basic_left.kind==basic_right.kind)
929                         // Both operands are either scalars or vectors.
930                         opcode = elem_op;
931         }
932
933         if(opcode==OP_NOP)
934                 throw internal_error("unknown binary operator");
935
936         if(swap_operands)
937                 swap(left_id, right_id);
938
939         r_expression_result_id = write_expression(opcode, result_type_id, left_id, right_id);
940 }
941
942 void SpirVGenerator::visit(Assignment &assign)
943 {
944         if(assign.oper->token[0]!='=')
945                 visit(static_cast<BinaryExpression &>(assign));
946         else
947                 assign.right->visit(*this);
948
949         SetForScope<Id> set_assign(assignment_source_id, r_expression_result_id);
950         assign.left->visit(*this);
951         r_constant_result = false;
952 }
953
954 void SpirVGenerator::visit(TernaryExpression &ternary)
955 {
956         if(constant_expression)
957         {
958                 ternary.condition->visit(*this);
959                 Id condition_id = r_expression_result_id;
960                 ternary.true_expr->visit(*this);
961                 Id true_result_id = r_expression_result_id;
962                 ternary.false_expr->visit(*this);
963                 Id false_result_id = r_expression_result_id;
964
965                 r_expression_result_id = begin_expression(OP_SELECT, get_id(*ternary.type), 3);
966                 writer.write(condition_id);
967                 writer.write(true_result_id);
968                 writer.write(false_result_id);
969                 end_expression(OP_SELECT);
970
971                 return;
972         }
973
974         ternary.condition->visit(*this);
975         Id condition_id = r_expression_result_id;
976
977         Id true_label_id = next_id++;
978         Id false_label_id = next_id++;
979         Id merge_block_id = next_id++;
980         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
981         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, condition_id, true_label_id, false_label_id);
982
983         writer.write_op_label(true_label_id);
984         ternary.true_expr->visit(*this);
985         Id true_result_id = r_expression_result_id;
986         writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
987
988         writer.write_op_label(false_label_id);
989         ternary.false_expr->visit(*this);
990         Id false_result_id = r_expression_result_id;
991
992         writer.write_op_label(merge_block_id);
993         r_expression_result_id = begin_expression(OP_PHI, get_id(*ternary.type), 4);
994         writer.write(true_result_id);
995         writer.write(true_label_id);
996         writer.write(false_result_id);
997         writer.write(false_label_id);
998         end_expression(OP_PHI);
999
1000         r_constant_result = false;
1001 }
1002
1003 void SpirVGenerator::visit(FunctionCall &call)
1004 {
1005         if(assignment_source_id)
1006                 throw internal_error("assignment to function call");
1007         else if(composite_access)
1008                 return visit_isolated(call);
1009         else if(call.constructor && call.arguments.size()==1 && call.arguments[0]->type==call.type)
1010                 return call.arguments[0]->visit(*this);
1011
1012         vector<Id> argument_ids;
1013         argument_ids.reserve(call.arguments.size());
1014         bool all_args_const = true;
1015         for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
1016         {
1017                 (*i)->visit(*this);
1018                 argument_ids.push_back(r_expression_result_id);
1019                 all_args_const &= r_constant_result;
1020         }
1021
1022         if(constant_expression && (!call.constructor || !all_args_const))
1023                 throw internal_error("function call in constant expression");
1024
1025         Id result_type_id = get_id(*call.type);
1026
1027         if(call.constructor)
1028                 visit_constructor(call, argument_ids, all_args_const);
1029         else if(call.declaration->source==BUILTIN_SOURCE)
1030         {
1031                 string arg_types;
1032                 for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
1033                         if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>((*i)->type))
1034                         {
1035                                 BasicTypeDeclaration &elem_arg = *get_element_type(*basic_arg);
1036                                 switch(elem_arg.kind)
1037                                 {
1038                                 case BasicTypeDeclaration::BOOL: arg_types += 'b'; break;
1039                                 case BasicTypeDeclaration::INT: arg_types += 'i'; break;
1040                                 case BasicTypeDeclaration::FLOAT: arg_types += 'f'; break;
1041                                 default: arg_types += '?';
1042                                 }
1043                         }
1044
1045                 const BuiltinFunctionInfo *builtin_info;
1046                 for(builtin_info=builtin_functions; builtin_info->function[0]; ++builtin_info)
1047                         if(builtin_info->function==call.name && (!builtin_info->arg_types[0] || builtin_info->arg_types==arg_types))
1048                                 break;
1049
1050                 if(builtin_info->opcode)
1051                 {
1052                         Opcode opcode;
1053                         if(builtin_info->extension[0])
1054                         {
1055                                 opcode = OP_EXT_INST;
1056                                 Id ext_id = import_extension(builtin_info->extension);
1057
1058                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1059                                 writer.write(ext_id);
1060                                 writer.write(builtin_info->opcode);
1061                         }
1062                         else
1063                         {
1064                                 opcode = static_cast<Opcode>(builtin_info->opcode);
1065                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1066                         }
1067
1068                         for(unsigned i=0; i<call.arguments.size(); ++i)
1069                         {
1070                                 if(!builtin_info->arg_order[i] || builtin_info->arg_order[i]>argument_ids.size())
1071                                         throw internal_error("invalid builtin function info");
1072                                 writer.write(argument_ids[builtin_info->arg_order[i]-1]);
1073                         }
1074
1075                         end_expression(opcode);
1076                 }
1077                 else if(builtin_info->handler)
1078                         (this->*(builtin_info->handler))(call, argument_ids);
1079                 else
1080                         throw internal_error("unknown builtin function "+call.name);
1081         }
1082         else
1083         {
1084                 r_expression_result_id = begin_expression(OP_FUNCTION_CALL, result_type_id, 1+call.arguments.size());
1085                 writer.write(get_id(*call.declaration->definition));
1086                 for(vector<Id>::const_iterator i=argument_ids.begin(); i!=argument_ids.end(); ++i)
1087                         writer.write(*i);
1088                 end_expression(OP_FUNCTION_CALL);
1089
1090                 // Any global variables the called function uses might have changed value
1091                 set<Node *> dependencies = DependencyCollector().apply(*call.declaration->definition);
1092                 for(set<Node *>::const_iterator i=dependencies.begin(); i!=dependencies.end(); ++i)
1093                         if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(*i))
1094                                 variable_load_ids.erase(var);
1095         }
1096 }
1097
1098 void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids, bool all_args_const)
1099 {
1100         Id result_type_id = get_id(*call.type);
1101
1102         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(call.type);
1103         if(!basic)
1104         {
1105                 if(dynamic_cast<const StructDeclaration *>(call.type))
1106                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1107                 else
1108                         throw internal_error("unconstructable type "+call.name);
1109                 return;
1110         }
1111
1112         SetFlag set_const(constant_expression, constant_expression || all_args_const);
1113
1114         BasicTypeDeclaration &elem = *get_element_type(*basic);
1115         BasicTypeDeclaration &basic_arg0 = dynamic_cast<BasicTypeDeclaration &>(*call.arguments[0]->type);
1116         BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0);
1117
1118         if(basic->kind==BasicTypeDeclaration::MATRIX)
1119         {
1120                 Id col_type_id = get_id(*basic->base_type);
1121                 unsigned n_columns = basic->size&0xFFFF;
1122                 unsigned n_rows = basic->size>>16;
1123
1124                 Id column_ids[4];
1125                 if(call.arguments.size()==1)
1126                 {
1127                         // Construct diagonal matrix from a single scalar.
1128                         Id zero_id = get_constant_id(get_id(elem), 0.0f);
1129                         for(unsigned i=0; i<n_columns; ++i)
1130                         {
1131                                 column_ids[i] = begin_expression(OP_COMPOSITE_CONSTRUCT, col_type_id, n_rows);;
1132                                 for(unsigned j=0; j<n_rows; ++j)
1133                                         writer.write(j==i ? argument_ids[0] : zero_id);
1134                                 end_expression(OP_COMPOSITE_CONSTRUCT);
1135                         }
1136                 }
1137                 else
1138                         // Construct a matrix from column vectors
1139                         copy(argument_ids.begin(), argument_ids.begin()+n_columns, column_ids);
1140
1141                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
1142         }
1143         else if(basic->kind==BasicTypeDeclaration::VECTOR && (call.arguments.size()>1 || basic_arg0.kind!=BasicTypeDeclaration::VECTOR))
1144         {
1145                 /* There's either a single scalar argument or multiple arguments
1146                 which make up the vector's components. */
1147                 if(call.arguments.size()==1)
1148                 {
1149                         r_expression_result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, result_type_id);
1150                         for(unsigned i=0; i<basic->size; ++i)
1151                                 writer.write(argument_ids[0]);
1152                         end_expression(OP_COMPOSITE_CONSTRUCT);
1153                 }
1154                 else
1155                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1156         }
1157         else if(elem.kind==BasicTypeDeclaration::BOOL)
1158         {
1159                 if(constant_expression)
1160                         throw internal_error("unconverted constant");
1161
1162                 // Conversion to boolean is implemented as comparing against zero.
1163                 Id number_type_id = get_id(elem_arg0);
1164                 Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ?
1165                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1166                 if(basic_arg0.kind==BasicTypeDeclaration::VECTOR)
1167                         zero_id = get_vector_constant_id(get_id(basic_arg0), basic_arg0.size, zero_id);
1168
1169                 Opcode opcode = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ? OP_F_ORD_NOT_EQUAL : OP_I_NOT_EQUAL);
1170                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0], zero_id);
1171         }
1172         else if(elem_arg0.kind==BasicTypeDeclaration::BOOL)
1173         {
1174                 if(constant_expression)
1175                         throw internal_error("unconverted constant");
1176
1177                 /* Conversion from boolean is implemented as selecting from zero
1178                 or one. */
1179                 Id number_type_id = get_id(elem);
1180                 Id zero_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1181                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1182                 Id one_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1183                         get_constant_id(number_type_id, 1.0f) : get_constant_id(number_type_id, 1));
1184                 if(basic->kind==BasicTypeDeclaration::VECTOR)
1185                 {
1186                         zero_id = get_vector_constant_id(get_id(*basic), basic->size, zero_id);
1187                         one_id = get_vector_constant_id(get_id(*basic), basic->size, one_id);
1188                 }
1189
1190                 r_expression_result_id = begin_expression(OP_SELECT, result_type_id, 3);
1191                 writer.write(argument_ids[0]);
1192                 writer.write(zero_id);
1193                 writer.write(one_id);
1194                 end_expression(OP_SELECT);
1195         }
1196         else
1197         {
1198                 if(constant_expression)
1199                         throw internal_error("unconverted constant");
1200
1201                 // Scalar or vector conversion between types of equal size.
1202                 Opcode opcode;
1203                 if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT)
1204                         opcode = OP_CONVERT_F_TO_S;
1205                 else if(elem.kind==BasicTypeDeclaration::FLOAT && elem_arg0.kind==BasicTypeDeclaration::INT)
1206                         opcode = OP_CONVERT_S_TO_F;
1207                 else
1208                         throw internal_error("invalid conversion");
1209
1210                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0]);
1211         }
1212 }
1213
1214 void SpirVGenerator::visit_builtin_matrix_comp_mult(FunctionCall &call, const vector<Id> &argument_ids)
1215 {
1216         if(argument_ids.size()!=2)
1217                 throw internal_error("invalid matrixCompMult call");
1218
1219         const BasicTypeDeclaration &basic_arg0 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[0]->type);
1220         Id column_type_id = get_id(*basic_arg0.base_type);
1221         Id column_ids[8];
1222
1223         unsigned n_columns = basic_arg0.size&0xFFFF;
1224         write_deconstruct(column_type_id, argument_ids[0], column_ids, n_columns);
1225         write_deconstruct(column_type_id, argument_ids[1], column_ids+4, n_columns);
1226
1227         for(unsigned i=0; i<n_columns; ++i)
1228                 column_ids[i] = write_expression(OP_F_MUL, column_type_id, column_ids[i], column_ids[4+i]);
1229
1230         r_expression_result_id = write_construct(get_id(*call.type), column_ids, n_columns);
1231 }
1232
1233 void SpirVGenerator::visit_builtin_texture(FunctionCall &call, const vector<Id> &argument_ids)
1234 {
1235         if(argument_ids.size()<2)
1236                 throw internal_error("invalid texture sampling call");
1237
1238         bool explicit_lod = (stage->type!=Stage::FRAGMENT || call.name=="textureLod");
1239         Id lod_id = (!explicit_lod ? 0 : call.name=="textureLod" ? argument_ids.back() :
1240                 get_constant_id(get_standard_type_id(BasicTypeDeclaration::FLOAT, 1), 0.0f));
1241
1242         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1243
1244         Opcode opcode;
1245         Id result_type_id = get_id(*call.type);
1246         Id dref_id = 0;
1247         if(image.shadow)
1248         {
1249                 if(argument_ids.size()==2)
1250                 {
1251                         const BasicTypeDeclaration &basic_arg1 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[1]->type);
1252                         dref_id = begin_expression(OP_COMPOSITE_EXTRACT, get_id(*basic_arg1.base_type), 2);
1253                         writer.write(argument_ids.back());
1254                         writer.write(basic_arg1.size-1);
1255                         end_expression(OP_COMPOSITE_EXTRACT);
1256                 }
1257                 else
1258                         dref_id = argument_ids[2];
1259
1260                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD : OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD);
1261                 r_expression_result_id = begin_expression(opcode, result_type_id, 3+explicit_lod*2);
1262         }
1263         else
1264         {
1265                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_EXPLICIT_LOD : OP_IMAGE_SAMPLE_IMPLICIT_LOD);
1266                 r_expression_result_id = begin_expression(opcode, result_type_id, 2+explicit_lod*2);
1267         }
1268
1269         for(unsigned i=0; i<2; ++i)
1270                 writer.write(argument_ids[i]);
1271         if(dref_id)
1272                 writer.write(dref_id);
1273         if(explicit_lod)
1274         {
1275                 writer.write(2);  // Lod
1276                 writer.write(lod_id);
1277         }
1278
1279         end_expression(opcode);
1280 }
1281
1282 void SpirVGenerator::visit_builtin_texel_fetch(FunctionCall &call, const vector<Id> &argument_ids)
1283 {
1284         if(argument_ids.size()!=3)
1285                 throw internal_error("invalid texelFetch call");
1286
1287         r_expression_result_id = begin_expression(OP_IMAGE_FETCH, get_id(*call.type), 4);
1288         for(unsigned i=0; i<2; ++i)
1289                 writer.write(argument_ids[i]);
1290         writer.write(2);  // Lod
1291         writer.write(argument_ids.back());
1292         end_expression(OP_IMAGE_FETCH);
1293 }
1294
1295 void SpirVGenerator::visit_builtin_interpolate(FunctionCall &call, const vector<Id> &argument_ids)
1296 {
1297         if(argument_ids.size()<1)
1298                 throw internal_error("invalid interpolate call");
1299         const VariableReference *var = dynamic_cast<const VariableReference *>(call.arguments[0].get());
1300         if(!var || !var->declaration || var->declaration->interface!="in")
1301                 throw internal_error("invalid interpolate call");
1302
1303         SpirVGlslStd450Opcode opcode;
1304         if(call.name=="interpolateAtCentroid")
1305                 opcode = GLSL450_INTERPOLATE_AT_CENTROID;
1306         else if(call.name=="interpolateAtSample")
1307                 opcode = GLSL450_INTERPOLATE_AT_SAMPLE;
1308         else if(call.name=="interpolateAtOffset")
1309                 opcode = GLSL450_INTERPOLATE_AT_OFFSET;
1310         else
1311                 throw internal_error("invalid interpolate call");
1312
1313         use_capability(CAP_INTERPOLATION_FUNCTION);
1314
1315         Id ext_id = import_extension("GLSL.std.450");
1316         r_expression_result_id = begin_expression(OP_EXT_INST, get_id(*call.type));
1317         writer.write(ext_id);
1318         writer.write(opcode);
1319         writer.write(get_id(*var->declaration));
1320         for(vector<Id>::const_iterator i=argument_ids.begin(); ++i!=argument_ids.end(); )
1321                 writer.write(*i);
1322         end_expression(OP_EXT_INST);
1323 }
1324
1325 void SpirVGenerator::visit(ExpressionStatement &expr)
1326 {
1327         expr.expression->visit(*this);
1328 }
1329
1330 void SpirVGenerator::visit(InterfaceLayout &layout)
1331 {
1332         interface_layouts.push_back(&layout);
1333 }
1334
1335 bool SpirVGenerator::check_duplicate_type(TypeDeclaration &type)
1336 {
1337         for(map<Node *, Declaration>::const_iterator i=declared_ids.begin(); i!=declared_ids.end(); ++i)
1338                 if(TypeDeclaration *type2 = dynamic_cast<TypeDeclaration *>(i->first))
1339                         if(TypeComparer().apply(type, *type2))
1340                         {
1341                                 insert_unique(declared_ids, &type, i->second);
1342                                 return true;
1343                         }
1344
1345         return false;
1346 }
1347
1348 bool SpirVGenerator::check_standard_type(BasicTypeDeclaration &basic)
1349 {
1350         const BasicTypeDeclaration *elem = (basic.kind==BasicTypeDeclaration::VECTOR ?
1351                 dynamic_cast<const BasicTypeDeclaration *>(basic.base_type) : &basic);
1352         if(!elem || elem->base_type)
1353                 return false;
1354         if((elem->kind==BasicTypeDeclaration::INT || elem->kind==BasicTypeDeclaration::FLOAT) && elem->size!=32)
1355                 return false;
1356
1357         Id standard_id = get_standard_type_id(elem->kind, (basic.kind==BasicTypeDeclaration::VECTOR ? basic.size : 1));
1358         insert_unique(declared_ids, &basic, Declaration(standard_id, 0));
1359         writer.write_op_name(standard_id, basic.name);
1360
1361         return true;
1362 }
1363
1364 void SpirVGenerator::visit(BasicTypeDeclaration &basic)
1365 {
1366         if(check_standard_type(basic))
1367                 return;
1368         if(check_duplicate_type(basic))
1369                 return;
1370         // Alias types shouldn't exist at this point and arrays are handled elsewhere
1371         if(basic.kind==BasicTypeDeclaration::ALIAS || basic.kind==BasicTypeDeclaration::ARRAY)
1372                 return;
1373
1374         Id type_id = allocate_id(basic, 0);
1375         writer.write_op_name(type_id, basic.name);
1376
1377         switch(basic.kind)
1378         {
1379         case BasicTypeDeclaration::INT:
1380                 writer.write_op(content.globals, OP_TYPE_INT, type_id, basic.size, 1);
1381                 break;
1382         case BasicTypeDeclaration::FLOAT:
1383                 writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, basic.size);
1384                 break;
1385         case BasicTypeDeclaration::VECTOR:
1386                 writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, get_id(*basic.base_type), basic.size);
1387                 break;
1388         case BasicTypeDeclaration::MATRIX:
1389                 writer.write_op(content.globals, OP_TYPE_MATRIX, type_id, get_id(*basic.base_type), basic.size&0xFFFF);
1390                 break;
1391         default:
1392                 throw internal_error("unknown basic type");
1393         }
1394 }
1395
1396 void SpirVGenerator::visit(ImageTypeDeclaration &image)
1397 {
1398         if(check_duplicate_type(image))
1399                 return;
1400
1401         Id type_id = allocate_id(image, 0);
1402
1403         Id image_id = (image.sampled ? next_id++ : type_id);
1404         writer.begin_op(content.globals, OP_TYPE_IMAGE, 9);
1405         writer.write(image_id);
1406         writer.write(get_id(*image.base_type));
1407         writer.write(image.dimensions-1);
1408         writer.write(image.shadow);
1409         writer.write(image.array);
1410         writer.write(false);  // Multisample
1411         writer.write(image.sampled ? 1 : 2);
1412         writer.write(0);  // Format (unknown)
1413         writer.end_op(OP_TYPE_IMAGE);
1414
1415         if(image.sampled)
1416         {
1417                 writer.write_op_name(type_id, image.name);
1418                 writer.write_op(content.globals, OP_TYPE_SAMPLED_IMAGE, type_id, image_id);
1419         }
1420
1421         if(image.dimensions==ImageTypeDeclaration::ONE)
1422                 use_capability(image.sampled ? CAP_SAMPLED_1D : CAP_IMAGE_1D);
1423         else if(image.dimensions==ImageTypeDeclaration::CUBE && image.array)
1424                 use_capability(image.sampled ? CAP_SAMPLED_CUBE_ARRAY : CAP_IMAGE_CUBE_ARRAY);
1425 }
1426
1427 void SpirVGenerator::visit(StructDeclaration &strct)
1428 {
1429         if(check_duplicate_type(strct))
1430                 return;
1431
1432         Id type_id = allocate_id(strct, 0);
1433         writer.write_op_name(type_id, strct.name);
1434
1435         if(strct.interface_block)
1436                 writer.write_op_decorate(type_id, DECO_BLOCK);
1437
1438         bool builtin = (strct.interface_block && !strct.interface_block->block_name.compare(0, 3, "gl_"));
1439         vector<Id> member_type_ids;
1440         member_type_ids.reserve(strct.members.body.size());
1441         for(NodeList<Statement>::const_iterator i=strct.members.body.begin(); i!=strct.members.body.end(); ++i)
1442         {
1443                 const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(i->get());
1444                 if(!var)
1445                         continue;
1446
1447                 unsigned index = member_type_ids.size();
1448                 member_type_ids.push_back(get_variable_type_id(*var));
1449
1450                 writer.write_op_member_name(type_id, index, var->name);
1451
1452                 if(builtin)
1453                 {
1454                         BuiltinSemantic semantic = get_builtin_semantic(var->name);
1455                         writer.write_op_member_decorate(type_id, index, DECO_BUILTIN, semantic);
1456                 }
1457                 else
1458                 {
1459                         if(var->layout)
1460                         {
1461                                 const vector<Layout::Qualifier> &qualifiers = var->layout->qualifiers;
1462                                 for(vector<Layout::Qualifier>::const_iterator j=qualifiers.begin(); j!=qualifiers.end(); ++j)
1463                                 {
1464                                         if(j->name=="offset")
1465                                                 writer.write_op_member_decorate(type_id, index, DECO_OFFSET, j->value);
1466                                         else if(j->name=="column_major")
1467                                                 writer.write_op_member_decorate(type_id, index, DECO_COL_MAJOR);
1468                                         else if(j->name=="row_major")
1469                                                 writer.write_op_member_decorate(type_id, index, DECO_ROW_MAJOR);
1470                                 }
1471                         }
1472
1473                         const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var->type_declaration);
1474                         while(basic && basic->kind==BasicTypeDeclaration::ARRAY)
1475                                 basic = dynamic_cast<const BasicTypeDeclaration *>(basic->base_type);
1476                         if(basic && basic->kind==BasicTypeDeclaration::MATRIX)
1477                         {
1478                                 unsigned stride = MemoryRequirementsCalculator().apply(*basic->base_type).stride;
1479                                 writer.write_op_member_decorate(type_id, index, DECO_MATRIX_STRIDE, stride);
1480                         }
1481                 }
1482         }
1483
1484         writer.begin_op(content.globals, OP_TYPE_STRUCT);
1485         writer.write(type_id);
1486         for(vector<Id>::const_iterator i=member_type_ids.begin(); i!=member_type_ids.end(); ++i)
1487                 writer.write(*i);
1488         writer.end_op(OP_TYPE_STRUCT);
1489 }
1490
1491 void SpirVGenerator::visit(VariableDeclaration &var)
1492 {
1493         const vector<Layout::Qualifier> *layout_ql = (var.layout ? &var.layout->qualifiers : 0);
1494
1495         int spec_id = -1;
1496         if(layout_ql)
1497         {
1498                 for(vector<Layout::Qualifier>::const_iterator i=layout_ql->begin(); (spec_id<0 && i!=layout_ql->end()); ++i)
1499                         if(i->name=="constant_id")
1500                                 spec_id = i->value;
1501         }
1502
1503         Id type_id = get_variable_type_id(var);
1504         Id var_id;
1505
1506         if(var.constant)
1507         {
1508                 if(!var.init_expression)
1509                         throw internal_error("const variable without initializer");
1510
1511                 SetFlag set_const(constant_expression);
1512                 SetFlag set_spec(spec_constant, spec_id>=0);
1513                 r_expression_result_id = 0;
1514                 var.init_expression->visit(*this);
1515                 var_id = r_expression_result_id;
1516                 insert_unique(declared_ids, &var, Declaration(var_id, type_id));
1517                 writer.write_op_decorate(var_id, DECO_SPEC_ID, spec_id);
1518
1519                 /* It's unclear what should be done if a specialization constant is
1520                 initialized with anything other than a literal.  GLSL doesn't seem to
1521                 prohibit that but SPIR-V says OpSpecConstantOp can't be updated via
1522                 specialization. */
1523         }
1524         else
1525         {
1526                 StorageClass storage = (current_function ? STORAGE_FUNCTION : get_interface_storage(var.interface, false));
1527                 Id ptr_type_id = get_pointer_type_id(type_id, storage);
1528                 if(var.interface=="uniform")
1529                 {
1530                         Id &uni_id = declared_uniform_ids["v"+var.name];
1531                         if(uni_id)
1532                         {
1533                                 insert_unique(declared_ids, &var, Declaration(uni_id, ptr_type_id));
1534                                 return;
1535                         }
1536
1537                         uni_id = var_id = allocate_id(var, ptr_type_id);
1538                 }
1539                 else
1540                         var_id = allocate_id(var, (var.constant ? type_id : ptr_type_id));
1541
1542                 Id init_id = 0;
1543                 if(var.init_expression)
1544                 {
1545                         SetFlag set_const(constant_expression, !current_function);
1546                         r_expression_result_id = 0;
1547                         var.init_expression->visit(*this);
1548                         init_id = r_expression_result_id;
1549                 }
1550
1551                 vector<Word> &target = (current_function ? content.locals : content.globals);
1552                 writer.begin_op(target, OP_VARIABLE, 4+(init_id && !current_function));
1553                 writer.write(ptr_type_id);
1554                 writer.write(var_id);
1555                 writer.write(storage);
1556                 if(init_id && !current_function)
1557                         writer.write(init_id);
1558                 writer.end_op(OP_VARIABLE);
1559
1560                 if(layout_ql)
1561                 {
1562                         for(vector<Layout::Qualifier>::const_iterator i=layout_ql->begin(); i!=layout_ql->end(); ++i)
1563                         {
1564                                 if(i->name=="location")
1565                                         writer.write_op_decorate(var_id, DECO_LOCATION, i->value);
1566                                 else if(i->name=="set")
1567                                         writer.write_op_decorate(var_id, DECO_DESCRIPTOR_SET, i->value);
1568                                 else if(i->name=="binding")
1569                                         writer.write_op_decorate(var_id, DECO_BINDING, i->value);
1570                         }
1571                 }
1572
1573                 if(init_id && current_function)
1574                         writer.write_op(content.function_body, OP_STORE, var_id, init_id);
1575         }
1576
1577         writer.write_op_name(var_id, var.name);
1578 }
1579
1580 void SpirVGenerator::visit(InterfaceBlock &iface)
1581 {
1582         StorageClass storage = get_interface_storage(iface.interface, true);
1583         Id type_id;
1584         if(iface.array)
1585                 type_id = get_array_type_id(*iface.struct_declaration, 0);
1586         else
1587                 type_id = get_id(*iface.struct_declaration);
1588         Id ptr_type_id = get_pointer_type_id(type_id, storage);
1589
1590         Id block_id;
1591         if(iface.interface=="uniform")
1592         {
1593                 Id &uni_id = declared_uniform_ids["b"+iface.block_name];
1594                 if(uni_id)
1595                 {
1596                         insert_unique(declared_ids, &iface, Declaration(uni_id, ptr_type_id));
1597                         return;
1598                 }
1599
1600                 uni_id = block_id = allocate_id(iface, ptr_type_id);
1601         }
1602         else
1603                 block_id = allocate_id(iface, ptr_type_id);
1604         writer.write_op_name(block_id, iface.instance_name);
1605
1606         writer.write_op(content.globals, OP_VARIABLE, ptr_type_id, block_id, storage);
1607
1608         if(iface.layout)
1609         {
1610                 const vector<Layout::Qualifier> &qualifiers = iface.layout->qualifiers;
1611                 for(vector<Layout::Qualifier>::const_iterator i=qualifiers.begin(); i!=qualifiers.end(); ++i)
1612                         if(i->name=="binding")
1613                                 writer.write_op_decorate(block_id, DECO_BINDING, i->value);
1614         }
1615 }
1616
1617 void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id)
1618 {
1619         writer.begin_op(content.entry_points, OP_ENTRY_POINT);
1620         switch(stage->type)
1621         {
1622         case Stage::VERTEX: writer.write(0); break;
1623         case Stage::GEOMETRY: writer.write(3); break;
1624         case Stage::FRAGMENT: writer.write(4); break;
1625         default: throw internal_error("unknown stage");
1626         }
1627         writer.write(func_id);
1628         writer.write_string(func.name);
1629
1630         set<Node *> dependencies = DependencyCollector().apply(func);
1631         for(set<Node *>::const_iterator i=dependencies.begin(); i!=dependencies.end(); ++i)
1632         {
1633                 if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(*i))
1634                 {
1635                         if(!var->interface.empty())
1636                                 writer.write(get_id(**i));
1637                 }
1638                 else if(dynamic_cast<InterfaceBlock *>(*i))
1639                         writer.write(get_id(**i));
1640         }
1641
1642         writer.end_op(OP_ENTRY_POINT);
1643
1644         if(stage->type==Stage::FRAGMENT)
1645                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_ORIGIN_LOWER_LEFT);
1646         else if(stage->type==Stage::GEOMETRY)
1647                 use_capability(CAP_GEOMETRY);
1648
1649         for(vector<const InterfaceLayout *>::const_iterator i=interface_layouts.begin(); i!=interface_layouts.end(); ++i)
1650         {
1651                 const vector<Layout::Qualifier> &qualifiers = (*i)->layout.qualifiers;
1652                 for(vector<Layout::Qualifier>::const_iterator j=qualifiers.begin(); j!=qualifiers.end(); ++j)
1653                 {
1654                         if(j->name=="point")
1655                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id,
1656                                         ((*i)->interface=="in" ? EXEC_INPUT_POINTS : EXEC_OUTPUT_POINTS));
1657                         else if(j->name=="lines")
1658                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES);
1659                         else if(j->name=="lines_adjacency")
1660                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES_ADJACENCY);
1661                         else if(j->name=="triangles")
1662                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_TRIANGLES);
1663                         else if(j->name=="triangles_adjacency")
1664                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_TRIANGLES_ADJACENCY);
1665                         else if(j->name=="line_strip")
1666                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_LINE_STRIP);
1667                         else if(j->name=="triangle_strip")
1668                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_TRIANGLE_STRIP);
1669                         else if(j->name=="max_vertices")
1670                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_VERTICES, j->value);
1671                 }
1672         }
1673 }
1674
1675 void SpirVGenerator::visit(FunctionDeclaration &func)
1676 {
1677         if(func.source==BUILTIN_SOURCE || func.definition!=&func)
1678                 return;
1679
1680         Id return_type_id = get_id(*func.return_type_declaration);
1681         vector<unsigned> param_type_ids;
1682         param_type_ids.reserve(func.parameters.size());
1683         for(NodeArray<VariableDeclaration>::const_iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
1684                 param_type_ids.push_back(get_variable_type_id(**i));
1685
1686         string sig_with_return = func.return_type+func.signature;
1687         Id &type_id = function_type_ids[sig_with_return];
1688         if(!type_id)
1689         {
1690                 type_id = next_id++;
1691                 writer.begin_op(content.globals, OP_TYPE_FUNCTION);
1692                 writer.write(type_id);
1693                 writer.write(return_type_id);
1694                 for(vector<unsigned>::const_iterator i=param_type_ids.begin(); i!=param_type_ids.end(); ++i)
1695                         writer.write(*i);
1696                 writer.end_op(OP_TYPE_FUNCTION);
1697
1698                 writer.write_op_name(type_id, sig_with_return);
1699         }
1700
1701         Id func_id = allocate_id(func, type_id);
1702         writer.write_op_name(func_id, func.name+func.signature);
1703
1704         if(func.name=="main")
1705                 visit_entry_point(func, func_id);
1706
1707         writer.begin_op(content.functions, OP_FUNCTION, 5);
1708         writer.write(return_type_id);
1709         writer.write(func_id);
1710         writer.write(0);  // Function control flags (none)
1711         writer.write(type_id);
1712         writer.end_op(OP_FUNCTION);
1713
1714         for(unsigned i=0; i<func.parameters.size(); ++i)
1715         {
1716                 Id param_id = allocate_id(*func.parameters[i], param_type_ids[i]);
1717                 writer.write_op(content.functions, OP_FUNCTION_PARAMETER, param_type_ids[i], param_id);
1718                 // TODO This is probably incorrect if the parameter is assigned to.
1719                 variable_load_ids[func.parameters[i].get()] = param_id;
1720         }
1721
1722         writer.begin_function_body(next_id++);
1723         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
1724         func.body.visit(*this);
1725
1726         if(writer.has_current_block())
1727         {
1728                 if(!reachable)
1729                         writer.write_op(content.function_body, OP_UNREACHABLE);
1730                 else
1731                 {
1732                         const BasicTypeDeclaration *basic_return = dynamic_cast<const BasicTypeDeclaration *>(func.return_type_declaration);
1733                         if(basic_return && basic_return->kind==BasicTypeDeclaration::VOID)
1734                                 writer.write_op(content.function_body, OP_RETURN);
1735                         else
1736                                 throw internal_error("missing return in non-void function");
1737                 }
1738         }
1739         writer.end_function_body();
1740         variable_load_ids.clear();
1741 }
1742
1743 void SpirVGenerator::visit(Conditional &cond)
1744 {
1745         cond.condition->visit(*this);
1746
1747         Id true_label_id = next_id++;
1748         Id merge_block_id = next_id++;
1749         Id false_label_id = (cond.else_body.body.empty() ? merge_block_id : next_id++);
1750         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1751         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, true_label_id, false_label_id);
1752
1753         writer.write_op_label(true_label_id);
1754         cond.body.visit(*this);
1755         if(writer.has_current_block())
1756                 writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1757
1758         bool reachable_if_true = reachable;
1759
1760         reachable = true;
1761         if(!cond.else_body.body.empty())
1762         {
1763                 writer.write_op_label(false_label_id);
1764                 cond.else_body.visit(*this);
1765                 reachable |= reachable_if_true;
1766         }
1767
1768         writer.write_op_label(merge_block_id);
1769         prune_loads(true_label_id);
1770 }
1771
1772 void SpirVGenerator::visit(Iteration &iter)
1773 {
1774         if(iter.init_statement)
1775                 iter.init_statement->visit(*this);
1776
1777         Id header_id = next_id++;
1778         Id continue_id = next_id++;
1779         Id merge_block_id = next_id++;
1780
1781         SetForScope<Id> set_merge(loop_merge_block_id, merge_block_id);
1782         SetForScope<Id> set_continue(loop_continue_target_id, continue_id);
1783
1784         writer.write_op_label(header_id);
1785         writer.write_op(content.function_body, OP_LOOP_MERGE, merge_block_id, continue_id, 0);  // Loop control (none)
1786
1787         Id body_id = next_id++;
1788         if(iter.condition)
1789         {
1790                 writer.write_op_label(next_id++);
1791                 iter.condition->visit(*this);
1792                 writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, body_id, merge_block_id);
1793         }
1794
1795         writer.write_op_label(body_id);
1796         iter.body.visit(*this);
1797
1798         writer.write_op_label(continue_id);
1799         if(iter.loop_expression)
1800                 iter.loop_expression->visit(*this);
1801         writer.write_op(content.function_body, OP_BRANCH, header_id);
1802
1803         writer.write_op_label(merge_block_id);
1804         prune_loads(header_id);
1805         reachable = true;
1806 }
1807
1808 void SpirVGenerator::visit(Return &ret)
1809 {
1810         if(ret.expression)
1811         {
1812                 ret.expression->visit(*this);
1813                 writer.write_op(content.function_body, OP_RETURN_VALUE, r_expression_result_id);
1814         }
1815         else
1816                 writer.write_op(content.function_body, OP_RETURN);
1817         reachable = false;
1818 }
1819
1820 void SpirVGenerator::visit(Jump &jump)
1821 {
1822         if(jump.keyword=="discard")
1823                 writer.write_op(content.function_body, OP_KILL);
1824         else if(jump.keyword=="break")
1825                 writer.write_op(content.function_body, OP_BRANCH, loop_merge_block_id);
1826         else if(jump.keyword=="continue")
1827                 writer.write_op(content.function_body, OP_BRANCH, loop_continue_target_id);
1828         else
1829                 throw internal_error("unknown jump");
1830         reachable = false;
1831 }
1832
1833
1834 bool SpirVGenerator::TypeKey::operator<(const TypeKey &other) const
1835 {
1836         if(type_id!=other.type_id)
1837                 return type_id<other.type_id;
1838         return detail<other.detail;
1839 }
1840
1841
1842 bool SpirVGenerator::ConstantKey::operator<(const ConstantKey &other) const
1843 {
1844         if(type_id!=other.type_id)
1845                 return type_id<other.type_id;
1846         return int_value<other.int_value;
1847 }
1848
1849 } // namespace SL
1850 } // namespace GL
1851 } // namespace Msp