]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/spirv.cpp
Preparatory refactoring of the texelFetch implementation
[libs/gl.git] / source / glsl / spirv.cpp
1 #include <msp/core/algorithm.h>
2 #include <msp/core/maputils.h>
3 #include <msp/core/raii.h>
4 #include "reflect.h"
5 #include "spirv.h"
6
7 using namespace std;
8
9 namespace Msp {
10 namespace GL {
11 namespace SL {
12
13 const SpirVGenerator::BuiltinFunctionInfo SpirVGenerator::builtin_functions[] =
14 {
15         { "radians", "f", "GLSL.std.450", GLSL450_RADIANS, { 1 }, 0, 0 },
16         { "degrees", "f", "GLSL.std.450", GLSL450_DEGREES, { 1 }, 0, 0 },
17         { "sin", "f", "GLSL.std.450", GLSL450_SIN, { 1 }, 0, 0 },
18         { "cos", "f", "GLSL.std.450", GLSL450_COS, { 1 }, 0, 0 },
19         { "tan", "f", "GLSL.std.450", GLSL450_TAN, { 1 }, 0, 0 },
20         { "asin", "f", "GLSL.std.450", GLSL450_ASIN, { 1 }, 0, 0 },
21         { "acos", "f", "GLSL.std.450", GLSL450_ACOS, { 1 }, 0, 0 },
22         { "atan", "f", "GLSL.std.450", GLSL450_ATAN, { 1 }, 0, 0 },
23         { "atan", "ff", "GLSL.std.450", GLSL450_ATAN2, { 1, 2 }, 0, 0 },
24         { "sinh", "f", "GLSL.std.450", GLSL450_SINH, { 1 }, 0, 0 },
25         { "cosh", "f", "GLSL.std.450", GLSL450_COSH, { 1 }, 0, 0 },
26         { "tanh", "f", "GLSL.std.450", GLSL450_TANH, { 1 }, 0, 0 },
27         { "asinh", "f", "GLSL.std.450", GLSL450_ASINH, { 1 }, 0, 0 },
28         { "acosh", "f", "GLSL.std.450", GLSL450_ACOSH, { 1 }, 0, 0 },
29         { "atanh", "f", "GLSL.std.450", GLSL450_ATANH, { 1 }, 0, 0 },
30         { "pow", "ff", "GLSL.std.450", GLSL450_POW, { 1, 2 }, 0, 0 },
31         { "exp", "f", "GLSL.std.450", GLSL450_EXP, { 1 }, 0, 0 },
32         { "log", "f", "GLSL.std.450", GLSL450_LOG, { 1 }, 0, 0 },
33         { "exp2", "f", "GLSL.std.450", GLSL450_EXP2, { 1 }, 0, 0 },
34         { "log2", "f", "GLSL.std.450", GLSL450_LOG2, { 1 }, 0, 0 },
35         { "sqrt", "f", "GLSL.std.450", GLSL450_SQRT, { 1 }, 0, 0 },
36         { "inversesqrt", "f", "GLSL.std.450", GLSL450_INVERSE_SQRT, { 1 }, 0, 0 },
37         { "abs", "f", "GLSL.std.450", GLSL450_F_ABS, { 1 }, 0, 0 },
38         { "abs", "i", "GLSL.std.450", GLSL450_S_ABS, { 1 }, 0, 0 },
39         { "sign", "f", "GLSL.std.450", GLSL450_F_SIGN, { 1 }, 0, 0 },
40         { "sign", "i", "GLSL.std.450", GLSL450_S_SIGN, { 1 }, 0, 0 },
41         { "floor", "f", "GLSL.std.450", GLSL450_FLOOR, { 1 }, 0, 0 },
42         { "trunc", "f", "GLSL.std.450", GLSL450_TRUNC, { 1 }, 0, 0 },
43         { "round", "f", "GLSL.std.450", GLSL450_ROUND, { 1 }, 0, 0 },
44         { "roundEven", "f", "GLSL.std.450", GLSL450_ROUND_EVEN, { 1 }, 0, 0 },
45         { "ceil", "f", "GLSL.std.450", GLSL450_CEIL, { 1 }, 0, 0 },
46         { "fract", "f", "GLSL.std.450", GLSL450_FRACT, { 1 }, 0, 0 },
47         { "mod", "f", "", OP_F_MOD, { 1, 2 }, 0, 0 },
48         { "min", "ff", "GLSL.std.450", GLSL450_F_MIN, { 1, 2 }, 0, 0 },
49         { "min", "ii", "GLSL.std.450", GLSL450_S_MIN, { 1, 2 }, 0, 0 },
50         { "min", "uu", "GLSL.std.450", GLSL450_U_MIN, { 1, 2 }, 0, 0 },
51         { "max", "ff", "GLSL.std.450", GLSL450_F_MAX, { 1, 2 }, 0, 0 },
52         { "max", "ii", "GLSL.std.450", GLSL450_S_MAX, { 1, 2 }, 0, 0 },
53         { "max", "uu", "GLSL.std.450", GLSL450_U_MAX, { 1, 2 }, 0, 0 },
54         { "clamp", "fff", "GLSL.std.450", GLSL450_F_CLAMP, { 1, 2, 3 }, 0, 0 },
55         { "clamp", "iii", "GLSL.std.450", GLSL450_S_CLAMP, { 1, 2, 3 }, 0, 0 },
56         { "clamp", "uuu", "GLSL.std.450", GLSL450_U_CLAMP, { 1, 2, 3 }, 0, 0 },
57         { "mix", "fff", "GLSL.std.450", GLSL450_F_MIX, { 1, 2, 3 }, 0, 0 },
58         { "mix", "ffb", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
59         { "mix", "iib", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
60         { "mix", "uub", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
61         { "step", "ff", "GLSL.std.450", GLSL450_F_STEP, { 1, 2 }, 0, 0 },
62         { "smoothstep", "fff", "GLSL.std.450", GLSL450_F_SMOOTH_STEP, { 1, 2, 3 }, 0, 0 },
63         { "isnan", "f", "", OP_IS_NAN, { 1 }, 0, 0 },
64         { "isinf", "f", "", OP_IS_INF, { 1 }, 0, 0 },
65         { "fma", "fff", "GLSL.std.450", GLSL450_F_FMA, { 1, 2, 3 }, 0, 0 },
66         { "length", "f", "GLSL.std.450", GLSL450_LENGTH, { 1 }, 0, 0 },
67         { "distance", "ff", "GLSL.std.450", GLSL450_DISTANCE, { 1, 2 }, 0, 0 },
68         { "dot", "ff", "", OP_DOT, { 1, 2 }, 0, 0 },
69         { "cross", "ff", "GLSL.std.450", GLSL450_CROSS, { 1, 2 }, 0, 0 },
70         { "normalize", "f", "GLSL.std.450", GLSL450_NORMALIZE, { 1 }, 0, 0 },
71         { "faceforward", "fff", "GLSL.std.450", GLSL450_FACE_FORWARD, { 1, 2, 3 }, 0, 0 },
72         { "reflect", "ff", "GLSL.std.450", GLSL450_REFLECT, { 1, 2 }, 0, 0 },
73         { "refract", "fff", "GLSL.std.450", GLSL450_REFRACT, { 1, 2, 3 }, 0, 0 },
74         { "matrixCompMult", "ff", "", 0, { 0 }, 0, &SpirVGenerator::visit_builtin_matrix_comp_mult },
75         { "outerProduct", "ff", "", OP_OUTER_PRODUCT, { 1, 2 }, 0, 0 },
76         { "transpose", "f", "", OP_TRANSPOSE, { 1 }, 0, 0 },
77         { "determinant", "f", "GLSL.std.450", GLSL450_DETERMINANT, { 1 }, 0, 0 },
78         { "inverse", "f", "GLSL.std.450", GLSL450_MATRIX_INVERSE, { 1 }, 0, 0 },
79         { "lessThan", "ff", "", OP_F_ORD_LESS_THAN, { 1, 2 }, 0, 0 },
80         { "lessThan", "ii", "", OP_S_LESS_THAN, { 1, 2 }, 0, 0 },
81         { "lessThan", "uu", "", OP_U_LESS_THAN, { 1, 2 }, 0, 0 },
82         { "lessThanEqual", "ff", "", OP_F_ORD_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
83         { "lessThanEqual", "ii", "", OP_S_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
84         { "lessThanEqual", "uu", "", OP_U_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
85         { "greaterThan", "ff", "", OP_F_ORD_GREATER_THAN, { 1, 2 }, 0, 0 },
86         { "greaterThan", "ii", "", OP_S_GREATER_THAN, { 1, 2 }, 0, 0 },
87         { "greaterThan", "uu", "", OP_U_GREATER_THAN, { 1, 2 }, 0, 0 },
88         { "greaterThanEqual", "ff", "", OP_F_ORD_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
89         { "greaterThanEqual", "ii", "", OP_S_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
90         { "greaterThanEqual", "uu", "", OP_U_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
91         { "equal", "ff", "", OP_F_ORD_EQUAL, { 1, 2 }, 0, 0 },
92         { "equal", "ii", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
93         { "equal", "uu", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
94         { "notEqual", "ff", "", OP_F_ORD_NOT_EQUAL, { 1, 2 }, 0, 0 },
95         { "notEqual", "ii", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
96         { "notEqual", "uu", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
97         { "any", "b", "", OP_ANY, { 1 }, 0, 0 },
98         { "all", "b", "", OP_ALL, { 1 }, 0, 0 },
99         { "not", "b", "", OP_LOGICAL_NOT, { 1 }, 0, 0 },
100         { "bitfieldExtract", "iii", "", OP_BIT_FIELD_S_EXTRACT, { 1, 2, 3 }, 0, 0 },
101         { "bitfieldExtract", "uii", "", OP_BIT_FIELD_U_EXTRACT, { 1, 2, 3 }, 0, 0 },
102         { "bitfieldInsert", "iiii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
103         { "bitfieldInsert", "uuii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
104         { "bitfieldReverse", "i", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
105         { "bitfieldReverse", "u", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
106         { "bitCount", "i", "", OP_BIT_COUNT, { 1 }, 0, 0 },
107         { "findLSB", "i", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
108         { "findLSB", "u", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
109         { "findMSB", "i", "GLSL.std.450", GLSL450_FIND_S_MSB, { 1 }, 0, 0 },
110         { "findMSB", "u", "GLSL.std.450", GLSL450_FIND_U_MSB, { 1 }, 0, 0 },
111         { "textureSize", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
112         { "textureQueryLod", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
113         { "textureQueryLevels", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
114         { "textureSamples", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
115         { "texture", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
116         { "textureLod", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
117         { "texelFetch", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texel_fetch },
118         { "EmitVertex", "", "", OP_EMIT_VERTEX, { }, 0, 0 },
119         { "EndPrimitive", "", "", OP_END_PRIMITIVE, { }, 0, 0 },
120         { "dFdx", "f", "", OP_DP_DX, { 1 }, 0, 0 },
121         { "dFdy", "f", "", OP_DP_DY, { 1 }, 0, 0 },
122         { "dFdxFine", "f", "", OP_DP_DX_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
123         { "dFdyFine", "f", "", OP_DP_DY_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
124         { "dFdxCoarse", "f", "", OP_DP_DX_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
125         { "dFdyCoarse", "f", "", OP_DP_DY_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
126         { "fwidth", "f", "", OP_FWIDTH, { 1 }, 0, 0 },
127         { "fwidthFine", "f", "", OP_FWIDTH_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
128         { "fwidthCoarse", "f", "", OP_FWIDTH_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
129         { "interpolateAtCentroid", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
130         { "interpolateAtSample", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
131         { "interpolateAtOffset", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
132         { "", "", "", 0, { }, 0, 0 }
133 };
134
135 SpirVGenerator::SpirVGenerator():
136         writer(content)
137 { }
138
139 void SpirVGenerator::apply(Module &module, const Features &f)
140 {
141         features = f;
142         use_capability(CAP_SHADER);
143
144         for(Stage &s: module.stages)
145         {
146                 stage = &s;
147                 interface_layouts.clear();
148                 s.content.visit(*this);
149         }
150
151         writer.finalize(SPIRV_GENERATOR_MSP, next_id);
152 }
153
154 SpirVGenerator::StorageClass SpirVGenerator::get_interface_storage(const string &iface, bool block)
155 {
156         if(iface=="in")
157                 return STORAGE_INPUT;
158         else if(iface=="out")
159                 return STORAGE_OUTPUT;
160         else if(iface=="uniform")
161                 return (block ? STORAGE_UNIFORM : STORAGE_UNIFORM_CONSTANT);
162         else if(iface.empty())
163                 return STORAGE_PRIVATE;
164         else
165                 throw invalid_argument("SpirVGenerator::get_interface_storage");
166 }
167
168 SpirVGenerator::BuiltinSemantic SpirVGenerator::get_builtin_semantic(const string &name)
169 {
170         if(name=="gl_Position")
171                 return BUILTIN_POSITION;
172         else if(name=="gl_PointSize")
173                 return BUILTIN_POINT_SIZE;
174         else if(name=="gl_ClipDistance")
175                 return BUILTIN_CLIP_DISTANCE;
176         else if(name=="gl_VertexID")
177                 return BUILTIN_VERTEX_ID;
178         else if(name=="gl_InstanceID")
179                 return BUILTIN_INSTANCE_ID;
180         else if(name=="gl_PrimitiveID" || name=="gl_PrimitiveIDIn")
181                 return BUILTIN_PRIMITIVE_ID;
182         else if(name=="gl_InvocationID")
183                 return BUILTIN_INVOCATION_ID;
184         else if(name=="gl_Layer")
185                 return BUILTIN_LAYER;
186         else if(name=="gl_FragCoord")
187                 return BUILTIN_FRAG_COORD;
188         else if(name=="gl_PointCoord")
189                 return BUILTIN_POINT_COORD;
190         else if(name=="gl_FrontFacing")
191                 return BUILTIN_FRONT_FACING;
192         else if(name=="gl_SampleId")
193                 return BUILTIN_SAMPLE_ID;
194         else if(name=="gl_SamplePosition")
195                 return BUILTIN_SAMPLE_POSITION;
196         else if(name=="gl_FragDepth")
197                 return BUILTIN_FRAG_DEPTH;
198         else
199                 throw invalid_argument("SpirVGenerator::get_builtin_semantic");
200 }
201
202 void SpirVGenerator::use_capability(Capability cap)
203 {
204         if(used_capabilities.count(cap))
205                 return;
206
207         used_capabilities.insert(cap);
208         writer.write_op(content.capabilities, OP_CAPABILITY, cap);
209 }
210
211 SpirVGenerator::Id SpirVGenerator::import_extension(const string &name)
212 {
213         Id &ext_id = imported_extension_ids[name];
214         if(!ext_id)
215         {
216                 ext_id = next_id++;
217                 writer.begin_op(content.extensions, OP_EXT_INST_IMPORT);
218                 writer.write(ext_id);
219                 writer.write_string(name);
220                 writer.end_op(OP_EXT_INST_IMPORT);
221         }
222         return ext_id;
223 }
224
225 SpirVGenerator::Id SpirVGenerator::get_id(Node &node) const
226 {
227         return get_item(declared_ids, &node).id;
228 }
229
230 SpirVGenerator::Id SpirVGenerator::allocate_id(Node &node, Id type_id)
231 {
232         auto i = declared_ids.find(&node);
233         if(i!=declared_ids.end())
234         {
235                 if(i->second.type_id)
236                         throw key_error(&node);
237                 i->second.type_id = type_id;
238                 return i->second.id;
239         }
240
241         Id id = next_id++;
242         declared_ids.insert(make_pair(&node, Declaration(id, type_id)));
243         return id;
244 }
245
246 SpirVGenerator::Id SpirVGenerator::allocate_forward_id(Node &node)
247 {
248         auto i = declared_ids.find(&node);
249         if(i!=declared_ids.end())
250                 return i->second.id;
251
252         Id id = next_id++;
253         declared_ids.insert(make_pair(&node, Declaration(id, 0)));
254         return id;
255 }
256
257 SpirVGenerator::Id SpirVGenerator::write_constant(Id type_id, Word value, bool spec)
258 {
259         Id const_id = next_id++;
260         if(is_scalar_type(type_id, BasicTypeDeclaration::BOOL))
261         {
262                 Opcode opcode = (value ? (spec ? OP_SPEC_CONSTANT_TRUE : OP_CONSTANT_TRUE) :
263                         (spec ? OP_SPEC_CONSTANT_FALSE : OP_CONSTANT_FALSE));
264                 writer.write_op(content.globals, opcode, type_id, const_id);
265         }
266         else
267         {
268                 Opcode opcode = (spec ? OP_SPEC_CONSTANT : OP_CONSTANT);
269                 writer.write_op(content.globals, opcode, type_id, const_id, value);
270         }
271         return const_id;
272 }
273
274 SpirVGenerator::ConstantKey SpirVGenerator::get_constant_key(Id type_id, const Variant &value)
275 {
276         if(value.check_type<bool>())
277                 return ConstantKey(type_id, value.value<bool>());
278         else if(value.check_type<int>())
279                 return ConstantKey(type_id, value.value<int>());
280         else if(value.check_type<unsigned>())
281                 return ConstantKey(type_id, value.value<unsigned>());
282         else if(value.check_type<float>())
283                 return ConstantKey(type_id, value.value<float>());
284         else
285                 throw invalid_argument("SpirVGenerator::get_constant_key");
286 }
287
288 SpirVGenerator::Id SpirVGenerator::get_constant_id(Id type_id, const Variant &value)
289 {
290         ConstantKey key = get_constant_key(type_id, value);
291         Id &const_id = constant_ids[key];
292         if(!const_id)
293                 const_id = write_constant(type_id, key.int_value, false);
294         return const_id;
295 }
296
297 SpirVGenerator::Id SpirVGenerator::get_vector_constant_id(Id type_id, unsigned size, Id scalar_id)
298 {
299         Id &const_id = constant_ids[get_constant_key(type_id, static_cast<int>(scalar_id))];
300         if(!const_id)
301         {
302                 const_id = next_id++;
303                 writer.begin_op(content.globals, OP_CONSTANT_COMPOSITE, 3+size);
304                 writer.write(type_id);
305                 writer.write(const_id);
306                 for(unsigned i=0; i<size; ++i)
307                         writer.write(scalar_id);
308                 writer.end_op(OP_CONSTANT_COMPOSITE);
309         }
310         return const_id;
311 }
312
313 SpirVGenerator::Id SpirVGenerator::get_standard_type_id(BasicTypeDeclaration::Kind kind, unsigned size, bool sign)
314 {
315         Id base_id = (size>1 ? get_standard_type_id(kind, 1, sign) : 0);
316         Id &type_id = standard_type_ids[base_id ? TypeKey(base_id, size) : TypeKey(kind, sign)];
317         if(!type_id)
318         {
319                 type_id = next_id++;
320                 if(size>1)
321                         writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, base_id, size);
322                 else if(kind==BasicTypeDeclaration::VOID)
323                         writer.write_op(content.globals, OP_TYPE_VOID, type_id);
324                 else if(kind==BasicTypeDeclaration::BOOL)
325                         writer.write_op(content.globals, OP_TYPE_BOOL, type_id);
326                 else if(kind==BasicTypeDeclaration::INT)
327                         writer.write_op(content.globals, OP_TYPE_INT, type_id, 32, sign);
328                 else if(kind==BasicTypeDeclaration::FLOAT)
329                         writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, 32);
330                 else
331                         throw invalid_argument("SpirVGenerator::get_standard_type_id");
332         }
333         return type_id;
334 }
335
336 bool SpirVGenerator::is_scalar_type(Id type_id, BasicTypeDeclaration::Kind kind) const
337 {
338         auto i = standard_type_ids.find(TypeKey(kind, true));
339         return (i!=standard_type_ids.end() && i->second==type_id);
340 }
341
342 SpirVGenerator::Id SpirVGenerator::get_array_type_id(TypeDeclaration &base_type, Id size_id, bool extended_align)
343 {
344         Id base_type_id = get_id(base_type);
345         Id &array_type_id = array_type_ids[TypeKey(base_type_id, extended_align*0x400000 | size_id)];
346         if(!array_type_id)
347         {
348                 array_type_id = next_id++;
349                 if(size_id)
350                         writer.write_op(content.globals, OP_TYPE_ARRAY, array_type_id, base_type_id, size_id);
351                 else
352                         writer.write_op(content.globals, OP_TYPE_RUNTIME_ARRAY, array_type_id, base_type_id);
353
354                 unsigned stride = MemoryRequirementsCalculator().apply(base_type).stride;
355                 if(extended_align)
356                         stride = (stride+15)&~15U;
357                 writer.write_op_decorate(array_type_id, DECO_ARRAY_STRIDE, stride);
358         }
359
360         return array_type_id;
361 }
362
363 SpirVGenerator::Id SpirVGenerator::get_pointer_type_id(Id type_id, StorageClass storage)
364 {
365         Id &ptr_type_id = pointer_type_ids[TypeKey(type_id, storage)];
366         if(!ptr_type_id)
367         {
368                 ptr_type_id = next_id++;
369                 writer.write_op(content.globals, OP_TYPE_POINTER, ptr_type_id, storage, type_id);
370         }
371         return ptr_type_id;
372 }
373
374 SpirVGenerator::Id SpirVGenerator::get_variable_type_id(const VariableDeclaration &var)
375 {
376         if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var.type_declaration))
377                 if(basic->kind==BasicTypeDeclaration::ARRAY)
378                 {
379                         if(!var.array_size)
380                                 throw logic_error("array without size");
381
382                         SetFlag set_const(constant_expression);
383                         r_expression_result_id = 0;
384                         var.array_size->visit(*this);
385                         return get_array_type_id(*basic->base_type, r_expression_result_id, basic->extended_alignment);
386                 }
387
388         return get_id(*var.type_declaration);
389 }
390
391 SpirVGenerator::Id SpirVGenerator::get_load_id(VariableDeclaration &var)
392 {
393         Id &load_result_id = variable_load_ids[&var];
394         if(!load_result_id)
395         {
396                 load_result_id = next_id++;
397                 writer.write_op(content.function_body, OP_LOAD, get_variable_type_id(var), load_result_id, get_id(var));
398         }
399         return load_result_id;
400 }
401
402 void SpirVGenerator::prune_loads(Id min_id)
403 {
404         for(auto i=variable_load_ids.begin(); i!=variable_load_ids.end(); )
405         {
406                 if(i->second>=min_id)
407                         variable_load_ids.erase(i++);
408                 else
409                         ++i;
410         }
411 }
412
413 SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, unsigned n_args)
414 {
415         bool has_result = (opcode==OP_FUNCTION_CALL || !is_scalar_type(type_id, BasicTypeDeclaration::VOID));
416         if(!constant_expression)
417         {
418                 if(!current_function)
419                         throw internal_error("non-constant expression outside a function");
420
421                 writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0));
422         }
423         else if(opcode==OP_COMPOSITE_CONSTRUCT)
424                 writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE),
425                         (n_args ? 1+has_result*2+n_args : 0));
426         else if(!spec_constant)
427                 throw internal_error("invalid non-specialization constant expression");
428         else
429                 writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0));
430
431         Id result_id = next_id++;
432         if(has_result)
433         {
434                 writer.write(type_id);
435                 writer.write(result_id);
436         }
437         if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT)
438                 writer.write(opcode);
439
440         return result_id;
441 }
442
443 void SpirVGenerator::end_expression(Opcode opcode)
444 {
445         if(constant_expression)
446                 opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
447         writer.end_op(opcode);
448 }
449
450 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id arg_id)
451 {
452         Id result_id = begin_expression(opcode, type_id, 1);
453         writer.write(arg_id);
454         end_expression(opcode);
455         return result_id;
456 }
457
458 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id left_id, Id right_id)
459 {
460         Id result_id = begin_expression(opcode, type_id, 2);
461         writer.write(left_id);
462         writer.write(right_id);
463         end_expression(opcode);
464         return result_id;
465 }
466
467 void SpirVGenerator::write_deconstruct(Id elem_type_id, Id composite_id, Id *elem_ids, unsigned n_elems)
468 {
469         for(unsigned i=0; i<n_elems; ++i)
470         {
471                 elem_ids[i] = begin_expression(OP_COMPOSITE_EXTRACT, elem_type_id, 2);
472                 writer.write(composite_id);
473                 writer.write(i);
474                 end_expression(OP_COMPOSITE_EXTRACT);
475         }
476 }
477
478 SpirVGenerator::Id SpirVGenerator::write_construct(Id type_id, const Id *elem_ids, unsigned n_elems)
479 {
480         Id result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, type_id, n_elems);
481         for(unsigned i=0; i<n_elems; ++i)
482                 writer.write(elem_ids[i]);
483         end_expression(OP_COMPOSITE_CONSTRUCT);
484
485         return result_id;
486 }
487
488 void SpirVGenerator::visit(Block &block)
489 {
490         for(const RefPtr<Statement> &s: block.body)
491                 s->visit(*this);
492 }
493
494 void SpirVGenerator::visit(Literal &literal)
495 {
496         Id type_id = get_id(*literal.type);
497         if(spec_constant)
498                 r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true);
499         else
500                 r_expression_result_id = get_constant_id(type_id, literal.value);
501         r_constant_result = true;
502 }
503
504 void SpirVGenerator::visit(VariableReference &var)
505 {
506         if(constant_expression || var.declaration->constant)
507         {
508                 if(!var.declaration->constant)
509                         throw internal_error("reference to non-constant variable in constant context");
510
511                 r_expression_result_id = get_id(*var.declaration);
512                 r_constant_result = true;
513                 return;
514         }
515         else if(!current_function)
516                 throw internal_error("non-constant context outside a function");
517
518         r_constant_result = false;
519         if(composite_access)
520         {
521                 r_expression_result_id = 0;
522                 if(!assignment_source_id)
523                 {
524                         auto i = variable_load_ids.find(var.declaration);
525                         if(i!=variable_load_ids.end())
526                                 r_expression_result_id = i->second;
527                 }
528                 if(!r_expression_result_id)
529                         r_composite_base = var.declaration;
530         }
531         else if(assignment_source_id)
532         {
533                 writer.write_op(content.function_body, OP_STORE, get_id(*var.declaration), assignment_source_id);
534                 variable_load_ids[var.declaration] = assignment_source_id;
535                 r_expression_result_id = assignment_source_id;
536         }
537         else
538                 r_expression_result_id = get_load_id(*var.declaration);
539 }
540
541 void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
542 {
543         Opcode opcode;
544         Id result_type_id = get_id(result_type);
545         Id access_type_id = result_type_id;
546         if(r_composite_base)
547         {
548                 if(constant_expression)
549                         throw internal_error("composite access through pointer in constant context");
550
551                 Id int32_type_id = get_standard_type_id(BasicTypeDeclaration::INT, 1);
552                 for(unsigned &i: r_composite_chain)
553                         i = (i<0x400000 ? get_constant_id(int32_type_id, static_cast<int>(i)) : i&0x3FFFFF);
554
555                 /* Find the storage class of the base and obtain appropriate pointer type
556                 for the result. */
557                 const Declaration &base_decl = get_item(declared_ids, r_composite_base);
558                 auto i = pointer_type_ids.begin();
559                 for(; (i!=pointer_type_ids.end() && i->second!=base_decl.type_id); ++i) ;
560                 if(i==pointer_type_ids.end())
561                         throw internal_error("could not find storage class");
562                 access_type_id = get_pointer_type_id(result_type_id, static_cast<StorageClass>(i->first.detail));
563
564                 opcode = OP_ACCESS_CHAIN;
565         }
566         else if(assignment_source_id)
567                 throw internal_error("assignment to temporary composite");
568         else
569         {
570                 for(unsigned &i: r_composite_chain)
571                         for(auto j=constant_ids.begin(); (i>=0x400000 && j!=constant_ids.end()); ++j)
572                                 if(j->second==(i&0x3FFFFF))
573                                         i = j->first.int_value;
574
575                 opcode = OP_COMPOSITE_EXTRACT;
576         }
577
578         Id access_id = begin_expression(opcode, access_type_id, 1+r_composite_chain.size());
579         writer.write(r_composite_base_id);
580         for(unsigned i: r_composite_chain)
581                 writer.write(i);
582         end_expression(opcode);
583
584         r_constant_result = false;
585         if(r_composite_base)
586         {
587                 if(assignment_source_id)
588                 {
589                         writer.write_op(content.function_body, OP_STORE, access_id, assignment_source_id);
590                         r_expression_result_id = assignment_source_id;
591                 }
592                 else
593                         r_expression_result_id = write_expression(OP_LOAD, result_type_id, access_id);
594         }
595         else
596                 r_expression_result_id = access_id;
597 }
598
599 void SpirVGenerator::visit_composite(Expression &base_expr, unsigned index, TypeDeclaration &type)
600 {
601         if(!composite_access)
602         {
603                 r_composite_base = 0;
604                 r_composite_base_id = 0;
605                 r_composite_chain.clear();
606         }
607
608         {
609                 SetFlag set_composite(composite_access);
610                 base_expr.visit(*this);
611         }
612
613         if(!r_composite_base_id)
614                 r_composite_base_id = (r_composite_base ? get_id(*r_composite_base) : r_expression_result_id);
615
616         r_composite_chain.push_back(index);
617         if(!composite_access)
618                 generate_composite_access(type);
619         else
620                 r_expression_result_id = 0;
621 }
622
623 void SpirVGenerator::visit_isolated(Expression &expr)
624 {
625         SetForScope<Id> clear_assign(assignment_source_id, 0);
626         SetFlag clear_composite(composite_access, false);
627         SetForScope<Node *> clear_base(r_composite_base, 0);
628         SetForScope<Id> clear_base_id(r_composite_base_id, 0);
629         vector<unsigned> saved_chain;
630         swap(saved_chain, r_composite_chain);
631         expr.visit(*this);
632         swap(saved_chain, r_composite_chain);
633 }
634
635 void SpirVGenerator::visit(MemberAccess &memacc)
636 {
637         visit_composite(*memacc.left, memacc.index, *memacc.type);
638 }
639
640 void SpirVGenerator::visit(Swizzle &swizzle)
641 {
642         if(swizzle.count==1)
643                 visit_composite(*swizzle.left, swizzle.components[0], *swizzle.type);
644         else if(assignment_source_id)
645         {
646                 const BasicTypeDeclaration &basic = dynamic_cast<const BasicTypeDeclaration &>(*swizzle.left->type);
647
648                 unsigned mask = 0;
649                 for(unsigned i=0; i<swizzle.count; ++i)
650                         mask |= 1<<swizzle.components[i];
651
652                 visit_isolated(*swizzle.left);
653
654                 Id combined_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.left->type), 2+basic.size);
655                 writer.write(r_expression_result_id);
656                 writer.write(assignment_source_id);
657                 for(unsigned i=0; i<basic.size; ++i)
658                         writer.write(i+((mask>>i)&1)*basic.size);
659                 end_expression(OP_VECTOR_SHUFFLE);
660
661                 SetForScope<Id> set_assign(assignment_source_id, combined_id);
662                 swizzle.left->visit(*this);
663
664                 r_expression_result_id = combined_id;
665         }
666         else
667         {
668                 swizzle.left->visit(*this);
669                 Id left_id = r_expression_result_id;
670
671                 r_expression_result_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.type), 2+swizzle.count);
672                 writer.write(left_id);
673                 writer.write(left_id);
674                 for(unsigned i=0; i<swizzle.count; ++i)
675                         writer.write(swizzle.components[i]);
676                 end_expression(OP_VECTOR_SHUFFLE);
677         }
678         r_constant_result = false;
679 }
680
681 void SpirVGenerator::visit(UnaryExpression &unary)
682 {
683         if(composite_access)
684                 return visit_isolated(unary);
685
686         unary.expression->visit(*this);
687
688         char oper = unary.oper->token[0];
689         char oper2 = unary.oper->token[1];
690         if(oper=='+' && !oper2)
691                 return;
692
693         BasicTypeDeclaration &basic = dynamic_cast<BasicTypeDeclaration &>(*unary.expression->type);
694         BasicTypeDeclaration &elem = *get_element_type(basic);
695
696         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
697                 /* SPIR-V allows constant operations on floating-point values only for
698                 OpenGL kernels. */
699                 throw internal_error("invalid operands for constant unary expression");
700
701         Id result_type_id = get_id(*unary.type);
702         Opcode opcode = OP_NOP;
703
704         r_constant_result = false;
705         if(oper=='!')
706                 opcode = OP_LOGICAL_NOT;
707         else if(oper=='~')
708                 opcode = OP_NOT;
709         else if(oper=='-' && !oper2)
710         {
711                 opcode = (elem.kind==BasicTypeDeclaration::INT ? OP_S_NEGATE : OP_F_NEGATE);
712
713                 if(basic.kind==BasicTypeDeclaration::MATRIX)
714                 {
715                         Id column_type_id = get_id(*basic.base_type);
716                         unsigned n_columns = basic.size&0xFFFF;
717                         Id column_ids[4];
718                         write_deconstruct(column_type_id, r_expression_result_id, column_ids, n_columns);
719                         for(unsigned i=0; i<n_columns; ++i)
720                                 column_ids[i] = write_expression(opcode, column_type_id, column_ids[i]);
721                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
722                         return;
723                 }
724         }
725         else if((oper=='+' || oper=='-') && oper2==oper)
726         {
727                 if(constant_expression)
728                         throw internal_error("increment or decrement in constant expression");
729
730                 Id one_id = 0;
731                 if(elem.kind==BasicTypeDeclaration::INT)
732                 {
733                         opcode = (oper=='+' ? OP_I_ADD : OP_I_SUB);
734                         one_id = get_constant_id(get_id(elem), 1);
735                 }
736                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
737                 {
738                         opcode = (oper=='+' ? OP_F_ADD : OP_F_SUB);
739                         one_id = get_constant_id(get_id(elem), 1.0f);
740                 }
741                 else
742                         throw internal_error("invalid increment/decrement");
743
744                 if(basic.kind==BasicTypeDeclaration::VECTOR)
745                         one_id = get_vector_constant_id(result_type_id, basic.size, one_id);
746
747                 Id post_value_id = write_expression(opcode, result_type_id, r_expression_result_id, one_id);
748
749                 SetForScope<Id> set_assign(assignment_source_id, post_value_id);
750                 unary.expression->visit(*this);
751
752                 r_expression_result_id = (unary.oper->type==Operator::POSTFIX ? r_expression_result_id : post_value_id);
753                 return;
754         }
755
756         if(opcode==OP_NOP)
757                 throw internal_error("unknown unary operator");
758
759         r_expression_result_id = write_expression(opcode, result_type_id, r_expression_result_id);
760 }
761
762 void SpirVGenerator::visit(BinaryExpression &binary)
763 {
764         char oper = binary.oper->token[0];
765         if(oper=='[')
766         {
767                 visit_isolated(*binary.right);
768                 return visit_composite(*binary.left, 0x400000|r_expression_result_id, *binary.type);
769         }
770         else if(composite_access)
771                 return visit_isolated(binary);
772
773         if(assignment_source_id)
774                 throw internal_error("invalid binary expression in assignment target");
775
776         BasicTypeDeclaration &basic_left = dynamic_cast<BasicTypeDeclaration &>(*binary.left->type);
777         BasicTypeDeclaration &basic_right = dynamic_cast<BasicTypeDeclaration &>(*binary.right->type);
778         // Expression resolver ensures that element types are the same
779         BasicTypeDeclaration &elem = *get_element_type(basic_left);
780
781         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
782                 /* SPIR-V allows constant operations on floating-point values only for
783                 OpenGL kernels. */
784                 throw internal_error("invalid operands for constant binary expression");
785
786         binary.left->visit(*this);
787         Id left_id = r_expression_result_id;
788         binary.right->visit(*this);
789         Id right_id = r_expression_result_id;
790
791         Id result_type_id = get_id(*binary.type);
792         Opcode opcode = OP_NOP;
793         bool swap_operands = false;
794
795         r_constant_result = false;
796
797         char oper2 = binary.oper->token[1];
798         if((oper=='<' || oper=='>') && oper2!=oper)
799         {
800                 if(basic_left.kind==BasicTypeDeclaration::INT)
801                 {
802                         if(basic_left.sign)
803                                 opcode = (oper=='<' ? (oper2=='=' ? OP_S_LESS_THAN_EQUAL : OP_S_LESS_THAN) :
804                                         (oper2=='=' ? OP_S_GREATER_THAN_EQUAL : OP_S_GREATER_THAN));
805                         else
806                                 opcode = (oper=='<' ? (oper2=='=' ? OP_U_LESS_THAN_EQUAL : OP_U_LESS_THAN) :
807                                         (oper2=='=' ? OP_U_GREATER_THAN_EQUAL : OP_U_GREATER_THAN));
808                 }
809                 else if(basic_left.kind==BasicTypeDeclaration::FLOAT)
810                         opcode = (oper=='<' ? (oper2=='=' ? OP_F_ORD_LESS_THAN_EQUAL : OP_F_ORD_LESS_THAN) :
811                                 (oper2=='=' ? OP_F_ORD_GREATER_THAN_EQUAL : OP_F_ORD_GREATER_THAN));
812         }
813         else if((oper=='=' || oper=='!') && oper2=='=')
814         {
815                 if(elem.kind==BasicTypeDeclaration::BOOL)
816                         opcode = (oper=='=' ? OP_LOGICAL_EQUAL : OP_LOGICAL_NOT_EQUAL);
817                 else if(elem.kind==BasicTypeDeclaration::INT)
818                         opcode = (oper=='=' ? OP_I_EQUAL : OP_I_NOT_EQUAL);
819                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
820                         opcode = (oper=='=' ? OP_F_ORD_EQUAL : OP_F_ORD_NOT_EQUAL);
821
822                 if(opcode!=OP_NOP && basic_left.base_type)
823                 {
824                         /* The SPIR-V equality operations produce component-wise results, but
825                         GLSL operators return a single boolean.  Use the any/all operations to
826                         combine the results. */
827                         Opcode combine_op = (oper=='!' ? OP_ANY : OP_ALL);
828                         unsigned n_elems = basic_left.size&0xFFFF;
829                         Id bool_vec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, n_elems);
830
831                         Id compare_id = 0;
832                         if(basic_left.kind==BasicTypeDeclaration::VECTOR)
833                                 compare_id = write_expression(opcode, bool_vec_type_id, left_id, right_id);
834                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX)
835                         {
836                                 Id column_type_id = get_id(*basic_left.base_type);
837                                 Id column_ids[8];
838                                 write_deconstruct(column_type_id, left_id, column_ids, n_elems);
839                                 write_deconstruct(column_type_id, right_id, column_ids+4, n_elems);
840
841                                 Id column_bvec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, basic_left.size>>16);
842                                 for(unsigned i=0; i<n_elems; ++i)
843                                 {
844                                         compare_id = write_expression(opcode, column_bvec_type_id, column_ids[i], column_ids[4+i]);
845                                         column_ids[i] = write_expression(combine_op, result_type_id, compare_id);;
846                                 }
847
848                                 compare_id = write_construct(bool_vec_type_id, column_ids, n_elems);
849                         }
850
851                         if(compare_id)
852                                 r_expression_result_id = write_expression(combine_op, result_type_id, compare_id);
853                         return;
854                 }
855         }
856         else if(oper2=='&' && elem.kind==BasicTypeDeclaration::BOOL)
857                 opcode = OP_LOGICAL_AND;
858         else if(oper2=='|' && elem.kind==BasicTypeDeclaration::BOOL)
859                 opcode = OP_LOGICAL_OR;
860         else if(oper2=='^' && elem.kind==BasicTypeDeclaration::BOOL)
861                 opcode = OP_LOGICAL_NOT_EQUAL;
862         else if(oper=='&' && elem.kind==BasicTypeDeclaration::INT)
863                 opcode = OP_BITWISE_AND;
864         else if(oper=='|' && elem.kind==BasicTypeDeclaration::INT)
865                 opcode = OP_BITWISE_OR;
866         else if(oper=='^' && elem.kind==BasicTypeDeclaration::INT)
867                 opcode = OP_BITWISE_XOR;
868         else if(oper=='<' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
869                 opcode = OP_SHIFT_LEFT_LOGICAL;
870         else if(oper=='>' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
871                 opcode = OP_SHIFT_RIGHT_ARITHMETIC;
872         else if(oper=='%' && elem.kind==BasicTypeDeclaration::INT)
873                 opcode = (elem.sign ? OP_S_MOD : OP_U_MOD);
874         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
875         {
876                 Opcode elem_op = OP_NOP;
877                 if(elem.kind==BasicTypeDeclaration::INT)
878                 {
879                         if(oper=='/')
880                                 elem_op = (elem.sign ? OP_S_DIV : OP_U_DIV);
881                         else
882                                 elem_op = (oper=='+' ? OP_I_ADD : oper=='-' ? OP_I_SUB : OP_I_MUL);
883                 }
884                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
885                         elem_op = (oper=='+' ? OP_F_ADD : oper=='-' ? OP_F_SUB : oper=='*' ? OP_F_MUL : OP_F_DIV);
886
887                 if(oper=='*' && (basic_left.base_type || basic_right.base_type) && elem.kind==BasicTypeDeclaration::FLOAT)
888                 {
889                         /* Multiplication between floating-point vectors and matrices has
890                         dedicated operations. */
891                         if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
892                                 opcode = OP_MATRIX_TIMES_MATRIX;
893                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
894                         {
895                                 if(basic_left.kind==BasicTypeDeclaration::VECTOR)
896                                         opcode = OP_VECTOR_TIMES_MATRIX;
897                                 else if(basic_right.kind==BasicTypeDeclaration::VECTOR)
898                                         opcode = OP_MATRIX_TIMES_VECTOR;
899                                 else
900                                 {
901                                         opcode = OP_MATRIX_TIMES_SCALAR;
902                                         swap_operands = (basic_right.kind==BasicTypeDeclaration::MATRIX);
903                                 }
904                         }
905                         else if(basic_left.kind==BasicTypeDeclaration::VECTOR && basic_right.kind==BasicTypeDeclaration::VECTOR)
906                                 opcode = elem_op;
907                         else
908                         {
909                                 opcode = OP_VECTOR_TIMES_SCALAR;
910                                 swap_operands = (basic_right.kind==BasicTypeDeclaration::VECTOR);
911                         }
912                 }
913                 else if((basic_left.base_type!=0)!=(basic_right.base_type!=0))
914                 {
915                         /* One operand is scalar and the other is a vector or a matrix.
916                         Expand the scalar to a vector of appropriate size. */
917                         Id &scalar_id = (basic_left.base_type ? right_id : left_id);
918                         BasicTypeDeclaration *vector_type = (basic_left.base_type ? &basic_left : &basic_right);
919                         if(vector_type->kind==BasicTypeDeclaration::MATRIX)
920                                 vector_type = dynamic_cast<BasicTypeDeclaration *>(vector_type->base_type);
921                         Id vector_type_id = get_id(*vector_type);
922
923                         Id expanded_id = begin_expression(OP_COMPOSITE_CONSTRUCT, vector_type_id, vector_type->size);
924                         for(unsigned i=0; i<vector_type->size; ++i)
925                                 writer.write(scalar_id);
926                         end_expression(OP_COMPOSITE_CONSTRUCT);
927
928                         scalar_id = expanded_id;
929
930                         if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
931                         {
932                                 // Apply matrix operation column-wise.
933                                 Id matrix_id = (basic_left.base_type ? left_id : right_id);
934
935                                 Id column_ids[4];
936                                 unsigned n_columns = (basic_left.base_type ? basic_left.size : basic_right.size)&0xFFFF;
937                                 write_deconstruct(vector_type_id, matrix_id, column_ids, n_columns);
938
939                                 for(unsigned i=0; i<n_columns; ++i)
940                                         column_ids[i] = write_expression(elem_op, vector_type_id, column_ids[i], expanded_id);
941
942                                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
943                                 return;
944                         }
945                         else
946                                 opcode = elem_op;
947                 }
948                 else if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
949                 {
950                         if(oper=='*')
951                                 throw internal_error("non-float matrix multiplication");
952
953                         /* Other operations involving matrices need to be performed
954                         column-wise. */
955                         Id column_type_id = get_id(*basic_left.base_type);
956                         Id column_ids[8];
957
958                         unsigned n_columns = basic_left.size&0xFFFF;
959                         write_deconstruct(column_type_id, left_id, column_ids, n_columns);
960                         write_deconstruct(column_type_id, right_id, column_ids+4, n_columns);
961
962                         for(unsigned i=0; i<n_columns; ++i)
963                                 column_ids[i] = write_expression(elem_op, column_type_id, column_ids[i], column_ids[4+i]);
964
965                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
966                         return;
967                 }
968                 else if(basic_left.kind==basic_right.kind)
969                         // Both operands are either scalars or vectors.
970                         opcode = elem_op;
971         }
972
973         if(opcode==OP_NOP)
974                 throw internal_error("unknown binary operator");
975
976         if(swap_operands)
977                 swap(left_id, right_id);
978
979         r_expression_result_id = write_expression(opcode, result_type_id, left_id, right_id);
980 }
981
982 void SpirVGenerator::visit(Assignment &assign)
983 {
984         if(assign.oper->token[0]!='=')
985                 visit(static_cast<BinaryExpression &>(assign));
986         else
987                 assign.right->visit(*this);
988
989         SetForScope<Id> set_assign(assignment_source_id, r_expression_result_id);
990         assign.left->visit(*this);
991         r_constant_result = false;
992 }
993
994 void SpirVGenerator::visit(TernaryExpression &ternary)
995 {
996         if(composite_access)
997                 return visit_isolated(ternary);
998         if(constant_expression)
999         {
1000                 ternary.condition->visit(*this);
1001                 Id condition_id = r_expression_result_id;
1002                 ternary.true_expr->visit(*this);
1003                 Id true_result_id = r_expression_result_id;
1004                 ternary.false_expr->visit(*this);
1005                 Id false_result_id = r_expression_result_id;
1006
1007                 r_expression_result_id = begin_expression(OP_SELECT, get_id(*ternary.type), 3);
1008                 writer.write(condition_id);
1009                 writer.write(true_result_id);
1010                 writer.write(false_result_id);
1011                 end_expression(OP_SELECT);
1012
1013                 return;
1014         }
1015
1016         ternary.condition->visit(*this);
1017         Id condition_id = r_expression_result_id;
1018
1019         Id true_label_id = next_id++;
1020         Id false_label_id = next_id++;
1021         Id merge_block_id = next_id++;
1022         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1023         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, condition_id, true_label_id, false_label_id);
1024
1025         std::map<const VariableDeclaration *, Id> saved_load_ids = variable_load_ids;
1026
1027         writer.write_op_label(true_label_id);
1028         ternary.true_expr->visit(*this);
1029         Id true_result_id = r_expression_result_id;
1030         true_label_id = writer.get_current_block();
1031         writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1032
1033         swap(saved_load_ids, variable_load_ids);
1034         writer.write_op_label(false_label_id);
1035         ternary.false_expr->visit(*this);
1036         Id false_result_id = r_expression_result_id;
1037         false_label_id = writer.get_current_block();
1038
1039         writer.write_op_label(merge_block_id);
1040         prune_loads(true_label_id);
1041         r_expression_result_id = begin_expression(OP_PHI, get_id(*ternary.type), 4);
1042         writer.write(true_result_id);
1043         writer.write(true_label_id);
1044         writer.write(false_result_id);
1045         writer.write(false_label_id);
1046         end_expression(OP_PHI);
1047
1048         r_constant_result = false;
1049 }
1050
1051 void SpirVGenerator::visit(FunctionCall &call)
1052 {
1053         if(assignment_source_id)
1054                 throw internal_error("assignment to function call");
1055         else if(composite_access)
1056                 return visit_isolated(call);
1057         else if(call.constructor && call.arguments.size()==1 && call.arguments[0]->type==call.type)
1058                 return call.arguments[0]->visit(*this);
1059
1060         vector<Id> argument_ids;
1061         argument_ids.reserve(call.arguments.size());
1062         bool all_args_const = true;
1063         for(const RefPtr<Expression> &a: call.arguments)
1064         {
1065                 a->visit(*this);
1066                 argument_ids.push_back(r_expression_result_id);
1067                 all_args_const &= r_constant_result;
1068         }
1069
1070         if(constant_expression && (!call.constructor || !all_args_const))
1071                 throw internal_error("function call in constant expression");
1072
1073         Id result_type_id = get_id(*call.type);
1074         r_constant_result = false;
1075
1076         if(call.constructor)
1077                 visit_constructor(call, argument_ids, all_args_const);
1078         else if(call.declaration->source==BUILTIN_SOURCE)
1079         {
1080                 string arg_types;
1081                 for(const RefPtr<Expression> &a: call.arguments)
1082                         if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>(a->type))
1083                         {
1084                                 BasicTypeDeclaration &elem_arg = *get_element_type(*basic_arg);
1085                                 switch(elem_arg.kind)
1086                                 {
1087                                 case BasicTypeDeclaration::BOOL: arg_types += 'b'; break;
1088                                 case BasicTypeDeclaration::INT: arg_types += (elem_arg.sign ? 'i' : 'u'); break;
1089                                 case BasicTypeDeclaration::FLOAT: arg_types += 'f'; break;
1090                                 default: arg_types += '?';
1091                                 }
1092                         }
1093
1094                 const BuiltinFunctionInfo *builtin_info;
1095                 for(builtin_info=builtin_functions; builtin_info->function[0]; ++builtin_info)
1096                         if(builtin_info->function==call.name && (!builtin_info->arg_types[0] || builtin_info->arg_types==arg_types))
1097                                 break;
1098
1099                 if(builtin_info->capability)
1100                         use_capability(static_cast<Capability>(builtin_info->capability));
1101
1102                 if(builtin_info->opcode)
1103                 {
1104                         Opcode opcode;
1105                         if(builtin_info->extension[0])
1106                         {
1107                                 opcode = OP_EXT_INST;
1108                                 Id ext_id = import_extension(builtin_info->extension);
1109
1110                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1111                                 writer.write(ext_id);
1112                                 writer.write(builtin_info->opcode);
1113                         }
1114                         else
1115                         {
1116                                 opcode = static_cast<Opcode>(builtin_info->opcode);
1117                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1118                         }
1119
1120                         for(unsigned i=0; i<call.arguments.size(); ++i)
1121                         {
1122                                 if(!builtin_info->arg_order[i] || builtin_info->arg_order[i]>argument_ids.size())
1123                                         throw internal_error("invalid builtin function info");
1124                                 writer.write(argument_ids[builtin_info->arg_order[i]-1]);
1125                         }
1126
1127                         end_expression(opcode);
1128                 }
1129                 else if(builtin_info->handler)
1130                         (this->*(builtin_info->handler))(call, argument_ids);
1131                 else
1132                         throw internal_error("unknown builtin function "+call.name);
1133         }
1134         else
1135         {
1136                 r_expression_result_id = begin_expression(OP_FUNCTION_CALL, result_type_id, 1+call.arguments.size());
1137                 writer.write(get_id(*call.declaration->definition));
1138                 for(Id i: argument_ids)
1139                         writer.write(i);
1140                 end_expression(OP_FUNCTION_CALL);
1141
1142                 // Any global variables the called function uses might have changed value
1143                 set<Node *> dependencies = DependencyCollector().apply(*call.declaration->definition);
1144                 for(Node *n: dependencies)
1145                         if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(n))
1146                                 variable_load_ids.erase(var);
1147         }
1148 }
1149
1150 void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids, bool all_args_const)
1151 {
1152         Id result_type_id = get_id(*call.type);
1153
1154         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(call.type);
1155         if(!basic)
1156         {
1157                 if(dynamic_cast<const StructDeclaration *>(call.type))
1158                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1159                 else
1160                         throw internal_error("unconstructable type "+call.name);
1161                 return;
1162         }
1163
1164         SetFlag set_const(constant_expression, constant_expression || all_args_const);
1165
1166         BasicTypeDeclaration &elem = *get_element_type(*basic);
1167         BasicTypeDeclaration &basic_arg0 = dynamic_cast<BasicTypeDeclaration &>(*call.arguments[0]->type);
1168         BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0);
1169
1170         if(basic->kind==BasicTypeDeclaration::MATRIX)
1171         {
1172                 Id col_type_id = get_id(*basic->base_type);
1173                 unsigned n_columns = basic->size&0xFFFF;
1174                 unsigned n_rows = basic->size>>16;
1175
1176                 Id column_ids[4];
1177                 if(call.arguments.size()==1)
1178                 {
1179                         // Construct diagonal matrix from a single scalar.
1180                         Id zero_id = get_constant_id(get_id(elem), 0.0f);
1181                         for(unsigned i=0; i<n_columns; ++i)
1182                         {
1183                                 column_ids[i] = begin_expression(OP_COMPOSITE_CONSTRUCT, col_type_id, n_rows);
1184                                 for(unsigned j=0; j<n_rows; ++j)
1185                                         writer.write(j==i ? argument_ids[0] : zero_id);
1186                                 end_expression(OP_COMPOSITE_CONSTRUCT);
1187                         }
1188                 }
1189                 else
1190                         // Construct a matrix from column vectors
1191                         copy(argument_ids.begin(), argument_ids.begin()+n_columns, column_ids);
1192
1193                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
1194         }
1195         else if(basic->kind==BasicTypeDeclaration::VECTOR && (call.arguments.size()>1 || basic_arg0.kind!=BasicTypeDeclaration::VECTOR))
1196         {
1197                 /* There's either a single scalar argument or multiple arguments
1198                 which make up the vector's components. */
1199                 if(call.arguments.size()==1)
1200                 {
1201                         r_expression_result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, result_type_id);
1202                         for(unsigned i=0; i<basic->size; ++i)
1203                                 writer.write(argument_ids[0]);
1204                         end_expression(OP_COMPOSITE_CONSTRUCT);
1205                 }
1206                 else
1207                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1208         }
1209         else if(elem.kind==BasicTypeDeclaration::BOOL)
1210         {
1211                 if(constant_expression)
1212                         throw internal_error("unconverted constant");
1213
1214                 // Conversion to boolean is implemented as comparing against zero.
1215                 Id number_type_id = get_id(elem_arg0);
1216                 Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ?
1217                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1218                 if(basic_arg0.kind==BasicTypeDeclaration::VECTOR)
1219                         zero_id = get_vector_constant_id(get_id(basic_arg0), basic_arg0.size, zero_id);
1220
1221                 Opcode opcode = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ? OP_F_ORD_NOT_EQUAL : OP_I_NOT_EQUAL);
1222                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0], zero_id);
1223         }
1224         else if(elem_arg0.kind==BasicTypeDeclaration::BOOL)
1225         {
1226                 if(constant_expression)
1227                         throw internal_error("unconverted constant");
1228
1229                 /* Conversion from boolean is implemented as selecting from zero
1230                 or one. */
1231                 Id number_type_id = get_id(elem);
1232                 Id zero_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1233                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1234                 Id one_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1235                         get_constant_id(number_type_id, 1.0f) : get_constant_id(number_type_id, 1));
1236                 if(basic->kind==BasicTypeDeclaration::VECTOR)
1237                 {
1238                         zero_id = get_vector_constant_id(get_id(*basic), basic->size, zero_id);
1239                         one_id = get_vector_constant_id(get_id(*basic), basic->size, one_id);
1240                 }
1241
1242                 r_expression_result_id = begin_expression(OP_SELECT, result_type_id, 3);
1243                 writer.write(argument_ids[0]);
1244                 writer.write(zero_id);
1245                 writer.write(one_id);
1246                 end_expression(OP_SELECT);
1247         }
1248         else
1249         {
1250                 if(constant_expression)
1251                         throw internal_error("unconverted constant");
1252
1253                 // Scalar or vector conversion between types of equal size.
1254                 Opcode opcode;
1255                 if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT)
1256                         opcode = (elem.sign ? OP_CONVERT_F_TO_S : OP_CONVERT_F_TO_U);
1257                 else if(elem.kind==BasicTypeDeclaration::FLOAT && elem_arg0.kind==BasicTypeDeclaration::INT)
1258                         opcode = (elem_arg0.sign ? OP_CONVERT_S_TO_F : OP_CONVERT_U_TO_F);
1259                 else if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::INT)
1260                         opcode = OP_BITCAST;
1261                 else
1262                         throw internal_error("invalid conversion");
1263
1264                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0]);
1265         }
1266 }
1267
1268 void SpirVGenerator::visit_builtin_matrix_comp_mult(FunctionCall &call, const vector<Id> &argument_ids)
1269 {
1270         if(argument_ids.size()!=2)
1271                 throw internal_error("invalid matrixCompMult call");
1272
1273         const BasicTypeDeclaration &basic_arg0 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[0]->type);
1274         Id column_type_id = get_id(*basic_arg0.base_type);
1275         Id column_ids[8];
1276
1277         unsigned n_columns = basic_arg0.size&0xFFFF;
1278         write_deconstruct(column_type_id, argument_ids[0], column_ids, n_columns);
1279         write_deconstruct(column_type_id, argument_ids[1], column_ids+4, n_columns);
1280
1281         for(unsigned i=0; i<n_columns; ++i)
1282                 column_ids[i] = write_expression(OP_F_MUL, column_type_id, column_ids[i], column_ids[4+i]);
1283
1284         r_expression_result_id = write_construct(get_id(*call.type), column_ids, n_columns);
1285 }
1286
1287 void SpirVGenerator::visit_builtin_texture_query(FunctionCall &call, const vector<Id> &argument_ids)
1288 {
1289         if(argument_ids.size()<1)
1290                 throw internal_error("invalid texture query call");
1291
1292         Opcode opcode;
1293         if(call.name=="textureSize")
1294                 opcode = OP_IMAGE_QUERY_SIZE_LOD;
1295         else if(call.name=="textureQueryLod")
1296                 opcode = OP_IMAGE_QUERY_LOD;
1297         else if(call.name=="textureQueryLevels")
1298                 opcode = OP_IMAGE_QUERY_LEVELS;
1299         else if(call.name=="textureSamples")
1300                 opcode = OP_IMAGE_QUERY_SAMPLES;
1301         else
1302                 throw internal_error("invalid texture query call");
1303
1304         ImageTypeDeclaration &image_arg0 = dynamic_cast<ImageTypeDeclaration &>(*call.arguments[0]->type);
1305
1306         Id image_id;
1307         if(image_arg0.sampled && opcode!=OP_IMAGE_QUERY_LOD)
1308         {
1309                 Id image_type_id = get_item(image_type_ids, get_id(image_arg0));
1310                 image_id = write_expression(OP_IMAGE, image_type_id, argument_ids[0]);
1311         }
1312         else
1313                 image_id = argument_ids[0];
1314
1315         Id result_type_id = get_id(*call.type);
1316         r_expression_result_id = begin_expression(opcode, result_type_id, argument_ids.size());
1317         writer.write(image_id);
1318         for(unsigned i=1; i<argument_ids.size(); ++i)
1319                 writer.write(argument_ids[i]);
1320         end_expression(opcode);
1321 }
1322
1323 void SpirVGenerator::visit_builtin_texture(FunctionCall &call, const vector<Id> &argument_ids)
1324 {
1325         if(argument_ids.size()<2)
1326                 throw internal_error("invalid texture sampling call");
1327
1328         bool explicit_lod = (stage->type!=Stage::FRAGMENT || call.name=="textureLod");
1329         Id lod_id = (!explicit_lod ? 0 : call.name=="textureLod" ? argument_ids.back() :
1330                 get_constant_id(get_standard_type_id(BasicTypeDeclaration::FLOAT, 1), 0.0f));
1331
1332         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1333
1334         Opcode opcode;
1335         Id result_type_id = get_id(*call.type);
1336         Id dref_id = 0;
1337         if(image.shadow)
1338         {
1339                 if(argument_ids.size()==2)
1340                 {
1341                         const BasicTypeDeclaration &basic_arg1 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[1]->type);
1342                         dref_id = begin_expression(OP_COMPOSITE_EXTRACT, get_id(*basic_arg1.base_type), 2);
1343                         writer.write(argument_ids.back());
1344                         writer.write(basic_arg1.size-1);
1345                         end_expression(OP_COMPOSITE_EXTRACT);
1346                 }
1347                 else
1348                         dref_id = argument_ids[2];
1349
1350                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD : OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD);
1351                 r_expression_result_id = begin_expression(opcode, result_type_id, 3+explicit_lod*2);
1352         }
1353         else
1354         {
1355                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_EXPLICIT_LOD : OP_IMAGE_SAMPLE_IMPLICIT_LOD);
1356                 r_expression_result_id = begin_expression(opcode, result_type_id, 2+explicit_lod*2);
1357         }
1358
1359         for(unsigned i=0; i<2; ++i)
1360                 writer.write(argument_ids[i]);
1361         if(dref_id)
1362                 writer.write(dref_id);
1363         if(explicit_lod)
1364         {
1365                 writer.write(2);  // Lod
1366                 writer.write(lod_id);
1367         }
1368
1369         end_expression(opcode);
1370 }
1371
1372 void SpirVGenerator::visit_builtin_texel_fetch(FunctionCall &call, const vector<Id> &argument_ids)
1373 {
1374         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1375
1376         Opcode opcode = OP_IMAGE_FETCH;
1377
1378         bool need_sample = image.multisample;
1379         bool need_lod = !need_sample;
1380
1381         if(argument_ids.size()!=2U+need_sample+need_lod)
1382                 throw internal_error("invalid texture fetch call");
1383
1384         r_expression_result_id = begin_expression(opcode, get_id(*call.type), 2+(need_lod|need_sample)+need_lod+need_sample);
1385         for(unsigned i=0; i<2; ++i)
1386                 writer.write(argument_ids[i]);
1387         if(need_lod || need_sample)
1388         {
1389                 writer.write(need_lod*0x02 | need_sample*0x40);
1390                 writer.write(argument_ids.back());
1391         }
1392         end_expression(opcode);
1393 }
1394
1395 void SpirVGenerator::visit_builtin_interpolate(FunctionCall &call, const vector<Id> &argument_ids)
1396 {
1397         if(argument_ids.size()<1)
1398                 throw internal_error("invalid interpolate call");
1399         const VariableReference *var = dynamic_cast<const VariableReference *>(call.arguments[0].get());
1400         if(!var || !var->declaration || var->declaration->interface!="in")
1401                 throw internal_error("invalid interpolate call");
1402
1403         SpirVGlslStd450Opcode opcode;
1404         if(call.name=="interpolateAtCentroid")
1405                 opcode = GLSL450_INTERPOLATE_AT_CENTROID;
1406         else if(call.name=="interpolateAtSample")
1407                 opcode = GLSL450_INTERPOLATE_AT_SAMPLE;
1408         else if(call.name=="interpolateAtOffset")
1409                 opcode = GLSL450_INTERPOLATE_AT_OFFSET;
1410         else
1411                 throw internal_error("invalid interpolate call");
1412
1413         Id ext_id = import_extension("GLSL.std.450");
1414         r_expression_result_id = begin_expression(OP_EXT_INST, get_id(*call.type));
1415         writer.write(ext_id);
1416         writer.write(opcode);
1417         writer.write(get_id(*var->declaration));
1418         for(auto i=argument_ids.begin(); ++i!=argument_ids.end(); )
1419                 writer.write(*i);
1420         end_expression(OP_EXT_INST);
1421 }
1422
1423 void SpirVGenerator::visit(ExpressionStatement &expr)
1424 {
1425         expr.expression->visit(*this);
1426 }
1427
1428 void SpirVGenerator::visit(InterfaceLayout &layout)
1429 {
1430         interface_layouts.push_back(&layout);
1431 }
1432
1433 bool SpirVGenerator::check_duplicate_type(TypeDeclaration &type)
1434 {
1435         for(const auto &kvp: declared_ids)
1436                 if(TypeDeclaration *type2 = dynamic_cast<TypeDeclaration *>(kvp.first))
1437                         if(TypeComparer().apply(type, *type2))
1438                         {
1439                                 insert_unique(declared_ids, &type, kvp.second);
1440                                 return true;
1441                         }
1442
1443         return false;
1444 }
1445
1446 bool SpirVGenerator::check_standard_type(BasicTypeDeclaration &basic)
1447 {
1448         const BasicTypeDeclaration *elem = (basic.kind==BasicTypeDeclaration::VECTOR ?
1449                 dynamic_cast<const BasicTypeDeclaration *>(basic.base_type) : &basic);
1450         if(!elem || elem->base_type)
1451                 return false;
1452         if((elem->kind==BasicTypeDeclaration::INT || elem->kind==BasicTypeDeclaration::FLOAT) && elem->size!=32)
1453                 return false;
1454
1455         Id standard_id = get_standard_type_id(elem->kind, (basic.kind==BasicTypeDeclaration::VECTOR ? basic.size : 1), elem->sign);
1456         insert_unique(declared_ids, &basic, Declaration(standard_id, 0));
1457         writer.write_op_name(standard_id, basic.name);
1458
1459         return true;
1460 }
1461
1462 void SpirVGenerator::visit(BasicTypeDeclaration &basic)
1463 {
1464         if(check_standard_type(basic))
1465                 return;
1466         if(check_duplicate_type(basic))
1467                 return;
1468         // Alias types shouldn't exist at this point and arrays are handled elsewhere
1469         if(basic.kind==BasicTypeDeclaration::ALIAS || basic.kind==BasicTypeDeclaration::ARRAY)
1470                 return;
1471
1472         Id type_id = allocate_id(basic, 0);
1473         writer.write_op_name(type_id, basic.name);
1474
1475         switch(basic.kind)
1476         {
1477         case BasicTypeDeclaration::INT:
1478                 writer.write_op(content.globals, OP_TYPE_INT, type_id, basic.size, basic.sign);
1479                 break;
1480         case BasicTypeDeclaration::FLOAT:
1481                 writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, basic.size);
1482                 break;
1483         case BasicTypeDeclaration::VECTOR:
1484                 writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, get_id(*basic.base_type), basic.size);
1485                 break;
1486         case BasicTypeDeclaration::MATRIX:
1487                 writer.write_op(content.globals, OP_TYPE_MATRIX, type_id, get_id(*basic.base_type), basic.size&0xFFFF);
1488                 break;
1489         default:
1490                 throw internal_error("unknown basic type");
1491         }
1492 }
1493
1494 void SpirVGenerator::visit(ImageTypeDeclaration &image)
1495 {
1496         if(check_duplicate_type(image))
1497                 return;
1498
1499         Id type_id = allocate_id(image, 0);
1500
1501         Id image_id = (image.sampled ? next_id++ : type_id);
1502         writer.begin_op(content.globals, OP_TYPE_IMAGE, 9);
1503         writer.write(image_id);
1504         writer.write(get_id(*image.base_type));
1505         writer.write(image.dimensions-1);
1506         writer.write(image.shadow);
1507         writer.write(image.array);
1508         writer.write(image.multisample);
1509         writer.write(image.sampled ? 1 : 2);
1510         writer.write(0);  // Format (unknown)
1511         writer.end_op(OP_TYPE_IMAGE);
1512
1513         if(image.sampled)
1514         {
1515                 writer.write_op_name(type_id, image.name);
1516                 writer.write_op(content.globals, OP_TYPE_SAMPLED_IMAGE, type_id, image_id);
1517                 insert_unique(image_type_ids, type_id, image_id);
1518         }
1519
1520         if(image.dimensions==ImageTypeDeclaration::ONE)
1521                 use_capability(image.sampled ? CAP_SAMPLED_1D : CAP_IMAGE_1D);
1522         else if(image.dimensions==ImageTypeDeclaration::CUBE && image.array)
1523                 use_capability(image.sampled ? CAP_SAMPLED_CUBE_ARRAY : CAP_IMAGE_CUBE_ARRAY);
1524 }
1525
1526 void SpirVGenerator::visit(StructDeclaration &strct)
1527 {
1528         if(check_duplicate_type(strct))
1529                 return;
1530
1531         Id type_id = allocate_id(strct, 0);
1532         writer.write_op_name(type_id, (strct.block_name.empty() ? strct.name : strct.block_name));
1533
1534         if(!strct.block_name.empty())
1535                 writer.write_op_decorate(type_id, DECO_BLOCK);
1536
1537         bool builtin = !strct.block_name.compare(0, 3, "gl_");
1538         vector<Id> member_type_ids;
1539         member_type_ids.reserve(strct.members.body.size());
1540         for(const RefPtr<Statement> &s: strct.members.body)
1541         {
1542                 const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(s.get());
1543                 if(!var)
1544                         continue;
1545
1546                 unsigned index = member_type_ids.size();
1547                 member_type_ids.push_back(get_variable_type_id(*var));
1548
1549                 writer.write_op_member_name(type_id, index, var->name);
1550
1551                 if(builtin)
1552                 {
1553                         BuiltinSemantic semantic = get_builtin_semantic(var->name);
1554                         writer.write_op_member_decorate(type_id, index, DECO_BUILTIN, semantic);
1555                 }
1556                 else
1557                 {
1558                         if(var->layout)
1559                         {
1560                                 for(const Layout::Qualifier &q: var->layout->qualifiers)
1561                                 {
1562                                         if(q.name=="offset")
1563                                                 writer.write_op_member_decorate(type_id, index, DECO_OFFSET, q.value);
1564                                         else if(q.name=="column_major")
1565                                                 writer.write_op_member_decorate(type_id, index, DECO_COL_MAJOR);
1566                                         else if(q.name=="row_major")
1567                                                 writer.write_op_member_decorate(type_id, index, DECO_ROW_MAJOR);
1568                                 }
1569                         }
1570
1571                         const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var->type_declaration);
1572                         while(basic && basic->kind==BasicTypeDeclaration::ARRAY)
1573                                 basic = dynamic_cast<const BasicTypeDeclaration *>(basic->base_type);
1574                         if(basic && basic->kind==BasicTypeDeclaration::MATRIX)
1575                         {
1576                                 unsigned stride = MemoryRequirementsCalculator().apply(*basic->base_type).stride;
1577                                 writer.write_op_member_decorate(type_id, index, DECO_MATRIX_STRIDE, stride);
1578                         }
1579                 }
1580         }
1581
1582         writer.begin_op(content.globals, OP_TYPE_STRUCT);
1583         writer.write(type_id);
1584         for(Id i: member_type_ids)
1585                 writer.write(i);
1586         writer.end_op(OP_TYPE_STRUCT);
1587 }
1588
1589 void SpirVGenerator::visit(VariableDeclaration &var)
1590 {
1591         Id type_id = get_variable_type_id(var);
1592         Id var_id;
1593
1594         if(var.constant)
1595         {
1596                 if(!var.init_expression)
1597                         throw internal_error("const variable without initializer");
1598
1599                 int spec_id = get_layout_value(var.layout.get(), "constant_id");
1600                 Id *spec_var_id = (spec_id>=0 ? &declared_spec_ids[spec_id] : 0);
1601                 if(spec_id>=0 && *spec_var_id)
1602                 {
1603                         insert_unique(declared_ids, &var, Declaration(*spec_var_id, type_id));
1604                         return;
1605                 }
1606
1607                 SetFlag set_const(constant_expression);
1608                 SetFlag set_spec(spec_constant, spec_id>=0);
1609                 r_expression_result_id = 0;
1610                 var.init_expression->visit(*this);
1611                 var_id = r_expression_result_id;
1612                 insert_unique(declared_ids, &var, Declaration(var_id, type_id));
1613                 if(spec_id>=0)
1614                 {
1615                         writer.write_op_decorate(var_id, DECO_SPEC_ID, spec_id);
1616                         *spec_var_id = var_id;
1617                 }
1618         }
1619         else
1620         {
1621                 bool push_const = has_layout_qualifier(var.layout.get(), "push_constant");
1622
1623                 StorageClass storage;
1624                 if(current_function)
1625                         storage = STORAGE_FUNCTION;
1626                 else if(push_const)
1627                         storage = STORAGE_PUSH_CONSTANT;
1628                 else
1629                         storage = get_interface_storage(var.interface, var.block_declaration);
1630
1631                 Id ptr_type_id = get_pointer_type_id(type_id, storage);
1632                 if(var.interface=="uniform")
1633                 {
1634                         Id &uni_id = declared_uniform_ids[var.block_declaration ? "b"+var.block_declaration->block_name : "v"+var.name];
1635                         if(uni_id)
1636                         {
1637                                 insert_unique(declared_ids, &var, Declaration(uni_id, ptr_type_id));
1638                                 return;
1639                         }
1640
1641                         uni_id = var_id = allocate_id(var, ptr_type_id);
1642                 }
1643                 else
1644                         var_id = allocate_id(var, (var.constant ? type_id : ptr_type_id));
1645
1646                 Id init_id = 0;
1647                 if(var.init_expression)
1648                 {
1649                         SetFlag set_const(constant_expression, !current_function);
1650                         r_expression_result_id = 0;
1651                         r_constant_result = false;
1652                         var.init_expression->visit(*this);
1653                         init_id = r_expression_result_id;
1654                 }
1655
1656                 vector<Word> &target = (current_function ? content.locals : content.globals);
1657                 writer.begin_op(target, OP_VARIABLE, 4+(init_id && !current_function));
1658                 writer.write(ptr_type_id);
1659                 writer.write(var_id);
1660                 writer.write(storage);
1661                 if(init_id && !current_function)
1662                         writer.write(init_id);
1663                 writer.end_op(OP_VARIABLE);
1664
1665                 if(var.layout)
1666                 {
1667                         for(const Layout::Qualifier &q: var.layout->qualifiers)
1668                         {
1669                                 if(q.name=="location")
1670                                         writer.write_op_decorate(var_id, DECO_LOCATION, q.value);
1671                                 else if(q.name=="set")
1672                                         writer.write_op_decorate(var_id, DECO_DESCRIPTOR_SET, q.value);
1673                                 else if(q.name=="binding")
1674                                         writer.write_op_decorate(var_id, DECO_BINDING, q.value);
1675                         }
1676                 }
1677                 if(!var.block_declaration && !var.name.compare(0, 3, "gl_"))
1678                 {
1679                         BuiltinSemantic semantic = get_builtin_semantic(var.name);
1680                         writer.write_op_decorate(var_id, DECO_BUILTIN, semantic);
1681                 }
1682
1683                 if(init_id && current_function)
1684                 {
1685                         writer.write_op(content.function_body, OP_STORE, var_id, init_id);
1686                         variable_load_ids[&var] = init_id;
1687                 }
1688         }
1689
1690         if(var.name.find(' ')==string::npos)
1691                 writer.write_op_name(var_id, var.name);
1692 }
1693
1694 void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id)
1695 {
1696         writer.begin_op(content.entry_points, OP_ENTRY_POINT);
1697         switch(stage->type)
1698         {
1699         case Stage::VERTEX: writer.write(0); break;
1700         case Stage::GEOMETRY: writer.write(3); break;
1701         case Stage::FRAGMENT: writer.write(4); break;
1702         default: throw internal_error("unknown stage");
1703         }
1704         writer.write(func_id);
1705         writer.write_string(func.name);
1706
1707         set<Node *> dependencies = DependencyCollector().apply(func);
1708         for(Node *n: dependencies)
1709                 if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(n))
1710                         if(!var->interface.empty())
1711                                 writer.write(get_id(*n));
1712
1713         writer.end_op(OP_ENTRY_POINT);
1714
1715         if(stage->type==Stage::FRAGMENT)
1716         {
1717                 SpirVExecutionMode origin = (features.target_api==VULKAN ? EXEC_ORIGIN_UPPER_LEFT : EXEC_ORIGIN_LOWER_LEFT);
1718                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, origin);
1719         }
1720         else if(stage->type==Stage::GEOMETRY)
1721         {
1722                 use_capability(CAP_GEOMETRY);
1723                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INVOCATIONS, 1);
1724         }
1725
1726         for(const InterfaceLayout *i: interface_layouts)
1727         {
1728                 for(const Layout::Qualifier &q: i->layout.qualifiers)
1729                 {
1730                         if(q.name=="point")
1731                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id,
1732                                         (i->interface=="in" ? EXEC_INPUT_POINTS : EXEC_OUTPUT_POINTS));
1733                         else if(q.name=="lines")
1734                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES);
1735                         else if(q.name=="lines_adjacency")
1736                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES_ADJACENCY);
1737                         else if(q.name=="triangles")
1738                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_TRIANGLES);
1739                         else if(q.name=="triangles_adjacency")
1740                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_TRIANGLES_ADJACENCY);
1741                         else if(q.name=="line_strip")
1742                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_LINE_STRIP);
1743                         else if(q.name=="triangle_strip")
1744                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_TRIANGLE_STRIP);
1745                         else if(q.name=="max_vertices")
1746                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_VERTICES, q.value);
1747                 }
1748         }
1749 }
1750
1751 void SpirVGenerator::visit(FunctionDeclaration &func)
1752 {
1753         if(func.source==BUILTIN_SOURCE)
1754                 return;
1755         else if(func.definition!=&func)
1756         {
1757                 if(func.definition)
1758                         allocate_forward_id(*func.definition);
1759                 return;
1760         }
1761
1762         Id return_type_id = get_id(*func.return_type_declaration);
1763         vector<unsigned> param_type_ids;
1764         param_type_ids.reserve(func.parameters.size());
1765         for(const RefPtr<VariableDeclaration> &p: func.parameters)
1766                 param_type_ids.push_back(get_variable_type_id(*p));
1767
1768         string sig_with_return = func.return_type+func.signature;
1769         Id &type_id = function_type_ids[sig_with_return];
1770         if(!type_id)
1771         {
1772                 type_id = next_id++;
1773                 writer.begin_op(content.globals, OP_TYPE_FUNCTION);
1774                 writer.write(type_id);
1775                 writer.write(return_type_id);
1776                 for(unsigned i: param_type_ids)
1777                         writer.write(i);
1778                 writer.end_op(OP_TYPE_FUNCTION);
1779
1780                 writer.write_op_name(type_id, sig_with_return);
1781         }
1782
1783         Id func_id = allocate_id(func, type_id);
1784         writer.write_op_name(func_id, func.name+func.signature);
1785
1786         if(func.name=="main")
1787                 visit_entry_point(func, func_id);
1788
1789         writer.begin_op(content.functions, OP_FUNCTION, 5);
1790         writer.write(return_type_id);
1791         writer.write(func_id);
1792         writer.write(0);  // Function control flags (none)
1793         writer.write(type_id);
1794         writer.end_op(OP_FUNCTION);
1795
1796         for(unsigned i=0; i<func.parameters.size(); ++i)
1797         {
1798                 Id param_id = allocate_id(*func.parameters[i], param_type_ids[i]);
1799                 writer.write_op(content.functions, OP_FUNCTION_PARAMETER, param_type_ids[i], param_id);
1800                 writer.write_op_name(param_id, func.parameters[i]->name);
1801                 // TODO This is probably incorrect if the parameter is assigned to.
1802                 variable_load_ids[func.parameters[i].get()] = param_id;
1803         }
1804
1805         reachable = true;
1806         writer.begin_function_body(next_id++);
1807         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
1808         func.body.visit(*this);
1809
1810         if(writer.get_current_block())
1811         {
1812                 if(!reachable)
1813                         writer.write_op(content.function_body, OP_UNREACHABLE);
1814                 else
1815                 {
1816                         const BasicTypeDeclaration *basic_return = dynamic_cast<const BasicTypeDeclaration *>(func.return_type_declaration);
1817                         if(basic_return && basic_return->kind==BasicTypeDeclaration::VOID)
1818                                 writer.write_op(content.function_body, OP_RETURN);
1819                         else
1820                                 throw internal_error("missing return in non-void function");
1821                 }
1822         }
1823         writer.end_function_body();
1824         variable_load_ids.clear();
1825 }
1826
1827 void SpirVGenerator::visit(Conditional &cond)
1828 {
1829         cond.condition->visit(*this);
1830
1831         Id true_label_id = next_id++;
1832         Id merge_block_id = next_id++;
1833         Id false_label_id = (cond.else_body.body.empty() ? merge_block_id : next_id++);
1834         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1835         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, true_label_id, false_label_id);
1836
1837         std::map<const VariableDeclaration *, Id> saved_load_ids = variable_load_ids;
1838
1839         writer.write_op_label(true_label_id);
1840         cond.body.visit(*this);
1841         if(writer.get_current_block())
1842                 writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1843
1844         bool reachable_if_true = reachable;
1845
1846         reachable = true;
1847         if(!cond.else_body.body.empty())
1848         {
1849                 swap(saved_load_ids, variable_load_ids);
1850                 writer.write_op_label(false_label_id);
1851                 cond.else_body.visit(*this);
1852                 reachable |= reachable_if_true;
1853         }
1854
1855         writer.write_op_label(merge_block_id);
1856         prune_loads(true_label_id);
1857 }
1858
1859 void SpirVGenerator::visit(Iteration &iter)
1860 {
1861         if(iter.init_statement)
1862                 iter.init_statement->visit(*this);
1863
1864         for(Node *n: AssignmentCollector().apply(iter))
1865                 if(VariableDeclaration *var = dynamic_cast<VariableDeclaration *>(n))
1866                         variable_load_ids.erase(var);
1867
1868         Id header_id = next_id++;
1869         Id continue_id = next_id++;
1870         Id merge_block_id = next_id++;
1871
1872         SetForScope<Id> set_merge(loop_merge_block_id, merge_block_id);
1873         SetForScope<Id> set_continue(loop_continue_target_id, continue_id);
1874
1875         writer.write_op_label(header_id);
1876         writer.write_op(content.function_body, OP_LOOP_MERGE, merge_block_id, continue_id, 0);  // Loop control (none)
1877
1878         Id body_id = next_id++;
1879         if(iter.condition)
1880         {
1881                 writer.write_op_label(next_id++);
1882                 iter.condition->visit(*this);
1883                 writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, body_id, merge_block_id);
1884         }
1885
1886         writer.write_op_label(body_id);
1887         iter.body.visit(*this);
1888
1889         writer.write_op_label(continue_id);
1890         if(iter.loop_expression)
1891                 iter.loop_expression->visit(*this);
1892         writer.write_op(content.function_body, OP_BRANCH, header_id);
1893
1894         writer.write_op_label(merge_block_id);
1895         prune_loads(header_id);
1896         reachable = true;
1897 }
1898
1899 void SpirVGenerator::visit(Return &ret)
1900 {
1901         if(ret.expression)
1902         {
1903                 ret.expression->visit(*this);
1904                 writer.write_op(content.function_body, OP_RETURN_VALUE, r_expression_result_id);
1905         }
1906         else
1907                 writer.write_op(content.function_body, OP_RETURN);
1908         reachable = false;
1909 }
1910
1911 void SpirVGenerator::visit(Jump &jump)
1912 {
1913         if(jump.keyword=="discard")
1914                 writer.write_op(content.function_body, OP_KILL);
1915         else if(jump.keyword=="break")
1916                 writer.write_op(content.function_body, OP_BRANCH, loop_merge_block_id);
1917         else if(jump.keyword=="continue")
1918                 writer.write_op(content.function_body, OP_BRANCH, loop_continue_target_id);
1919         else
1920                 throw internal_error("unknown jump");
1921         reachable = false;
1922 }
1923
1924
1925 SpirVGenerator::TypeKey::TypeKey(BasicTypeDeclaration::Kind kind, bool sign):
1926         type_id(0)
1927 {
1928         switch(kind)
1929         {
1930         case BasicTypeDeclaration::VOID: detail = 'v'; break;
1931         case BasicTypeDeclaration::BOOL: detail = 'b'; break;
1932         case BasicTypeDeclaration::INT: detail = (sign ? 'i' : 'u'); break;
1933         case BasicTypeDeclaration::FLOAT: detail = 'f'; break;
1934         default: throw invalid_argument("TypeKey::TypeKey");
1935         }
1936 }
1937
1938 bool SpirVGenerator::TypeKey::operator<(const TypeKey &other) const
1939 {
1940         if(type_id!=other.type_id)
1941                 return type_id<other.type_id;
1942         return detail<other.detail;
1943 }
1944
1945
1946 bool SpirVGenerator::ConstantKey::operator<(const ConstantKey &other) const
1947 {
1948         if(type_id!=other.type_id)
1949                 return type_id<other.type_id;
1950         return int_value<other.int_value;
1951 }
1952
1953 } // namespace SL
1954 } // namespace GL
1955 } // namespace Msp