]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/spirv.cpp
Use extended alignment in SPIR-V struct layout when necessary
[libs/gl.git] / source / glsl / spirv.cpp
1 #include <msp/core/algorithm.h>
2 #include <msp/core/maputils.h>
3 #include <msp/core/raii.h>
4 #include "reflect.h"
5 #include "spirv.h"
6
7 using namespace std;
8
9 namespace Msp {
10 namespace GL {
11 namespace SL {
12
13 const SpirVGenerator::BuiltinFunctionInfo SpirVGenerator::builtin_functions[] =
14 {
15         { "radians", "f", "GLSL.std.450", GLSL450_RADIANS, { 1 }, 0, 0 },
16         { "degrees", "f", "GLSL.std.450", GLSL450_DEGREES, { 1 }, 0, 0 },
17         { "sin", "f", "GLSL.std.450", GLSL450_SIN, { 1 }, 0, 0 },
18         { "cos", "f", "GLSL.std.450", GLSL450_COS, { 1 }, 0, 0 },
19         { "tan", "f", "GLSL.std.450", GLSL450_TAN, { 1 }, 0, 0 },
20         { "asin", "f", "GLSL.std.450", GLSL450_ASIN, { 1 }, 0, 0 },
21         { "acos", "f", "GLSL.std.450", GLSL450_ACOS, { 1 }, 0, 0 },
22         { "atan", "f", "GLSL.std.450", GLSL450_ATAN, { 1 }, 0, 0 },
23         { "atan", "ff", "GLSL.std.450", GLSL450_ATAN2, { 1, 2 }, 0, 0 },
24         { "sinh", "f", "GLSL.std.450", GLSL450_SINH, { 1 }, 0, 0 },
25         { "cosh", "f", "GLSL.std.450", GLSL450_COSH, { 1 }, 0, 0 },
26         { "tanh", "f", "GLSL.std.450", GLSL450_TANH, { 1 }, 0, 0 },
27         { "asinh", "f", "GLSL.std.450", GLSL450_ASINH, { 1 }, 0, 0 },
28         { "acosh", "f", "GLSL.std.450", GLSL450_ACOSH, { 1 }, 0, 0 },
29         { "atanh", "f", "GLSL.std.450", GLSL450_ATANH, { 1 }, 0, 0 },
30         { "pow", "ff", "GLSL.std.450", GLSL450_POW, { 1, 2 }, 0, 0 },
31         { "exp", "f", "GLSL.std.450", GLSL450_EXP, { 1 }, 0, 0 },
32         { "log", "f", "GLSL.std.450", GLSL450_LOG, { 1 }, 0, 0 },
33         { "exp2", "f", "GLSL.std.450", GLSL450_EXP2, { 1 }, 0, 0 },
34         { "log2", "f", "GLSL.std.450", GLSL450_LOG2, { 1 }, 0, 0 },
35         { "sqrt", "f", "GLSL.std.450", GLSL450_SQRT, { 1 }, 0, 0 },
36         { "inversesqrt", "f", "GLSL.std.450", GLSL450_INVERSE_SQRT, { 1 }, 0, 0 },
37         { "abs", "f", "GLSL.std.450", GLSL450_F_ABS, { 1 }, 0, 0 },
38         { "abs", "i", "GLSL.std.450", GLSL450_S_ABS, { 1 }, 0, 0 },
39         { "sign", "f", "GLSL.std.450", GLSL450_F_SIGN, { 1 }, 0, 0 },
40         { "sign", "i", "GLSL.std.450", GLSL450_S_SIGN, { 1 }, 0, 0 },
41         { "floor", "f", "GLSL.std.450", GLSL450_FLOOR, { 1 }, 0, 0 },
42         { "trunc", "f", "GLSL.std.450", GLSL450_TRUNC, { 1 }, 0, 0 },
43         { "round", "f", "GLSL.std.450", GLSL450_ROUND, { 1 }, 0, 0 },
44         { "roundEven", "f", "GLSL.std.450", GLSL450_ROUND_EVEN, { 1 }, 0, 0 },
45         { "ceil", "f", "GLSL.std.450", GLSL450_CEIL, { 1 }, 0, 0 },
46         { "fract", "f", "GLSL.std.450", GLSL450_FRACT, { 1 }, 0, 0 },
47         { "mod", "f", "", OP_F_MOD, { 1, 2 }, 0, 0 },
48         { "min", "ff", "GLSL.std.450", GLSL450_F_MIN, { 1, 2 }, 0, 0 },
49         { "min", "ii", "GLSL.std.450", GLSL450_S_MIN, { 1, 2 }, 0, 0 },
50         { "min", "uu", "GLSL.std.450", GLSL450_U_MIN, { 1, 2 }, 0, 0 },
51         { "max", "ff", "GLSL.std.450", GLSL450_F_MAX, { 1, 2 }, 0, 0 },
52         { "max", "ii", "GLSL.std.450", GLSL450_S_MAX, { 1, 2 }, 0, 0 },
53         { "max", "uu", "GLSL.std.450", GLSL450_U_MAX, { 1, 2 }, 0, 0 },
54         { "clamp", "fff", "GLSL.std.450", GLSL450_F_CLAMP, { 1, 2, 3 }, 0, 0 },
55         { "clamp", "iii", "GLSL.std.450", GLSL450_S_CLAMP, { 1, 2, 3 }, 0, 0 },
56         { "clamp", "uuu", "GLSL.std.450", GLSL450_U_CLAMP, { 1, 2, 3 }, 0, 0 },
57         { "mix", "fff", "GLSL.std.450", GLSL450_F_MIX, { 1, 2, 3 }, 0, 0 },
58         { "mix", "ffb", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
59         { "mix", "iib", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
60         { "mix", "uub", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
61         { "step", "ff", "GLSL.std.450", GLSL450_F_STEP, { 1, 2 }, 0, 0 },
62         { "smoothstep", "fff", "GLSL.std.450", GLSL450_F_SMOOTH_STEP, { 1, 2, 3 }, 0, 0 },
63         { "isnan", "f", "", OP_IS_NAN, { 1 }, 0, 0 },
64         { "isinf", "f", "", OP_IS_INF, { 1 }, 0, 0 },
65         { "fma", "fff", "GLSL.std.450", GLSL450_F_FMA, { 1, 2, 3 }, 0, 0 },
66         { "length", "f", "GLSL.std.450", GLSL450_LENGTH, { 1 }, 0, 0 },
67         { "distance", "ff", "GLSL.std.450", GLSL450_DISTANCE, { 1, 2 }, 0, 0 },
68         { "dot", "ff", "", OP_DOT, { 1, 2 }, 0, 0 },
69         { "cross", "ff", "GLSL.std.450", GLSL450_CROSS, { 1, 2 }, 0, 0 },
70         { "normalize", "f", "GLSL.std.450", GLSL450_NORMALIZE, { 1 }, 0, 0 },
71         { "faceforward", "fff", "GLSL.std.450", GLSL450_FACE_FORWARD, { 1, 2, 3 }, 0, 0 },
72         { "reflect", "ff", "GLSL.std.450", GLSL450_REFLECT, { 1, 2 }, 0, 0 },
73         { "refract", "fff", "GLSL.std.450", GLSL450_REFRACT, { 1, 2, 3 }, 0, 0 },
74         { "matrixCompMult", "ff", "", 0, { 0 }, 0, &SpirVGenerator::visit_builtin_matrix_comp_mult },
75         { "outerProduct", "ff", "", OP_OUTER_PRODUCT, { 1, 2 }, 0, 0 },
76         { "transpose", "f", "", OP_TRANSPOSE, { 1 }, 0, 0 },
77         { "determinant", "f", "GLSL.std.450", GLSL450_DETERMINANT, { 1 }, 0, 0 },
78         { "inverse", "f", "GLSL.std.450", GLSL450_MATRIX_INVERSE, { 1 }, 0, 0 },
79         { "lessThan", "ff", "", OP_F_ORD_LESS_THAN, { 1, 2 }, 0, 0 },
80         { "lessThan", "ii", "", OP_S_LESS_THAN, { 1, 2 }, 0, 0 },
81         { "lessThan", "uu", "", OP_U_LESS_THAN, { 1, 2 }, 0, 0 },
82         { "lessThanEqual", "ff", "", OP_F_ORD_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
83         { "lessThanEqual", "ii", "", OP_S_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
84         { "lessThanEqual", "uu", "", OP_U_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
85         { "greaterThan", "ff", "", OP_F_ORD_GREATER_THAN, { 1, 2 }, 0, 0 },
86         { "greaterThan", "ii", "", OP_S_GREATER_THAN, { 1, 2 }, 0, 0 },
87         { "greaterThan", "uu", "", OP_U_GREATER_THAN, { 1, 2 }, 0, 0 },
88         { "greaterThanEqual", "ff", "", OP_F_ORD_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
89         { "greaterThanEqual", "ii", "", OP_S_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
90         { "greaterThanEqual", "uu", "", OP_U_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
91         { "equal", "ff", "", OP_F_ORD_EQUAL, { 1, 2 }, 0, 0 },
92         { "equal", "ii", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
93         { "equal", "uu", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
94         { "notEqual", "ff", "", OP_F_ORD_NOT_EQUAL, { 1, 2 }, 0, 0 },
95         { "notEqual", "ii", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
96         { "notEqual", "uu", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
97         { "any", "b", "", OP_ANY, { 1 }, 0, 0 },
98         { "all", "b", "", OP_ALL, { 1 }, 0, 0 },
99         { "not", "b", "", OP_LOGICAL_NOT, { 1 }, 0, 0 },
100         { "bitfieldExtract", "iii", "", OP_BIT_FIELD_S_EXTRACT, { 1, 2, 3 }, 0, 0 },
101         { "bitfieldExtract", "uii", "", OP_BIT_FIELD_U_EXTRACT, { 1, 2, 3 }, 0, 0 },
102         { "bitfieldInsert", "iiii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
103         { "bitfieldInsert", "uuii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
104         { "bitfieldReverse", "i", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
105         { "bitfieldReverse", "u", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
106         { "bitCount", "i", "", OP_BIT_COUNT, { 1 }, 0, 0 },
107         { "findLSB", "i", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
108         { "findLSB", "u", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
109         { "findMSB", "i", "GLSL.std.450", GLSL450_FIND_S_MSB, { 1 }, 0, 0 },
110         { "findMSB", "u", "GLSL.std.450", GLSL450_FIND_U_MSB, { 1 }, 0, 0 },
111         { "textureSize", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
112         { "textureQueryLod", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
113         { "textureQueryLevels", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
114         { "texture", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
115         { "textureLod", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
116         { "texelFetch", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texel_fetch },
117         { "EmitVertex", "", "", OP_EMIT_VERTEX, { }, 0, 0 },
118         { "EndPrimitive", "", "", OP_END_PRIMITIVE, { }, 0, 0 },
119         { "dFdx", "f", "", OP_DP_DX, { 1 }, 0, 0 },
120         { "dFdy", "f", "", OP_DP_DY, { 1 }, 0, 0 },
121         { "dFdxFine", "f", "", OP_DP_DX_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
122         { "dFdyFine", "f", "", OP_DP_DY_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
123         { "dFdxCoarse", "f", "", OP_DP_DX_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
124         { "dFdyCoarse", "f", "", OP_DP_DY_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
125         { "fwidth", "f", "", OP_FWIDTH, { 1 }, 0, 0 },
126         { "fwidthFine", "f", "", OP_FWIDTH_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
127         { "fwidthCoarse", "f", "", OP_FWIDTH_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
128         { "interpolateAtCentroid", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
129         { "interpolateAtSample", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
130         { "interpolateAtOffset", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
131         { "", "", "", 0, { }, 0, 0 }
132 };
133
134 SpirVGenerator::SpirVGenerator():
135         writer(content)
136 { }
137
138 void SpirVGenerator::apply(Module &module, const Features &f)
139 {
140         features = f;
141         use_capability(CAP_SHADER);
142
143         for(Stage &s: module.stages)
144         {
145                 stage = &s;
146                 interface_layouts.clear();
147                 s.content.visit(*this);
148         }
149
150         writer.finalize(SPIRV_GENERATOR_MSP, next_id);
151 }
152
153 SpirVGenerator::StorageClass SpirVGenerator::get_interface_storage(const string &iface, bool block)
154 {
155         if(iface=="in")
156                 return STORAGE_INPUT;
157         else if(iface=="out")
158                 return STORAGE_OUTPUT;
159         else if(iface=="uniform")
160                 return (block ? STORAGE_UNIFORM : STORAGE_UNIFORM_CONSTANT);
161         else if(iface.empty())
162                 return STORAGE_PRIVATE;
163         else
164                 throw invalid_argument("SpirVGenerator::get_interface_storage");
165 }
166
167 SpirVGenerator::BuiltinSemantic SpirVGenerator::get_builtin_semantic(const string &name)
168 {
169         if(name=="gl_Position")
170                 return BUILTIN_POSITION;
171         else if(name=="gl_PointSize")
172                 return BUILTIN_POINT_SIZE;
173         else if(name=="gl_ClipDistance")
174                 return BUILTIN_CLIP_DISTANCE;
175         else if(name=="gl_VertexID")
176                 return BUILTIN_VERTEX_ID;
177         else if(name=="gl_InstanceID")
178                 return BUILTIN_INSTANCE_ID;
179         else if(name=="gl_PrimitiveID" || name=="gl_PrimitiveIDIn")
180                 return BUILTIN_PRIMITIVE_ID;
181         else if(name=="gl_InvocationID")
182                 return BUILTIN_INVOCATION_ID;
183         else if(name=="gl_Layer")
184                 return BUILTIN_LAYER;
185         else if(name=="gl_FragCoord")
186                 return BUILTIN_FRAG_COORD;
187         else if(name=="gl_PointCoord")
188                 return BUILTIN_POINT_COORD;
189         else if(name=="gl_FrontFacing")
190                 return BUILTIN_FRONT_FACING;
191         else if(name=="gl_SampleId")
192                 return BUILTIN_SAMPLE_ID;
193         else if(name=="gl_SamplePosition")
194                 return BUILTIN_SAMPLE_POSITION;
195         else if(name=="gl_FragDepth")
196                 return BUILTIN_FRAG_DEPTH;
197         else
198                 throw invalid_argument("SpirVGenerator::get_builtin_semantic");
199 }
200
201 void SpirVGenerator::use_capability(Capability cap)
202 {
203         if(used_capabilities.count(cap))
204                 return;
205
206         used_capabilities.insert(cap);
207         writer.write_op(content.capabilities, OP_CAPABILITY, cap);
208 }
209
210 SpirVGenerator::Id SpirVGenerator::import_extension(const string &name)
211 {
212         Id &ext_id = imported_extension_ids[name];
213         if(!ext_id)
214         {
215                 ext_id = next_id++;
216                 writer.begin_op(content.extensions, OP_EXT_INST_IMPORT);
217                 writer.write(ext_id);
218                 writer.write_string(name);
219                 writer.end_op(OP_EXT_INST_IMPORT);
220         }
221         return ext_id;
222 }
223
224 SpirVGenerator::Id SpirVGenerator::get_id(Node &node) const
225 {
226         return get_item(declared_ids, &node).id;
227 }
228
229 SpirVGenerator::Id SpirVGenerator::allocate_id(Node &node, Id type_id)
230 {
231         auto i = declared_ids.find(&node);
232         if(i!=declared_ids.end())
233         {
234                 if(i->second.type_id)
235                         throw key_error(&node);
236                 i->second.type_id = type_id;
237                 return i->second.id;
238         }
239
240         Id id = next_id++;
241         declared_ids.insert(make_pair(&node, Declaration(id, type_id)));
242         return id;
243 }
244
245 SpirVGenerator::Id SpirVGenerator::allocate_forward_id(Node &node)
246 {
247         auto i = declared_ids.find(&node);
248         if(i!=declared_ids.end())
249                 return i->second.id;
250
251         Id id = next_id++;
252         declared_ids.insert(make_pair(&node, Declaration(id, 0)));
253         return id;
254 }
255
256 SpirVGenerator::Id SpirVGenerator::write_constant(Id type_id, Word value, bool spec)
257 {
258         Id const_id = next_id++;
259         if(is_scalar_type(type_id, BasicTypeDeclaration::BOOL))
260         {
261                 Opcode opcode = (value ? (spec ? OP_SPEC_CONSTANT_TRUE : OP_CONSTANT_TRUE) :
262                         (spec ? OP_SPEC_CONSTANT_FALSE : OP_CONSTANT_FALSE));
263                 writer.write_op(content.globals, opcode, type_id, const_id);
264         }
265         else
266         {
267                 Opcode opcode = (spec ? OP_SPEC_CONSTANT : OP_CONSTANT);
268                 writer.write_op(content.globals, opcode, type_id, const_id, value);
269         }
270         return const_id;
271 }
272
273 SpirVGenerator::ConstantKey SpirVGenerator::get_constant_key(Id type_id, const Variant &value)
274 {
275         if(value.check_type<bool>())
276                 return ConstantKey(type_id, value.value<bool>());
277         else if(value.check_type<int>())
278                 return ConstantKey(type_id, value.value<int>());
279         else if(value.check_type<unsigned>())
280                 return ConstantKey(type_id, value.value<unsigned>());
281         else if(value.check_type<float>())
282                 return ConstantKey(type_id, value.value<float>());
283         else
284                 throw invalid_argument("SpirVGenerator::get_constant_key");
285 }
286
287 SpirVGenerator::Id SpirVGenerator::get_constant_id(Id type_id, const Variant &value)
288 {
289         ConstantKey key = get_constant_key(type_id, value);
290         Id &const_id = constant_ids[key];
291         if(!const_id)
292                 const_id = write_constant(type_id, key.int_value, false);
293         return const_id;
294 }
295
296 SpirVGenerator::Id SpirVGenerator::get_vector_constant_id(Id type_id, unsigned size, Id scalar_id)
297 {
298         Id &const_id = constant_ids[get_constant_key(type_id, static_cast<int>(scalar_id))];
299         if(!const_id)
300         {
301                 const_id = next_id++;
302                 writer.begin_op(content.globals, OP_CONSTANT_COMPOSITE, 3+size);
303                 writer.write(type_id);
304                 writer.write(const_id);
305                 for(unsigned i=0; i<size; ++i)
306                         writer.write(scalar_id);
307                 writer.end_op(OP_CONSTANT_COMPOSITE);
308         }
309         return const_id;
310 }
311
312 SpirVGenerator::Id SpirVGenerator::get_standard_type_id(BasicTypeDeclaration::Kind kind, unsigned size, bool sign)
313 {
314         Id base_id = (size>1 ? get_standard_type_id(kind, 1, sign) : 0);
315         Id &type_id = standard_type_ids[base_id ? TypeKey(base_id, size) : TypeKey(kind, sign)];
316         if(!type_id)
317         {
318                 type_id = next_id++;
319                 if(size>1)
320                         writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, base_id, size);
321                 else if(kind==BasicTypeDeclaration::VOID)
322                         writer.write_op(content.globals, OP_TYPE_VOID, type_id);
323                 else if(kind==BasicTypeDeclaration::BOOL)
324                         writer.write_op(content.globals, OP_TYPE_BOOL, type_id);
325                 else if(kind==BasicTypeDeclaration::INT)
326                         writer.write_op(content.globals, OP_TYPE_INT, type_id, 32, sign);
327                 else if(kind==BasicTypeDeclaration::FLOAT)
328                         writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, 32);
329                 else
330                         throw invalid_argument("SpirVGenerator::get_standard_type_id");
331         }
332         return type_id;
333 }
334
335 bool SpirVGenerator::is_scalar_type(Id type_id, BasicTypeDeclaration::Kind kind) const
336 {
337         auto i = standard_type_ids.find(TypeKey(kind, true));
338         return (i!=standard_type_ids.end() && i->second==type_id);
339 }
340
341 SpirVGenerator::Id SpirVGenerator::get_array_type_id(TypeDeclaration &base_type, Id size_id, bool extended_align)
342 {
343         Id base_type_id = get_id(base_type);
344         Id &array_type_id = array_type_ids[TypeKey(base_type_id, extended_align*0x400000 | size_id)];
345         if(!array_type_id)
346         {
347                 array_type_id = next_id++;
348                 if(size_id)
349                         writer.write_op(content.globals, OP_TYPE_ARRAY, array_type_id, base_type_id, size_id);
350                 else
351                         writer.write_op(content.globals, OP_TYPE_RUNTIME_ARRAY, array_type_id, base_type_id);
352
353                 unsigned stride = MemoryRequirementsCalculator().apply(base_type).stride;
354                 if(extended_align)
355                         stride = (stride+15)&~15U;
356                 writer.write_op_decorate(array_type_id, DECO_ARRAY_STRIDE, stride);
357         }
358
359         return array_type_id;
360 }
361
362 SpirVGenerator::Id SpirVGenerator::get_pointer_type_id(Id type_id, StorageClass storage)
363 {
364         Id &ptr_type_id = pointer_type_ids[TypeKey(type_id, storage)];
365         if(!ptr_type_id)
366         {
367                 ptr_type_id = next_id++;
368                 writer.write_op(content.globals, OP_TYPE_POINTER, ptr_type_id, storage, type_id);
369         }
370         return ptr_type_id;
371 }
372
373 SpirVGenerator::Id SpirVGenerator::get_variable_type_id(const VariableDeclaration &var)
374 {
375         if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var.type_declaration))
376                 if(basic->kind==BasicTypeDeclaration::ARRAY)
377                 {
378                         Id size_id = 0;
379                         if(var.array_size)
380                         {
381                                 SetFlag set_const(constant_expression);
382                                 r_expression_result_id = 0;
383                                 var.array_size->visit(*this);
384                                 size_id = r_expression_result_id;
385                         }
386                         else
387                                 size_id = get_constant_id(get_standard_type_id(BasicTypeDeclaration::INT, 1), 1);
388                         return get_array_type_id(*basic->base_type, size_id, true);
389                 }
390
391         return get_id(*var.type_declaration);
392 }
393
394 SpirVGenerator::Id SpirVGenerator::get_load_id(VariableDeclaration &var)
395 {
396         Id &load_result_id = variable_load_ids[&var];
397         if(!load_result_id)
398         {
399                 load_result_id = next_id++;
400                 writer.write_op(content.function_body, OP_LOAD, get_variable_type_id(var), load_result_id, get_id(var));
401         }
402         return load_result_id;
403 }
404
405 void SpirVGenerator::prune_loads(Id min_id)
406 {
407         for(auto i=variable_load_ids.begin(); i!=variable_load_ids.end(); )
408         {
409                 if(i->second>=min_id)
410                         variable_load_ids.erase(i++);
411                 else
412                         ++i;
413         }
414 }
415
416 SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, unsigned n_args)
417 {
418         bool has_result = (opcode==OP_FUNCTION_CALL || !is_scalar_type(type_id, BasicTypeDeclaration::VOID));
419         if(!constant_expression)
420         {
421                 if(!current_function)
422                         throw internal_error("non-constant expression outside a function");
423
424                 writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0));
425         }
426         else if(opcode==OP_COMPOSITE_CONSTRUCT)
427                 writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE),
428                         (n_args ? 1+has_result*2+n_args : 0));
429         else if(!spec_constant)
430                 throw internal_error("invalid non-specialization constant expression");
431         else
432                 writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0));
433
434         Id result_id = next_id++;
435         if(has_result)
436         {
437                 writer.write(type_id);
438                 writer.write(result_id);
439         }
440         if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT)
441                 writer.write(opcode);
442
443         return result_id;
444 }
445
446 void SpirVGenerator::end_expression(Opcode opcode)
447 {
448         if(constant_expression)
449                 opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
450         writer.end_op(opcode);
451 }
452
453 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id arg_id)
454 {
455         Id result_id = begin_expression(opcode, type_id, 1);
456         writer.write(arg_id);
457         end_expression(opcode);
458         return result_id;
459 }
460
461 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id left_id, Id right_id)
462 {
463         Id result_id = begin_expression(opcode, type_id, 2);
464         writer.write(left_id);
465         writer.write(right_id);
466         end_expression(opcode);
467         return result_id;
468 }
469
470 void SpirVGenerator::write_deconstruct(Id elem_type_id, Id composite_id, Id *elem_ids, unsigned n_elems)
471 {
472         for(unsigned i=0; i<n_elems; ++i)
473         {
474                 elem_ids[i] = begin_expression(OP_COMPOSITE_EXTRACT, elem_type_id, 2);
475                 writer.write(composite_id);
476                 writer.write(i);
477                 end_expression(OP_COMPOSITE_EXTRACT);
478         }
479 }
480
481 SpirVGenerator::Id SpirVGenerator::write_construct(Id type_id, const Id *elem_ids, unsigned n_elems)
482 {
483         Id result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, type_id, n_elems);
484         for(unsigned i=0; i<n_elems; ++i)
485                 writer.write(elem_ids[i]);
486         end_expression(OP_COMPOSITE_CONSTRUCT);
487
488         return result_id;
489 }
490
491 void SpirVGenerator::visit(Block &block)
492 {
493         for(const RefPtr<Statement> &s: block.body)
494                 s->visit(*this);
495 }
496
497 void SpirVGenerator::visit(Literal &literal)
498 {
499         Id type_id = get_id(*literal.type);
500         if(spec_constant)
501                 r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true);
502         else
503                 r_expression_result_id = get_constant_id(type_id, literal.value);
504         r_constant_result = true;
505 }
506
507 void SpirVGenerator::visit(VariableReference &var)
508 {
509         if(constant_expression || var.declaration->constant)
510         {
511                 if(!var.declaration->constant)
512                         throw internal_error("reference to non-constant variable in constant context");
513
514                 r_expression_result_id = get_id(*var.declaration);
515                 r_constant_result = true;
516                 return;
517         }
518         else if(!current_function)
519                 throw internal_error("non-constant context outside a function");
520
521         r_constant_result = false;
522         if(composite_access)
523         {
524                 r_expression_result_id = 0;
525                 if(!assignment_source_id)
526                 {
527                         auto i = variable_load_ids.find(var.declaration);
528                         if(i!=variable_load_ids.end())
529                                 r_expression_result_id = i->second;
530                 }
531                 if(!r_expression_result_id)
532                         r_composite_base = var.declaration;
533         }
534         else if(assignment_source_id)
535         {
536                 writer.write_op(content.function_body, OP_STORE, get_id(*var.declaration), assignment_source_id);
537                 variable_load_ids[var.declaration] = assignment_source_id;
538                 r_expression_result_id = assignment_source_id;
539         }
540         else
541                 r_expression_result_id = get_load_id(*var.declaration);
542 }
543
544 void SpirVGenerator::visit(InterfaceBlockReference &iface)
545 {
546         if(!composite_access || !current_function)
547                 throw internal_error("invalid interface block reference");
548
549         r_composite_base = iface.declaration;
550         r_expression_result_id = 0;
551         r_constant_result = false;
552 }
553
554 void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
555 {
556         Opcode opcode;
557         Id result_type_id = get_id(result_type);
558         Id access_type_id = result_type_id;
559         if(r_composite_base)
560         {
561                 if(constant_expression)
562                         throw internal_error("composite access through pointer in constant context");
563
564                 Id int32_type_id = get_standard_type_id(BasicTypeDeclaration::INT, 1);
565                 for(unsigned &i: r_composite_chain)
566                         i = (i<0x400000 ? get_constant_id(int32_type_id, static_cast<int>(i)) : i&0x3FFFFF);
567
568                 /* Find the storage class of the base and obtain appropriate pointer type
569                 for the result. */
570                 const Declaration &base_decl = get_item(declared_ids, r_composite_base);
571                 auto i = pointer_type_ids.begin();
572                 for(; (i!=pointer_type_ids.end() && i->second!=base_decl.type_id); ++i) ;
573                 if(i==pointer_type_ids.end())
574                         throw internal_error("could not find storage class");
575                 access_type_id = get_pointer_type_id(result_type_id, static_cast<StorageClass>(i->first.detail));
576
577                 opcode = OP_ACCESS_CHAIN;
578         }
579         else if(assignment_source_id)
580                 throw internal_error("assignment to temporary composite");
581         else
582         {
583                 for(unsigned &i: r_composite_chain)
584                         for(auto j=constant_ids.begin(); (i>=0x400000 && j!=constant_ids.end()); ++j)
585                                 if(j->second==(i&0x3FFFFF))
586                                         i = j->first.int_value;
587
588                 opcode = OP_COMPOSITE_EXTRACT;
589         }
590
591         Id access_id = begin_expression(opcode, access_type_id, 1+r_composite_chain.size());
592         writer.write(r_composite_base_id);
593         for(unsigned i: r_composite_chain)
594                 writer.write(i);
595         end_expression(opcode);
596
597         r_constant_result = false;
598         if(r_composite_base)
599         {
600                 if(assignment_source_id)
601                 {
602                         writer.write_op(content.function_body, OP_STORE, access_id, assignment_source_id);
603                         r_expression_result_id = assignment_source_id;
604                 }
605                 else
606                         r_expression_result_id = write_expression(OP_LOAD, result_type_id, access_id);
607         }
608         else
609                 r_expression_result_id = access_id;
610 }
611
612 void SpirVGenerator::visit_composite(Expression &base_expr, unsigned index, TypeDeclaration &type)
613 {
614         if(!composite_access)
615         {
616                 r_composite_base = 0;
617                 r_composite_base_id = 0;
618                 r_composite_chain.clear();
619         }
620
621         {
622                 SetFlag set_composite(composite_access);
623                 base_expr.visit(*this);
624         }
625
626         if(!r_composite_base_id)
627                 r_composite_base_id = (r_composite_base ? get_id(*r_composite_base) : r_expression_result_id);
628
629         r_composite_chain.push_back(index);
630         if(!composite_access)
631                 generate_composite_access(type);
632         else
633                 r_expression_result_id = 0;
634 }
635
636 void SpirVGenerator::visit_isolated(Expression &expr)
637 {
638         SetForScope<Id> clear_assign(assignment_source_id, 0);
639         SetFlag clear_composite(composite_access, false);
640         SetForScope<Node *> clear_base(r_composite_base, 0);
641         SetForScope<Id> clear_base_id(r_composite_base_id, 0);
642         vector<unsigned> saved_chain;
643         swap(saved_chain, r_composite_chain);
644         expr.visit(*this);
645         swap(saved_chain, r_composite_chain);
646 }
647
648 void SpirVGenerator::visit(MemberAccess &memacc)
649 {
650         visit_composite(*memacc.left, memacc.index, *memacc.type);
651 }
652
653 void SpirVGenerator::visit(Swizzle &swizzle)
654 {
655         if(swizzle.count==1)
656                 visit_composite(*swizzle.left, swizzle.components[0], *swizzle.type);
657         else if(assignment_source_id)
658         {
659                 const BasicTypeDeclaration &basic = dynamic_cast<const BasicTypeDeclaration &>(*swizzle.left->type);
660
661                 unsigned mask = 0;
662                 for(unsigned i=0; i<swizzle.count; ++i)
663                         mask |= 1<<swizzle.components[i];
664
665                 visit_isolated(*swizzle.left);
666
667                 Id combined_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.left->type), 2+basic.size);
668                 writer.write(r_expression_result_id);
669                 writer.write(assignment_source_id);
670                 for(unsigned i=0; i<basic.size; ++i)
671                         writer.write(i+((mask>>i)&1)*basic.size);
672                 end_expression(OP_VECTOR_SHUFFLE);
673
674                 SetForScope<Id> set_assign(assignment_source_id, combined_id);
675                 swizzle.left->visit(*this);
676
677                 r_expression_result_id = combined_id;
678         }
679         else
680         {
681                 swizzle.left->visit(*this);
682                 Id left_id = r_expression_result_id;
683
684                 r_expression_result_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.type), 2+swizzle.count);
685                 writer.write(left_id);
686                 writer.write(left_id);
687                 for(unsigned i=0; i<swizzle.count; ++i)
688                         writer.write(swizzle.components[i]);
689                 end_expression(OP_VECTOR_SHUFFLE);
690         }
691         r_constant_result = false;
692 }
693
694 void SpirVGenerator::visit(UnaryExpression &unary)
695 {
696         if(composite_access)
697                 return visit_isolated(unary);
698
699         unary.expression->visit(*this);
700
701         char oper = unary.oper->token[0];
702         char oper2 = unary.oper->token[1];
703         if(oper=='+' && !oper2)
704                 return;
705
706         BasicTypeDeclaration &basic = dynamic_cast<BasicTypeDeclaration &>(*unary.expression->type);
707         BasicTypeDeclaration &elem = *get_element_type(basic);
708
709         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
710                 /* SPIR-V allows constant operations on floating-point values only for
711                 OpenGL kernels. */
712                 throw internal_error("invalid operands for constant unary expression");
713
714         Id result_type_id = get_id(*unary.type);
715         Opcode opcode = OP_NOP;
716
717         r_constant_result = false;
718         if(oper=='!')
719                 opcode = OP_LOGICAL_NOT;
720         else if(oper=='~')
721                 opcode = OP_NOT;
722         else if(oper=='-' && !oper2)
723         {
724                 opcode = (elem.kind==BasicTypeDeclaration::INT ? OP_S_NEGATE : OP_F_NEGATE);
725
726                 if(basic.kind==BasicTypeDeclaration::MATRIX)
727                 {
728                         Id column_type_id = get_id(*basic.base_type);
729                         unsigned n_columns = basic.size&0xFFFF;
730                         Id column_ids[4];
731                         write_deconstruct(column_type_id, r_expression_result_id, column_ids, n_columns);
732                         for(unsigned i=0; i<n_columns; ++i)
733                                 column_ids[i] = write_expression(opcode, column_type_id, column_ids[i]);
734                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
735                         return;
736                 }
737         }
738         else if((oper=='+' || oper=='-') && oper2==oper)
739         {
740                 if(constant_expression)
741                         throw internal_error("increment or decrement in constant expression");
742
743                 Id one_id = 0;
744                 if(elem.kind==BasicTypeDeclaration::INT)
745                 {
746                         opcode = (oper=='+' ? OP_I_ADD : OP_I_SUB);
747                         one_id = get_constant_id(get_id(elem), 1);
748                 }
749                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
750                 {
751                         opcode = (oper=='+' ? OP_F_ADD : OP_F_SUB);
752                         one_id = get_constant_id(get_id(elem), 1.0f);
753                 }
754                 else
755                         throw internal_error("invalid increment/decrement");
756
757                 if(basic.kind==BasicTypeDeclaration::VECTOR)
758                         one_id = get_vector_constant_id(result_type_id, basic.size, one_id);
759
760                 Id post_value_id = write_expression(opcode, result_type_id, r_expression_result_id, one_id);
761
762                 SetForScope<Id> set_assign(assignment_source_id, post_value_id);
763                 unary.expression->visit(*this);
764
765                 r_expression_result_id = (unary.oper->type==Operator::POSTFIX ? r_expression_result_id : post_value_id);
766                 return;
767         }
768
769         if(opcode==OP_NOP)
770                 throw internal_error("unknown unary operator");
771
772         r_expression_result_id = write_expression(opcode, result_type_id, r_expression_result_id);
773 }
774
775 void SpirVGenerator::visit(BinaryExpression &binary)
776 {
777         char oper = binary.oper->token[0];
778         if(oper=='[')
779         {
780                 visit_isolated(*binary.right);
781                 return visit_composite(*binary.left, 0x400000|r_expression_result_id, *binary.type);
782         }
783         else if(composite_access)
784                 return visit_isolated(binary);
785
786         if(assignment_source_id)
787                 throw internal_error("invalid binary expression in assignment target");
788
789         BasicTypeDeclaration &basic_left = dynamic_cast<BasicTypeDeclaration &>(*binary.left->type);
790         BasicTypeDeclaration &basic_right = dynamic_cast<BasicTypeDeclaration &>(*binary.right->type);
791         // Expression resolver ensures that element types are the same
792         BasicTypeDeclaration &elem = *get_element_type(basic_left);
793
794         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
795                 /* SPIR-V allows constant operations on floating-point values only for
796                 OpenGL kernels. */
797                 throw internal_error("invalid operands for constant binary expression");
798
799         binary.left->visit(*this);
800         Id left_id = r_expression_result_id;
801         binary.right->visit(*this);
802         Id right_id = r_expression_result_id;
803
804         Id result_type_id = get_id(*binary.type);
805         Opcode opcode = OP_NOP;
806         bool swap_operands = false;
807
808         r_constant_result = false;
809
810         char oper2 = binary.oper->token[1];
811         if((oper=='<' || oper=='>') && oper2!=oper)
812         {
813                 if(basic_left.kind==BasicTypeDeclaration::INT)
814                 {
815                         if(basic_left.sign)
816                                 opcode = (oper=='<' ? (oper2=='=' ? OP_S_LESS_THAN_EQUAL : OP_S_LESS_THAN) :
817                                         (oper2=='=' ? OP_S_GREATER_THAN_EQUAL : OP_S_GREATER_THAN));
818                         else
819                                 opcode = (oper=='<' ? (oper2=='=' ? OP_U_LESS_THAN_EQUAL : OP_U_LESS_THAN) :
820                                         (oper2=='=' ? OP_U_GREATER_THAN_EQUAL : OP_U_GREATER_THAN));
821                 }
822                 else if(basic_left.kind==BasicTypeDeclaration::FLOAT)
823                         opcode = (oper=='<' ? (oper2=='=' ? OP_F_ORD_LESS_THAN_EQUAL : OP_F_ORD_LESS_THAN) :
824                                 (oper2=='=' ? OP_F_ORD_GREATER_THAN_EQUAL : OP_F_ORD_GREATER_THAN));
825         }
826         else if((oper=='=' || oper=='!') && oper2=='=')
827         {
828                 if(elem.kind==BasicTypeDeclaration::BOOL)
829                         opcode = (oper=='=' ? OP_LOGICAL_EQUAL : OP_LOGICAL_NOT_EQUAL);
830                 else if(elem.kind==BasicTypeDeclaration::INT)
831                         opcode = (oper=='=' ? OP_I_EQUAL : OP_I_NOT_EQUAL);
832                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
833                         opcode = (oper=='=' ? OP_F_ORD_EQUAL : OP_F_ORD_NOT_EQUAL);
834
835                 if(opcode!=OP_NOP && basic_left.base_type)
836                 {
837                         /* The SPIR-V equality operations produce component-wise results, but
838                         GLSL operators return a single boolean.  Use the any/all operations to
839                         combine the results. */
840                         Opcode combine_op = (oper=='!' ? OP_ANY : OP_ALL);
841                         unsigned n_elems = basic_left.size&0xFFFF;
842                         Id bool_vec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, n_elems);
843
844                         Id compare_id = 0;
845                         if(basic_left.kind==BasicTypeDeclaration::VECTOR)
846                                 compare_id = write_expression(opcode, bool_vec_type_id, left_id, right_id);
847                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX)
848                         {
849                                 Id column_type_id = get_id(*basic_left.base_type);
850                                 Id column_ids[8];
851                                 write_deconstruct(column_type_id, left_id, column_ids, n_elems);
852                                 write_deconstruct(column_type_id, right_id, column_ids+4, n_elems);
853
854                                 Id column_bvec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, basic_left.size>>16);
855                                 for(unsigned i=0; i<n_elems; ++i)
856                                 {
857                                         compare_id = write_expression(opcode, column_bvec_type_id, column_ids[i], column_ids[4+i]);
858                                         column_ids[i] = write_expression(combine_op, result_type_id, compare_id);;
859                                 }
860
861                                 compare_id = write_construct(bool_vec_type_id, column_ids, n_elems);
862                         }
863
864                         if(compare_id)
865                                 r_expression_result_id = write_expression(combine_op, result_type_id, compare_id);
866                         return;
867                 }
868         }
869         else if(oper2=='&' && elem.kind==BasicTypeDeclaration::BOOL)
870                 opcode = OP_LOGICAL_AND;
871         else if(oper2=='|' && elem.kind==BasicTypeDeclaration::BOOL)
872                 opcode = OP_LOGICAL_OR;
873         else if(oper2=='^' && elem.kind==BasicTypeDeclaration::BOOL)
874                 opcode = OP_LOGICAL_NOT_EQUAL;
875         else if(oper=='&' && elem.kind==BasicTypeDeclaration::INT)
876                 opcode = OP_BITWISE_AND;
877         else if(oper=='|' && elem.kind==BasicTypeDeclaration::INT)
878                 opcode = OP_BITWISE_OR;
879         else if(oper=='^' && elem.kind==BasicTypeDeclaration::INT)
880                 opcode = OP_BITWISE_XOR;
881         else if(oper=='<' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
882                 opcode = OP_SHIFT_LEFT_LOGICAL;
883         else if(oper=='>' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
884                 opcode = OP_SHIFT_RIGHT_ARITHMETIC;
885         else if(oper=='%' && elem.kind==BasicTypeDeclaration::INT)
886                 opcode = (elem.sign ? OP_S_MOD : OP_U_MOD);
887         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
888         {
889                 Opcode elem_op = OP_NOP;
890                 if(elem.kind==BasicTypeDeclaration::INT)
891                 {
892                         if(oper=='/')
893                                 elem_op = (elem.sign ? OP_S_DIV : OP_U_DIV);
894                         else
895                                 elem_op = (oper=='+' ? OP_I_ADD : oper=='-' ? OP_I_SUB : OP_I_MUL);
896                 }
897                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
898                         elem_op = (oper=='+' ? OP_F_ADD : oper=='-' ? OP_F_SUB : oper=='*' ? OP_F_MUL : OP_F_DIV);
899
900                 if(oper=='*' && (basic_left.base_type || basic_right.base_type) && elem.kind==BasicTypeDeclaration::FLOAT)
901                 {
902                         /* Multiplication between floating-point vectors and matrices has
903                         dedicated operations. */
904                         if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
905                                 opcode = OP_MATRIX_TIMES_MATRIX;
906                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
907                         {
908                                 if(basic_left.kind==BasicTypeDeclaration::VECTOR)
909                                         opcode = OP_VECTOR_TIMES_MATRIX;
910                                 else if(basic_right.kind==BasicTypeDeclaration::VECTOR)
911                                         opcode = OP_MATRIX_TIMES_VECTOR;
912                                 else
913                                 {
914                                         opcode = OP_MATRIX_TIMES_SCALAR;
915                                         swap_operands = (basic_right.kind==BasicTypeDeclaration::MATRIX);
916                                 }
917                         }
918                         else if(basic_left.kind==BasicTypeDeclaration::VECTOR && basic_right.kind==BasicTypeDeclaration::VECTOR)
919                                 opcode = elem_op;
920                         else
921                         {
922                                 opcode = OP_VECTOR_TIMES_SCALAR;
923                                 swap_operands = (basic_right.kind==BasicTypeDeclaration::VECTOR);
924                         }
925                 }
926                 else if((basic_left.base_type!=0)!=(basic_right.base_type!=0))
927                 {
928                         /* One operand is scalar and the other is a vector or a matrix.
929                         Expand the scalar to a vector of appropriate size. */
930                         Id &scalar_id = (basic_left.base_type ? right_id : left_id);
931                         BasicTypeDeclaration *vector_type = (basic_left.base_type ? &basic_left : &basic_right);
932                         if(vector_type->kind==BasicTypeDeclaration::MATRIX)
933                                 vector_type = dynamic_cast<BasicTypeDeclaration *>(vector_type->base_type);
934                         Id vector_type_id = get_id(*vector_type);
935
936                         Id expanded_id = begin_expression(OP_COMPOSITE_CONSTRUCT, vector_type_id, vector_type->size);
937                         for(unsigned i=0; i<vector_type->size; ++i)
938                                 writer.write(scalar_id);
939                         end_expression(OP_COMPOSITE_CONSTRUCT);
940
941                         scalar_id = expanded_id;
942
943                         if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
944                         {
945                                 // Apply matrix operation column-wise.
946                                 Id matrix_id = (basic_left.base_type ? left_id : right_id);
947
948                                 Id column_ids[4];
949                                 unsigned n_columns = (basic_left.base_type ? basic_left.size : basic_right.size)&0xFFFF;
950                                 write_deconstruct(vector_type_id, matrix_id, column_ids, n_columns);
951
952                                 for(unsigned i=0; i<n_columns; ++i)
953                                         column_ids[i] = write_expression(elem_op, vector_type_id, column_ids[i], expanded_id);
954
955                                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
956                                 return;
957                         }
958                         else
959                                 opcode = elem_op;
960                 }
961                 else if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
962                 {
963                         if(oper=='*')
964                                 throw internal_error("non-float matrix multiplication");
965
966                         /* Other operations involving matrices need to be performed
967                         column-wise. */
968                         Id column_type_id = get_id(*basic_left.base_type);
969                         Id column_ids[8];
970
971                         unsigned n_columns = basic_left.size&0xFFFF;
972                         write_deconstruct(column_type_id, left_id, column_ids, n_columns);
973                         write_deconstruct(column_type_id, right_id, column_ids+4, n_columns);
974
975                         for(unsigned i=0; i<n_columns; ++i)
976                                 column_ids[i] = write_expression(elem_op, column_type_id, column_ids[i], column_ids[4+i]);
977
978                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
979                         return;
980                 }
981                 else if(basic_left.kind==basic_right.kind)
982                         // Both operands are either scalars or vectors.
983                         opcode = elem_op;
984         }
985
986         if(opcode==OP_NOP)
987                 throw internal_error("unknown binary operator");
988
989         if(swap_operands)
990                 swap(left_id, right_id);
991
992         r_expression_result_id = write_expression(opcode, result_type_id, left_id, right_id);
993 }
994
995 void SpirVGenerator::visit(Assignment &assign)
996 {
997         if(assign.oper->token[0]!='=')
998                 visit(static_cast<BinaryExpression &>(assign));
999         else
1000                 assign.right->visit(*this);
1001
1002         SetForScope<Id> set_assign(assignment_source_id, r_expression_result_id);
1003         assign.left->visit(*this);
1004         r_constant_result = false;
1005 }
1006
1007 void SpirVGenerator::visit(TernaryExpression &ternary)
1008 {
1009         if(composite_access)
1010                 return visit_isolated(ternary);
1011         if(constant_expression)
1012         {
1013                 ternary.condition->visit(*this);
1014                 Id condition_id = r_expression_result_id;
1015                 ternary.true_expr->visit(*this);
1016                 Id true_result_id = r_expression_result_id;
1017                 ternary.false_expr->visit(*this);
1018                 Id false_result_id = r_expression_result_id;
1019
1020                 r_expression_result_id = begin_expression(OP_SELECT, get_id(*ternary.type), 3);
1021                 writer.write(condition_id);
1022                 writer.write(true_result_id);
1023                 writer.write(false_result_id);
1024                 end_expression(OP_SELECT);
1025
1026                 return;
1027         }
1028
1029         ternary.condition->visit(*this);
1030         Id condition_id = r_expression_result_id;
1031
1032         Id true_label_id = next_id++;
1033         Id false_label_id = next_id++;
1034         Id merge_block_id = next_id++;
1035         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1036         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, condition_id, true_label_id, false_label_id);
1037
1038         writer.write_op_label(true_label_id);
1039         ternary.true_expr->visit(*this);
1040         Id true_result_id = r_expression_result_id;
1041         true_label_id = writer.get_current_block();
1042         writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1043
1044         writer.write_op_label(false_label_id);
1045         ternary.false_expr->visit(*this);
1046         Id false_result_id = r_expression_result_id;
1047         false_label_id = writer.get_current_block();
1048
1049         writer.write_op_label(merge_block_id);
1050         r_expression_result_id = begin_expression(OP_PHI, get_id(*ternary.type), 4);
1051         writer.write(true_result_id);
1052         writer.write(true_label_id);
1053         writer.write(false_result_id);
1054         writer.write(false_label_id);
1055         end_expression(OP_PHI);
1056
1057         r_constant_result = false;
1058 }
1059
1060 void SpirVGenerator::visit(FunctionCall &call)
1061 {
1062         if(assignment_source_id)
1063                 throw internal_error("assignment to function call");
1064         else if(composite_access)
1065                 return visit_isolated(call);
1066         else if(call.constructor && call.arguments.size()==1 && call.arguments[0]->type==call.type)
1067                 return call.arguments[0]->visit(*this);
1068
1069         vector<Id> argument_ids;
1070         argument_ids.reserve(call.arguments.size());
1071         bool all_args_const = true;
1072         for(const RefPtr<Expression> &a: call.arguments)
1073         {
1074                 a->visit(*this);
1075                 argument_ids.push_back(r_expression_result_id);
1076                 all_args_const &= r_constant_result;
1077         }
1078
1079         if(constant_expression && (!call.constructor || !all_args_const))
1080                 throw internal_error("function call in constant expression");
1081
1082         Id result_type_id = get_id(*call.type);
1083         r_constant_result = false;
1084
1085         if(call.constructor)
1086                 visit_constructor(call, argument_ids, all_args_const);
1087         else if(call.declaration->source==BUILTIN_SOURCE)
1088         {
1089                 string arg_types;
1090                 for(const RefPtr<Expression> &a: call.arguments)
1091                         if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>(a->type))
1092                         {
1093                                 BasicTypeDeclaration &elem_arg = *get_element_type(*basic_arg);
1094                                 switch(elem_arg.kind)
1095                                 {
1096                                 case BasicTypeDeclaration::BOOL: arg_types += 'b'; break;
1097                                 case BasicTypeDeclaration::INT: arg_types += (elem_arg.sign ? 'i' : 'u'); break;
1098                                 case BasicTypeDeclaration::FLOAT: arg_types += 'f'; break;
1099                                 default: arg_types += '?';
1100                                 }
1101                         }
1102
1103                 const BuiltinFunctionInfo *builtin_info;
1104                 for(builtin_info=builtin_functions; builtin_info->function[0]; ++builtin_info)
1105                         if(builtin_info->function==call.name && (!builtin_info->arg_types[0] || builtin_info->arg_types==arg_types))
1106                                 break;
1107
1108                 if(builtin_info->capability)
1109                         use_capability(static_cast<Capability>(builtin_info->capability));
1110
1111                 if(builtin_info->opcode)
1112                 {
1113                         Opcode opcode;
1114                         if(builtin_info->extension[0])
1115                         {
1116                                 opcode = OP_EXT_INST;
1117                                 Id ext_id = import_extension(builtin_info->extension);
1118
1119                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1120                                 writer.write(ext_id);
1121                                 writer.write(builtin_info->opcode);
1122                         }
1123                         else
1124                         {
1125                                 opcode = static_cast<Opcode>(builtin_info->opcode);
1126                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1127                         }
1128
1129                         for(unsigned i=0; i<call.arguments.size(); ++i)
1130                         {
1131                                 if(!builtin_info->arg_order[i] || builtin_info->arg_order[i]>argument_ids.size())
1132                                         throw internal_error("invalid builtin function info");
1133                                 writer.write(argument_ids[builtin_info->arg_order[i]-1]);
1134                         }
1135
1136                         end_expression(opcode);
1137                 }
1138                 else if(builtin_info->handler)
1139                         (this->*(builtin_info->handler))(call, argument_ids);
1140                 else
1141                         throw internal_error("unknown builtin function "+call.name);
1142         }
1143         else
1144         {
1145                 r_expression_result_id = begin_expression(OP_FUNCTION_CALL, result_type_id, 1+call.arguments.size());
1146                 writer.write(get_id(*call.declaration->definition));
1147                 for(Id i: argument_ids)
1148                         writer.write(i);
1149                 end_expression(OP_FUNCTION_CALL);
1150
1151                 // Any global variables the called function uses might have changed value
1152                 set<Node *> dependencies = DependencyCollector().apply(*call.declaration->definition);
1153                 for(Node *n: dependencies)
1154                         if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(n))
1155                                 variable_load_ids.erase(var);
1156         }
1157 }
1158
1159 void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids, bool all_args_const)
1160 {
1161         Id result_type_id = get_id(*call.type);
1162
1163         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(call.type);
1164         if(!basic)
1165         {
1166                 if(dynamic_cast<const StructDeclaration *>(call.type))
1167                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1168                 else
1169                         throw internal_error("unconstructable type "+call.name);
1170                 return;
1171         }
1172
1173         SetFlag set_const(constant_expression, constant_expression || all_args_const);
1174
1175         BasicTypeDeclaration &elem = *get_element_type(*basic);
1176         BasicTypeDeclaration &basic_arg0 = dynamic_cast<BasicTypeDeclaration &>(*call.arguments[0]->type);
1177         BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0);
1178
1179         if(basic->kind==BasicTypeDeclaration::MATRIX)
1180         {
1181                 Id col_type_id = get_id(*basic->base_type);
1182                 unsigned n_columns = basic->size&0xFFFF;
1183                 unsigned n_rows = basic->size>>16;
1184
1185                 Id column_ids[4];
1186                 if(call.arguments.size()==1)
1187                 {
1188                         // Construct diagonal matrix from a single scalar.
1189                         Id zero_id = get_constant_id(get_id(elem), 0.0f);
1190                         for(unsigned i=0; i<n_columns; ++i)
1191                         {
1192                                 column_ids[i] = begin_expression(OP_COMPOSITE_CONSTRUCT, col_type_id, n_rows);
1193                                 for(unsigned j=0; j<n_rows; ++j)
1194                                         writer.write(j==i ? argument_ids[0] : zero_id);
1195                                 end_expression(OP_COMPOSITE_CONSTRUCT);
1196                         }
1197                 }
1198                 else
1199                         // Construct a matrix from column vectors
1200                         copy(argument_ids.begin(), argument_ids.begin()+n_columns, column_ids);
1201
1202                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
1203         }
1204         else if(basic->kind==BasicTypeDeclaration::VECTOR && (call.arguments.size()>1 || basic_arg0.kind!=BasicTypeDeclaration::VECTOR))
1205         {
1206                 /* There's either a single scalar argument or multiple arguments
1207                 which make up the vector's components. */
1208                 if(call.arguments.size()==1)
1209                 {
1210                         r_expression_result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, result_type_id);
1211                         for(unsigned i=0; i<basic->size; ++i)
1212                                 writer.write(argument_ids[0]);
1213                         end_expression(OP_COMPOSITE_CONSTRUCT);
1214                 }
1215                 else
1216                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1217         }
1218         else if(elem.kind==BasicTypeDeclaration::BOOL)
1219         {
1220                 if(constant_expression)
1221                         throw internal_error("unconverted constant");
1222
1223                 // Conversion to boolean is implemented as comparing against zero.
1224                 Id number_type_id = get_id(elem_arg0);
1225                 Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ?
1226                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1227                 if(basic_arg0.kind==BasicTypeDeclaration::VECTOR)
1228                         zero_id = get_vector_constant_id(get_id(basic_arg0), basic_arg0.size, zero_id);
1229
1230                 Opcode opcode = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ? OP_F_ORD_NOT_EQUAL : OP_I_NOT_EQUAL);
1231                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0], zero_id);
1232         }
1233         else if(elem_arg0.kind==BasicTypeDeclaration::BOOL)
1234         {
1235                 if(constant_expression)
1236                         throw internal_error("unconverted constant");
1237
1238                 /* Conversion from boolean is implemented as selecting from zero
1239                 or one. */
1240                 Id number_type_id = get_id(elem);
1241                 Id zero_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1242                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1243                 Id one_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1244                         get_constant_id(number_type_id, 1.0f) : get_constant_id(number_type_id, 1));
1245                 if(basic->kind==BasicTypeDeclaration::VECTOR)
1246                 {
1247                         zero_id = get_vector_constant_id(get_id(*basic), basic->size, zero_id);
1248                         one_id = get_vector_constant_id(get_id(*basic), basic->size, one_id);
1249                 }
1250
1251                 r_expression_result_id = begin_expression(OP_SELECT, result_type_id, 3);
1252                 writer.write(argument_ids[0]);
1253                 writer.write(zero_id);
1254                 writer.write(one_id);
1255                 end_expression(OP_SELECT);
1256         }
1257         else
1258         {
1259                 if(constant_expression)
1260                         throw internal_error("unconverted constant");
1261
1262                 // Scalar or vector conversion between types of equal size.
1263                 Opcode opcode;
1264                 if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT)
1265                         opcode = (elem.sign ? OP_CONVERT_F_TO_S : OP_CONVERT_F_TO_U);
1266                 else if(elem.kind==BasicTypeDeclaration::FLOAT && elem_arg0.kind==BasicTypeDeclaration::INT)
1267                         opcode = (elem_arg0.sign ? OP_CONVERT_S_TO_F : OP_CONVERT_U_TO_F);
1268                 else if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::INT)
1269                         opcode = OP_BITCAST;
1270                 else
1271                         throw internal_error("invalid conversion");
1272
1273                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0]);
1274         }
1275 }
1276
1277 void SpirVGenerator::visit_builtin_matrix_comp_mult(FunctionCall &call, const vector<Id> &argument_ids)
1278 {
1279         if(argument_ids.size()!=2)
1280                 throw internal_error("invalid matrixCompMult call");
1281
1282         const BasicTypeDeclaration &basic_arg0 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[0]->type);
1283         Id column_type_id = get_id(*basic_arg0.base_type);
1284         Id column_ids[8];
1285
1286         unsigned n_columns = basic_arg0.size&0xFFFF;
1287         write_deconstruct(column_type_id, argument_ids[0], column_ids, n_columns);
1288         write_deconstruct(column_type_id, argument_ids[1], column_ids+4, n_columns);
1289
1290         for(unsigned i=0; i<n_columns; ++i)
1291                 column_ids[i] = write_expression(OP_F_MUL, column_type_id, column_ids[i], column_ids[4+i]);
1292
1293         r_expression_result_id = write_construct(get_id(*call.type), column_ids, n_columns);
1294 }
1295
1296 void SpirVGenerator::visit_builtin_texture_query(FunctionCall &call, const vector<Id> &argument_ids)
1297 {
1298         if(argument_ids.size()<1)
1299                 throw internal_error("invalid texture query call");
1300
1301         Opcode opcode;
1302         if(call.name=="textureSize")
1303                 opcode = OP_IMAGE_QUERY_SIZE_LOD;
1304         else if(call.name=="textureQueryLod")
1305                 opcode = OP_IMAGE_QUERY_LOD;
1306         else if(call.name=="textureQueryLevels")
1307                 opcode = OP_IMAGE_QUERY_LEVELS;
1308         else
1309                 throw internal_error("invalid texture query call");
1310
1311         ImageTypeDeclaration &image_arg0 = dynamic_cast<ImageTypeDeclaration &>(*call.arguments[0]->type);
1312
1313         Id image_id;
1314         if(image_arg0.sampled)
1315         {
1316                 Id image_type_id = get_item(image_type_ids, get_id(image_arg0));
1317                 image_id = write_expression(OP_IMAGE, image_type_id, argument_ids[0]);
1318         }
1319         else
1320                 image_id = argument_ids[0];
1321
1322         Id result_type_id = get_id(*call.type);
1323         r_expression_result_id = begin_expression(opcode, result_type_id, argument_ids.size());
1324         writer.write(image_id);
1325         for(unsigned i=1; i<argument_ids.size(); ++i)
1326                 writer.write(argument_ids[i]);
1327         end_expression(opcode);
1328 }
1329
1330 void SpirVGenerator::visit_builtin_texture(FunctionCall &call, const vector<Id> &argument_ids)
1331 {
1332         if(argument_ids.size()<2)
1333                 throw internal_error("invalid texture sampling call");
1334
1335         bool explicit_lod = (stage->type!=Stage::FRAGMENT || call.name=="textureLod");
1336         Id lod_id = (!explicit_lod ? 0 : call.name=="textureLod" ? argument_ids.back() :
1337                 get_constant_id(get_standard_type_id(BasicTypeDeclaration::FLOAT, 1), 0.0f));
1338
1339         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1340
1341         Opcode opcode;
1342         Id result_type_id = get_id(*call.type);
1343         Id dref_id = 0;
1344         if(image.shadow)
1345         {
1346                 if(argument_ids.size()==2)
1347                 {
1348                         const BasicTypeDeclaration &basic_arg1 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[1]->type);
1349                         dref_id = begin_expression(OP_COMPOSITE_EXTRACT, get_id(*basic_arg1.base_type), 2);
1350                         writer.write(argument_ids.back());
1351                         writer.write(basic_arg1.size-1);
1352                         end_expression(OP_COMPOSITE_EXTRACT);
1353                 }
1354                 else
1355                         dref_id = argument_ids[2];
1356
1357                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD : OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD);
1358                 r_expression_result_id = begin_expression(opcode, result_type_id, 3+explicit_lod*2);
1359         }
1360         else
1361         {
1362                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_EXPLICIT_LOD : OP_IMAGE_SAMPLE_IMPLICIT_LOD);
1363                 r_expression_result_id = begin_expression(opcode, result_type_id, 2+explicit_lod*2);
1364         }
1365
1366         for(unsigned i=0; i<2; ++i)
1367                 writer.write(argument_ids[i]);
1368         if(dref_id)
1369                 writer.write(dref_id);
1370         if(explicit_lod)
1371         {
1372                 writer.write(2);  // Lod
1373                 writer.write(lod_id);
1374         }
1375
1376         end_expression(opcode);
1377 }
1378
1379 void SpirVGenerator::visit_builtin_texel_fetch(FunctionCall &call, const vector<Id> &argument_ids)
1380 {
1381         if(argument_ids.size()!=3)
1382                 throw internal_error("invalid texelFetch call");
1383
1384         r_expression_result_id = begin_expression(OP_IMAGE_FETCH, get_id(*call.type), 4);
1385         for(unsigned i=0; i<2; ++i)
1386                 writer.write(argument_ids[i]);
1387         writer.write(2);  // Lod
1388         writer.write(argument_ids.back());
1389         end_expression(OP_IMAGE_FETCH);
1390 }
1391
1392 void SpirVGenerator::visit_builtin_interpolate(FunctionCall &call, const vector<Id> &argument_ids)
1393 {
1394         if(argument_ids.size()<1)
1395                 throw internal_error("invalid interpolate call");
1396         const VariableReference *var = dynamic_cast<const VariableReference *>(call.arguments[0].get());
1397         if(!var || !var->declaration || var->declaration->interface!="in")
1398                 throw internal_error("invalid interpolate call");
1399
1400         SpirVGlslStd450Opcode opcode;
1401         if(call.name=="interpolateAtCentroid")
1402                 opcode = GLSL450_INTERPOLATE_AT_CENTROID;
1403         else if(call.name=="interpolateAtSample")
1404                 opcode = GLSL450_INTERPOLATE_AT_SAMPLE;
1405         else if(call.name=="interpolateAtOffset")
1406                 opcode = GLSL450_INTERPOLATE_AT_OFFSET;
1407         else
1408                 throw internal_error("invalid interpolate call");
1409
1410         Id ext_id = import_extension("GLSL.std.450");
1411         r_expression_result_id = begin_expression(OP_EXT_INST, get_id(*call.type));
1412         writer.write(ext_id);
1413         writer.write(opcode);
1414         writer.write(get_id(*var->declaration));
1415         for(auto i=argument_ids.begin(); ++i!=argument_ids.end(); )
1416                 writer.write(*i);
1417         end_expression(OP_EXT_INST);
1418 }
1419
1420 void SpirVGenerator::visit(ExpressionStatement &expr)
1421 {
1422         expr.expression->visit(*this);
1423 }
1424
1425 void SpirVGenerator::visit(InterfaceLayout &layout)
1426 {
1427         interface_layouts.push_back(&layout);
1428 }
1429
1430 bool SpirVGenerator::check_duplicate_type(TypeDeclaration &type)
1431 {
1432         for(const auto &kvp: declared_ids)
1433                 if(TypeDeclaration *type2 = dynamic_cast<TypeDeclaration *>(kvp.first))
1434                         if(TypeComparer().apply(type, *type2))
1435                         {
1436                                 insert_unique(declared_ids, &type, kvp.second);
1437                                 return true;
1438                         }
1439
1440         return false;
1441 }
1442
1443 bool SpirVGenerator::check_standard_type(BasicTypeDeclaration &basic)
1444 {
1445         const BasicTypeDeclaration *elem = (basic.kind==BasicTypeDeclaration::VECTOR ?
1446                 dynamic_cast<const BasicTypeDeclaration *>(basic.base_type) : &basic);
1447         if(!elem || elem->base_type)
1448                 return false;
1449         if((elem->kind==BasicTypeDeclaration::INT || elem->kind==BasicTypeDeclaration::FLOAT) && elem->size!=32)
1450                 return false;
1451
1452         Id standard_id = get_standard_type_id(elem->kind, (basic.kind==BasicTypeDeclaration::VECTOR ? basic.size : 1), elem->sign);
1453         insert_unique(declared_ids, &basic, Declaration(standard_id, 0));
1454         writer.write_op_name(standard_id, basic.name);
1455
1456         return true;
1457 }
1458
1459 void SpirVGenerator::visit(BasicTypeDeclaration &basic)
1460 {
1461         if(check_standard_type(basic))
1462                 return;
1463         if(check_duplicate_type(basic))
1464                 return;
1465         // Alias types shouldn't exist at this point and arrays are handled elsewhere
1466         if(basic.kind==BasicTypeDeclaration::ALIAS || basic.kind==BasicTypeDeclaration::ARRAY)
1467                 return;
1468
1469         Id type_id = allocate_id(basic, 0);
1470         writer.write_op_name(type_id, basic.name);
1471
1472         switch(basic.kind)
1473         {
1474         case BasicTypeDeclaration::INT:
1475                 writer.write_op(content.globals, OP_TYPE_INT, type_id, basic.size, basic.sign);
1476                 break;
1477         case BasicTypeDeclaration::FLOAT:
1478                 writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, basic.size);
1479                 break;
1480         case BasicTypeDeclaration::VECTOR:
1481                 writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, get_id(*basic.base_type), basic.size);
1482                 break;
1483         case BasicTypeDeclaration::MATRIX:
1484                 writer.write_op(content.globals, OP_TYPE_MATRIX, type_id, get_id(*basic.base_type), basic.size&0xFFFF);
1485                 break;
1486         default:
1487                 throw internal_error("unknown basic type");
1488         }
1489 }
1490
1491 void SpirVGenerator::visit(ImageTypeDeclaration &image)
1492 {
1493         if(check_duplicate_type(image))
1494                 return;
1495
1496         Id type_id = allocate_id(image, 0);
1497
1498         Id image_id = (image.sampled ? next_id++ : type_id);
1499         writer.begin_op(content.globals, OP_TYPE_IMAGE, 9);
1500         writer.write(image_id);
1501         writer.write(get_id(*image.base_type));
1502         writer.write(image.dimensions-1);
1503         writer.write(image.shadow);
1504         writer.write(image.array);
1505         writer.write(false);  // Multisample
1506         writer.write(image.sampled ? 1 : 2);
1507         writer.write(0);  // Format (unknown)
1508         writer.end_op(OP_TYPE_IMAGE);
1509
1510         if(image.sampled)
1511         {
1512                 writer.write_op_name(type_id, image.name);
1513                 writer.write_op(content.globals, OP_TYPE_SAMPLED_IMAGE, type_id, image_id);
1514                 insert_unique(image_type_ids, type_id, image_id);
1515         }
1516
1517         if(image.dimensions==ImageTypeDeclaration::ONE)
1518                 use_capability(image.sampled ? CAP_SAMPLED_1D : CAP_IMAGE_1D);
1519         else if(image.dimensions==ImageTypeDeclaration::CUBE && image.array)
1520                 use_capability(image.sampled ? CAP_SAMPLED_CUBE_ARRAY : CAP_IMAGE_CUBE_ARRAY);
1521 }
1522
1523 void SpirVGenerator::visit(StructDeclaration &strct)
1524 {
1525         if(check_duplicate_type(strct))
1526                 return;
1527
1528         Id type_id = allocate_id(strct, 0);
1529         writer.write_op_name(type_id, strct.name);
1530
1531         if(strct.interface_block)
1532                 writer.write_op_decorate(type_id, DECO_BLOCK);
1533
1534         bool builtin = (strct.interface_block && !strct.interface_block->block_name.compare(0, 3, "gl_"));
1535         vector<Id> member_type_ids;
1536         member_type_ids.reserve(strct.members.body.size());
1537         for(const RefPtr<Statement> &s: strct.members.body)
1538         {
1539                 const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(s.get());
1540                 if(!var)
1541                         continue;
1542
1543                 unsigned index = member_type_ids.size();
1544                 member_type_ids.push_back(get_variable_type_id(*var));
1545
1546                 writer.write_op_member_name(type_id, index, var->name);
1547
1548                 if(builtin)
1549                 {
1550                         BuiltinSemantic semantic = get_builtin_semantic(var->name);
1551                         writer.write_op_member_decorate(type_id, index, DECO_BUILTIN, semantic);
1552                 }
1553                 else
1554                 {
1555                         if(var->layout)
1556                         {
1557                                 for(const Layout::Qualifier &q: var->layout->qualifiers)
1558                                 {
1559                                         if(q.name=="offset")
1560                                                 writer.write_op_member_decorate(type_id, index, DECO_OFFSET, q.value);
1561                                         else if(q.name=="column_major")
1562                                                 writer.write_op_member_decorate(type_id, index, DECO_COL_MAJOR);
1563                                         else if(q.name=="row_major")
1564                                                 writer.write_op_member_decorate(type_id, index, DECO_ROW_MAJOR);
1565                                 }
1566                         }
1567
1568                         const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var->type_declaration);
1569                         while(basic && basic->kind==BasicTypeDeclaration::ARRAY)
1570                                 basic = dynamic_cast<const BasicTypeDeclaration *>(basic->base_type);
1571                         if(basic && basic->kind==BasicTypeDeclaration::MATRIX)
1572                         {
1573                                 unsigned stride = MemoryRequirementsCalculator().apply(*basic->base_type).stride;
1574                                 writer.write_op_member_decorate(type_id, index, DECO_MATRIX_STRIDE, stride);
1575                         }
1576                 }
1577         }
1578
1579         writer.begin_op(content.globals, OP_TYPE_STRUCT);
1580         writer.write(type_id);
1581         for(Id i: member_type_ids)
1582                 writer.write(i);
1583         writer.end_op(OP_TYPE_STRUCT);
1584 }
1585
1586 void SpirVGenerator::visit(VariableDeclaration &var)
1587 {
1588         Id type_id = get_variable_type_id(var);
1589         Id var_id;
1590
1591         if(var.constant)
1592         {
1593                 if(!var.init_expression)
1594                         throw internal_error("const variable without initializer");
1595
1596                 int spec_id = get_layout_value(var.layout.get(), "constant_id");
1597                 Id *spec_var_id = (spec_id>=0 ? &declared_spec_ids[spec_id] : 0);
1598                 if(spec_id>=0 && *spec_var_id)
1599                 {
1600                         insert_unique(declared_ids, &var, Declaration(*spec_var_id, type_id));
1601                         return;
1602                 }
1603
1604                 SetFlag set_const(constant_expression);
1605                 SetFlag set_spec(spec_constant, spec_id>=0);
1606                 r_expression_result_id = 0;
1607                 var.init_expression->visit(*this);
1608                 var_id = r_expression_result_id;
1609                 insert_unique(declared_ids, &var, Declaration(var_id, type_id));
1610                 if(spec_id>=0)
1611                 {
1612                         writer.write_op_decorate(var_id, DECO_SPEC_ID, spec_id);
1613                         *spec_var_id = var_id;
1614                 }
1615         }
1616         else
1617         {
1618                 StorageClass storage = (current_function ? STORAGE_FUNCTION : get_interface_storage(var.interface, false));
1619                 Id ptr_type_id = get_pointer_type_id(type_id, storage);
1620                 if(var.interface=="uniform")
1621                 {
1622                         Id &uni_id = declared_uniform_ids["v"+var.name];
1623                         if(uni_id)
1624                         {
1625                                 insert_unique(declared_ids, &var, Declaration(uni_id, ptr_type_id));
1626                                 return;
1627                         }
1628
1629                         uni_id = var_id = allocate_id(var, ptr_type_id);
1630                 }
1631                 else
1632                         var_id = allocate_id(var, (var.constant ? type_id : ptr_type_id));
1633
1634                 Id init_id = 0;
1635                 if(var.init_expression)
1636                 {
1637                         SetFlag set_const(constant_expression, !current_function);
1638                         r_expression_result_id = 0;
1639                         r_constant_result = false;
1640                         var.init_expression->visit(*this);
1641                         init_id = r_expression_result_id;
1642                 }
1643
1644                 vector<Word> &target = (current_function ? content.locals : content.globals);
1645                 writer.begin_op(target, OP_VARIABLE, 4+(init_id && !current_function));
1646                 writer.write(ptr_type_id);
1647                 writer.write(var_id);
1648                 writer.write(storage);
1649                 if(init_id && !current_function)
1650                         writer.write(init_id);
1651                 writer.end_op(OP_VARIABLE);
1652
1653                 if(var.layout)
1654                 {
1655                         for(const Layout::Qualifier &q: var.layout->qualifiers)
1656                         {
1657                                 if(q.name=="location")
1658                                         writer.write_op_decorate(var_id, DECO_LOCATION, q.value);
1659                                 else if(q.name=="set")
1660                                         writer.write_op_decorate(var_id, DECO_DESCRIPTOR_SET, q.value);
1661                                 else if(q.name=="binding")
1662                                         writer.write_op_decorate(var_id, DECO_BINDING, q.value);
1663                         }
1664                 }
1665                 if(!var.name.compare(0, 3, "gl_"))
1666                 {
1667                         BuiltinSemantic semantic = get_builtin_semantic(var.name);
1668                         writer.write_op_decorate(var_id, DECO_BUILTIN, semantic);
1669                 }
1670
1671                 if(init_id && current_function)
1672                 {
1673                         writer.write_op(content.function_body, OP_STORE, var_id, init_id);
1674                         variable_load_ids[&var] = init_id;
1675                 }
1676         }
1677
1678         writer.write_op_name(var_id, var.name);
1679 }
1680
1681 void SpirVGenerator::visit(InterfaceBlock &iface)
1682 {
1683         bool push_const = has_layout_qualifier(iface.layout.get(), "push_constant");
1684
1685         StorageClass storage = (push_const ? STORAGE_PUSH_CONSTANT : get_interface_storage(iface.interface, true));
1686         Id type_id;
1687         if(iface.array)
1688                 type_id = get_array_type_id(*iface.struct_declaration, 0, true);
1689         else
1690                 type_id = get_id(*iface.struct_declaration);
1691         Id ptr_type_id = get_pointer_type_id(type_id, storage);
1692
1693         Id block_id;
1694         if(iface.interface=="uniform")
1695         {
1696                 Id &uni_id = declared_uniform_ids["b"+iface.block_name];
1697                 if(uni_id)
1698                 {
1699                         insert_unique(declared_ids, &iface, Declaration(uni_id, ptr_type_id));
1700                         return;
1701                 }
1702
1703                 uni_id = block_id = allocate_id(iface, ptr_type_id);
1704         }
1705         else
1706                 block_id = allocate_id(iface, ptr_type_id);
1707         writer.write_op_name(block_id, iface.instance_name);
1708
1709         writer.write_op(content.globals, OP_VARIABLE, ptr_type_id, block_id, storage);
1710
1711         if(iface.layout)
1712         {
1713                 for(const Layout::Qualifier &q: iface.layout->qualifiers)
1714                 {
1715                         if(q.name=="set")
1716                                 writer.write_op_decorate(block_id, DECO_DESCRIPTOR_SET, q.value);
1717                         else if(q.name=="binding")
1718                                 writer.write_op_decorate(block_id, DECO_BINDING, q.value);
1719                 }
1720         }
1721 }
1722
1723 void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id)
1724 {
1725         writer.begin_op(content.entry_points, OP_ENTRY_POINT);
1726         switch(stage->type)
1727         {
1728         case Stage::VERTEX: writer.write(0); break;
1729         case Stage::GEOMETRY: writer.write(3); break;
1730         case Stage::FRAGMENT: writer.write(4); break;
1731         default: throw internal_error("unknown stage");
1732         }
1733         writer.write(func_id);
1734         writer.write_string(func.name);
1735
1736         set<Node *> dependencies = DependencyCollector().apply(func);
1737         for(Node *n: dependencies)
1738         {
1739                 if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(n))
1740                 {
1741                         if(!var->interface.empty())
1742                                 writer.write(get_id(*n));
1743                 }
1744                 else if(dynamic_cast<InterfaceBlock *>(n))
1745                         writer.write(get_id(*n));
1746         }
1747
1748         writer.end_op(OP_ENTRY_POINT);
1749
1750         if(stage->type==Stage::FRAGMENT)
1751         {
1752                 SpirVExecutionMode origin = (features.target_api==VULKAN ? EXEC_ORIGIN_UPPER_LEFT : EXEC_ORIGIN_LOWER_LEFT);
1753                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, origin);
1754         }
1755         else if(stage->type==Stage::GEOMETRY)
1756                 use_capability(CAP_GEOMETRY);
1757
1758         for(const InterfaceLayout *i: interface_layouts)
1759         {
1760                 for(const Layout::Qualifier &q: i->layout.qualifiers)
1761                 {
1762                         if(q.name=="point")
1763                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id,
1764                                         (i->interface=="in" ? EXEC_INPUT_POINTS : EXEC_OUTPUT_POINTS));
1765                         else if(q.name=="lines")
1766                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES);
1767                         else if(q.name=="lines_adjacency")
1768                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES_ADJACENCY);
1769                         else if(q.name=="triangles")
1770                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_TRIANGLES);
1771                         else if(q.name=="triangles_adjacency")
1772                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_TRIANGLES_ADJACENCY);
1773                         else if(q.name=="line_strip")
1774                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_LINE_STRIP);
1775                         else if(q.name=="triangle_strip")
1776                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_TRIANGLE_STRIP);
1777                         else if(q.name=="max_vertices")
1778                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_VERTICES, q.value);
1779                 }
1780         }
1781 }
1782
1783 void SpirVGenerator::visit(FunctionDeclaration &func)
1784 {
1785         if(func.source==BUILTIN_SOURCE)
1786                 return;
1787         else if(func.definition!=&func)
1788         {
1789                 if(func.definition)
1790                         allocate_forward_id(*func.definition);
1791                 return;
1792         }
1793
1794         Id return_type_id = get_id(*func.return_type_declaration);
1795         vector<unsigned> param_type_ids;
1796         param_type_ids.reserve(func.parameters.size());
1797         for(const RefPtr<VariableDeclaration> &p: func.parameters)
1798                 param_type_ids.push_back(get_variable_type_id(*p));
1799
1800         string sig_with_return = func.return_type+func.signature;
1801         Id &type_id = function_type_ids[sig_with_return];
1802         if(!type_id)
1803         {
1804                 type_id = next_id++;
1805                 writer.begin_op(content.globals, OP_TYPE_FUNCTION);
1806                 writer.write(type_id);
1807                 writer.write(return_type_id);
1808                 for(unsigned i: param_type_ids)
1809                         writer.write(i);
1810                 writer.end_op(OP_TYPE_FUNCTION);
1811
1812                 writer.write_op_name(type_id, sig_with_return);
1813         }
1814
1815         Id func_id = allocate_id(func, type_id);
1816         writer.write_op_name(func_id, func.name+func.signature);
1817
1818         if(func.name=="main")
1819                 visit_entry_point(func, func_id);
1820
1821         writer.begin_op(content.functions, OP_FUNCTION, 5);
1822         writer.write(return_type_id);
1823         writer.write(func_id);
1824         writer.write(0);  // Function control flags (none)
1825         writer.write(type_id);
1826         writer.end_op(OP_FUNCTION);
1827
1828         for(unsigned i=0; i<func.parameters.size(); ++i)
1829         {
1830                 Id param_id = allocate_id(*func.parameters[i], param_type_ids[i]);
1831                 writer.write_op(content.functions, OP_FUNCTION_PARAMETER, param_type_ids[i], param_id);
1832                 // TODO This is probably incorrect if the parameter is assigned to.
1833                 variable_load_ids[func.parameters[i].get()] = param_id;
1834         }
1835
1836         reachable = true;
1837         writer.begin_function_body(next_id++);
1838         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
1839         func.body.visit(*this);
1840
1841         if(writer.get_current_block())
1842         {
1843                 if(!reachable)
1844                         writer.write_op(content.function_body, OP_UNREACHABLE);
1845                 else
1846                 {
1847                         const BasicTypeDeclaration *basic_return = dynamic_cast<const BasicTypeDeclaration *>(func.return_type_declaration);
1848                         if(basic_return && basic_return->kind==BasicTypeDeclaration::VOID)
1849                                 writer.write_op(content.function_body, OP_RETURN);
1850                         else
1851                                 throw internal_error("missing return in non-void function");
1852                 }
1853         }
1854         writer.end_function_body();
1855         variable_load_ids.clear();
1856 }
1857
1858 void SpirVGenerator::visit(Conditional &cond)
1859 {
1860         cond.condition->visit(*this);
1861
1862         Id true_label_id = next_id++;
1863         Id merge_block_id = next_id++;
1864         Id false_label_id = (cond.else_body.body.empty() ? merge_block_id : next_id++);
1865         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1866         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, true_label_id, false_label_id);
1867
1868         std::map<const VariableDeclaration *, Id> saved_load_ids = variable_load_ids;
1869
1870         writer.write_op_label(true_label_id);
1871         cond.body.visit(*this);
1872         if(writer.get_current_block())
1873                 writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1874
1875         bool reachable_if_true = reachable;
1876
1877         reachable = true;
1878         if(!cond.else_body.body.empty())
1879         {
1880                 swap(saved_load_ids, variable_load_ids);
1881                 writer.write_op_label(false_label_id);
1882                 cond.else_body.visit(*this);
1883                 reachable |= reachable_if_true;
1884         }
1885
1886         writer.write_op_label(merge_block_id);
1887         prune_loads(true_label_id);
1888 }
1889
1890 void SpirVGenerator::visit(Iteration &iter)
1891 {
1892         if(iter.init_statement)
1893                 iter.init_statement->visit(*this);
1894
1895         for(Node *n: AssignmentCollector().apply(iter))
1896                 if(VariableDeclaration *var = dynamic_cast<VariableDeclaration *>(n))
1897                         variable_load_ids.erase(var);
1898
1899         Id header_id = next_id++;
1900         Id continue_id = next_id++;
1901         Id merge_block_id = next_id++;
1902
1903         SetForScope<Id> set_merge(loop_merge_block_id, merge_block_id);
1904         SetForScope<Id> set_continue(loop_continue_target_id, continue_id);
1905
1906         writer.write_op_label(header_id);
1907         writer.write_op(content.function_body, OP_LOOP_MERGE, merge_block_id, continue_id, 0);  // Loop control (none)
1908
1909         Id body_id = next_id++;
1910         if(iter.condition)
1911         {
1912                 writer.write_op_label(next_id++);
1913                 iter.condition->visit(*this);
1914                 writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, body_id, merge_block_id);
1915         }
1916
1917         writer.write_op_label(body_id);
1918         iter.body.visit(*this);
1919
1920         writer.write_op_label(continue_id);
1921         if(iter.loop_expression)
1922                 iter.loop_expression->visit(*this);
1923         writer.write_op(content.function_body, OP_BRANCH, header_id);
1924
1925         writer.write_op_label(merge_block_id);
1926         prune_loads(header_id);
1927         reachable = true;
1928 }
1929
1930 void SpirVGenerator::visit(Return &ret)
1931 {
1932         if(ret.expression)
1933         {
1934                 ret.expression->visit(*this);
1935                 writer.write_op(content.function_body, OP_RETURN_VALUE, r_expression_result_id);
1936         }
1937         else
1938                 writer.write_op(content.function_body, OP_RETURN);
1939         reachable = false;
1940 }
1941
1942 void SpirVGenerator::visit(Jump &jump)
1943 {
1944         if(jump.keyword=="discard")
1945                 writer.write_op(content.function_body, OP_KILL);
1946         else if(jump.keyword=="break")
1947                 writer.write_op(content.function_body, OP_BRANCH, loop_merge_block_id);
1948         else if(jump.keyword=="continue")
1949                 writer.write_op(content.function_body, OP_BRANCH, loop_continue_target_id);
1950         else
1951                 throw internal_error("unknown jump");
1952         reachable = false;
1953 }
1954
1955
1956 SpirVGenerator::TypeKey::TypeKey(BasicTypeDeclaration::Kind kind, bool sign):
1957         type_id(0)
1958 {
1959         switch(kind)
1960         {
1961         case BasicTypeDeclaration::VOID: detail = 'v'; break;
1962         case BasicTypeDeclaration::BOOL: detail = 'b'; break;
1963         case BasicTypeDeclaration::INT: detail = (sign ? 'i' : 'u'); break;
1964         case BasicTypeDeclaration::FLOAT: detail = 'f'; break;
1965         default: throw invalid_argument("TypeKey::TypeKey");
1966         }
1967 }
1968
1969 bool SpirVGenerator::TypeKey::operator<(const TypeKey &other) const
1970 {
1971         if(type_id!=other.type_id)
1972                 return type_id<other.type_id;
1973         return detail<other.detail;
1974 }
1975
1976
1977 bool SpirVGenerator::ConstantKey::operator<(const ConstantKey &other) const
1978 {
1979         if(type_id!=other.type_id)
1980                 return type_id<other.type_id;
1981         return int_value<other.int_value;
1982 }
1983
1984 } // namespace SL
1985 } // namespace GL
1986 } // namespace Msp