]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/spirv.cpp
1da748a64eb297f23f5f135aa3743af92eabbaea
[libs/gl.git] / source / glsl / spirv.cpp
1 #include <msp/core/algorithm.h>
2 #include <msp/core/maputils.h>
3 #include <msp/core/raii.h>
4 #include "reflect.h"
5 #include "spirv.h"
6
7 using namespace std;
8
9 namespace Msp {
10 namespace GL {
11 namespace SL {
12
13 const SpirVGenerator::BuiltinFunctionInfo SpirVGenerator::builtin_functions[] =
14 {
15         { "radians", "f", "GLSL.std.450", GLSL450_RADIANS, { 1 }, 0, 0 },
16         { "degrees", "f", "GLSL.std.450", GLSL450_DEGREES, { 1 }, 0, 0 },
17         { "sin", "f", "GLSL.std.450", GLSL450_SIN, { 1 }, 0, 0 },
18         { "cos", "f", "GLSL.std.450", GLSL450_COS, { 1 }, 0, 0 },
19         { "tan", "f", "GLSL.std.450", GLSL450_TAN, { 1 }, 0, 0 },
20         { "asin", "f", "GLSL.std.450", GLSL450_ASIN, { 1 }, 0, 0 },
21         { "acos", "f", "GLSL.std.450", GLSL450_ACOS, { 1 }, 0, 0 },
22         { "atan", "f", "GLSL.std.450", GLSL450_ATAN, { 1 }, 0, 0 },
23         { "atan", "ff", "GLSL.std.450", GLSL450_ATAN2, { 1, 2 }, 0, 0 },
24         { "sinh", "f", "GLSL.std.450", GLSL450_SINH, { 1 }, 0, 0 },
25         { "cosh", "f", "GLSL.std.450", GLSL450_COSH, { 1 }, 0, 0 },
26         { "tanh", "f", "GLSL.std.450", GLSL450_TANH, { 1 }, 0, 0 },
27         { "asinh", "f", "GLSL.std.450", GLSL450_ASINH, { 1 }, 0, 0 },
28         { "acosh", "f", "GLSL.std.450", GLSL450_ACOSH, { 1 }, 0, 0 },
29         { "atanh", "f", "GLSL.std.450", GLSL450_ATANH, { 1 }, 0, 0 },
30         { "pow", "ff", "GLSL.std.450", GLSL450_POW, { 1, 2 }, 0, 0 },
31         { "exp", "f", "GLSL.std.450", GLSL450_EXP, { 1 }, 0, 0 },
32         { "log", "f", "GLSL.std.450", GLSL450_LOG, { 1 }, 0, 0 },
33         { "exp2", "f", "GLSL.std.450", GLSL450_EXP2, { 1 }, 0, 0 },
34         { "log2", "f", "GLSL.std.450", GLSL450_LOG2, { 1 }, 0, 0 },
35         { "sqrt", "f", "GLSL.std.450", GLSL450_SQRT, { 1 }, 0, 0 },
36         { "inversesqrt", "f", "GLSL.std.450", GLSL450_INVERSE_SQRT, { 1 }, 0, 0 },
37         { "abs", "f", "GLSL.std.450", GLSL450_F_ABS, { 1 }, 0, 0 },
38         { "abs", "i", "GLSL.std.450", GLSL450_S_ABS, { 1 }, 0, 0 },
39         { "sign", "f", "GLSL.std.450", GLSL450_F_SIGN, { 1 }, 0, 0 },
40         { "sign", "i", "GLSL.std.450", GLSL450_S_SIGN, { 1 }, 0, 0 },
41         { "floor", "f", "GLSL.std.450", GLSL450_FLOOR, { 1 }, 0, 0 },
42         { "trunc", "f", "GLSL.std.450", GLSL450_TRUNC, { 1 }, 0, 0 },
43         { "round", "f", "GLSL.std.450", GLSL450_ROUND, { 1 }, 0, 0 },
44         { "roundEven", "f", "GLSL.std.450", GLSL450_ROUND_EVEN, { 1 }, 0, 0 },
45         { "ceil", "f", "GLSL.std.450", GLSL450_CEIL, { 1 }, 0, 0 },
46         { "fract", "f", "GLSL.std.450", GLSL450_FRACT, { 1 }, 0, 0 },
47         { "mod", "f", "", OP_F_MOD, { 1, 2 }, 0, 0 },
48         { "min", "ff", "GLSL.std.450", GLSL450_F_MIN, { 1, 2 }, 0, 0 },
49         { "min", "ii", "GLSL.std.450", GLSL450_S_MIN, { 1, 2 }, 0, 0 },
50         { "min", "uu", "GLSL.std.450", GLSL450_U_MIN, { 1, 2 }, 0, 0 },
51         { "max", "ff", "GLSL.std.450", GLSL450_F_MAX, { 1, 2 }, 0, 0 },
52         { "max", "ii", "GLSL.std.450", GLSL450_S_MAX, { 1, 2 }, 0, 0 },
53         { "max", "uu", "GLSL.std.450", GLSL450_U_MAX, { 1, 2 }, 0, 0 },
54         { "clamp", "fff", "GLSL.std.450", GLSL450_F_CLAMP, { 1, 2, 3 }, 0, 0 },
55         { "clamp", "iii", "GLSL.std.450", GLSL450_S_CLAMP, { 1, 2, 3 }, 0, 0 },
56         { "clamp", "uuu", "GLSL.std.450", GLSL450_U_CLAMP, { 1, 2, 3 }, 0, 0 },
57         { "mix", "fff", "GLSL.std.450", GLSL450_F_MIX, { 1, 2, 3 }, 0, 0 },
58         { "mix", "ffb", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
59         { "mix", "iib", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
60         { "mix", "uub", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
61         { "step", "ff", "GLSL.std.450", GLSL450_F_STEP, { 1, 2 }, 0, 0 },
62         { "smoothstep", "fff", "GLSL.std.450", GLSL450_F_SMOOTH_STEP, { 1, 2, 3 }, 0, 0 },
63         { "isnan", "f", "", OP_IS_NAN, { 1 }, 0, 0 },
64         { "isinf", "f", "", OP_IS_INF, { 1 }, 0, 0 },
65         { "fma", "fff", "GLSL.std.450", GLSL450_F_FMA, { 1, 2, 3 }, 0, 0 },
66         { "length", "f", "GLSL.std.450", GLSL450_LENGTH, { 1 }, 0, 0 },
67         { "distance", "ff", "GLSL.std.450", GLSL450_DISTANCE, { 1, 2 }, 0, 0 },
68         { "dot", "ff", "", OP_DOT, { 1, 2 }, 0, 0 },
69         { "cross", "ff", "GLSL.std.450", GLSL450_CROSS, { 1, 2 }, 0, 0 },
70         { "normalize", "f", "GLSL.std.450", GLSL450_NORMALIZE, { 1 }, 0, 0 },
71         { "faceforward", "fff", "GLSL.std.450", GLSL450_FACE_FORWARD, { 1, 2, 3 }, 0, 0 },
72         { "reflect", "ff", "GLSL.std.450", GLSL450_REFLECT, { 1, 2 }, 0, 0 },
73         { "refract", "fff", "GLSL.std.450", GLSL450_REFRACT, { 1, 2, 3 }, 0, 0 },
74         { "matrixCompMult", "ff", "", 0, { 0 }, 0, &SpirVGenerator::visit_builtin_matrix_comp_mult },
75         { "outerProduct", "ff", "", OP_OUTER_PRODUCT, { 1, 2 }, 0, 0 },
76         { "transpose", "f", "", OP_TRANSPOSE, { 1 }, 0, 0 },
77         { "determinant", "f", "GLSL.std.450", GLSL450_DETERMINANT, { 1 }, 0, 0 },
78         { "inverse", "f", "GLSL.std.450", GLSL450_MATRIX_INVERSE, { 1 }, 0, 0 },
79         { "lessThan", "ff", "", OP_F_ORD_LESS_THAN, { 1, 2 }, 0, 0 },
80         { "lessThan", "ii", "", OP_S_LESS_THAN, { 1, 2 }, 0, 0 },
81         { "lessThan", "uu", "", OP_U_LESS_THAN, { 1, 2 }, 0, 0 },
82         { "lessThanEqual", "ff", "", OP_F_ORD_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
83         { "lessThanEqual", "ii", "", OP_S_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
84         { "lessThanEqual", "uu", "", OP_U_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
85         { "greaterThan", "ff", "", OP_F_ORD_GREATER_THAN, { 1, 2 }, 0, 0 },
86         { "greaterThan", "ii", "", OP_S_GREATER_THAN, { 1, 2 }, 0, 0 },
87         { "greaterThan", "uu", "", OP_U_GREATER_THAN, { 1, 2 }, 0, 0 },
88         { "greaterThanEqual", "ff", "", OP_F_ORD_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
89         { "greaterThanEqual", "ii", "", OP_S_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
90         { "greaterThanEqual", "uu", "", OP_U_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
91         { "equal", "ff", "", OP_F_ORD_EQUAL, { 1, 2 }, 0, 0 },
92         { "equal", "ii", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
93         { "equal", "uu", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
94         { "notEqual", "ff", "", OP_F_ORD_NOT_EQUAL, { 1, 2 }, 0, 0 },
95         { "notEqual", "ii", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
96         { "notEqual", "uu", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
97         { "any", "b", "", OP_ANY, { 1 }, 0, 0 },
98         { "all", "b", "", OP_ALL, { 1 }, 0, 0 },
99         { "not", "b", "", OP_LOGICAL_NOT, { 1 }, 0, 0 },
100         { "bitfieldExtract", "iii", "", OP_BIT_FIELD_S_EXTRACT, { 1, 2, 3 }, 0, 0 },
101         { "bitfieldExtract", "uii", "", OP_BIT_FIELD_U_EXTRACT, { 1, 2, 3 }, 0, 0 },
102         { "bitfieldInsert", "iiii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
103         { "bitfieldInsert", "uuii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
104         { "bitfieldReverse", "i", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
105         { "bitfieldReverse", "u", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
106         { "bitCount", "i", "", OP_BIT_COUNT, { 1 }, 0, 0 },
107         { "findLSB", "i", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
108         { "findLSB", "u", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
109         { "findMSB", "i", "GLSL.std.450", GLSL450_FIND_S_MSB, { 1 }, 0, 0 },
110         { "findMSB", "u", "GLSL.std.450", GLSL450_FIND_U_MSB, { 1 }, 0, 0 },
111         { "textureSize", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
112         { "textureQueryLod", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
113         { "textureQueryLevels", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
114         { "texture", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
115         { "textureLod", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
116         { "texelFetch", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texel_fetch },
117         { "EmitVertex", "", "", OP_EMIT_VERTEX, { }, 0, 0 },
118         { "EndPrimitive", "", "", OP_END_PRIMITIVE, { }, 0, 0 },
119         { "dFdx", "f", "", OP_DP_DX, { 1 }, 0, 0 },
120         { "dFdy", "f", "", OP_DP_DY, { 1 }, 0, 0 },
121         { "dFdxFine", "f", "", OP_DP_DX_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
122         { "dFdyFine", "f", "", OP_DP_DY_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
123         { "dFdxCoarse", "f", "", OP_DP_DX_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
124         { "dFdyCoarse", "f", "", OP_DP_DY_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
125         { "fwidth", "f", "", OP_FWIDTH, { 1 }, 0, 0 },
126         { "fwidthFine", "f", "", OP_FWIDTH_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
127         { "fwidthCoarse", "f", "", OP_FWIDTH_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
128         { "interpolateAtCentroid", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
129         { "interpolateAtSample", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
130         { "interpolateAtOffset", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
131         { "", "", "", 0, { }, 0, 0 }
132 };
133
134 SpirVGenerator::SpirVGenerator():
135         stage(0),
136         current_function(0),
137         writer(content),
138         next_id(1),
139         r_expression_result_id(0),
140         constant_expression(false),
141         spec_constant(false),
142         reachable(false),
143         composite_access(false),
144         r_composite_base_id(0),
145         r_composite_base(0),
146         assignment_source_id(0),
147         loop_merge_block_id(0),
148         loop_continue_target_id(0)
149 { }
150
151 void SpirVGenerator::apply(Module &module)
152 {
153         use_capability(CAP_SHADER);
154
155         for(Stage &s: module.stages)
156         {
157                 stage = &s;
158                 interface_layouts.clear();
159                 s.content.visit(*this);
160         }
161
162         writer.finalize(SPIRV_GENERATOR_MSP, next_id);
163 }
164
165 SpirVGenerator::StorageClass SpirVGenerator::get_interface_storage(const string &iface, bool block)
166 {
167         if(iface=="in")
168                 return STORAGE_INPUT;
169         else if(iface=="out")
170                 return STORAGE_OUTPUT;
171         else if(iface=="uniform")
172                 return (block ? STORAGE_UNIFORM : STORAGE_UNIFORM_CONSTANT);
173         else if(iface.empty())
174                 return STORAGE_PRIVATE;
175         else
176                 throw invalid_argument("SpirVGenerator::get_interface_storage");
177 }
178
179 SpirVGenerator::BuiltinSemantic SpirVGenerator::get_builtin_semantic(const string &name)
180 {
181         if(name=="gl_Position")
182                 return BUILTIN_POSITION;
183         else if(name=="gl_PointSize")
184                 return BUILTIN_POINT_SIZE;
185         else if(name=="gl_ClipDistance")
186                 return BUILTIN_CLIP_DISTANCE;
187         else if(name=="gl_VertexID")
188                 return BUILTIN_VERTEX_ID;
189         else if(name=="gl_InstanceID")
190                 return BUILTIN_INSTANCE_ID;
191         else if(name=="gl_PrimitiveID" || name=="gl_PrimitiveIDIn")
192                 return BUILTIN_PRIMITIVE_ID;
193         else if(name=="gl_InvocationID")
194                 return BUILTIN_INVOCATION_ID;
195         else if(name=="gl_Layer")
196                 return BUILTIN_LAYER;
197         else if(name=="gl_FragCoord")
198                 return BUILTIN_FRAG_COORD;
199         else if(name=="gl_PointCoord")
200                 return BUILTIN_POINT_COORD;
201         else if(name=="gl_FrontFacing")
202                 return BUILTIN_FRONT_FACING;
203         else if(name=="gl_SampleId")
204                 return BUILTIN_SAMPLE_ID;
205         else if(name=="gl_SamplePosition")
206                 return BUILTIN_SAMPLE_POSITION;
207         else if(name=="gl_FragDepth")
208                 return BUILTIN_FRAG_DEPTH;
209         else
210                 throw invalid_argument("SpirVGenerator::get_builtin_semantic");
211 }
212
213 void SpirVGenerator::use_capability(Capability cap)
214 {
215         if(used_capabilities.count(cap))
216                 return;
217
218         used_capabilities.insert(cap);
219         writer.write_op(content.capabilities, OP_CAPABILITY, cap);
220 }
221
222 SpirVGenerator::Id SpirVGenerator::import_extension(const string &name)
223 {
224         Id &ext_id = imported_extension_ids[name];
225         if(!ext_id)
226         {
227                 ext_id = next_id++;
228                 writer.begin_op(content.extensions, OP_EXT_INST_IMPORT);
229                 writer.write(ext_id);
230                 writer.write_string(name);
231                 writer.end_op(OP_EXT_INST_IMPORT);
232         }
233         return ext_id;
234 }
235
236 SpirVGenerator::Id SpirVGenerator::get_id(Node &node) const
237 {
238         return get_item(declared_ids, &node).id;
239 }
240
241 SpirVGenerator::Id SpirVGenerator::allocate_id(Node &node, Id type_id)
242 {
243         auto i = declared_ids.find(&node);
244         if(i!=declared_ids.end())
245         {
246                 if(i->second.type_id)
247                         throw key_error(&node);
248                 i->second.type_id = type_id;
249                 return i->second.id;
250         }
251
252         Id id = next_id++;
253         declared_ids.insert(make_pair(&node, Declaration(id, type_id)));
254         return id;
255 }
256
257 SpirVGenerator::Id SpirVGenerator::allocate_forward_id(Node &node)
258 {
259         auto i = declared_ids.find(&node);
260         if(i!=declared_ids.end())
261                 return i->second.id;
262
263         Id id = next_id++;
264         declared_ids.insert(make_pair(&node, Declaration(id, 0)));
265         return id;
266 }
267
268 SpirVGenerator::Id SpirVGenerator::write_constant(Id type_id, Word value, bool spec)
269 {
270         Id const_id = next_id++;
271         if(is_scalar_type(type_id, BasicTypeDeclaration::BOOL))
272         {
273                 Opcode opcode = (value ? (spec ? OP_SPEC_CONSTANT_TRUE : OP_CONSTANT_TRUE) :
274                         (spec ? OP_SPEC_CONSTANT_FALSE : OP_CONSTANT_FALSE));
275                 writer.write_op(content.globals, opcode, type_id, const_id);
276         }
277         else
278         {
279                 Opcode opcode = (spec ? OP_SPEC_CONSTANT : OP_CONSTANT);
280                 writer.write_op(content.globals, opcode, type_id, const_id, value);
281         }
282         return const_id;
283 }
284
285 SpirVGenerator::ConstantKey SpirVGenerator::get_constant_key(Id type_id, const Variant &value)
286 {
287         if(value.check_type<bool>())
288                 return ConstantKey(type_id, value.value<bool>());
289         else if(value.check_type<int>())
290                 return ConstantKey(type_id, value.value<int>());
291         else if(value.check_type<unsigned>())
292                 return ConstantKey(type_id, value.value<unsigned>());
293         else if(value.check_type<float>())
294                 return ConstantKey(type_id, value.value<float>());
295         else
296                 throw invalid_argument("SpirVGenerator::get_constant_key");
297 }
298
299 SpirVGenerator::Id SpirVGenerator::get_constant_id(Id type_id, const Variant &value)
300 {
301         ConstantKey key = get_constant_key(type_id, value);
302         Id &const_id = constant_ids[key];
303         if(!const_id)
304                 const_id = write_constant(type_id, key.int_value, false);
305         return const_id;
306 }
307
308 SpirVGenerator::Id SpirVGenerator::get_vector_constant_id(Id type_id, unsigned size, Id scalar_id)
309 {
310         Id &const_id = constant_ids[get_constant_key(type_id, static_cast<int>(scalar_id))];
311         if(!const_id)
312         {
313                 const_id = next_id++;
314                 writer.begin_op(content.globals, OP_CONSTANT_COMPOSITE, 3+size);
315                 writer.write(type_id);
316                 writer.write(const_id);
317                 for(unsigned i=0; i<size; ++i)
318                         writer.write(scalar_id);
319                 writer.end_op(OP_CONSTANT_COMPOSITE);
320         }
321         return const_id;
322 }
323
324 SpirVGenerator::Id SpirVGenerator::get_standard_type_id(BasicTypeDeclaration::Kind kind, unsigned size, bool sign)
325 {
326         Id base_id = (size>1 ? get_standard_type_id(kind, 1, sign) : 0);
327         Id &type_id = standard_type_ids[base_id ? TypeKey(base_id, size) : TypeKey(kind, sign)];
328         if(!type_id)
329         {
330                 type_id = next_id++;
331                 if(size>1)
332                         writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, base_id, size);
333                 else if(kind==BasicTypeDeclaration::VOID)
334                         writer.write_op(content.globals, OP_TYPE_VOID, type_id);
335                 else if(kind==BasicTypeDeclaration::BOOL)
336                         writer.write_op(content.globals, OP_TYPE_BOOL, type_id);
337                 else if(kind==BasicTypeDeclaration::INT)
338                         writer.write_op(content.globals, OP_TYPE_INT, type_id, 32, sign);
339                 else if(kind==BasicTypeDeclaration::FLOAT)
340                         writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, 32);
341                 else
342                         throw invalid_argument("SpirVGenerator::get_standard_type_id");
343         }
344         return type_id;
345 }
346
347 bool SpirVGenerator::is_scalar_type(Id type_id, BasicTypeDeclaration::Kind kind) const
348 {
349         auto i = standard_type_ids.find(TypeKey(kind, true));
350         return (i!=standard_type_ids.end() && i->second==type_id);
351 }
352
353 SpirVGenerator::Id SpirVGenerator::get_array_type_id(TypeDeclaration &base_type, Id size_id)
354 {
355         Id base_type_id = get_id(base_type);
356         Id &array_type_id = array_type_ids[TypeKey(base_type_id, size_id)];
357         if(!array_type_id)
358         {
359                 array_type_id = next_id++;
360                 if(size_id)
361                         writer.write_op(content.globals, OP_TYPE_ARRAY, array_type_id, base_type_id, size_id);
362                 else
363                         writer.write_op(content.globals, OP_TYPE_RUNTIME_ARRAY, array_type_id, base_type_id);
364
365                 unsigned stride = MemoryRequirementsCalculator().apply(base_type).stride;
366                 writer.write_op_decorate(array_type_id, DECO_ARRAY_STRIDE, stride);
367         }
368
369         return array_type_id;
370 }
371
372 SpirVGenerator::Id SpirVGenerator::get_pointer_type_id(Id type_id, StorageClass storage)
373 {
374         Id &ptr_type_id = pointer_type_ids[TypeKey(type_id, storage)];
375         if(!ptr_type_id)
376         {
377                 ptr_type_id = next_id++;
378                 writer.write_op(content.globals, OP_TYPE_POINTER, ptr_type_id, storage, type_id);
379         }
380         return ptr_type_id;
381 }
382
383 SpirVGenerator::Id SpirVGenerator::get_variable_type_id(const VariableDeclaration &var)
384 {
385         if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var.type_declaration))
386                 if(basic->kind==BasicTypeDeclaration::ARRAY)
387                 {
388                         Id size_id = 0;
389                         if(var.array_size)
390                         {
391                                 SetFlag set_const(constant_expression);
392                                 r_expression_result_id = 0;
393                                 var.array_size->visit(*this);
394                                 size_id = r_expression_result_id;
395                         }
396                         else
397                                 size_id = get_constant_id(get_standard_type_id(BasicTypeDeclaration::INT, 1), 1);
398                         return get_array_type_id(*basic->base_type, size_id);
399                 }
400
401         return get_id(*var.type_declaration);
402 }
403
404 SpirVGenerator::Id SpirVGenerator::get_load_id(VariableDeclaration &var)
405 {
406         Id &load_result_id = variable_load_ids[&var];
407         if(!load_result_id)
408         {
409                 load_result_id = next_id++;
410                 writer.write_op(content.function_body, OP_LOAD, get_variable_type_id(var), load_result_id, get_id(var));
411         }
412         return load_result_id;
413 }
414
415 void SpirVGenerator::prune_loads(Id min_id)
416 {
417         for(auto i=variable_load_ids.begin(); i!=variable_load_ids.end(); )
418         {
419                 if(i->second>=min_id)
420                         variable_load_ids.erase(i++);
421                 else
422                         ++i;
423         }
424 }
425
426 SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, unsigned n_args)
427 {
428         bool has_result = (opcode==OP_FUNCTION_CALL || !is_scalar_type(type_id, BasicTypeDeclaration::VOID));
429         if(!constant_expression)
430         {
431                 if(!current_function)
432                         throw internal_error("non-constant expression outside a function");
433
434                 writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0));
435         }
436         else if(opcode==OP_COMPOSITE_CONSTRUCT)
437                 writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE),
438                         (n_args ? 1+has_result*2+n_args : 0));
439         else if(!spec_constant)
440                 throw internal_error("invalid non-specialization constant expression");
441         else
442                 writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0));
443
444         Id result_id = next_id++;
445         if(has_result)
446         {
447                 writer.write(type_id);
448                 writer.write(result_id);
449         }
450         if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT)
451                 writer.write(opcode);
452
453         return result_id;
454 }
455
456 void SpirVGenerator::end_expression(Opcode opcode)
457 {
458         if(constant_expression)
459                 opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
460         writer.end_op(opcode);
461 }
462
463 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id arg_id)
464 {
465         Id result_id = begin_expression(opcode, type_id, 1);
466         writer.write(arg_id);
467         end_expression(opcode);
468         return result_id;
469 }
470
471 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id left_id, Id right_id)
472 {
473         Id result_id = begin_expression(opcode, type_id, 2);
474         writer.write(left_id);
475         writer.write(right_id);
476         end_expression(opcode);
477         return result_id;
478 }
479
480 void SpirVGenerator::write_deconstruct(Id elem_type_id, Id composite_id, Id *elem_ids, unsigned n_elems)
481 {
482         for(unsigned i=0; i<n_elems; ++i)
483         {
484                 elem_ids[i] = begin_expression(OP_COMPOSITE_EXTRACT, elem_type_id, 2);
485                 writer.write(composite_id);
486                 writer.write(i);
487                 end_expression(OP_COMPOSITE_EXTRACT);
488         }
489 }
490
491 SpirVGenerator::Id SpirVGenerator::write_construct(Id type_id, const Id *elem_ids, unsigned n_elems)
492 {
493         Id result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, type_id, n_elems);
494         for(unsigned i=0; i<n_elems; ++i)
495                 writer.write(elem_ids[i]);
496         end_expression(OP_COMPOSITE_CONSTRUCT);
497
498         return result_id;
499 }
500
501 void SpirVGenerator::visit(Block &block)
502 {
503         for(const RefPtr<Statement> &s: block.body)
504                 s->visit(*this);
505 }
506
507 void SpirVGenerator::visit(Literal &literal)
508 {
509         Id type_id = get_id(*literal.type);
510         if(spec_constant)
511                 r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true);
512         else
513                 r_expression_result_id = get_constant_id(type_id, literal.value);
514         r_constant_result = true;
515 }
516
517 void SpirVGenerator::visit(VariableReference &var)
518 {
519         if(constant_expression || var.declaration->constant)
520         {
521                 if(!var.declaration->constant)
522                         throw internal_error("reference to non-constant variable in constant context");
523
524                 r_expression_result_id = get_id(*var.declaration);
525                 r_constant_result = true;
526                 return;
527         }
528         else if(!current_function)
529                 throw internal_error("non-constant context outside a function");
530
531         r_constant_result = false;
532         if(composite_access)
533         {
534                 r_composite_base = var.declaration;
535                 r_expression_result_id = 0;
536         }
537         else if(assignment_source_id)
538         {
539                 writer.write_op(content.function_body, OP_STORE, get_id(*var.declaration), assignment_source_id);
540                 variable_load_ids[var.declaration] = assignment_source_id;
541                 r_expression_result_id = assignment_source_id;
542         }
543         else
544                 r_expression_result_id = get_load_id(*var.declaration);
545 }
546
547 void SpirVGenerator::visit(InterfaceBlockReference &iface)
548 {
549         if(!composite_access || !current_function)
550                 throw internal_error("invalid interface block reference");
551
552         r_composite_base = iface.declaration;
553         r_expression_result_id = 0;
554         r_constant_result = false;
555 }
556
557 void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
558 {
559         Opcode opcode;
560         Id result_type_id = get_id(result_type);
561         Id access_type_id = result_type_id;
562         if(r_composite_base)
563         {
564                 if(constant_expression)
565                         throw internal_error("composite access through pointer in constant context");
566
567                 Id int32_type_id = get_standard_type_id(BasicTypeDeclaration::INT, 1);
568                 for(unsigned &i: r_composite_chain)
569                         i = (i<0x400000 ? get_constant_id(int32_type_id, static_cast<int>(i)) : i&0x3FFFFF);
570
571                 /* Find the storage class of the base and obtain appropriate pointer type
572                 for the result. */
573                 const Declaration &base_decl = get_item(declared_ids, r_composite_base);
574                 auto i = pointer_type_ids.begin();
575                 for(; (i!=pointer_type_ids.end() && i->second!=base_decl.type_id); ++i) ;
576                 if(i==pointer_type_ids.end())
577                         throw internal_error("could not find storage class");
578                 access_type_id = get_pointer_type_id(result_type_id, static_cast<StorageClass>(i->first.detail));
579
580                 opcode = OP_ACCESS_CHAIN;
581         }
582         else if(assignment_source_id)
583                 throw internal_error("assignment to temporary composite");
584         else
585         {
586                 for(unsigned i: r_composite_chain)
587                         for(auto j=constant_ids.begin(); (i>=0x400000 && j!=constant_ids.end()); ++j)
588                                 if(j->second==(i&0x3FFFFF))
589                                         i = j->first.int_value;
590
591                 opcode = OP_COMPOSITE_EXTRACT;
592         }
593
594         Id access_id = begin_expression(opcode, access_type_id, 1+r_composite_chain.size());
595         writer.write(r_composite_base_id);
596         for(unsigned i: r_composite_chain)
597                 writer.write(i);
598         end_expression(opcode);
599
600         r_constant_result = false;
601         if(r_composite_base)
602         {
603                 if(assignment_source_id)
604                 {
605                         writer.write_op(content.function_body, OP_STORE, access_id, assignment_source_id);
606                         r_expression_result_id = assignment_source_id;
607                 }
608                 else
609                         r_expression_result_id = write_expression(OP_LOAD, result_type_id, access_id);
610         }
611         else
612                 r_expression_result_id = access_id;
613 }
614
615 void SpirVGenerator::visit_composite(Expression &base_expr, unsigned index, TypeDeclaration &type)
616 {
617         if(!composite_access)
618         {
619                 r_composite_base = 0;
620                 r_composite_base_id = 0;
621                 r_composite_chain.clear();
622         }
623
624         {
625                 SetFlag set_composite(composite_access);
626                 base_expr.visit(*this);
627         }
628
629         if(!r_composite_base_id)
630                 r_composite_base_id = (r_composite_base ? get_id(*r_composite_base) : r_expression_result_id);
631
632         r_composite_chain.push_back(index);
633         if(!composite_access)
634                 generate_composite_access(type);
635         else
636                 r_expression_result_id = 0;
637 }
638
639 void SpirVGenerator::visit_isolated(Expression &expr)
640 {
641         SetForScope<Id> clear_assign(assignment_source_id, 0);
642         SetFlag clear_composite(composite_access, false);
643         SetForScope<Node *> clear_base(r_composite_base, 0);
644         SetForScope<Id> clear_base_id(r_composite_base_id, 0);
645         vector<unsigned> saved_chain;
646         swap(saved_chain, r_composite_chain);
647         expr.visit(*this);
648         swap(saved_chain, r_composite_chain);
649 }
650
651 void SpirVGenerator::visit(MemberAccess &memacc)
652 {
653         visit_composite(*memacc.left, memacc.index, *memacc.type);
654 }
655
656 void SpirVGenerator::visit(Swizzle &swizzle)
657 {
658         if(swizzle.count==1)
659                 visit_composite(*swizzle.left, swizzle.components[0], *swizzle.type);
660         else if(assignment_source_id)
661         {
662                 const BasicTypeDeclaration &basic = dynamic_cast<const BasicTypeDeclaration &>(*swizzle.left->type);
663
664                 unsigned mask = 0;
665                 for(unsigned i=0; i<swizzle.count; ++i)
666                         mask |= 1<<swizzle.components[i];
667
668                 visit_isolated(*swizzle.left);
669
670                 Id combined_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.left->type), 2+basic.size);
671                 writer.write(r_expression_result_id);
672                 writer.write(assignment_source_id);
673                 for(unsigned i=0; i<basic.size; ++i)
674                         writer.write(i+((mask>>i)&1)*basic.size);
675                 end_expression(OP_VECTOR_SHUFFLE);
676
677                 SetForScope<Id> set_assign(assignment_source_id, combined_id);
678                 swizzle.left->visit(*this);
679
680                 r_expression_result_id = combined_id;
681         }
682         else
683         {
684                 swizzle.left->visit(*this);
685                 Id left_id = r_expression_result_id;
686
687                 r_expression_result_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.type), 2+swizzle.count);
688                 writer.write(left_id);
689                 writer.write(left_id);
690                 for(unsigned i=0; i<swizzle.count; ++i)
691                         writer.write(swizzle.components[i]);
692                 end_expression(OP_VECTOR_SHUFFLE);
693         }
694         r_constant_result = false;
695 }
696
697 void SpirVGenerator::visit(UnaryExpression &unary)
698 {
699         unary.expression->visit(*this);
700
701         char oper = unary.oper->token[0];
702         char oper2 = unary.oper->token[1];
703         if(oper=='+' && !oper2)
704                 return;
705
706         BasicTypeDeclaration &basic = dynamic_cast<BasicTypeDeclaration &>(*unary.expression->type);
707         BasicTypeDeclaration &elem = *get_element_type(basic);
708
709         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
710                 /* SPIR-V allows constant operations on floating-point values only for
711                 OpenGL kernels. */
712                 throw internal_error("invalid operands for constant unary expression");
713
714         Id result_type_id = get_id(*unary.type);
715         Opcode opcode = OP_NOP;
716
717         r_constant_result = false;
718         if(oper=='!')
719                 opcode = OP_LOGICAL_NOT;
720         else if(oper=='~')
721                 opcode = OP_NOT;
722         else if(oper=='-' && !oper2)
723         {
724                 opcode = (elem.kind==BasicTypeDeclaration::INT ? OP_S_NEGATE : OP_F_NEGATE);
725
726                 if(basic.kind==BasicTypeDeclaration::MATRIX)
727                 {
728                         Id column_type_id = get_id(*basic.base_type);
729                         unsigned n_columns = basic.size&0xFFFF;
730                         Id column_ids[4];
731                         write_deconstruct(column_type_id, r_expression_result_id, column_ids, n_columns);
732                         for(unsigned i=0; i<n_columns; ++i)
733                                 column_ids[i] = write_expression(opcode, column_type_id, column_ids[i]);
734                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
735                         return;
736                 }
737         }
738         else if((oper=='+' || oper=='-') && oper2==oper)
739         {
740                 if(constant_expression)
741                         throw internal_error("increment or decrement in constant expression");
742
743                 Id one_id = 0;
744                 if(elem.kind==BasicTypeDeclaration::INT)
745                 {
746                         opcode = (oper=='+' ? OP_I_ADD : OP_I_SUB);
747                         one_id = get_constant_id(get_id(elem), 1);
748                 }
749                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
750                 {
751                         opcode = (oper=='+' ? OP_F_ADD : OP_F_SUB);
752                         one_id = get_constant_id(get_id(elem), 1.0f);
753                 }
754                 else
755                         throw internal_error("invalid increment/decrement");
756
757                 if(basic.kind==BasicTypeDeclaration::VECTOR)
758                         one_id = get_vector_constant_id(result_type_id, basic.size, one_id);
759
760                 Id post_value_id = write_expression(opcode, result_type_id, r_expression_result_id, one_id);
761
762                 SetForScope<Id> set_assign(assignment_source_id, post_value_id);
763                 unary.expression->visit(*this);
764
765                 r_expression_result_id = (unary.oper->type==Operator::POSTFIX ? r_expression_result_id : post_value_id);
766                 return;
767         }
768
769         if(opcode==OP_NOP)
770                 throw internal_error("unknown unary operator");
771
772         r_expression_result_id = write_expression(opcode, result_type_id, r_expression_result_id);
773 }
774
775 void SpirVGenerator::visit(BinaryExpression &binary)
776 {
777         char oper = binary.oper->token[0];
778         if(oper=='[')
779         {
780                 visit_isolated(*binary.right);
781                 return visit_composite(*binary.left, 0x400000|r_expression_result_id, *binary.type);
782         }
783
784         if(assignment_source_id)
785                 throw internal_error("invalid binary expression in assignment target");
786
787         BasicTypeDeclaration &basic_left = dynamic_cast<BasicTypeDeclaration &>(*binary.left->type);
788         BasicTypeDeclaration &basic_right = dynamic_cast<BasicTypeDeclaration &>(*binary.right->type);
789         // Expression resolver ensures that element types are the same
790         BasicTypeDeclaration &elem = *get_element_type(basic_left);
791
792         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
793                 /* SPIR-V allows constant operations on floating-point values only for
794                 OpenGL kernels. */
795                 throw internal_error("invalid operands for constant binary expression");
796
797         binary.left->visit(*this);
798         Id left_id = r_expression_result_id;
799         binary.right->visit(*this);
800         Id right_id = r_expression_result_id;
801
802         Id result_type_id = get_id(*binary.type);
803         Opcode opcode = OP_NOP;
804         bool swap_operands = false;
805
806         r_constant_result = false;
807
808         char oper2 = binary.oper->token[1];
809         if((oper=='<' || oper=='>') && oper2!=oper)
810         {
811                 if(basic_left.kind==BasicTypeDeclaration::INT)
812                 {
813                         if(basic_left.sign)
814                                 opcode = (oper=='<' ? (oper2=='=' ? OP_S_LESS_THAN_EQUAL : OP_S_LESS_THAN) :
815                                         (oper2=='=' ? OP_S_GREATER_THAN_EQUAL : OP_S_GREATER_THAN));
816                         else
817                                 opcode = (oper=='<' ? (oper2=='=' ? OP_U_LESS_THAN_EQUAL : OP_U_LESS_THAN) :
818                                         (oper2=='=' ? OP_U_GREATER_THAN_EQUAL : OP_U_GREATER_THAN));
819                 }
820                 else if(basic_left.kind==BasicTypeDeclaration::FLOAT)
821                         opcode = (oper=='<' ? (oper2=='=' ? OP_F_ORD_LESS_THAN_EQUAL : OP_F_ORD_LESS_THAN) :
822                                 (oper2=='=' ? OP_F_ORD_GREATER_THAN_EQUAL : OP_F_ORD_GREATER_THAN));
823         }
824         else if((oper=='=' || oper=='!') && oper2=='=')
825         {
826                 if(elem.kind==BasicTypeDeclaration::BOOL)
827                         opcode = (oper=='=' ? OP_LOGICAL_EQUAL : OP_LOGICAL_NOT_EQUAL);
828                 else if(elem.kind==BasicTypeDeclaration::INT)
829                         opcode = (oper=='=' ? OP_I_EQUAL : OP_I_NOT_EQUAL);
830                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
831                         opcode = (oper=='=' ? OP_F_ORD_EQUAL : OP_F_ORD_NOT_EQUAL);
832
833                 if(opcode!=OP_NOP && basic_left.base_type)
834                 {
835                         /* The SPIR-V equality operations produce component-wise results, but
836                         GLSL operators return a single boolean.  Use the any/all operations to
837                         combine the results. */
838                         Opcode combine_op = (oper=='!' ? OP_ANY : OP_ALL);
839                         unsigned n_elems = basic_left.size&0xFFFF;
840                         Id bool_vec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, n_elems);
841
842                         Id compare_id = 0;
843                         if(basic_left.kind==BasicTypeDeclaration::VECTOR)
844                                 compare_id = write_expression(opcode, bool_vec_type_id, left_id, right_id);
845                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX)
846                         {
847                                 Id column_type_id = get_id(*basic_left.base_type);
848                                 Id column_ids[8];
849                                 write_deconstruct(column_type_id, left_id, column_ids, n_elems);
850                                 write_deconstruct(column_type_id, right_id, column_ids+4, n_elems);
851
852                                 Id column_bvec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, basic_left.size>>16);
853                                 for(unsigned i=0; i<n_elems; ++i)
854                                 {
855                                         compare_id = write_expression(opcode, column_bvec_type_id, column_ids[i], column_ids[4+i]);
856                                         column_ids[i] = write_expression(combine_op, result_type_id, compare_id);;
857                                 }
858
859                                 compare_id = write_construct(bool_vec_type_id, column_ids, n_elems);
860                         }
861
862                         if(compare_id)
863                                 r_expression_result_id = write_expression(combine_op, result_type_id, compare_id);
864                         return;
865                 }
866         }
867         else if(oper2=='&' && elem.kind==BasicTypeDeclaration::BOOL)
868                 opcode = OP_LOGICAL_AND;
869         else if(oper2=='|' && elem.kind==BasicTypeDeclaration::BOOL)
870                 opcode = OP_LOGICAL_OR;
871         else if(oper2=='^' && elem.kind==BasicTypeDeclaration::BOOL)
872                 opcode = OP_LOGICAL_NOT_EQUAL;
873         else if(oper=='&' && elem.kind==BasicTypeDeclaration::INT)
874                 opcode = OP_BITWISE_AND;
875         else if(oper=='|' && elem.kind==BasicTypeDeclaration::INT)
876                 opcode = OP_BITWISE_OR;
877         else if(oper=='^' && elem.kind==BasicTypeDeclaration::INT)
878                 opcode = OP_BITWISE_XOR;
879         else if(oper=='<' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
880                 opcode = OP_SHIFT_LEFT_LOGICAL;
881         else if(oper=='>' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
882                 opcode = OP_SHIFT_RIGHT_ARITHMETIC;
883         else if(oper=='%' && elem.kind==BasicTypeDeclaration::INT)
884                 opcode = (elem.sign ? OP_S_MOD : OP_U_MOD);
885         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
886         {
887                 Opcode elem_op = OP_NOP;
888                 if(elem.kind==BasicTypeDeclaration::INT)
889                 {
890                         if(oper=='/')
891                                 elem_op = (elem.sign ? OP_S_DIV : OP_U_DIV);
892                         else
893                                 elem_op = (oper=='+' ? OP_I_ADD : oper=='-' ? OP_I_SUB : OP_I_MUL);
894                 }
895                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
896                         elem_op = (oper=='+' ? OP_F_ADD : oper=='-' ? OP_F_SUB : oper=='*' ? OP_F_MUL : OP_F_DIV);
897
898                 if(oper=='*' && (basic_left.base_type || basic_right.base_type) && elem.kind==BasicTypeDeclaration::FLOAT)
899                 {
900                         /* Multiplication between floating-point vectors and matrices has
901                         dedicated operations. */
902                         if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
903                                 opcode = OP_MATRIX_TIMES_MATRIX;
904                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
905                         {
906                                 if(basic_left.kind==BasicTypeDeclaration::VECTOR)
907                                         opcode = OP_VECTOR_TIMES_MATRIX;
908                                 else if(basic_right.kind==BasicTypeDeclaration::VECTOR)
909                                         opcode = OP_MATRIX_TIMES_VECTOR;
910                                 else
911                                 {
912                                         opcode = OP_MATRIX_TIMES_SCALAR;
913                                         swap_operands = (basic_right.kind==BasicTypeDeclaration::MATRIX);
914                                 }
915                         }
916                         else if(basic_left.kind==BasicTypeDeclaration::VECTOR && basic_right.kind==BasicTypeDeclaration::VECTOR)
917                                 opcode = elem_op;
918                         else
919                         {
920                                 opcode = OP_VECTOR_TIMES_SCALAR;
921                                 swap_operands = (basic_right.kind==BasicTypeDeclaration::VECTOR);
922                         }
923                 }
924                 else if((basic_left.base_type!=0)!=(basic_right.base_type!=0))
925                 {
926                         /* One operand is scalar and the other is a vector or a matrix.
927                         Expand the scalar to a vector of appropriate size. */
928                         Id &scalar_id = (basic_left.base_type ? right_id : left_id);
929                         BasicTypeDeclaration *vector_type = (basic_left.base_type ? &basic_left : &basic_right);
930                         if(vector_type->kind==BasicTypeDeclaration::MATRIX)
931                                 vector_type = dynamic_cast<BasicTypeDeclaration *>(vector_type->base_type);
932                         Id vector_type_id = get_id(*vector_type);
933
934                         Id expanded_id = begin_expression(OP_COMPOSITE_CONSTRUCT, vector_type_id, vector_type->size);
935                         for(unsigned i=0; i<vector_type->size; ++i)
936                                 writer.write(scalar_id);
937                         end_expression(OP_COMPOSITE_CONSTRUCT);
938
939                         scalar_id = expanded_id;
940
941                         if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
942                         {
943                                 // Apply matrix operation column-wise.
944                                 Id matrix_id = (basic_left.base_type ? left_id : right_id);
945
946                                 Id column_ids[4];
947                                 unsigned n_columns = (basic_left.base_type ? basic_left.size : basic_right.size)&0xFFFF;
948                                 write_deconstruct(vector_type_id, matrix_id, column_ids, n_columns);
949
950                                 for(unsigned i=0; i<n_columns; ++i)
951                                         column_ids[i] = write_expression(elem_op, vector_type_id, column_ids[i], expanded_id);
952
953                                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
954                                 return;
955                         }
956                         else
957                                 opcode = elem_op;
958                 }
959                 else if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
960                 {
961                         if(oper=='*')
962                                 throw internal_error("non-float matrix multiplication");
963
964                         /* Other operations involving matrices need to be performed
965                         column-wise. */
966                         Id column_type_id = get_id(*basic_left.base_type);
967                         Id column_ids[8];
968
969                         unsigned n_columns = basic_left.size&0xFFFF;
970                         write_deconstruct(column_type_id, left_id, column_ids, n_columns);
971                         write_deconstruct(column_type_id, right_id, column_ids+4, n_columns);
972
973                         for(unsigned i=0; i<n_columns; ++i)
974                                 column_ids[i] = write_expression(elem_op, column_type_id, column_ids[i], column_ids[4+i]);
975
976                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
977                         return;
978                 }
979                 else if(basic_left.kind==basic_right.kind)
980                         // Both operands are either scalars or vectors.
981                         opcode = elem_op;
982         }
983
984         if(opcode==OP_NOP)
985                 throw internal_error("unknown binary operator");
986
987         if(swap_operands)
988                 swap(left_id, right_id);
989
990         r_expression_result_id = write_expression(opcode, result_type_id, left_id, right_id);
991 }
992
993 void SpirVGenerator::visit(Assignment &assign)
994 {
995         if(assign.oper->token[0]!='=')
996                 visit(static_cast<BinaryExpression &>(assign));
997         else
998                 assign.right->visit(*this);
999
1000         SetForScope<Id> set_assign(assignment_source_id, r_expression_result_id);
1001         assign.left->visit(*this);
1002         r_constant_result = false;
1003 }
1004
1005 void SpirVGenerator::visit(TernaryExpression &ternary)
1006 {
1007         if(constant_expression)
1008         {
1009                 ternary.condition->visit(*this);
1010                 Id condition_id = r_expression_result_id;
1011                 ternary.true_expr->visit(*this);
1012                 Id true_result_id = r_expression_result_id;
1013                 ternary.false_expr->visit(*this);
1014                 Id false_result_id = r_expression_result_id;
1015
1016                 r_expression_result_id = begin_expression(OP_SELECT, get_id(*ternary.type), 3);
1017                 writer.write(condition_id);
1018                 writer.write(true_result_id);
1019                 writer.write(false_result_id);
1020                 end_expression(OP_SELECT);
1021
1022                 return;
1023         }
1024
1025         ternary.condition->visit(*this);
1026         Id condition_id = r_expression_result_id;
1027
1028         Id true_label_id = next_id++;
1029         Id false_label_id = next_id++;
1030         Id merge_block_id = next_id++;
1031         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1032         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, condition_id, true_label_id, false_label_id);
1033
1034         writer.write_op_label(true_label_id);
1035         ternary.true_expr->visit(*this);
1036         Id true_result_id = r_expression_result_id;
1037         writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1038
1039         writer.write_op_label(false_label_id);
1040         ternary.false_expr->visit(*this);
1041         Id false_result_id = r_expression_result_id;
1042
1043         writer.write_op_label(merge_block_id);
1044         r_expression_result_id = begin_expression(OP_PHI, get_id(*ternary.type), 4);
1045         writer.write(true_result_id);
1046         writer.write(true_label_id);
1047         writer.write(false_result_id);
1048         writer.write(false_label_id);
1049         end_expression(OP_PHI);
1050
1051         r_constant_result = false;
1052 }
1053
1054 void SpirVGenerator::visit(FunctionCall &call)
1055 {
1056         if(assignment_source_id)
1057                 throw internal_error("assignment to function call");
1058         else if(composite_access)
1059                 return visit_isolated(call);
1060         else if(call.constructor && call.arguments.size()==1 && call.arguments[0]->type==call.type)
1061                 return call.arguments[0]->visit(*this);
1062
1063         vector<Id> argument_ids;
1064         argument_ids.reserve(call.arguments.size());
1065         bool all_args_const = true;
1066         for(const RefPtr<Expression> &a: call.arguments)
1067         {
1068                 a->visit(*this);
1069                 argument_ids.push_back(r_expression_result_id);
1070                 all_args_const &= r_constant_result;
1071         }
1072
1073         if(constant_expression && (!call.constructor || !all_args_const))
1074                 throw internal_error("function call in constant expression");
1075
1076         Id result_type_id = get_id(*call.type);
1077         r_constant_result = false;
1078
1079         if(call.constructor)
1080                 visit_constructor(call, argument_ids, all_args_const);
1081         else if(call.declaration->source==BUILTIN_SOURCE)
1082         {
1083                 string arg_types;
1084                 for(const RefPtr<Expression> &a: call.arguments)
1085                         if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>(a->type))
1086                         {
1087                                 BasicTypeDeclaration &elem_arg = *get_element_type(*basic_arg);
1088                                 switch(elem_arg.kind)
1089                                 {
1090                                 case BasicTypeDeclaration::BOOL: arg_types += 'b'; break;
1091                                 case BasicTypeDeclaration::INT: arg_types += (elem_arg.sign ? 'i' : 'u'); break;
1092                                 case BasicTypeDeclaration::FLOAT: arg_types += 'f'; break;
1093                                 default: arg_types += '?';
1094                                 }
1095                         }
1096
1097                 const BuiltinFunctionInfo *builtin_info;
1098                 for(builtin_info=builtin_functions; builtin_info->function[0]; ++builtin_info)
1099                         if(builtin_info->function==call.name && (!builtin_info->arg_types[0] || builtin_info->arg_types==arg_types))
1100                                 break;
1101
1102                 if(builtin_info->capability)
1103                         use_capability(static_cast<Capability>(builtin_info->capability));
1104
1105                 if(builtin_info->opcode)
1106                 {
1107                         Opcode opcode;
1108                         if(builtin_info->extension[0])
1109                         {
1110                                 opcode = OP_EXT_INST;
1111                                 Id ext_id = import_extension(builtin_info->extension);
1112
1113                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1114                                 writer.write(ext_id);
1115                                 writer.write(builtin_info->opcode);
1116                         }
1117                         else
1118                         {
1119                                 opcode = static_cast<Opcode>(builtin_info->opcode);
1120                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1121                         }
1122
1123                         for(unsigned i=0; i<call.arguments.size(); ++i)
1124                         {
1125                                 if(!builtin_info->arg_order[i] || builtin_info->arg_order[i]>argument_ids.size())
1126                                         throw internal_error("invalid builtin function info");
1127                                 writer.write(argument_ids[builtin_info->arg_order[i]-1]);
1128                         }
1129
1130                         end_expression(opcode);
1131                 }
1132                 else if(builtin_info->handler)
1133                         (this->*(builtin_info->handler))(call, argument_ids);
1134                 else
1135                         throw internal_error("unknown builtin function "+call.name);
1136         }
1137         else
1138         {
1139                 r_expression_result_id = begin_expression(OP_FUNCTION_CALL, result_type_id, 1+call.arguments.size());
1140                 writer.write(get_id(*call.declaration->definition));
1141                 for(Id i: argument_ids)
1142                         writer.write(i);
1143                 end_expression(OP_FUNCTION_CALL);
1144
1145                 // Any global variables the called function uses might have changed value
1146                 set<Node *> dependencies = DependencyCollector().apply(*call.declaration->definition);
1147                 for(Node *n: dependencies)
1148                         if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(n))
1149                                 variable_load_ids.erase(var);
1150         }
1151 }
1152
1153 void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids, bool all_args_const)
1154 {
1155         Id result_type_id = get_id(*call.type);
1156
1157         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(call.type);
1158         if(!basic)
1159         {
1160                 if(dynamic_cast<const StructDeclaration *>(call.type))
1161                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1162                 else
1163                         throw internal_error("unconstructable type "+call.name);
1164                 return;
1165         }
1166
1167         SetFlag set_const(constant_expression, constant_expression || all_args_const);
1168
1169         BasicTypeDeclaration &elem = *get_element_type(*basic);
1170         BasicTypeDeclaration &basic_arg0 = dynamic_cast<BasicTypeDeclaration &>(*call.arguments[0]->type);
1171         BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0);
1172
1173         if(basic->kind==BasicTypeDeclaration::MATRIX)
1174         {
1175                 Id col_type_id = get_id(*basic->base_type);
1176                 unsigned n_columns = basic->size&0xFFFF;
1177                 unsigned n_rows = basic->size>>16;
1178
1179                 Id column_ids[4];
1180                 if(call.arguments.size()==1)
1181                 {
1182                         // Construct diagonal matrix from a single scalar.
1183                         Id zero_id = get_constant_id(get_id(elem), 0.0f);
1184                         for(unsigned i=0; i<n_columns; ++i)
1185                         {
1186                                 column_ids[i] = begin_expression(OP_COMPOSITE_CONSTRUCT, col_type_id, n_rows);
1187                                 for(unsigned j=0; j<n_rows; ++j)
1188                                         writer.write(j==i ? argument_ids[0] : zero_id);
1189                                 end_expression(OP_COMPOSITE_CONSTRUCT);
1190                         }
1191                 }
1192                 else
1193                         // Construct a matrix from column vectors
1194                         copy(argument_ids.begin(), argument_ids.begin()+n_columns, column_ids);
1195
1196                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
1197         }
1198         else if(basic->kind==BasicTypeDeclaration::VECTOR && (call.arguments.size()>1 || basic_arg0.kind!=BasicTypeDeclaration::VECTOR))
1199         {
1200                 /* There's either a single scalar argument or multiple arguments
1201                 which make up the vector's components. */
1202                 if(call.arguments.size()==1)
1203                 {
1204                         r_expression_result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, result_type_id);
1205                         for(unsigned i=0; i<basic->size; ++i)
1206                                 writer.write(argument_ids[0]);
1207                         end_expression(OP_COMPOSITE_CONSTRUCT);
1208                 }
1209                 else
1210                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1211         }
1212         else if(elem.kind==BasicTypeDeclaration::BOOL)
1213         {
1214                 if(constant_expression)
1215                         throw internal_error("unconverted constant");
1216
1217                 // Conversion to boolean is implemented as comparing against zero.
1218                 Id number_type_id = get_id(elem_arg0);
1219                 Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ?
1220                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1221                 if(basic_arg0.kind==BasicTypeDeclaration::VECTOR)
1222                         zero_id = get_vector_constant_id(get_id(basic_arg0), basic_arg0.size, zero_id);
1223
1224                 Opcode opcode = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ? OP_F_ORD_NOT_EQUAL : OP_I_NOT_EQUAL);
1225                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0], zero_id);
1226         }
1227         else if(elem_arg0.kind==BasicTypeDeclaration::BOOL)
1228         {
1229                 if(constant_expression)
1230                         throw internal_error("unconverted constant");
1231
1232                 /* Conversion from boolean is implemented as selecting from zero
1233                 or one. */
1234                 Id number_type_id = get_id(elem);
1235                 Id zero_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1236                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1237                 Id one_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1238                         get_constant_id(number_type_id, 1.0f) : get_constant_id(number_type_id, 1));
1239                 if(basic->kind==BasicTypeDeclaration::VECTOR)
1240                 {
1241                         zero_id = get_vector_constant_id(get_id(*basic), basic->size, zero_id);
1242                         one_id = get_vector_constant_id(get_id(*basic), basic->size, one_id);
1243                 }
1244
1245                 r_expression_result_id = begin_expression(OP_SELECT, result_type_id, 3);
1246                 writer.write(argument_ids[0]);
1247                 writer.write(zero_id);
1248                 writer.write(one_id);
1249                 end_expression(OP_SELECT);
1250         }
1251         else
1252         {
1253                 if(constant_expression)
1254                         throw internal_error("unconverted constant");
1255
1256                 // Scalar or vector conversion between types of equal size.
1257                 Opcode opcode;
1258                 if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT)
1259                         opcode = (elem.sign ? OP_CONVERT_F_TO_S : OP_CONVERT_F_TO_U);
1260                 else if(elem.kind==BasicTypeDeclaration::FLOAT && elem_arg0.kind==BasicTypeDeclaration::INT)
1261                         opcode = (elem_arg0.sign ? OP_CONVERT_S_TO_F : OP_CONVERT_U_TO_F);
1262                 else if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::INT)
1263                         opcode = OP_BITCAST;
1264                 else
1265                         throw internal_error("invalid conversion");
1266
1267                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0]);
1268         }
1269 }
1270
1271 void SpirVGenerator::visit_builtin_matrix_comp_mult(FunctionCall &call, const vector<Id> &argument_ids)
1272 {
1273         if(argument_ids.size()!=2)
1274                 throw internal_error("invalid matrixCompMult call");
1275
1276         const BasicTypeDeclaration &basic_arg0 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[0]->type);
1277         Id column_type_id = get_id(*basic_arg0.base_type);
1278         Id column_ids[8];
1279
1280         unsigned n_columns = basic_arg0.size&0xFFFF;
1281         write_deconstruct(column_type_id, argument_ids[0], column_ids, n_columns);
1282         write_deconstruct(column_type_id, argument_ids[1], column_ids+4, n_columns);
1283
1284         for(unsigned i=0; i<n_columns; ++i)
1285                 column_ids[i] = write_expression(OP_F_MUL, column_type_id, column_ids[i], column_ids[4+i]);
1286
1287         r_expression_result_id = write_construct(get_id(*call.type), column_ids, n_columns);
1288 }
1289
1290 void SpirVGenerator::visit_builtin_texture_query(FunctionCall &call, const vector<Id> &argument_ids)
1291 {
1292         if(argument_ids.size()<1)
1293                 throw internal_error("invalid texture query call");
1294
1295         Opcode opcode;
1296         if(call.name=="textureSize")
1297                 opcode = OP_IMAGE_QUERY_SIZE_LOD;
1298         else if(call.name=="textureQueryLod")
1299                 opcode = OP_IMAGE_QUERY_LOD;
1300         else if(call.name=="textureQueryLevels")
1301                 opcode = OP_IMAGE_QUERY_LEVELS;
1302         else
1303                 throw internal_error("invalid texture query call");
1304
1305         ImageTypeDeclaration &image_arg0 = dynamic_cast<ImageTypeDeclaration &>(*call.arguments[0]->type);
1306
1307         Id image_id;
1308         if(image_arg0.sampled)
1309         {
1310                 Id image_type_id = get_item(image_type_ids, get_id(image_arg0));
1311                 image_id = write_expression(OP_IMAGE, image_type_id, argument_ids[0]);
1312         }
1313         else
1314                 image_id = argument_ids[0];
1315
1316         Id result_type_id = get_id(*call.type);
1317         r_expression_result_id = begin_expression(opcode, result_type_id, argument_ids.size());
1318         writer.write(image_id);
1319         for(unsigned i=1; i<argument_ids.size(); ++i)
1320                 writer.write(argument_ids[i]);
1321         end_expression(opcode);
1322 }
1323
1324 void SpirVGenerator::visit_builtin_texture(FunctionCall &call, const vector<Id> &argument_ids)
1325 {
1326         if(argument_ids.size()<2)
1327                 throw internal_error("invalid texture sampling call");
1328
1329         bool explicit_lod = (stage->type!=Stage::FRAGMENT || call.name=="textureLod");
1330         Id lod_id = (!explicit_lod ? 0 : call.name=="textureLod" ? argument_ids.back() :
1331                 get_constant_id(get_standard_type_id(BasicTypeDeclaration::FLOAT, 1), 0.0f));
1332
1333         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1334
1335         Opcode opcode;
1336         Id result_type_id = get_id(*call.type);
1337         Id dref_id = 0;
1338         if(image.shadow)
1339         {
1340                 if(argument_ids.size()==2)
1341                 {
1342                         const BasicTypeDeclaration &basic_arg1 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[1]->type);
1343                         dref_id = begin_expression(OP_COMPOSITE_EXTRACT, get_id(*basic_arg1.base_type), 2);
1344                         writer.write(argument_ids.back());
1345                         writer.write(basic_arg1.size-1);
1346                         end_expression(OP_COMPOSITE_EXTRACT);
1347                 }
1348                 else
1349                         dref_id = argument_ids[2];
1350
1351                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD : OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD);
1352                 r_expression_result_id = begin_expression(opcode, result_type_id, 3+explicit_lod*2);
1353         }
1354         else
1355         {
1356                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_EXPLICIT_LOD : OP_IMAGE_SAMPLE_IMPLICIT_LOD);
1357                 r_expression_result_id = begin_expression(opcode, result_type_id, 2+explicit_lod*2);
1358         }
1359
1360         for(unsigned i=0; i<2; ++i)
1361                 writer.write(argument_ids[i]);
1362         if(dref_id)
1363                 writer.write(dref_id);
1364         if(explicit_lod)
1365         {
1366                 writer.write(2);  // Lod
1367                 writer.write(lod_id);
1368         }
1369
1370         end_expression(opcode);
1371 }
1372
1373 void SpirVGenerator::visit_builtin_texel_fetch(FunctionCall &call, const vector<Id> &argument_ids)
1374 {
1375         if(argument_ids.size()!=3)
1376                 throw internal_error("invalid texelFetch call");
1377
1378         r_expression_result_id = begin_expression(OP_IMAGE_FETCH, get_id(*call.type), 4);
1379         for(unsigned i=0; i<2; ++i)
1380                 writer.write(argument_ids[i]);
1381         writer.write(2);  // Lod
1382         writer.write(argument_ids.back());
1383         end_expression(OP_IMAGE_FETCH);
1384 }
1385
1386 void SpirVGenerator::visit_builtin_interpolate(FunctionCall &call, const vector<Id> &argument_ids)
1387 {
1388         if(argument_ids.size()<1)
1389                 throw internal_error("invalid interpolate call");
1390         const VariableReference *var = dynamic_cast<const VariableReference *>(call.arguments[0].get());
1391         if(!var || !var->declaration || var->declaration->interface!="in")
1392                 throw internal_error("invalid interpolate call");
1393
1394         SpirVGlslStd450Opcode opcode;
1395         if(call.name=="interpolateAtCentroid")
1396                 opcode = GLSL450_INTERPOLATE_AT_CENTROID;
1397         else if(call.name=="interpolateAtSample")
1398                 opcode = GLSL450_INTERPOLATE_AT_SAMPLE;
1399         else if(call.name=="interpolateAtOffset")
1400                 opcode = GLSL450_INTERPOLATE_AT_OFFSET;
1401         else
1402                 throw internal_error("invalid interpolate call");
1403
1404         Id ext_id = import_extension("GLSL.std.450");
1405         r_expression_result_id = begin_expression(OP_EXT_INST, get_id(*call.type));
1406         writer.write(ext_id);
1407         writer.write(opcode);
1408         writer.write(get_id(*var->declaration));
1409         for(auto i=argument_ids.begin(); ++i!=argument_ids.end(); )
1410                 writer.write(*i);
1411         end_expression(OP_EXT_INST);
1412 }
1413
1414 void SpirVGenerator::visit(ExpressionStatement &expr)
1415 {
1416         expr.expression->visit(*this);
1417 }
1418
1419 void SpirVGenerator::visit(InterfaceLayout &layout)
1420 {
1421         interface_layouts.push_back(&layout);
1422 }
1423
1424 bool SpirVGenerator::check_duplicate_type(TypeDeclaration &type)
1425 {
1426         for(const auto &kvp: declared_ids)
1427                 if(TypeDeclaration *type2 = dynamic_cast<TypeDeclaration *>(kvp.first))
1428                         if(TypeComparer().apply(type, *type2))
1429                         {
1430                                 insert_unique(declared_ids, &type, kvp.second);
1431                                 return true;
1432                         }
1433
1434         return false;
1435 }
1436
1437 bool SpirVGenerator::check_standard_type(BasicTypeDeclaration &basic)
1438 {
1439         const BasicTypeDeclaration *elem = (basic.kind==BasicTypeDeclaration::VECTOR ?
1440                 dynamic_cast<const BasicTypeDeclaration *>(basic.base_type) : &basic);
1441         if(!elem || elem->base_type)
1442                 return false;
1443         if((elem->kind==BasicTypeDeclaration::INT || elem->kind==BasicTypeDeclaration::FLOAT) && elem->size!=32)
1444                 return false;
1445
1446         Id standard_id = get_standard_type_id(elem->kind, (basic.kind==BasicTypeDeclaration::VECTOR ? basic.size : 1), elem->sign);
1447         insert_unique(declared_ids, &basic, Declaration(standard_id, 0));
1448         writer.write_op_name(standard_id, basic.name);
1449
1450         return true;
1451 }
1452
1453 void SpirVGenerator::visit(BasicTypeDeclaration &basic)
1454 {
1455         if(check_standard_type(basic))
1456                 return;
1457         if(check_duplicate_type(basic))
1458                 return;
1459         // Alias types shouldn't exist at this point and arrays are handled elsewhere
1460         if(basic.kind==BasicTypeDeclaration::ALIAS || basic.kind==BasicTypeDeclaration::ARRAY)
1461                 return;
1462
1463         Id type_id = allocate_id(basic, 0);
1464         writer.write_op_name(type_id, basic.name);
1465
1466         switch(basic.kind)
1467         {
1468         case BasicTypeDeclaration::INT:
1469                 writer.write_op(content.globals, OP_TYPE_INT, type_id, basic.size, basic.sign);
1470                 break;
1471         case BasicTypeDeclaration::FLOAT:
1472                 writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, basic.size);
1473                 break;
1474         case BasicTypeDeclaration::VECTOR:
1475                 writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, get_id(*basic.base_type), basic.size);
1476                 break;
1477         case BasicTypeDeclaration::MATRIX:
1478                 writer.write_op(content.globals, OP_TYPE_MATRIX, type_id, get_id(*basic.base_type), basic.size&0xFFFF);
1479                 break;
1480         default:
1481                 throw internal_error("unknown basic type");
1482         }
1483 }
1484
1485 void SpirVGenerator::visit(ImageTypeDeclaration &image)
1486 {
1487         if(check_duplicate_type(image))
1488                 return;
1489
1490         Id type_id = allocate_id(image, 0);
1491
1492         Id image_id = (image.sampled ? next_id++ : type_id);
1493         writer.begin_op(content.globals, OP_TYPE_IMAGE, 9);
1494         writer.write(image_id);
1495         writer.write(get_id(*image.base_type));
1496         writer.write(image.dimensions-1);
1497         writer.write(image.shadow);
1498         writer.write(image.array);
1499         writer.write(false);  // Multisample
1500         writer.write(image.sampled ? 1 : 2);
1501         writer.write(0);  // Format (unknown)
1502         writer.end_op(OP_TYPE_IMAGE);
1503
1504         if(image.sampled)
1505         {
1506                 writer.write_op_name(type_id, image.name);
1507                 writer.write_op(content.globals, OP_TYPE_SAMPLED_IMAGE, type_id, image_id);
1508                 insert_unique(image_type_ids, type_id, image_id);
1509         }
1510
1511         if(image.dimensions==ImageTypeDeclaration::ONE)
1512                 use_capability(image.sampled ? CAP_SAMPLED_1D : CAP_IMAGE_1D);
1513         else if(image.dimensions==ImageTypeDeclaration::CUBE && image.array)
1514                 use_capability(image.sampled ? CAP_SAMPLED_CUBE_ARRAY : CAP_IMAGE_CUBE_ARRAY);
1515 }
1516
1517 void SpirVGenerator::visit(StructDeclaration &strct)
1518 {
1519         if(check_duplicate_type(strct))
1520                 return;
1521
1522         Id type_id = allocate_id(strct, 0);
1523         writer.write_op_name(type_id, strct.name);
1524
1525         if(strct.interface_block)
1526                 writer.write_op_decorate(type_id, DECO_BLOCK);
1527
1528         bool builtin = (strct.interface_block && !strct.interface_block->block_name.compare(0, 3, "gl_"));
1529         vector<Id> member_type_ids;
1530         member_type_ids.reserve(strct.members.body.size());
1531         for(const RefPtr<Statement> &s: strct.members.body)
1532         {
1533                 const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(s.get());
1534                 if(!var)
1535                         continue;
1536
1537                 unsigned index = member_type_ids.size();
1538                 member_type_ids.push_back(get_variable_type_id(*var));
1539
1540                 writer.write_op_member_name(type_id, index, var->name);
1541
1542                 if(builtin)
1543                 {
1544                         BuiltinSemantic semantic = get_builtin_semantic(var->name);
1545                         writer.write_op_member_decorate(type_id, index, DECO_BUILTIN, semantic);
1546                 }
1547                 else
1548                 {
1549                         if(var->layout)
1550                         {
1551                                 for(const Layout::Qualifier &q: var->layout->qualifiers)
1552                                 {
1553                                         if(q.name=="offset")
1554                                                 writer.write_op_member_decorate(type_id, index, DECO_OFFSET, q.value);
1555                                         else if(q.name=="column_major")
1556                                                 writer.write_op_member_decorate(type_id, index, DECO_COL_MAJOR);
1557                                         else if(q.name=="row_major")
1558                                                 writer.write_op_member_decorate(type_id, index, DECO_ROW_MAJOR);
1559                                 }
1560                         }
1561
1562                         const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var->type_declaration);
1563                         while(basic && basic->kind==BasicTypeDeclaration::ARRAY)
1564                                 basic = dynamic_cast<const BasicTypeDeclaration *>(basic->base_type);
1565                         if(basic && basic->kind==BasicTypeDeclaration::MATRIX)
1566                         {
1567                                 unsigned stride = MemoryRequirementsCalculator().apply(*basic->base_type).stride;
1568                                 writer.write_op_member_decorate(type_id, index, DECO_MATRIX_STRIDE, stride);
1569                         }
1570                 }
1571         }
1572
1573         writer.begin_op(content.globals, OP_TYPE_STRUCT);
1574         writer.write(type_id);
1575         for(Id i: member_type_ids)
1576                 writer.write(i);
1577         writer.end_op(OP_TYPE_STRUCT);
1578 }
1579
1580 void SpirVGenerator::visit(VariableDeclaration &var)
1581 {
1582         const vector<Layout::Qualifier> *layout_ql = (var.layout ? &var.layout->qualifiers : 0);
1583
1584         int spec_id = -1;
1585         if(layout_ql)
1586         {
1587                 auto i = find_member(*layout_ql, string("constant_id"), &Layout::Qualifier::name);
1588                 if(i!=layout_ql->end())
1589                         spec_id = i->value;
1590         }
1591
1592         Id type_id = get_variable_type_id(var);
1593         Id var_id;
1594
1595         if(var.constant)
1596         {
1597                 if(!var.init_expression)
1598                         throw internal_error("const variable without initializer");
1599
1600                 SetFlag set_const(constant_expression);
1601                 SetFlag set_spec(spec_constant, spec_id>=0);
1602                 r_expression_result_id = 0;
1603                 var.init_expression->visit(*this);
1604                 var_id = r_expression_result_id;
1605                 insert_unique(declared_ids, &var, Declaration(var_id, type_id));
1606                 writer.write_op_decorate(var_id, DECO_SPEC_ID, spec_id);
1607
1608                 /* It's unclear what should be done if a specialization constant is
1609                 initialized with anything other than a literal.  GLSL doesn't seem to
1610                 prohibit that but SPIR-V says OpSpecConstantOp can't be updated via
1611                 specialization. */
1612         }
1613         else
1614         {
1615                 StorageClass storage = (current_function ? STORAGE_FUNCTION : get_interface_storage(var.interface, false));
1616                 Id ptr_type_id = get_pointer_type_id(type_id, storage);
1617                 if(var.interface=="uniform")
1618                 {
1619                         Id &uni_id = declared_uniform_ids["v"+var.name];
1620                         if(uni_id)
1621                         {
1622                                 insert_unique(declared_ids, &var, Declaration(uni_id, ptr_type_id));
1623                                 return;
1624                         }
1625
1626                         uni_id = var_id = allocate_id(var, ptr_type_id);
1627                 }
1628                 else
1629                         var_id = allocate_id(var, (var.constant ? type_id : ptr_type_id));
1630
1631                 Id init_id = 0;
1632                 if(var.init_expression)
1633                 {
1634                         SetFlag set_const(constant_expression, !current_function);
1635                         r_expression_result_id = 0;
1636                         r_constant_result = false;
1637                         var.init_expression->visit(*this);
1638                         init_id = r_expression_result_id;
1639                 }
1640
1641                 vector<Word> &target = (current_function ? content.locals : content.globals);
1642                 writer.begin_op(target, OP_VARIABLE, 4+(init_id && !current_function));
1643                 writer.write(ptr_type_id);
1644                 writer.write(var_id);
1645                 writer.write(storage);
1646                 if(init_id && !current_function)
1647                         writer.write(init_id);
1648                 writer.end_op(OP_VARIABLE);
1649
1650                 if(layout_ql)
1651                 {
1652                         for(const Layout::Qualifier &q: *layout_ql)
1653                         {
1654                                 if(q.name=="location")
1655                                         writer.write_op_decorate(var_id, DECO_LOCATION, q.value);
1656                                 else if(q.name=="set")
1657                                         writer.write_op_decorate(var_id, DECO_DESCRIPTOR_SET, q.value);
1658                                 else if(q.name=="binding")
1659                                         writer.write_op_decorate(var_id, DECO_BINDING, q.value);
1660                         }
1661                 }
1662
1663                 if(init_id && current_function)
1664                 {
1665                         writer.write_op(content.function_body, OP_STORE, var_id, init_id);
1666                         variable_load_ids[&var] = init_id;
1667                 }
1668         }
1669
1670         writer.write_op_name(var_id, var.name);
1671 }
1672
1673 void SpirVGenerator::visit(InterfaceBlock &iface)
1674 {
1675         StorageClass storage = get_interface_storage(iface.interface, true);
1676         Id type_id;
1677         if(iface.array)
1678                 type_id = get_array_type_id(*iface.struct_declaration, 0);
1679         else
1680                 type_id = get_id(*iface.struct_declaration);
1681         Id ptr_type_id = get_pointer_type_id(type_id, storage);
1682
1683         Id block_id;
1684         if(iface.interface=="uniform")
1685         {
1686                 Id &uni_id = declared_uniform_ids["b"+iface.block_name];
1687                 if(uni_id)
1688                 {
1689                         insert_unique(declared_ids, &iface, Declaration(uni_id, ptr_type_id));
1690                         return;
1691                 }
1692
1693                 uni_id = block_id = allocate_id(iface, ptr_type_id);
1694         }
1695         else
1696                 block_id = allocate_id(iface, ptr_type_id);
1697         writer.write_op_name(block_id, iface.instance_name);
1698
1699         writer.write_op(content.globals, OP_VARIABLE, ptr_type_id, block_id, storage);
1700
1701         if(iface.layout)
1702         {
1703                 auto i = find_member(iface.layout->qualifiers, string("binding"), &Layout::Qualifier::name);
1704                 if(i!=iface.layout->qualifiers.end())
1705                         writer.write_op_decorate(block_id, DECO_BINDING, i->value);
1706         }
1707 }
1708
1709 void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id)
1710 {
1711         writer.begin_op(content.entry_points, OP_ENTRY_POINT);
1712         switch(stage->type)
1713         {
1714         case Stage::VERTEX: writer.write(0); break;
1715         case Stage::GEOMETRY: writer.write(3); break;
1716         case Stage::FRAGMENT: writer.write(4); break;
1717         default: throw internal_error("unknown stage");
1718         }
1719         writer.write(func_id);
1720         writer.write_string(func.name);
1721
1722         set<Node *> dependencies = DependencyCollector().apply(func);
1723         for(Node *n: dependencies)
1724         {
1725                 if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(n))
1726                 {
1727                         if(!var->interface.empty())
1728                                 writer.write(get_id(*n));
1729                 }
1730                 else if(dynamic_cast<InterfaceBlock *>(n))
1731                         writer.write(get_id(*n));
1732         }
1733
1734         writer.end_op(OP_ENTRY_POINT);
1735
1736         if(stage->type==Stage::FRAGMENT)
1737                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_ORIGIN_LOWER_LEFT);
1738         else if(stage->type==Stage::GEOMETRY)
1739                 use_capability(CAP_GEOMETRY);
1740
1741         for(const InterfaceLayout *i: interface_layouts)
1742         {
1743                 for(const Layout::Qualifier &q: i->layout.qualifiers)
1744                 {
1745                         if(q.name=="point")
1746                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id,
1747                                         (i->interface=="in" ? EXEC_INPUT_POINTS : EXEC_OUTPUT_POINTS));
1748                         else if(q.name=="lines")
1749                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES);
1750                         else if(q.name=="lines_adjacency")
1751                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES_ADJACENCY);
1752                         else if(q.name=="triangles")
1753                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_TRIANGLES);
1754                         else if(q.name=="triangles_adjacency")
1755                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_TRIANGLES_ADJACENCY);
1756                         else if(q.name=="line_strip")
1757                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_LINE_STRIP);
1758                         else if(q.name=="triangle_strip")
1759                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_TRIANGLE_STRIP);
1760                         else if(q.name=="max_vertices")
1761                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_VERTICES, q.value);
1762                 }
1763         }
1764 }
1765
1766 void SpirVGenerator::visit(FunctionDeclaration &func)
1767 {
1768         if(func.source==BUILTIN_SOURCE)
1769                 return;
1770         else if(func.definition!=&func)
1771         {
1772                 if(func.definition)
1773                         allocate_forward_id(*func.definition);
1774                 return;
1775         }
1776
1777         Id return_type_id = get_id(*func.return_type_declaration);
1778         vector<unsigned> param_type_ids;
1779         param_type_ids.reserve(func.parameters.size());
1780         for(const RefPtr<VariableDeclaration> &p: func.parameters)
1781                 param_type_ids.push_back(get_variable_type_id(*p));
1782
1783         string sig_with_return = func.return_type+func.signature;
1784         Id &type_id = function_type_ids[sig_with_return];
1785         if(!type_id)
1786         {
1787                 type_id = next_id++;
1788                 writer.begin_op(content.globals, OP_TYPE_FUNCTION);
1789                 writer.write(type_id);
1790                 writer.write(return_type_id);
1791                 for(unsigned i: param_type_ids)
1792                         writer.write(i);
1793                 writer.end_op(OP_TYPE_FUNCTION);
1794
1795                 writer.write_op_name(type_id, sig_with_return);
1796         }
1797
1798         Id func_id = allocate_id(func, type_id);
1799         writer.write_op_name(func_id, func.name+func.signature);
1800
1801         if(func.name=="main")
1802                 visit_entry_point(func, func_id);
1803
1804         writer.begin_op(content.functions, OP_FUNCTION, 5);
1805         writer.write(return_type_id);
1806         writer.write(func_id);
1807         writer.write(0);  // Function control flags (none)
1808         writer.write(type_id);
1809         writer.end_op(OP_FUNCTION);
1810
1811         for(unsigned i=0; i<func.parameters.size(); ++i)
1812         {
1813                 Id param_id = allocate_id(*func.parameters[i], param_type_ids[i]);
1814                 writer.write_op(content.functions, OP_FUNCTION_PARAMETER, param_type_ids[i], param_id);
1815                 // TODO This is probably incorrect if the parameter is assigned to.
1816                 variable_load_ids[func.parameters[i].get()] = param_id;
1817         }
1818
1819         writer.begin_function_body(next_id++);
1820         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
1821         func.body.visit(*this);
1822
1823         if(writer.has_current_block())
1824         {
1825                 if(!reachable)
1826                         writer.write_op(content.function_body, OP_UNREACHABLE);
1827                 else
1828                 {
1829                         const BasicTypeDeclaration *basic_return = dynamic_cast<const BasicTypeDeclaration *>(func.return_type_declaration);
1830                         if(basic_return && basic_return->kind==BasicTypeDeclaration::VOID)
1831                                 writer.write_op(content.function_body, OP_RETURN);
1832                         else
1833                                 throw internal_error("missing return in non-void function");
1834                 }
1835         }
1836         writer.end_function_body();
1837         variable_load_ids.clear();
1838 }
1839
1840 void SpirVGenerator::visit(Conditional &cond)
1841 {
1842         cond.condition->visit(*this);
1843
1844         Id true_label_id = next_id++;
1845         Id merge_block_id = next_id++;
1846         Id false_label_id = (cond.else_body.body.empty() ? merge_block_id : next_id++);
1847         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1848         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, true_label_id, false_label_id);
1849
1850         writer.write_op_label(true_label_id);
1851         cond.body.visit(*this);
1852         if(writer.has_current_block())
1853                 writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1854
1855         bool reachable_if_true = reachable;
1856
1857         reachable = true;
1858         if(!cond.else_body.body.empty())
1859         {
1860                 writer.write_op_label(false_label_id);
1861                 cond.else_body.visit(*this);
1862                 reachable |= reachable_if_true;
1863         }
1864
1865         writer.write_op_label(merge_block_id);
1866         prune_loads(true_label_id);
1867 }
1868
1869 void SpirVGenerator::visit(Iteration &iter)
1870 {
1871         if(iter.init_statement)
1872                 iter.init_statement->visit(*this);
1873
1874         variable_load_ids.clear();
1875
1876         Id header_id = next_id++;
1877         Id continue_id = next_id++;
1878         Id merge_block_id = next_id++;
1879
1880         SetForScope<Id> set_merge(loop_merge_block_id, merge_block_id);
1881         SetForScope<Id> set_continue(loop_continue_target_id, continue_id);
1882
1883         writer.write_op_label(header_id);
1884         writer.write_op(content.function_body, OP_LOOP_MERGE, merge_block_id, continue_id, 0);  // Loop control (none)
1885
1886         Id body_id = next_id++;
1887         if(iter.condition)
1888         {
1889                 writer.write_op_label(next_id++);
1890                 iter.condition->visit(*this);
1891                 writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, body_id, merge_block_id);
1892         }
1893
1894         writer.write_op_label(body_id);
1895         iter.body.visit(*this);
1896
1897         writer.write_op_label(continue_id);
1898         if(iter.loop_expression)
1899                 iter.loop_expression->visit(*this);
1900         writer.write_op(content.function_body, OP_BRANCH, header_id);
1901
1902         writer.write_op_label(merge_block_id);
1903         prune_loads(header_id);
1904         reachable = true;
1905 }
1906
1907 void SpirVGenerator::visit(Return &ret)
1908 {
1909         if(ret.expression)
1910         {
1911                 ret.expression->visit(*this);
1912                 writer.write_op(content.function_body, OP_RETURN_VALUE, r_expression_result_id);
1913         }
1914         else
1915                 writer.write_op(content.function_body, OP_RETURN);
1916         reachable = false;
1917 }
1918
1919 void SpirVGenerator::visit(Jump &jump)
1920 {
1921         if(jump.keyword=="discard")
1922                 writer.write_op(content.function_body, OP_KILL);
1923         else if(jump.keyword=="break")
1924                 writer.write_op(content.function_body, OP_BRANCH, loop_merge_block_id);
1925         else if(jump.keyword=="continue")
1926                 writer.write_op(content.function_body, OP_BRANCH, loop_continue_target_id);
1927         else
1928                 throw internal_error("unknown jump");
1929         reachable = false;
1930 }
1931
1932
1933 SpirVGenerator::TypeKey::TypeKey(BasicTypeDeclaration::Kind kind, bool sign):
1934         type_id(0)
1935 {
1936         switch(kind)
1937         {
1938         case BasicTypeDeclaration::VOID: detail = 'v'; break;
1939         case BasicTypeDeclaration::BOOL: detail = 'b'; break;
1940         case BasicTypeDeclaration::INT: detail = (sign ? 'i' : 'u'); break;
1941         case BasicTypeDeclaration::FLOAT: detail = 'f'; break;
1942         default: throw invalid_argument("TypeKey::TypeKey");
1943         }
1944 }
1945
1946 bool SpirVGenerator::TypeKey::operator<(const TypeKey &other) const
1947 {
1948         if(type_id!=other.type_id)
1949                 return type_id<other.type_id;
1950         return detail<other.detail;
1951 }
1952
1953
1954 bool SpirVGenerator::ConstantKey::operator<(const ConstantKey &other) const
1955 {
1956         if(type_id!=other.type_id)
1957                 return type_id<other.type_id;
1958         return int_value<other.int_value;
1959 }
1960
1961 } // namespace SL
1962 } // namespace GL
1963 } // namespace Msp