]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/spirv.cpp
9cb1b7f2e112a8d8d06d85dc92ab3f1d8561dcd5
[libs/gl.git] / source / glsl / spirv.cpp
1 #include <msp/core/algorithm.h>
2 #include <msp/core/maputils.h>
3 #include <msp/core/raii.h>
4 #include "reflect.h"
5 #include "spirv.h"
6
7 using namespace std;
8
9 namespace Msp {
10 namespace GL {
11 namespace SL {
12
13 const SpirVGenerator::BuiltinFunctionInfo SpirVGenerator::builtin_functions[] =
14 {
15         { "radians", "f", "GLSL.std.450", GLSL450_RADIANS, { 1 }, 0, 0 },
16         { "degrees", "f", "GLSL.std.450", GLSL450_DEGREES, { 1 }, 0, 0 },
17         { "sin", "f", "GLSL.std.450", GLSL450_SIN, { 1 }, 0, 0 },
18         { "cos", "f", "GLSL.std.450", GLSL450_COS, { 1 }, 0, 0 },
19         { "tan", "f", "GLSL.std.450", GLSL450_TAN, { 1 }, 0, 0 },
20         { "asin", "f", "GLSL.std.450", GLSL450_ASIN, { 1 }, 0, 0 },
21         { "acos", "f", "GLSL.std.450", GLSL450_ACOS, { 1 }, 0, 0 },
22         { "atan", "f", "GLSL.std.450", GLSL450_ATAN, { 1 }, 0, 0 },
23         { "atan", "ff", "GLSL.std.450", GLSL450_ATAN2, { 1, 2 }, 0, 0 },
24         { "sinh", "f", "GLSL.std.450", GLSL450_SINH, { 1 }, 0, 0 },
25         { "cosh", "f", "GLSL.std.450", GLSL450_COSH, { 1 }, 0, 0 },
26         { "tanh", "f", "GLSL.std.450", GLSL450_TANH, { 1 }, 0, 0 },
27         { "asinh", "f", "GLSL.std.450", GLSL450_ASINH, { 1 }, 0, 0 },
28         { "acosh", "f", "GLSL.std.450", GLSL450_ACOSH, { 1 }, 0, 0 },
29         { "atanh", "f", "GLSL.std.450", GLSL450_ATANH, { 1 }, 0, 0 },
30         { "pow", "ff", "GLSL.std.450", GLSL450_POW, { 1, 2 }, 0, 0 },
31         { "exp", "f", "GLSL.std.450", GLSL450_EXP, { 1 }, 0, 0 },
32         { "log", "f", "GLSL.std.450", GLSL450_LOG, { 1 }, 0, 0 },
33         { "exp2", "f", "GLSL.std.450", GLSL450_EXP2, { 1 }, 0, 0 },
34         { "log2", "f", "GLSL.std.450", GLSL450_LOG2, { 1 }, 0, 0 },
35         { "sqrt", "f", "GLSL.std.450", GLSL450_SQRT, { 1 }, 0, 0 },
36         { "inversesqrt", "f", "GLSL.std.450", GLSL450_INVERSE_SQRT, { 1 }, 0, 0 },
37         { "abs", "f", "GLSL.std.450", GLSL450_F_ABS, { 1 }, 0, 0 },
38         { "abs", "i", "GLSL.std.450", GLSL450_S_ABS, { 1 }, 0, 0 },
39         { "sign", "f", "GLSL.std.450", GLSL450_F_SIGN, { 1 }, 0, 0 },
40         { "sign", "i", "GLSL.std.450", GLSL450_S_SIGN, { 1 }, 0, 0 },
41         { "floor", "f", "GLSL.std.450", GLSL450_FLOOR, { 1 }, 0, 0 },
42         { "trunc", "f", "GLSL.std.450", GLSL450_TRUNC, { 1 }, 0, 0 },
43         { "round", "f", "GLSL.std.450", GLSL450_ROUND, { 1 }, 0, 0 },
44         { "roundEven", "f", "GLSL.std.450", GLSL450_ROUND_EVEN, { 1 }, 0, 0 },
45         { "ceil", "f", "GLSL.std.450", GLSL450_CEIL, { 1 }, 0, 0 },
46         { "fract", "f", "GLSL.std.450", GLSL450_FRACT, { 1 }, 0, 0 },
47         { "mod", "f", "", OP_F_MOD, { 1, 2 }, 0, 0 },
48         { "min", "ff", "GLSL.std.450", GLSL450_F_MIN, { 1, 2 }, 0, 0 },
49         { "min", "ii", "GLSL.std.450", GLSL450_S_MIN, { 1, 2 }, 0, 0 },
50         { "min", "uu", "GLSL.std.450", GLSL450_U_MIN, { 1, 2 }, 0, 0 },
51         { "max", "ff", "GLSL.std.450", GLSL450_F_MAX, { 1, 2 }, 0, 0 },
52         { "max", "ii", "GLSL.std.450", GLSL450_S_MAX, { 1, 2 }, 0, 0 },
53         { "max", "uu", "GLSL.std.450", GLSL450_U_MAX, { 1, 2 }, 0, 0 },
54         { "clamp", "fff", "GLSL.std.450", GLSL450_F_CLAMP, { 1, 2, 3 }, 0, 0 },
55         { "clamp", "iii", "GLSL.std.450", GLSL450_S_CLAMP, { 1, 2, 3 }, 0, 0 },
56         { "clamp", "uuu", "GLSL.std.450", GLSL450_U_CLAMP, { 1, 2, 3 }, 0, 0 },
57         { "mix", "fff", "GLSL.std.450", GLSL450_F_MIX, { 1, 2, 3 }, 0, 0 },
58         { "mix", "ffb", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
59         { "mix", "iib", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
60         { "mix", "uub", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
61         { "step", "ff", "GLSL.std.450", GLSL450_F_STEP, { 1, 2 }, 0, 0 },
62         { "smoothstep", "fff", "GLSL.std.450", GLSL450_F_SMOOTH_STEP, { 1, 2, 3 }, 0, 0 },
63         { "isnan", "f", "", OP_IS_NAN, { 1 }, 0, 0 },
64         { "isinf", "f", "", OP_IS_INF, { 1 }, 0, 0 },
65         { "fma", "fff", "GLSL.std.450", GLSL450_F_FMA, { 1, 2, 3 }, 0, 0 },
66         { "length", "f", "GLSL.std.450", GLSL450_LENGTH, { 1 }, 0, 0 },
67         { "distance", "ff", "GLSL.std.450", GLSL450_DISTANCE, { 1, 2 }, 0, 0 },
68         { "dot", "ff", "", OP_DOT, { 1, 2 }, 0, 0 },
69         { "cross", "ff", "GLSL.std.450", GLSL450_CROSS, { 1, 2 }, 0, 0 },
70         { "normalize", "f", "GLSL.std.450", GLSL450_NORMALIZE, { 1 }, 0, 0 },
71         { "faceforward", "fff", "GLSL.std.450", GLSL450_FACE_FORWARD, { 1, 2, 3 }, 0, 0 },
72         { "reflect", "ff", "GLSL.std.450", GLSL450_REFLECT, { 1, 2 }, 0, 0 },
73         { "refract", "fff", "GLSL.std.450", GLSL450_REFRACT, { 1, 2, 3 }, 0, 0 },
74         { "matrixCompMult", "ff", "", 0, { 0 }, 0, &SpirVGenerator::visit_builtin_matrix_comp_mult },
75         { "outerProduct", "ff", "", OP_OUTER_PRODUCT, { 1, 2 }, 0, 0 },
76         { "transpose", "f", "", OP_TRANSPOSE, { 1 }, 0, 0 },
77         { "determinant", "f", "GLSL.std.450", GLSL450_DETERMINANT, { 1 }, 0, 0 },
78         { "inverse", "f", "GLSL.std.450", GLSL450_MATRIX_INVERSE, { 1 }, 0, 0 },
79         { "lessThan", "ff", "", OP_F_ORD_LESS_THAN, { 1, 2 }, 0, 0 },
80         { "lessThan", "ii", "", OP_S_LESS_THAN, { 1, 2 }, 0, 0 },
81         { "lessThan", "uu", "", OP_U_LESS_THAN, { 1, 2 }, 0, 0 },
82         { "lessThanEqual", "ff", "", OP_F_ORD_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
83         { "lessThanEqual", "ii", "", OP_S_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
84         { "lessThanEqual", "uu", "", OP_U_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
85         { "greaterThan", "ff", "", OP_F_ORD_GREATER_THAN, { 1, 2 }, 0, 0 },
86         { "greaterThan", "ii", "", OP_S_GREATER_THAN, { 1, 2 }, 0, 0 },
87         { "greaterThan", "uu", "", OP_U_GREATER_THAN, { 1, 2 }, 0, 0 },
88         { "greaterThanEqual", "ff", "", OP_F_ORD_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
89         { "greaterThanEqual", "ii", "", OP_S_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
90         { "greaterThanEqual", "uu", "", OP_U_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
91         { "equal", "ff", "", OP_F_ORD_EQUAL, { 1, 2 }, 0, 0 },
92         { "equal", "ii", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
93         { "equal", "uu", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
94         { "notEqual", "ff", "", OP_F_ORD_NOT_EQUAL, { 1, 2 }, 0, 0 },
95         { "notEqual", "ii", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
96         { "notEqual", "uu", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
97         { "any", "b", "", OP_ANY, { 1 }, 0, 0 },
98         { "all", "b", "", OP_ALL, { 1 }, 0, 0 },
99         { "not", "b", "", OP_LOGICAL_NOT, { 1 }, 0, 0 },
100         { "bitfieldExtract", "iii", "", OP_BIT_FIELD_S_EXTRACT, { 1, 2, 3 }, 0, 0 },
101         { "bitfieldExtract", "uii", "", OP_BIT_FIELD_U_EXTRACT, { 1, 2, 3 }, 0, 0 },
102         { "bitfieldInsert", "iiii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
103         { "bitfieldInsert", "uuii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
104         { "bitfieldReverse", "i", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
105         { "bitfieldReverse", "u", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
106         { "bitCount", "i", "", OP_BIT_COUNT, { 1 }, 0, 0 },
107         { "findLSB", "i", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
108         { "findLSB", "u", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
109         { "findMSB", "i", "GLSL.std.450", GLSL450_FIND_S_MSB, { 1 }, 0, 0 },
110         { "findMSB", "u", "GLSL.std.450", GLSL450_FIND_U_MSB, { 1 }, 0, 0 },
111         { "textureSize", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
112         { "textureQueryLod", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
113         { "textureQueryLevels", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
114         { "textureSamples", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
115         { "texture", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
116         { "textureLod", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
117         { "texelFetch", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture_fetch },
118         { "imageSize", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
119         { "imageSamples", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
120         { "imageLoad", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture_fetch },
121         { "imageStore", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture_store },
122         { "EmitVertex", "", "", OP_EMIT_VERTEX, { }, 0, 0 },
123         { "EndPrimitive", "", "", OP_END_PRIMITIVE, { }, 0, 0 },
124         { "dFdx", "f", "", OP_DP_DX, { 1 }, 0, 0 },
125         { "dFdy", "f", "", OP_DP_DY, { 1 }, 0, 0 },
126         { "dFdxFine", "f", "", OP_DP_DX_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
127         { "dFdyFine", "f", "", OP_DP_DY_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
128         { "dFdxCoarse", "f", "", OP_DP_DX_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
129         { "dFdyCoarse", "f", "", OP_DP_DY_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
130         { "fwidth", "f", "", OP_FWIDTH, { 1 }, 0, 0 },
131         { "fwidthFine", "f", "", OP_FWIDTH_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
132         { "fwidthCoarse", "f", "", OP_FWIDTH_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
133         { "interpolateAtCentroid", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
134         { "interpolateAtSample", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
135         { "interpolateAtOffset", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
136         { "", "", "", 0, { }, 0, 0 }
137 };
138
139 SpirVGenerator::SpirVGenerator():
140         writer(content)
141 { }
142
143 void SpirVGenerator::apply(Module &module, const Features &f)
144 {
145         features = f;
146         use_capability(CAP_SHADER);
147
148         for(Stage &s: module.stages)
149         {
150                 stage = &s;
151                 interface_layouts.clear();
152                 s.content.visit(*this);
153         }
154
155         writer.finalize(SPIRV_GENERATOR_MSP, next_id);
156 }
157
158 SpirVGenerator::StorageClass SpirVGenerator::get_interface_storage(const string &iface, bool block)
159 {
160         if(iface=="in")
161                 return STORAGE_INPUT;
162         else if(iface=="out")
163                 return STORAGE_OUTPUT;
164         else if(iface=="uniform")
165                 return (block ? STORAGE_UNIFORM : STORAGE_UNIFORM_CONSTANT);
166         else if(iface.empty())
167                 return STORAGE_PRIVATE;
168         else
169                 throw invalid_argument("SpirVGenerator::get_interface_storage");
170 }
171
172 SpirVGenerator::BuiltinSemantic SpirVGenerator::get_builtin_semantic(const string &name)
173 {
174         if(name=="gl_Position")
175                 return BUILTIN_POSITION;
176         else if(name=="gl_PointSize")
177                 return BUILTIN_POINT_SIZE;
178         else if(name=="gl_ClipDistance")
179                 return BUILTIN_CLIP_DISTANCE;
180         else if(name=="gl_VertexID")
181                 return BUILTIN_VERTEX_ID;
182         else if(name=="gl_InstanceID")
183                 return BUILTIN_INSTANCE_ID;
184         else if(name=="gl_PrimitiveID" || name=="gl_PrimitiveIDIn")
185                 return BUILTIN_PRIMITIVE_ID;
186         else if(name=="gl_InvocationID")
187                 return BUILTIN_INVOCATION_ID;
188         else if(name=="gl_Layer")
189                 return BUILTIN_LAYER;
190         else if(name=="gl_FragCoord")
191                 return BUILTIN_FRAG_COORD;
192         else if(name=="gl_PointCoord")
193                 return BUILTIN_POINT_COORD;
194         else if(name=="gl_FrontFacing")
195                 return BUILTIN_FRONT_FACING;
196         else if(name=="gl_SampleId")
197                 return BUILTIN_SAMPLE_ID;
198         else if(name=="gl_SamplePosition")
199                 return BUILTIN_SAMPLE_POSITION;
200         else if(name=="gl_FragDepth")
201                 return BUILTIN_FRAG_DEPTH;
202         else
203                 throw invalid_argument("SpirVGenerator::get_builtin_semantic");
204 }
205
206 SpirVFormat SpirVGenerator::get_format(const std::string &name)
207 {
208         if(name.empty())
209                 return FORMAT_UNKNOWN;
210         else if(name=="rgba32f")
211                 return FORMAT_RGBA32F;
212         else if(name=="rgba16f")
213                 return FORMAT_RGBA16F;
214         else if(name=="r32f")
215                 return FORMAT_R32F;
216         else if(name=="rgba8")
217                 return FORMAT_RGBA8;
218         else if(name=="rgba8_snorm")
219                 return FORMAT_RGBA8_SNORM;
220         else if(name=="rg32f")
221                 return FORMAT_RG32F;
222         else if(name=="rg16f")
223                 return FORMAT_RG16F;
224         else if(name=="r16f")
225                 return FORMAT_R16F;
226         else if(name=="rgba16")
227                 return FORMAT_RGBA16;
228         else if(name=="rg16")
229                 return FORMAT_RG16;
230         else if(name=="rg8")
231                 return FORMAT_RG8;
232         else if(name=="r16")
233                 return FORMAT_RG16;
234         else if(name=="r8")
235                 return FORMAT_RG8;
236         else if(name=="rgba16_snorm")
237                 return FORMAT_RGBA16_SNORM;
238         else if(name=="rg16_snorm")
239                 return FORMAT_RG16_SNORM;
240         else if(name=="rg8_snorm")
241                 return FORMAT_RG8_SNORM;
242         else if(name=="r16_snorm")
243                 return FORMAT_RG16_SNORM;
244         else if(name=="r8_snorm")
245                 return FORMAT_RG8_SNORM;
246         else
247                 throw invalid_argument("SpirVGenerator::get_format");
248 }
249
250 void SpirVGenerator::use_capability(Capability cap)
251 {
252         if(used_capabilities.count(cap))
253                 return;
254
255         used_capabilities.insert(cap);
256         writer.write_op(content.capabilities, OP_CAPABILITY, cap);
257 }
258
259 SpirVGenerator::Id SpirVGenerator::import_extension(const string &name)
260 {
261         Id &ext_id = imported_extension_ids[name];
262         if(!ext_id)
263         {
264                 ext_id = next_id++;
265                 writer.begin_op(content.extensions, OP_EXT_INST_IMPORT);
266                 writer.write(ext_id);
267                 writer.write_string(name);
268                 writer.end_op(OP_EXT_INST_IMPORT);
269         }
270         return ext_id;
271 }
272
273 SpirVGenerator::Id SpirVGenerator::get_id(Node &node) const
274 {
275         return get_item(declared_ids, &node).id;
276 }
277
278 SpirVGenerator::Id SpirVGenerator::allocate_id(Node &node, Id type_id)
279 {
280         auto i = declared_ids.find(&node);
281         if(i!=declared_ids.end())
282         {
283                 if(i->second.type_id)
284                         throw key_error(&node);
285                 i->second.type_id = type_id;
286                 return i->second.id;
287         }
288
289         Id id = next_id++;
290         declared_ids.insert(make_pair(&node, Declaration(id, type_id)));
291         return id;
292 }
293
294 SpirVGenerator::Id SpirVGenerator::allocate_forward_id(Node &node)
295 {
296         auto i = declared_ids.find(&node);
297         if(i!=declared_ids.end())
298                 return i->second.id;
299
300         Id id = next_id++;
301         declared_ids.insert(make_pair(&node, Declaration(id, 0)));
302         return id;
303 }
304
305 SpirVGenerator::Id SpirVGenerator::write_constant(Id type_id, Word value, bool spec)
306 {
307         Id const_id = next_id++;
308         if(is_scalar_type(type_id, BasicTypeDeclaration::BOOL))
309         {
310                 Opcode opcode = (value ? (spec ? OP_SPEC_CONSTANT_TRUE : OP_CONSTANT_TRUE) :
311                         (spec ? OP_SPEC_CONSTANT_FALSE : OP_CONSTANT_FALSE));
312                 writer.write_op(content.globals, opcode, type_id, const_id);
313         }
314         else
315         {
316                 Opcode opcode = (spec ? OP_SPEC_CONSTANT : OP_CONSTANT);
317                 writer.write_op(content.globals, opcode, type_id, const_id, value);
318         }
319         return const_id;
320 }
321
322 SpirVGenerator::ConstantKey SpirVGenerator::get_constant_key(Id type_id, const Variant &value)
323 {
324         if(value.check_type<bool>())
325                 return ConstantKey(type_id, value.value<bool>());
326         else if(value.check_type<int>())
327                 return ConstantKey(type_id, value.value<int>());
328         else if(value.check_type<unsigned>())
329                 return ConstantKey(type_id, value.value<unsigned>());
330         else if(value.check_type<float>())
331                 return ConstantKey(type_id, value.value<float>());
332         else
333                 throw invalid_argument("SpirVGenerator::get_constant_key");
334 }
335
336 SpirVGenerator::Id SpirVGenerator::get_constant_id(Id type_id, const Variant &value)
337 {
338         ConstantKey key = get_constant_key(type_id, value);
339         Id &const_id = constant_ids[key];
340         if(!const_id)
341                 const_id = write_constant(type_id, key.int_value, false);
342         return const_id;
343 }
344
345 SpirVGenerator::Id SpirVGenerator::get_vector_constant_id(Id type_id, unsigned size, Id scalar_id)
346 {
347         Id &const_id = constant_ids[get_constant_key(type_id, static_cast<int>(scalar_id))];
348         if(!const_id)
349         {
350                 const_id = next_id++;
351                 writer.begin_op(content.globals, OP_CONSTANT_COMPOSITE, 3+size);
352                 writer.write(type_id);
353                 writer.write(const_id);
354                 for(unsigned i=0; i<size; ++i)
355                         writer.write(scalar_id);
356                 writer.end_op(OP_CONSTANT_COMPOSITE);
357         }
358         return const_id;
359 }
360
361 SpirVGenerator::Id SpirVGenerator::get_standard_type_id(BasicTypeDeclaration::Kind kind, unsigned size, bool sign)
362 {
363         Id base_id = (size>1 ? get_standard_type_id(kind, 1, sign) : 0);
364         Id &type_id = standard_type_ids[base_id ? TypeKey(base_id, size) : TypeKey(kind, sign)];
365         if(!type_id)
366         {
367                 type_id = next_id++;
368                 if(size>1)
369                         writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, base_id, size);
370                 else if(kind==BasicTypeDeclaration::VOID)
371                         writer.write_op(content.globals, OP_TYPE_VOID, type_id);
372                 else if(kind==BasicTypeDeclaration::BOOL)
373                         writer.write_op(content.globals, OP_TYPE_BOOL, type_id);
374                 else if(kind==BasicTypeDeclaration::INT)
375                         writer.write_op(content.globals, OP_TYPE_INT, type_id, 32, sign);
376                 else if(kind==BasicTypeDeclaration::FLOAT)
377                         writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, 32);
378                 else
379                         throw invalid_argument("SpirVGenerator::get_standard_type_id");
380         }
381         return type_id;
382 }
383
384 bool SpirVGenerator::is_scalar_type(Id type_id, BasicTypeDeclaration::Kind kind) const
385 {
386         auto i = standard_type_ids.find(TypeKey(kind, true));
387         return (i!=standard_type_ids.end() && i->second==type_id);
388 }
389
390 SpirVGenerator::Id SpirVGenerator::get_array_type_id(TypeDeclaration &base_type, Id size_id, bool extended_align)
391 {
392         Id base_type_id = get_id(base_type);
393         Id &array_type_id = array_type_ids[TypeKey(base_type_id, extended_align*0x400000 | size_id)];
394         if(!array_type_id)
395         {
396                 array_type_id = next_id++;
397                 if(size_id)
398                         writer.write_op(content.globals, OP_TYPE_ARRAY, array_type_id, base_type_id, size_id);
399                 else
400                         writer.write_op(content.globals, OP_TYPE_RUNTIME_ARRAY, array_type_id, base_type_id);
401
402                 unsigned stride = MemoryRequirementsCalculator().apply(base_type).stride;
403                 if(extended_align)
404                         stride = (stride+15)&~15U;
405                 writer.write_op_decorate(array_type_id, DECO_ARRAY_STRIDE, stride);
406         }
407
408         return array_type_id;
409 }
410
411 SpirVGenerator::Id SpirVGenerator::get_pointer_type_id(Id type_id, StorageClass storage)
412 {
413         Id &ptr_type_id = pointer_type_ids[TypeKey(type_id, storage)];
414         if(!ptr_type_id)
415         {
416                 ptr_type_id = next_id++;
417                 writer.write_op(content.globals, OP_TYPE_POINTER, ptr_type_id, storage, type_id);
418         }
419         return ptr_type_id;
420 }
421
422 SpirVGenerator::Id SpirVGenerator::get_variable_type_id(const VariableDeclaration &var)
423 {
424         if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var.type_declaration))
425                 if(basic->kind==BasicTypeDeclaration::ARRAY)
426                 {
427                         if(!var.array_size)
428                                 throw logic_error("array without size");
429
430                         SetFlag set_const(constant_expression);
431                         r_expression_result_id = 0;
432                         var.array_size->visit(*this);
433                         return get_array_type_id(*basic->base_type, r_expression_result_id, basic->extended_alignment);
434                 }
435
436         return get_id(*var.type_declaration);
437 }
438
439 SpirVGenerator::Id SpirVGenerator::get_load_id(VariableDeclaration &var)
440 {
441         Id &load_result_id = variable_load_ids[&var];
442         if(!load_result_id)
443         {
444                 load_result_id = next_id++;
445                 writer.write_op(content.function_body, OP_LOAD, get_variable_type_id(var), load_result_id, get_id(var));
446         }
447         return load_result_id;
448 }
449
450 void SpirVGenerator::prune_loads(Id min_id)
451 {
452         for(auto i=variable_load_ids.begin(); i!=variable_load_ids.end(); )
453         {
454                 if(i->second>=min_id)
455                         variable_load_ids.erase(i++);
456                 else
457                         ++i;
458         }
459 }
460
461 SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, unsigned n_args)
462 {
463         bool has_result = (opcode==OP_FUNCTION_CALL || !is_scalar_type(type_id, BasicTypeDeclaration::VOID));
464         if(!constant_expression)
465         {
466                 if(!current_function)
467                         throw internal_error("non-constant expression outside a function");
468
469                 writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0));
470         }
471         else if(opcode==OP_COMPOSITE_CONSTRUCT)
472                 writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE),
473                         (n_args ? 1+has_result*2+n_args : 0));
474         else if(!spec_constant)
475                 throw internal_error("invalid non-specialization constant expression");
476         else
477                 writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0));
478
479         Id result_id = next_id++;
480         if(has_result)
481         {
482                 writer.write(type_id);
483                 writer.write(result_id);
484         }
485         if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT)
486                 writer.write(opcode);
487
488         return result_id;
489 }
490
491 void SpirVGenerator::end_expression(Opcode opcode)
492 {
493         if(constant_expression)
494                 opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
495         writer.end_op(opcode);
496 }
497
498 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id arg_id)
499 {
500         Id result_id = begin_expression(opcode, type_id, 1);
501         writer.write(arg_id);
502         end_expression(opcode);
503         return result_id;
504 }
505
506 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id left_id, Id right_id)
507 {
508         Id result_id = begin_expression(opcode, type_id, 2);
509         writer.write(left_id);
510         writer.write(right_id);
511         end_expression(opcode);
512         return result_id;
513 }
514
515 void SpirVGenerator::write_deconstruct(Id elem_type_id, Id composite_id, Id *elem_ids, unsigned n_elems)
516 {
517         for(unsigned i=0; i<n_elems; ++i)
518         {
519                 elem_ids[i] = begin_expression(OP_COMPOSITE_EXTRACT, elem_type_id, 2);
520                 writer.write(composite_id);
521                 writer.write(i);
522                 end_expression(OP_COMPOSITE_EXTRACT);
523         }
524 }
525
526 SpirVGenerator::Id SpirVGenerator::write_construct(Id type_id, const Id *elem_ids, unsigned n_elems)
527 {
528         Id result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, type_id, n_elems);
529         for(unsigned i=0; i<n_elems; ++i)
530                 writer.write(elem_ids[i]);
531         end_expression(OP_COMPOSITE_CONSTRUCT);
532
533         return result_id;
534 }
535
536 void SpirVGenerator::visit(Block &block)
537 {
538         for(const RefPtr<Statement> &s: block.body)
539                 s->visit(*this);
540 }
541
542 void SpirVGenerator::visit(Literal &literal)
543 {
544         Id type_id = get_id(*literal.type);
545         if(spec_constant)
546                 r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true);
547         else
548                 r_expression_result_id = get_constant_id(type_id, literal.value);
549         r_constant_result = true;
550 }
551
552 void SpirVGenerator::visit(VariableReference &var)
553 {
554         if(constant_expression || var.declaration->constant)
555         {
556                 if(!var.declaration->constant)
557                         throw internal_error("reference to non-constant variable in constant context");
558
559                 r_expression_result_id = get_id(*var.declaration);
560                 r_constant_result = true;
561                 return;
562         }
563         else if(!current_function)
564                 throw internal_error("non-constant context outside a function");
565
566         r_constant_result = false;
567         if(composite_access)
568         {
569                 r_expression_result_id = 0;
570                 if(!assignment_source_id)
571                 {
572                         auto i = variable_load_ids.find(var.declaration);
573                         if(i!=variable_load_ids.end())
574                                 r_expression_result_id = i->second;
575                 }
576                 if(!r_expression_result_id)
577                         r_composite_base = var.declaration;
578         }
579         else if(assignment_source_id)
580         {
581                 writer.write_op(content.function_body, OP_STORE, get_id(*var.declaration), assignment_source_id);
582                 variable_load_ids[var.declaration] = assignment_source_id;
583                 r_expression_result_id = assignment_source_id;
584         }
585         else
586                 r_expression_result_id = get_load_id(*var.declaration);
587 }
588
589 void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
590 {
591         Opcode opcode;
592         Id result_type_id = get_id(result_type);
593         Id access_type_id = result_type_id;
594         if(r_composite_base)
595         {
596                 if(constant_expression)
597                         throw internal_error("composite access through pointer in constant context");
598
599                 Id int32_type_id = get_standard_type_id(BasicTypeDeclaration::INT, 1);
600                 for(unsigned &i: r_composite_chain)
601                         i = (i<0x400000 ? get_constant_id(int32_type_id, static_cast<int>(i)) : i&0x3FFFFF);
602
603                 /* Find the storage class of the base and obtain appropriate pointer type
604                 for the result. */
605                 const Declaration &base_decl = get_item(declared_ids, r_composite_base);
606                 auto i = pointer_type_ids.begin();
607                 for(; (i!=pointer_type_ids.end() && i->second!=base_decl.type_id); ++i) ;
608                 if(i==pointer_type_ids.end())
609                         throw internal_error("could not find storage class");
610                 access_type_id = get_pointer_type_id(result_type_id, static_cast<StorageClass>(i->first.detail));
611
612                 opcode = OP_ACCESS_CHAIN;
613         }
614         else if(assignment_source_id)
615                 throw internal_error("assignment to temporary composite");
616         else
617         {
618                 for(unsigned &i: r_composite_chain)
619                         for(auto j=constant_ids.begin(); (i>=0x400000 && j!=constant_ids.end()); ++j)
620                                 if(j->second==(i&0x3FFFFF))
621                                         i = j->first.int_value;
622
623                 opcode = OP_COMPOSITE_EXTRACT;
624         }
625
626         Id access_id = begin_expression(opcode, access_type_id, 1+r_composite_chain.size());
627         writer.write(r_composite_base_id);
628         for(unsigned i: r_composite_chain)
629                 writer.write(i);
630         end_expression(opcode);
631
632         r_constant_result = false;
633         if(r_composite_base)
634         {
635                 if(assignment_source_id)
636                 {
637                         writer.write_op(content.function_body, OP_STORE, access_id, assignment_source_id);
638                         r_expression_result_id = assignment_source_id;
639                 }
640                 else
641                         r_expression_result_id = write_expression(OP_LOAD, result_type_id, access_id);
642         }
643         else
644                 r_expression_result_id = access_id;
645 }
646
647 void SpirVGenerator::visit_composite(Expression &base_expr, unsigned index, TypeDeclaration &type)
648 {
649         if(!composite_access)
650         {
651                 r_composite_base = 0;
652                 r_composite_base_id = 0;
653                 r_composite_chain.clear();
654         }
655
656         {
657                 SetFlag set_composite(composite_access);
658                 base_expr.visit(*this);
659         }
660
661         if(!r_composite_base_id)
662                 r_composite_base_id = (r_composite_base ? get_id(*r_composite_base) : r_expression_result_id);
663
664         r_composite_chain.push_back(index);
665         if(!composite_access)
666                 generate_composite_access(type);
667         else
668                 r_expression_result_id = 0;
669 }
670
671 void SpirVGenerator::visit_isolated(Expression &expr)
672 {
673         SetForScope<Id> clear_assign(assignment_source_id, 0);
674         SetFlag clear_composite(composite_access, false);
675         SetForScope<Node *> clear_base(r_composite_base, 0);
676         SetForScope<Id> clear_base_id(r_composite_base_id, 0);
677         vector<unsigned> saved_chain;
678         swap(saved_chain, r_composite_chain);
679         expr.visit(*this);
680         swap(saved_chain, r_composite_chain);
681 }
682
683 void SpirVGenerator::visit(MemberAccess &memacc)
684 {
685         visit_composite(*memacc.left, memacc.index, *memacc.type);
686 }
687
688 void SpirVGenerator::visit(Swizzle &swizzle)
689 {
690         if(swizzle.count==1)
691                 visit_composite(*swizzle.left, swizzle.components[0], *swizzle.type);
692         else if(assignment_source_id)
693         {
694                 const BasicTypeDeclaration &basic = dynamic_cast<const BasicTypeDeclaration &>(*swizzle.left->type);
695
696                 unsigned mask = 0;
697                 for(unsigned i=0; i<swizzle.count; ++i)
698                         mask |= 1<<swizzle.components[i];
699
700                 visit_isolated(*swizzle.left);
701
702                 Id combined_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.left->type), 2+basic.size);
703                 writer.write(r_expression_result_id);
704                 writer.write(assignment_source_id);
705                 for(unsigned i=0; i<basic.size; ++i)
706                         writer.write(i+((mask>>i)&1)*basic.size);
707                 end_expression(OP_VECTOR_SHUFFLE);
708
709                 SetForScope<Id> set_assign(assignment_source_id, combined_id);
710                 swizzle.left->visit(*this);
711
712                 r_expression_result_id = combined_id;
713         }
714         else
715         {
716                 swizzle.left->visit(*this);
717                 Id left_id = r_expression_result_id;
718
719                 r_expression_result_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.type), 2+swizzle.count);
720                 writer.write(left_id);
721                 writer.write(left_id);
722                 for(unsigned i=0; i<swizzle.count; ++i)
723                         writer.write(swizzle.components[i]);
724                 end_expression(OP_VECTOR_SHUFFLE);
725         }
726         r_constant_result = false;
727 }
728
729 void SpirVGenerator::visit(UnaryExpression &unary)
730 {
731         if(composite_access)
732                 return visit_isolated(unary);
733
734         unary.expression->visit(*this);
735
736         char oper = unary.oper->token[0];
737         char oper2 = unary.oper->token[1];
738         if(oper=='+' && !oper2)
739                 return;
740
741         BasicTypeDeclaration &basic = dynamic_cast<BasicTypeDeclaration &>(*unary.expression->type);
742         BasicTypeDeclaration &elem = *get_element_type(basic);
743
744         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
745                 /* SPIR-V allows constant operations on floating-point values only for
746                 OpenGL kernels. */
747                 throw internal_error("invalid operands for constant unary expression");
748
749         Id result_type_id = get_id(*unary.type);
750         Opcode opcode = OP_NOP;
751
752         r_constant_result = false;
753         if(oper=='!')
754                 opcode = OP_LOGICAL_NOT;
755         else if(oper=='~')
756                 opcode = OP_NOT;
757         else if(oper=='-' && !oper2)
758         {
759                 opcode = (elem.kind==BasicTypeDeclaration::INT ? OP_S_NEGATE : OP_F_NEGATE);
760
761                 if(basic.kind==BasicTypeDeclaration::MATRIX)
762                 {
763                         Id column_type_id = get_id(*basic.base_type);
764                         unsigned n_columns = basic.size&0xFFFF;
765                         Id column_ids[4];
766                         write_deconstruct(column_type_id, r_expression_result_id, column_ids, n_columns);
767                         for(unsigned i=0; i<n_columns; ++i)
768                                 column_ids[i] = write_expression(opcode, column_type_id, column_ids[i]);
769                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
770                         return;
771                 }
772         }
773         else if((oper=='+' || oper=='-') && oper2==oper)
774         {
775                 if(constant_expression)
776                         throw internal_error("increment or decrement in constant expression");
777
778                 Id one_id = 0;
779                 if(elem.kind==BasicTypeDeclaration::INT)
780                 {
781                         opcode = (oper=='+' ? OP_I_ADD : OP_I_SUB);
782                         one_id = get_constant_id(get_id(elem), 1);
783                 }
784                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
785                 {
786                         opcode = (oper=='+' ? OP_F_ADD : OP_F_SUB);
787                         one_id = get_constant_id(get_id(elem), 1.0f);
788                 }
789                 else
790                         throw internal_error("invalid increment/decrement");
791
792                 if(basic.kind==BasicTypeDeclaration::VECTOR)
793                         one_id = get_vector_constant_id(result_type_id, basic.size, one_id);
794
795                 Id post_value_id = write_expression(opcode, result_type_id, r_expression_result_id, one_id);
796
797                 SetForScope<Id> set_assign(assignment_source_id, post_value_id);
798                 unary.expression->visit(*this);
799
800                 r_expression_result_id = (unary.oper->type==Operator::POSTFIX ? r_expression_result_id : post_value_id);
801                 return;
802         }
803
804         if(opcode==OP_NOP)
805                 throw internal_error("unknown unary operator");
806
807         r_expression_result_id = write_expression(opcode, result_type_id, r_expression_result_id);
808 }
809
810 void SpirVGenerator::visit(BinaryExpression &binary)
811 {
812         char oper = binary.oper->token[0];
813         if(oper=='[')
814         {
815                 visit_isolated(*binary.right);
816                 return visit_composite(*binary.left, 0x400000|r_expression_result_id, *binary.type);
817         }
818         else if(composite_access)
819                 return visit_isolated(binary);
820
821         if(assignment_source_id)
822                 throw internal_error("invalid binary expression in assignment target");
823
824         BasicTypeDeclaration &basic_left = dynamic_cast<BasicTypeDeclaration &>(*binary.left->type);
825         BasicTypeDeclaration &basic_right = dynamic_cast<BasicTypeDeclaration &>(*binary.right->type);
826         // Expression resolver ensures that element types are the same
827         BasicTypeDeclaration &elem = *get_element_type(basic_left);
828
829         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
830                 /* SPIR-V allows constant operations on floating-point values only for
831                 OpenGL kernels. */
832                 throw internal_error("invalid operands for constant binary expression");
833
834         binary.left->visit(*this);
835         Id left_id = r_expression_result_id;
836         binary.right->visit(*this);
837         Id right_id = r_expression_result_id;
838
839         Id result_type_id = get_id(*binary.type);
840         Opcode opcode = OP_NOP;
841         bool swap_operands = false;
842
843         r_constant_result = false;
844
845         char oper2 = binary.oper->token[1];
846         if((oper=='<' || oper=='>') && oper2!=oper)
847         {
848                 if(basic_left.kind==BasicTypeDeclaration::INT)
849                 {
850                         if(basic_left.sign)
851                                 opcode = (oper=='<' ? (oper2=='=' ? OP_S_LESS_THAN_EQUAL : OP_S_LESS_THAN) :
852                                         (oper2=='=' ? OP_S_GREATER_THAN_EQUAL : OP_S_GREATER_THAN));
853                         else
854                                 opcode = (oper=='<' ? (oper2=='=' ? OP_U_LESS_THAN_EQUAL : OP_U_LESS_THAN) :
855                                         (oper2=='=' ? OP_U_GREATER_THAN_EQUAL : OP_U_GREATER_THAN));
856                 }
857                 else if(basic_left.kind==BasicTypeDeclaration::FLOAT)
858                         opcode = (oper=='<' ? (oper2=='=' ? OP_F_ORD_LESS_THAN_EQUAL : OP_F_ORD_LESS_THAN) :
859                                 (oper2=='=' ? OP_F_ORD_GREATER_THAN_EQUAL : OP_F_ORD_GREATER_THAN));
860         }
861         else if((oper=='=' || oper=='!') && oper2=='=')
862         {
863                 if(elem.kind==BasicTypeDeclaration::BOOL)
864                         opcode = (oper=='=' ? OP_LOGICAL_EQUAL : OP_LOGICAL_NOT_EQUAL);
865                 else if(elem.kind==BasicTypeDeclaration::INT)
866                         opcode = (oper=='=' ? OP_I_EQUAL : OP_I_NOT_EQUAL);
867                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
868                         opcode = (oper=='=' ? OP_F_ORD_EQUAL : OP_F_ORD_NOT_EQUAL);
869
870                 if(opcode!=OP_NOP && basic_left.base_type)
871                 {
872                         /* The SPIR-V equality operations produce component-wise results, but
873                         GLSL operators return a single boolean.  Use the any/all operations to
874                         combine the results. */
875                         Opcode combine_op = (oper=='!' ? OP_ANY : OP_ALL);
876                         unsigned n_elems = basic_left.size&0xFFFF;
877                         Id bool_vec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, n_elems);
878
879                         Id compare_id = 0;
880                         if(basic_left.kind==BasicTypeDeclaration::VECTOR)
881                                 compare_id = write_expression(opcode, bool_vec_type_id, left_id, right_id);
882                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX)
883                         {
884                                 Id column_type_id = get_id(*basic_left.base_type);
885                                 Id column_ids[8];
886                                 write_deconstruct(column_type_id, left_id, column_ids, n_elems);
887                                 write_deconstruct(column_type_id, right_id, column_ids+4, n_elems);
888
889                                 Id column_bvec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, basic_left.size>>16);
890                                 for(unsigned i=0; i<n_elems; ++i)
891                                 {
892                                         compare_id = write_expression(opcode, column_bvec_type_id, column_ids[i], column_ids[4+i]);
893                                         column_ids[i] = write_expression(combine_op, result_type_id, compare_id);;
894                                 }
895
896                                 compare_id = write_construct(bool_vec_type_id, column_ids, n_elems);
897                         }
898
899                         if(compare_id)
900                                 r_expression_result_id = write_expression(combine_op, result_type_id, compare_id);
901                         return;
902                 }
903         }
904         else if(oper2=='&' && elem.kind==BasicTypeDeclaration::BOOL)
905                 opcode = OP_LOGICAL_AND;
906         else if(oper2=='|' && elem.kind==BasicTypeDeclaration::BOOL)
907                 opcode = OP_LOGICAL_OR;
908         else if(oper2=='^' && elem.kind==BasicTypeDeclaration::BOOL)
909                 opcode = OP_LOGICAL_NOT_EQUAL;
910         else if(oper=='&' && elem.kind==BasicTypeDeclaration::INT)
911                 opcode = OP_BITWISE_AND;
912         else if(oper=='|' && elem.kind==BasicTypeDeclaration::INT)
913                 opcode = OP_BITWISE_OR;
914         else if(oper=='^' && elem.kind==BasicTypeDeclaration::INT)
915                 opcode = OP_BITWISE_XOR;
916         else if(oper=='<' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
917                 opcode = OP_SHIFT_LEFT_LOGICAL;
918         else if(oper=='>' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
919                 opcode = OP_SHIFT_RIGHT_ARITHMETIC;
920         else if(oper=='%' && elem.kind==BasicTypeDeclaration::INT)
921                 opcode = (elem.sign ? OP_S_MOD : OP_U_MOD);
922         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
923         {
924                 Opcode elem_op = OP_NOP;
925                 if(elem.kind==BasicTypeDeclaration::INT)
926                 {
927                         if(oper=='/')
928                                 elem_op = (elem.sign ? OP_S_DIV : OP_U_DIV);
929                         else
930                                 elem_op = (oper=='+' ? OP_I_ADD : oper=='-' ? OP_I_SUB : OP_I_MUL);
931                 }
932                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
933                         elem_op = (oper=='+' ? OP_F_ADD : oper=='-' ? OP_F_SUB : oper=='*' ? OP_F_MUL : OP_F_DIV);
934
935                 if(oper=='*' && (basic_left.base_type || basic_right.base_type) && elem.kind==BasicTypeDeclaration::FLOAT)
936                 {
937                         /* Multiplication between floating-point vectors and matrices has
938                         dedicated operations. */
939                         if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
940                                 opcode = OP_MATRIX_TIMES_MATRIX;
941                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
942                         {
943                                 if(basic_left.kind==BasicTypeDeclaration::VECTOR)
944                                         opcode = OP_VECTOR_TIMES_MATRIX;
945                                 else if(basic_right.kind==BasicTypeDeclaration::VECTOR)
946                                         opcode = OP_MATRIX_TIMES_VECTOR;
947                                 else
948                                 {
949                                         opcode = OP_MATRIX_TIMES_SCALAR;
950                                         swap_operands = (basic_right.kind==BasicTypeDeclaration::MATRIX);
951                                 }
952                         }
953                         else if(basic_left.kind==BasicTypeDeclaration::VECTOR && basic_right.kind==BasicTypeDeclaration::VECTOR)
954                                 opcode = elem_op;
955                         else
956                         {
957                                 opcode = OP_VECTOR_TIMES_SCALAR;
958                                 swap_operands = (basic_right.kind==BasicTypeDeclaration::VECTOR);
959                         }
960                 }
961                 else if((basic_left.base_type!=0)!=(basic_right.base_type!=0))
962                 {
963                         /* One operand is scalar and the other is a vector or a matrix.
964                         Expand the scalar to a vector of appropriate size. */
965                         Id &scalar_id = (basic_left.base_type ? right_id : left_id);
966                         BasicTypeDeclaration *vector_type = (basic_left.base_type ? &basic_left : &basic_right);
967                         if(vector_type->kind==BasicTypeDeclaration::MATRIX)
968                                 vector_type = dynamic_cast<BasicTypeDeclaration *>(vector_type->base_type);
969                         Id vector_type_id = get_id(*vector_type);
970
971                         Id expanded_id = begin_expression(OP_COMPOSITE_CONSTRUCT, vector_type_id, vector_type->size);
972                         for(unsigned i=0; i<vector_type->size; ++i)
973                                 writer.write(scalar_id);
974                         end_expression(OP_COMPOSITE_CONSTRUCT);
975
976                         scalar_id = expanded_id;
977
978                         if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
979                         {
980                                 // Apply matrix operation column-wise.
981                                 Id matrix_id = (basic_left.base_type ? left_id : right_id);
982
983                                 Id column_ids[4];
984                                 unsigned n_columns = (basic_left.base_type ? basic_left.size : basic_right.size)&0xFFFF;
985                                 write_deconstruct(vector_type_id, matrix_id, column_ids, n_columns);
986
987                                 for(unsigned i=0; i<n_columns; ++i)
988                                         column_ids[i] = write_expression(elem_op, vector_type_id, column_ids[i], expanded_id);
989
990                                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
991                                 return;
992                         }
993                         else
994                                 opcode = elem_op;
995                 }
996                 else if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
997                 {
998                         if(oper=='*')
999                                 throw internal_error("non-float matrix multiplication");
1000
1001                         /* Other operations involving matrices need to be performed
1002                         column-wise. */
1003                         Id column_type_id = get_id(*basic_left.base_type);
1004                         Id column_ids[8];
1005
1006                         unsigned n_columns = basic_left.size&0xFFFF;
1007                         write_deconstruct(column_type_id, left_id, column_ids, n_columns);
1008                         write_deconstruct(column_type_id, right_id, column_ids+4, n_columns);
1009
1010                         for(unsigned i=0; i<n_columns; ++i)
1011                                 column_ids[i] = write_expression(elem_op, column_type_id, column_ids[i], column_ids[4+i]);
1012
1013                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
1014                         return;
1015                 }
1016                 else if(basic_left.kind==basic_right.kind)
1017                         // Both operands are either scalars or vectors.
1018                         opcode = elem_op;
1019         }
1020
1021         if(opcode==OP_NOP)
1022                 throw internal_error("unknown binary operator");
1023
1024         if(swap_operands)
1025                 swap(left_id, right_id);
1026
1027         r_expression_result_id = write_expression(opcode, result_type_id, left_id, right_id);
1028 }
1029
1030 void SpirVGenerator::visit(Assignment &assign)
1031 {
1032         if(assign.oper->token[0]!='=')
1033                 visit(static_cast<BinaryExpression &>(assign));
1034         else
1035                 assign.right->visit(*this);
1036
1037         SetForScope<Id> set_assign(assignment_source_id, r_expression_result_id);
1038         assign.left->visit(*this);
1039         r_constant_result = false;
1040 }
1041
1042 void SpirVGenerator::visit(TernaryExpression &ternary)
1043 {
1044         if(composite_access)
1045                 return visit_isolated(ternary);
1046         if(constant_expression)
1047         {
1048                 ternary.condition->visit(*this);
1049                 Id condition_id = r_expression_result_id;
1050                 ternary.true_expr->visit(*this);
1051                 Id true_result_id = r_expression_result_id;
1052                 ternary.false_expr->visit(*this);
1053                 Id false_result_id = r_expression_result_id;
1054
1055                 r_expression_result_id = begin_expression(OP_SELECT, get_id(*ternary.type), 3);
1056                 writer.write(condition_id);
1057                 writer.write(true_result_id);
1058                 writer.write(false_result_id);
1059                 end_expression(OP_SELECT);
1060
1061                 return;
1062         }
1063
1064         ternary.condition->visit(*this);
1065         Id condition_id = r_expression_result_id;
1066
1067         Id true_label_id = next_id++;
1068         Id false_label_id = next_id++;
1069         Id merge_block_id = next_id++;
1070         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1071         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, condition_id, true_label_id, false_label_id);
1072
1073         std::map<const VariableDeclaration *, Id> saved_load_ids = variable_load_ids;
1074
1075         writer.write_op_label(true_label_id);
1076         ternary.true_expr->visit(*this);
1077         Id true_result_id = r_expression_result_id;
1078         true_label_id = writer.get_current_block();
1079         writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1080
1081         swap(saved_load_ids, variable_load_ids);
1082         writer.write_op_label(false_label_id);
1083         ternary.false_expr->visit(*this);
1084         Id false_result_id = r_expression_result_id;
1085         false_label_id = writer.get_current_block();
1086
1087         writer.write_op_label(merge_block_id);
1088         prune_loads(true_label_id);
1089         r_expression_result_id = begin_expression(OP_PHI, get_id(*ternary.type), 4);
1090         writer.write(true_result_id);
1091         writer.write(true_label_id);
1092         writer.write(false_result_id);
1093         writer.write(false_label_id);
1094         end_expression(OP_PHI);
1095
1096         r_constant_result = false;
1097 }
1098
1099 void SpirVGenerator::visit(FunctionCall &call)
1100 {
1101         if(assignment_source_id)
1102                 throw internal_error("assignment to function call");
1103         else if(composite_access)
1104                 return visit_isolated(call);
1105         else if(call.constructor && call.arguments.size()==1 && call.arguments[0]->type==call.type)
1106                 return call.arguments[0]->visit(*this);
1107
1108         vector<Id> argument_ids;
1109         argument_ids.reserve(call.arguments.size());
1110         bool all_args_const = true;
1111         for(const RefPtr<Expression> &a: call.arguments)
1112         {
1113                 a->visit(*this);
1114                 argument_ids.push_back(r_expression_result_id);
1115                 all_args_const &= r_constant_result;
1116         }
1117
1118         if(constant_expression && (!call.constructor || !all_args_const))
1119                 throw internal_error("function call in constant expression");
1120
1121         Id result_type_id = get_id(*call.type);
1122         r_constant_result = false;
1123
1124         if(call.constructor)
1125                 visit_constructor(call, argument_ids, all_args_const);
1126         else if(call.declaration->source==BUILTIN_SOURCE)
1127         {
1128                 string arg_types;
1129                 for(const RefPtr<Expression> &a: call.arguments)
1130                         if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>(a->type))
1131                         {
1132                                 BasicTypeDeclaration &elem_arg = *get_element_type(*basic_arg);
1133                                 switch(elem_arg.kind)
1134                                 {
1135                                 case BasicTypeDeclaration::BOOL: arg_types += 'b'; break;
1136                                 case BasicTypeDeclaration::INT: arg_types += (elem_arg.sign ? 'i' : 'u'); break;
1137                                 case BasicTypeDeclaration::FLOAT: arg_types += 'f'; break;
1138                                 default: arg_types += '?';
1139                                 }
1140                         }
1141
1142                 const BuiltinFunctionInfo *builtin_info;
1143                 for(builtin_info=builtin_functions; builtin_info->function[0]; ++builtin_info)
1144                         if(builtin_info->function==call.name && (!builtin_info->arg_types[0] || builtin_info->arg_types==arg_types))
1145                                 break;
1146
1147                 if(builtin_info->capability)
1148                         use_capability(static_cast<Capability>(builtin_info->capability));
1149
1150                 if(builtin_info->opcode)
1151                 {
1152                         Opcode opcode;
1153                         if(builtin_info->extension[0])
1154                         {
1155                                 opcode = OP_EXT_INST;
1156                                 Id ext_id = import_extension(builtin_info->extension);
1157
1158                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1159                                 writer.write(ext_id);
1160                                 writer.write(builtin_info->opcode);
1161                         }
1162                         else
1163                         {
1164                                 opcode = static_cast<Opcode>(builtin_info->opcode);
1165                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1166                         }
1167
1168                         for(unsigned i=0; i<call.arguments.size(); ++i)
1169                         {
1170                                 if(!builtin_info->arg_order[i] || builtin_info->arg_order[i]>argument_ids.size())
1171                                         throw internal_error("invalid builtin function info");
1172                                 writer.write(argument_ids[builtin_info->arg_order[i]-1]);
1173                         }
1174
1175                         end_expression(opcode);
1176                 }
1177                 else if(builtin_info->handler)
1178                         (this->*(builtin_info->handler))(call, argument_ids);
1179                 else
1180                         throw internal_error("unknown builtin function "+call.name);
1181         }
1182         else
1183         {
1184                 r_expression_result_id = begin_expression(OP_FUNCTION_CALL, result_type_id, 1+call.arguments.size());
1185                 writer.write(get_id(*call.declaration->definition));
1186                 for(Id i: argument_ids)
1187                         writer.write(i);
1188                 end_expression(OP_FUNCTION_CALL);
1189
1190                 // Any global variables the called function uses might have changed value
1191                 set<Node *> dependencies = DependencyCollector().apply(*call.declaration->definition);
1192                 for(Node *n: dependencies)
1193                         if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(n))
1194                                 variable_load_ids.erase(var);
1195         }
1196 }
1197
1198 void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids, bool all_args_const)
1199 {
1200         Id result_type_id = get_id(*call.type);
1201
1202         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(call.type);
1203         if(!basic)
1204         {
1205                 if(dynamic_cast<const StructDeclaration *>(call.type))
1206                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1207                 else
1208                         throw internal_error("unconstructable type "+call.name);
1209                 return;
1210         }
1211
1212         SetFlag set_const(constant_expression, constant_expression || all_args_const);
1213
1214         BasicTypeDeclaration &elem = *get_element_type(*basic);
1215         BasicTypeDeclaration &basic_arg0 = dynamic_cast<BasicTypeDeclaration &>(*call.arguments[0]->type);
1216         BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0);
1217
1218         if(basic->kind==BasicTypeDeclaration::MATRIX)
1219         {
1220                 Id col_type_id = get_id(*basic->base_type);
1221                 unsigned n_columns = basic->size&0xFFFF;
1222                 unsigned n_rows = basic->size>>16;
1223
1224                 Id column_ids[4];
1225                 if(call.arguments.size()==1)
1226                 {
1227                         // Construct diagonal matrix from a single scalar.
1228                         Id zero_id = get_constant_id(get_id(elem), 0.0f);
1229                         for(unsigned i=0; i<n_columns; ++i)
1230                         {
1231                                 column_ids[i] = begin_expression(OP_COMPOSITE_CONSTRUCT, col_type_id, n_rows);
1232                                 for(unsigned j=0; j<n_rows; ++j)
1233                                         writer.write(j==i ? argument_ids[0] : zero_id);
1234                                 end_expression(OP_COMPOSITE_CONSTRUCT);
1235                         }
1236                 }
1237                 else
1238                         // Construct a matrix from column vectors
1239                         copy(argument_ids.begin(), argument_ids.begin()+n_columns, column_ids);
1240
1241                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
1242         }
1243         else if(basic->kind==BasicTypeDeclaration::VECTOR && (call.arguments.size()>1 || basic_arg0.kind!=BasicTypeDeclaration::VECTOR))
1244         {
1245                 /* There's either a single scalar argument or multiple arguments
1246                 which make up the vector's components. */
1247                 if(call.arguments.size()==1)
1248                 {
1249                         r_expression_result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, result_type_id);
1250                         for(unsigned i=0; i<basic->size; ++i)
1251                                 writer.write(argument_ids[0]);
1252                         end_expression(OP_COMPOSITE_CONSTRUCT);
1253                 }
1254                 else
1255                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1256         }
1257         else if(elem.kind==BasicTypeDeclaration::BOOL)
1258         {
1259                 if(constant_expression)
1260                         throw internal_error("unconverted constant");
1261
1262                 // Conversion to boolean is implemented as comparing against zero.
1263                 Id number_type_id = get_id(elem_arg0);
1264                 Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ?
1265                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1266                 if(basic_arg0.kind==BasicTypeDeclaration::VECTOR)
1267                         zero_id = get_vector_constant_id(get_id(basic_arg0), basic_arg0.size, zero_id);
1268
1269                 Opcode opcode = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ? OP_F_ORD_NOT_EQUAL : OP_I_NOT_EQUAL);
1270                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0], zero_id);
1271         }
1272         else if(elem_arg0.kind==BasicTypeDeclaration::BOOL)
1273         {
1274                 if(constant_expression)
1275                         throw internal_error("unconverted constant");
1276
1277                 /* Conversion from boolean is implemented as selecting from zero
1278                 or one. */
1279                 Id number_type_id = get_id(elem);
1280                 Id zero_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1281                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1282                 Id one_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1283                         get_constant_id(number_type_id, 1.0f) : get_constant_id(number_type_id, 1));
1284                 if(basic->kind==BasicTypeDeclaration::VECTOR)
1285                 {
1286                         zero_id = get_vector_constant_id(get_id(*basic), basic->size, zero_id);
1287                         one_id = get_vector_constant_id(get_id(*basic), basic->size, one_id);
1288                 }
1289
1290                 r_expression_result_id = begin_expression(OP_SELECT, result_type_id, 3);
1291                 writer.write(argument_ids[0]);
1292                 writer.write(zero_id);
1293                 writer.write(one_id);
1294                 end_expression(OP_SELECT);
1295         }
1296         else
1297         {
1298                 if(constant_expression)
1299                         throw internal_error("unconverted constant");
1300
1301                 // Scalar or vector conversion between types of equal size.
1302                 Opcode opcode;
1303                 if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT)
1304                         opcode = (elem.sign ? OP_CONVERT_F_TO_S : OP_CONVERT_F_TO_U);
1305                 else if(elem.kind==BasicTypeDeclaration::FLOAT && elem_arg0.kind==BasicTypeDeclaration::INT)
1306                         opcode = (elem_arg0.sign ? OP_CONVERT_S_TO_F : OP_CONVERT_U_TO_F);
1307                 else if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::INT)
1308                         opcode = OP_BITCAST;
1309                 else
1310                         throw internal_error("invalid conversion");
1311
1312                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0]);
1313         }
1314 }
1315
1316 void SpirVGenerator::visit_builtin_matrix_comp_mult(FunctionCall &call, const vector<Id> &argument_ids)
1317 {
1318         if(argument_ids.size()!=2)
1319                 throw internal_error("invalid matrixCompMult call");
1320
1321         const BasicTypeDeclaration &basic_arg0 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[0]->type);
1322         Id column_type_id = get_id(*basic_arg0.base_type);
1323         Id column_ids[8];
1324
1325         unsigned n_columns = basic_arg0.size&0xFFFF;
1326         write_deconstruct(column_type_id, argument_ids[0], column_ids, n_columns);
1327         write_deconstruct(column_type_id, argument_ids[1], column_ids+4, n_columns);
1328
1329         for(unsigned i=0; i<n_columns; ++i)
1330                 column_ids[i] = write_expression(OP_F_MUL, column_type_id, column_ids[i], column_ids[4+i]);
1331
1332         r_expression_result_id = write_construct(get_id(*call.type), column_ids, n_columns);
1333 }
1334
1335 void SpirVGenerator::visit_builtin_texture_query(FunctionCall &call, const vector<Id> &argument_ids)
1336 {
1337         if(argument_ids.size()<1)
1338                 throw internal_error("invalid texture query call");
1339
1340         Opcode opcode;
1341         if(call.name=="textureSize")
1342                 opcode = OP_IMAGE_QUERY_SIZE_LOD;
1343         else if(call.name=="imageSize")
1344                 opcode = OP_IMAGE_QUERY_SIZE;
1345         else if(call.name=="textureQueryLod")
1346                 opcode = OP_IMAGE_QUERY_LOD;
1347         else if(call.name=="textureQueryLevels")
1348                 opcode = OP_IMAGE_QUERY_LEVELS;
1349         else if(call.name=="textureSamples" || call.name=="imageSamples")
1350                 opcode = OP_IMAGE_QUERY_SAMPLES;
1351         else
1352                 throw internal_error("invalid texture query call");
1353
1354         ImageTypeDeclaration &image_arg0 = dynamic_cast<ImageTypeDeclaration &>(*call.arguments[0]->type);
1355
1356         Id image_id;
1357         if(image_arg0.sampled && opcode!=OP_IMAGE_QUERY_LOD)
1358         {
1359                 Id image_type_id = get_item(image_type_ids, get_id(image_arg0));
1360                 image_id = write_expression(OP_IMAGE, image_type_id, argument_ids[0]);
1361         }
1362         else
1363                 image_id = argument_ids[0];
1364
1365         Id result_type_id = get_id(*call.type);
1366         r_expression_result_id = begin_expression(opcode, result_type_id, argument_ids.size());
1367         writer.write(image_id);
1368         for(unsigned i=1; i<argument_ids.size(); ++i)
1369                 writer.write(argument_ids[i]);
1370         end_expression(opcode);
1371 }
1372
1373 void SpirVGenerator::visit_builtin_texture(FunctionCall &call, const vector<Id> &argument_ids)
1374 {
1375         if(argument_ids.size()<2)
1376                 throw internal_error("invalid texture sampling call");
1377
1378         bool explicit_lod = (stage->type!=Stage::FRAGMENT || call.name=="textureLod");
1379         Id lod_id = (!explicit_lod ? 0 : call.name=="textureLod" ? argument_ids.back() :
1380                 get_constant_id(get_standard_type_id(BasicTypeDeclaration::FLOAT, 1), 0.0f));
1381
1382         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1383
1384         Opcode opcode;
1385         Id result_type_id = get_id(*call.type);
1386         Id dref_id = 0;
1387         if(image.shadow)
1388         {
1389                 if(argument_ids.size()==2)
1390                 {
1391                         const BasicTypeDeclaration &basic_arg1 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[1]->type);
1392                         dref_id = begin_expression(OP_COMPOSITE_EXTRACT, get_id(*basic_arg1.base_type), 2);
1393                         writer.write(argument_ids.back());
1394                         writer.write(basic_arg1.size-1);
1395                         end_expression(OP_COMPOSITE_EXTRACT);
1396                 }
1397                 else
1398                         dref_id = argument_ids[2];
1399
1400                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD : OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD);
1401                 r_expression_result_id = begin_expression(opcode, result_type_id, 3+explicit_lod*2);
1402         }
1403         else
1404         {
1405                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_EXPLICIT_LOD : OP_IMAGE_SAMPLE_IMPLICIT_LOD);
1406                 r_expression_result_id = begin_expression(opcode, result_type_id, 2+explicit_lod*2);
1407         }
1408
1409         for(unsigned i=0; i<2; ++i)
1410                 writer.write(argument_ids[i]);
1411         if(dref_id)
1412                 writer.write(dref_id);
1413         if(explicit_lod)
1414         {
1415                 writer.write(2);  // Lod
1416                 writer.write(lod_id);
1417         }
1418
1419         end_expression(opcode);
1420 }
1421
1422 void SpirVGenerator::visit_builtin_texture_fetch(FunctionCall &call, const vector<Id> &argument_ids)
1423 {
1424         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1425
1426         Opcode opcode;
1427         if(call.name=="texelFetch")
1428                 opcode = OP_IMAGE_FETCH;
1429         else if(call.name=="imageLoad")
1430                 opcode = OP_IMAGE_READ;
1431         else
1432                 throw internal_error("invalid texture fetch call");
1433
1434         bool need_sample = image.multisample;
1435         bool need_lod = (opcode==OP_IMAGE_FETCH && !need_sample);
1436
1437         if(argument_ids.size()!=2U+need_sample+need_lod)
1438                 throw internal_error("invalid texture fetch call");
1439
1440         r_expression_result_id = begin_expression(opcode, get_id(*call.type), 2+(need_lod|need_sample)+need_lod+need_sample);
1441         for(unsigned i=0; i<2; ++i)
1442                 writer.write(argument_ids[i]);
1443         if(need_lod || need_sample)
1444         {
1445                 writer.write(need_lod*0x02 | need_sample*0x40);
1446                 writer.write(argument_ids.back());
1447         }
1448         end_expression(opcode);
1449 }
1450
1451 void SpirVGenerator::visit_builtin_texture_store(FunctionCall &call, const vector<Id> &argument_ids)
1452 {
1453         if(argument_ids.size()!=3)
1454                 throw internal_error("invalid texture store call");
1455
1456         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1457
1458         begin_expression(OP_IMAGE_WRITE, get_id(*call.type), 3+image.multisample*2);
1459         for(unsigned i=0; i<2; ++i)
1460                 writer.write(argument_ids[i]);
1461         writer.write(argument_ids.back());
1462         if(image.multisample)
1463         {
1464                 writer.write(0x40);  // Sample
1465                 writer.write(argument_ids[2]);
1466         }
1467         end_expression(OP_IMAGE_WRITE);
1468
1469         r_expression_result_id = 0;
1470 }
1471
1472 void SpirVGenerator::visit_builtin_interpolate(FunctionCall &call, const vector<Id> &argument_ids)
1473 {
1474         if(argument_ids.size()<1)
1475                 throw internal_error("invalid interpolate call");
1476         const VariableReference *var = dynamic_cast<const VariableReference *>(call.arguments[0].get());
1477         if(!var || !var->declaration || var->declaration->interface!="in")
1478                 throw internal_error("invalid interpolate call");
1479
1480         SpirVGlslStd450Opcode opcode;
1481         if(call.name=="interpolateAtCentroid")
1482                 opcode = GLSL450_INTERPOLATE_AT_CENTROID;
1483         else if(call.name=="interpolateAtSample")
1484                 opcode = GLSL450_INTERPOLATE_AT_SAMPLE;
1485         else if(call.name=="interpolateAtOffset")
1486                 opcode = GLSL450_INTERPOLATE_AT_OFFSET;
1487         else
1488                 throw internal_error("invalid interpolate call");
1489
1490         Id ext_id = import_extension("GLSL.std.450");
1491         r_expression_result_id = begin_expression(OP_EXT_INST, get_id(*call.type));
1492         writer.write(ext_id);
1493         writer.write(opcode);
1494         writer.write(get_id(*var->declaration));
1495         for(auto i=argument_ids.begin(); ++i!=argument_ids.end(); )
1496                 writer.write(*i);
1497         end_expression(OP_EXT_INST);
1498 }
1499
1500 void SpirVGenerator::visit(ExpressionStatement &expr)
1501 {
1502         expr.expression->visit(*this);
1503 }
1504
1505 void SpirVGenerator::visit(InterfaceLayout &layout)
1506 {
1507         interface_layouts.push_back(&layout);
1508 }
1509
1510 bool SpirVGenerator::check_duplicate_type(TypeDeclaration &type)
1511 {
1512         for(const auto &kvp: declared_ids)
1513                 if(TypeDeclaration *type2 = dynamic_cast<TypeDeclaration *>(kvp.first))
1514                         if(TypeComparer().apply(type, *type2))
1515                         {
1516                                 insert_unique(declared_ids, &type, kvp.second);
1517                                 return true;
1518                         }
1519
1520         return false;
1521 }
1522
1523 bool SpirVGenerator::check_standard_type(BasicTypeDeclaration &basic)
1524 {
1525         const BasicTypeDeclaration *elem = (basic.kind==BasicTypeDeclaration::VECTOR ?
1526                 dynamic_cast<const BasicTypeDeclaration *>(basic.base_type) : &basic);
1527         if(!elem || elem->base_type)
1528                 return false;
1529         if((elem->kind==BasicTypeDeclaration::INT || elem->kind==BasicTypeDeclaration::FLOAT) && elem->size!=32)
1530                 return false;
1531
1532         Id standard_id = get_standard_type_id(elem->kind, (basic.kind==BasicTypeDeclaration::VECTOR ? basic.size : 1), elem->sign);
1533         insert_unique(declared_ids, &basic, Declaration(standard_id, 0));
1534         writer.write_op_name(standard_id, basic.name);
1535
1536         return true;
1537 }
1538
1539 void SpirVGenerator::visit(BasicTypeDeclaration &basic)
1540 {
1541         if(check_standard_type(basic))
1542                 return;
1543         if(check_duplicate_type(basic))
1544                 return;
1545         // Alias types shouldn't exist at this point and arrays are handled elsewhere
1546         if(basic.kind==BasicTypeDeclaration::ALIAS || basic.kind==BasicTypeDeclaration::ARRAY)
1547                 return;
1548
1549         Id type_id = allocate_id(basic, 0);
1550         writer.write_op_name(type_id, basic.name);
1551
1552         switch(basic.kind)
1553         {
1554         case BasicTypeDeclaration::INT:
1555                 writer.write_op(content.globals, OP_TYPE_INT, type_id, basic.size, basic.sign);
1556                 break;
1557         case BasicTypeDeclaration::FLOAT:
1558                 writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, basic.size);
1559                 break;
1560         case BasicTypeDeclaration::VECTOR:
1561                 writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, get_id(*basic.base_type), basic.size);
1562                 break;
1563         case BasicTypeDeclaration::MATRIX:
1564                 writer.write_op(content.globals, OP_TYPE_MATRIX, type_id, get_id(*basic.base_type), basic.size&0xFFFF);
1565                 break;
1566         default:
1567                 throw internal_error("unknown basic type");
1568         }
1569 }
1570
1571 void SpirVGenerator::visit(ImageTypeDeclaration &image)
1572 {
1573         if(check_duplicate_type(image))
1574                 return;
1575
1576         Id type_id = allocate_id(image, 0);
1577
1578         Id image_id = (image.sampled ? next_id++ : type_id);
1579         writer.begin_op(content.globals, OP_TYPE_IMAGE, 9);
1580         writer.write(image_id);
1581         writer.write(get_id(*image.base_type));
1582         writer.write(image.dimensions-1);
1583         writer.write(image.shadow);
1584         writer.write(image.array);
1585         writer.write(image.multisample);
1586         writer.write(image.sampled ? 1 : 2);
1587         writer.write(get_format(image.format));
1588         writer.end_op(OP_TYPE_IMAGE);
1589
1590         if(image.sampled)
1591         {
1592                 writer.write_op_name(type_id, image.name);
1593                 writer.write_op(content.globals, OP_TYPE_SAMPLED_IMAGE, type_id, image_id);
1594                 insert_unique(image_type_ids, type_id, image_id);
1595         }
1596
1597         if(image.dimensions==ImageTypeDeclaration::ONE)
1598                 use_capability(image.sampled ? CAP_SAMPLED_1D : CAP_IMAGE_1D);
1599         else if(image.dimensions==ImageTypeDeclaration::CUBE && image.array)
1600                 use_capability(image.sampled ? CAP_SAMPLED_CUBE_ARRAY : CAP_IMAGE_CUBE_ARRAY);
1601
1602         if(image.multisample && !image.sampled)
1603                 use_capability(CAP_STORAGE_IMAGE_MULTISAMPLE);
1604 }
1605
1606 void SpirVGenerator::visit(StructDeclaration &strct)
1607 {
1608         if(check_duplicate_type(strct))
1609                 return;
1610
1611         Id type_id = allocate_id(strct, 0);
1612         writer.write_op_name(type_id, (strct.block_name.empty() ? strct.name : strct.block_name));
1613
1614         if(!strct.block_name.empty())
1615                 writer.write_op_decorate(type_id, DECO_BLOCK);
1616
1617         bool builtin = !strct.block_name.compare(0, 3, "gl_");
1618         vector<Id> member_type_ids;
1619         member_type_ids.reserve(strct.members.body.size());
1620         for(const RefPtr<Statement> &s: strct.members.body)
1621         {
1622                 const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(s.get());
1623                 if(!var)
1624                         continue;
1625
1626                 unsigned index = member_type_ids.size();
1627                 member_type_ids.push_back(get_variable_type_id(*var));
1628
1629                 writer.write_op_member_name(type_id, index, var->name);
1630
1631                 if(builtin)
1632                 {
1633                         BuiltinSemantic semantic = get_builtin_semantic(var->name);
1634                         writer.write_op_member_decorate(type_id, index, DECO_BUILTIN, semantic);
1635                 }
1636                 else
1637                 {
1638                         if(var->layout)
1639                         {
1640                                 for(const Layout::Qualifier &q: var->layout->qualifiers)
1641                                 {
1642                                         if(q.name=="offset")
1643                                                 writer.write_op_member_decorate(type_id, index, DECO_OFFSET, q.value);
1644                                         else if(q.name=="column_major")
1645                                                 writer.write_op_member_decorate(type_id, index, DECO_COL_MAJOR);
1646                                         else if(q.name=="row_major")
1647                                                 writer.write_op_member_decorate(type_id, index, DECO_ROW_MAJOR);
1648                                 }
1649                         }
1650
1651                         const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var->type_declaration);
1652                         while(basic && basic->kind==BasicTypeDeclaration::ARRAY)
1653                                 basic = dynamic_cast<const BasicTypeDeclaration *>(basic->base_type);
1654                         if(basic && basic->kind==BasicTypeDeclaration::MATRIX)
1655                         {
1656                                 unsigned stride = MemoryRequirementsCalculator().apply(*basic->base_type).stride;
1657                                 writer.write_op_member_decorate(type_id, index, DECO_MATRIX_STRIDE, stride);
1658                         }
1659                 }
1660         }
1661
1662         writer.begin_op(content.globals, OP_TYPE_STRUCT);
1663         writer.write(type_id);
1664         for(Id i: member_type_ids)
1665                 writer.write(i);
1666         writer.end_op(OP_TYPE_STRUCT);
1667 }
1668
1669 void SpirVGenerator::visit(VariableDeclaration &var)
1670 {
1671         Id type_id = get_variable_type_id(var);
1672         Id var_id;
1673
1674         if(var.constant)
1675         {
1676                 if(!var.init_expression)
1677                         throw internal_error("const variable without initializer");
1678
1679                 int spec_id = get_layout_value(var.layout.get(), "constant_id");
1680                 Id *spec_var_id = (spec_id>=0 ? &declared_spec_ids[spec_id] : 0);
1681                 if(spec_id>=0 && *spec_var_id)
1682                 {
1683                         insert_unique(declared_ids, &var, Declaration(*spec_var_id, type_id));
1684                         return;
1685                 }
1686
1687                 SetFlag set_const(constant_expression);
1688                 SetFlag set_spec(spec_constant, spec_id>=0);
1689                 r_expression_result_id = 0;
1690                 var.init_expression->visit(*this);
1691                 var_id = r_expression_result_id;
1692                 insert_unique(declared_ids, &var, Declaration(var_id, type_id));
1693                 if(spec_id>=0)
1694                 {
1695                         writer.write_op_decorate(var_id, DECO_SPEC_ID, spec_id);
1696                         *spec_var_id = var_id;
1697                 }
1698         }
1699         else
1700         {
1701                 bool push_const = has_layout_qualifier(var.layout.get(), "push_constant");
1702
1703                 StorageClass storage;
1704                 if(current_function)
1705                         storage = STORAGE_FUNCTION;
1706                 else if(push_const)
1707                         storage = STORAGE_PUSH_CONSTANT;
1708                 else
1709                         storage = get_interface_storage(var.interface, var.block_declaration);
1710
1711                 Id ptr_type_id = get_pointer_type_id(type_id, storage);
1712                 if(var.interface=="uniform")
1713                 {
1714                         Id &uni_id = declared_uniform_ids[var.block_declaration ? "b"+var.block_declaration->block_name : "v"+var.name];
1715                         if(uni_id)
1716                         {
1717                                 insert_unique(declared_ids, &var, Declaration(uni_id, ptr_type_id));
1718                                 return;
1719                         }
1720
1721                         uni_id = var_id = allocate_id(var, ptr_type_id);
1722                 }
1723                 else
1724                         var_id = allocate_id(var, (var.constant ? type_id : ptr_type_id));
1725
1726                 Id init_id = 0;
1727                 if(var.init_expression)
1728                 {
1729                         SetFlag set_const(constant_expression, !current_function);
1730                         r_expression_result_id = 0;
1731                         r_constant_result = false;
1732                         var.init_expression->visit(*this);
1733                         init_id = r_expression_result_id;
1734                 }
1735
1736                 vector<Word> &target = (current_function ? content.locals : content.globals);
1737                 writer.begin_op(target, OP_VARIABLE, 4+(init_id && !current_function));
1738                 writer.write(ptr_type_id);
1739                 writer.write(var_id);
1740                 writer.write(storage);
1741                 if(init_id && !current_function)
1742                         writer.write(init_id);
1743                 writer.end_op(OP_VARIABLE);
1744
1745                 if(var.layout)
1746                 {
1747                         for(const Layout::Qualifier &q: var.layout->qualifiers)
1748                         {
1749                                 if(q.name=="location")
1750                                         writer.write_op_decorate(var_id, DECO_LOCATION, q.value);
1751                                 else if(q.name=="set")
1752                                         writer.write_op_decorate(var_id, DECO_DESCRIPTOR_SET, q.value);
1753                                 else if(q.name=="binding")
1754                                         writer.write_op_decorate(var_id, DECO_BINDING, q.value);
1755                         }
1756                 }
1757                 if(!var.block_declaration && !var.name.compare(0, 3, "gl_"))
1758                 {
1759                         BuiltinSemantic semantic = get_builtin_semantic(var.name);
1760                         writer.write_op_decorate(var_id, DECO_BUILTIN, semantic);
1761                 }
1762
1763                 if(init_id && current_function)
1764                 {
1765                         writer.write_op(content.function_body, OP_STORE, var_id, init_id);
1766                         variable_load_ids[&var] = init_id;
1767                 }
1768         }
1769
1770         if(var.name.find(' ')==string::npos)
1771                 writer.write_op_name(var_id, var.name);
1772 }
1773
1774 void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id)
1775 {
1776         writer.begin_op(content.entry_points, OP_ENTRY_POINT);
1777         switch(stage->type)
1778         {
1779         case Stage::VERTEX: writer.write(0); break;
1780         case Stage::GEOMETRY: writer.write(3); break;
1781         case Stage::FRAGMENT: writer.write(4); break;
1782         default: throw internal_error("unknown stage");
1783         }
1784         writer.write(func_id);
1785         writer.write_string(func.name);
1786
1787         set<Node *> dependencies = DependencyCollector().apply(func);
1788         for(Node *n: dependencies)
1789                 if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(n))
1790                         if(!var->interface.empty())
1791                                 writer.write(get_id(*n));
1792
1793         writer.end_op(OP_ENTRY_POINT);
1794
1795         if(stage->type==Stage::FRAGMENT)
1796         {
1797                 SpirVExecutionMode origin = (features.target_api==VULKAN ? EXEC_ORIGIN_UPPER_LEFT : EXEC_ORIGIN_LOWER_LEFT);
1798                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, origin);
1799         }
1800         else if(stage->type==Stage::GEOMETRY)
1801         {
1802                 use_capability(CAP_GEOMETRY);
1803                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INVOCATIONS, 1);
1804         }
1805
1806         for(const InterfaceLayout *i: interface_layouts)
1807         {
1808                 for(const Layout::Qualifier &q: i->layout.qualifiers)
1809                 {
1810                         if(q.name=="point")
1811                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id,
1812                                         (i->interface=="in" ? EXEC_INPUT_POINTS : EXEC_OUTPUT_POINTS));
1813                         else if(q.name=="lines")
1814                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES);
1815                         else if(q.name=="lines_adjacency")
1816                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES_ADJACENCY);
1817                         else if(q.name=="triangles")
1818                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_TRIANGLES);
1819                         else if(q.name=="triangles_adjacency")
1820                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_TRIANGLES_ADJACENCY);
1821                         else if(q.name=="line_strip")
1822                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_LINE_STRIP);
1823                         else if(q.name=="triangle_strip")
1824                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_TRIANGLE_STRIP);
1825                         else if(q.name=="max_vertices")
1826                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_VERTICES, q.value);
1827                 }
1828         }
1829 }
1830
1831 void SpirVGenerator::visit(FunctionDeclaration &func)
1832 {
1833         if(func.source==BUILTIN_SOURCE)
1834                 return;
1835         else if(func.definition!=&func)
1836         {
1837                 if(func.definition)
1838                         allocate_forward_id(*func.definition);
1839                 return;
1840         }
1841
1842         Id return_type_id = get_id(*func.return_type_declaration);
1843         vector<unsigned> param_type_ids;
1844         param_type_ids.reserve(func.parameters.size());
1845         for(const RefPtr<VariableDeclaration> &p: func.parameters)
1846                 param_type_ids.push_back(get_variable_type_id(*p));
1847
1848         string sig_with_return = func.return_type+func.signature;
1849         Id &type_id = function_type_ids[sig_with_return];
1850         if(!type_id)
1851         {
1852                 type_id = next_id++;
1853                 writer.begin_op(content.globals, OP_TYPE_FUNCTION);
1854                 writer.write(type_id);
1855                 writer.write(return_type_id);
1856                 for(unsigned i: param_type_ids)
1857                         writer.write(i);
1858                 writer.end_op(OP_TYPE_FUNCTION);
1859
1860                 writer.write_op_name(type_id, sig_with_return);
1861         }
1862
1863         Id func_id = allocate_id(func, type_id);
1864         writer.write_op_name(func_id, func.name+func.signature);
1865
1866         if(func.name=="main")
1867                 visit_entry_point(func, func_id);
1868
1869         writer.begin_op(content.functions, OP_FUNCTION, 5);
1870         writer.write(return_type_id);
1871         writer.write(func_id);
1872         writer.write(0);  // Function control flags (none)
1873         writer.write(type_id);
1874         writer.end_op(OP_FUNCTION);
1875
1876         for(unsigned i=0; i<func.parameters.size(); ++i)
1877         {
1878                 Id param_id = allocate_id(*func.parameters[i], param_type_ids[i]);
1879                 writer.write_op(content.functions, OP_FUNCTION_PARAMETER, param_type_ids[i], param_id);
1880                 writer.write_op_name(param_id, func.parameters[i]->name);
1881                 // TODO This is probably incorrect if the parameter is assigned to.
1882                 variable_load_ids[func.parameters[i].get()] = param_id;
1883         }
1884
1885         reachable = true;
1886         writer.begin_function_body(next_id++);
1887         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
1888         func.body.visit(*this);
1889
1890         if(writer.get_current_block())
1891         {
1892                 if(!reachable)
1893                         writer.write_op(content.function_body, OP_UNREACHABLE);
1894                 else
1895                 {
1896                         const BasicTypeDeclaration *basic_return = dynamic_cast<const BasicTypeDeclaration *>(func.return_type_declaration);
1897                         if(basic_return && basic_return->kind==BasicTypeDeclaration::VOID)
1898                                 writer.write_op(content.function_body, OP_RETURN);
1899                         else
1900                                 throw internal_error("missing return in non-void function");
1901                 }
1902         }
1903         writer.end_function_body();
1904         variable_load_ids.clear();
1905 }
1906
1907 void SpirVGenerator::visit(Conditional &cond)
1908 {
1909         cond.condition->visit(*this);
1910
1911         Id true_label_id = next_id++;
1912         Id merge_block_id = next_id++;
1913         Id false_label_id = (cond.else_body.body.empty() ? merge_block_id : next_id++);
1914         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1915         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, true_label_id, false_label_id);
1916
1917         std::map<const VariableDeclaration *, Id> saved_load_ids = variable_load_ids;
1918
1919         writer.write_op_label(true_label_id);
1920         cond.body.visit(*this);
1921         if(writer.get_current_block())
1922                 writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1923
1924         bool reachable_if_true = reachable;
1925
1926         reachable = true;
1927         if(!cond.else_body.body.empty())
1928         {
1929                 swap(saved_load_ids, variable_load_ids);
1930                 writer.write_op_label(false_label_id);
1931                 cond.else_body.visit(*this);
1932                 reachable |= reachable_if_true;
1933         }
1934
1935         writer.write_op_label(merge_block_id);
1936         prune_loads(true_label_id);
1937 }
1938
1939 void SpirVGenerator::visit(Iteration &iter)
1940 {
1941         if(iter.init_statement)
1942                 iter.init_statement->visit(*this);
1943
1944         for(Node *n: AssignmentCollector().apply(iter))
1945                 if(VariableDeclaration *var = dynamic_cast<VariableDeclaration *>(n))
1946                         variable_load_ids.erase(var);
1947
1948         Id header_id = next_id++;
1949         Id continue_id = next_id++;
1950         Id merge_block_id = next_id++;
1951
1952         SetForScope<Id> set_merge(loop_merge_block_id, merge_block_id);
1953         SetForScope<Id> set_continue(loop_continue_target_id, continue_id);
1954
1955         writer.write_op_label(header_id);
1956         writer.write_op(content.function_body, OP_LOOP_MERGE, merge_block_id, continue_id, 0);  // Loop control (none)
1957
1958         Id body_id = next_id++;
1959         if(iter.condition)
1960         {
1961                 writer.write_op_label(next_id++);
1962                 iter.condition->visit(*this);
1963                 writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, body_id, merge_block_id);
1964         }
1965
1966         writer.write_op_label(body_id);
1967         iter.body.visit(*this);
1968
1969         writer.write_op_label(continue_id);
1970         if(iter.loop_expression)
1971                 iter.loop_expression->visit(*this);
1972         writer.write_op(content.function_body, OP_BRANCH, header_id);
1973
1974         writer.write_op_label(merge_block_id);
1975         prune_loads(header_id);
1976         reachable = true;
1977 }
1978
1979 void SpirVGenerator::visit(Return &ret)
1980 {
1981         if(ret.expression)
1982         {
1983                 ret.expression->visit(*this);
1984                 writer.write_op(content.function_body, OP_RETURN_VALUE, r_expression_result_id);
1985         }
1986         else
1987                 writer.write_op(content.function_body, OP_RETURN);
1988         reachable = false;
1989 }
1990
1991 void SpirVGenerator::visit(Jump &jump)
1992 {
1993         if(jump.keyword=="discard")
1994                 writer.write_op(content.function_body, OP_KILL);
1995         else if(jump.keyword=="break")
1996                 writer.write_op(content.function_body, OP_BRANCH, loop_merge_block_id);
1997         else if(jump.keyword=="continue")
1998                 writer.write_op(content.function_body, OP_BRANCH, loop_continue_target_id);
1999         else
2000                 throw internal_error("unknown jump");
2001         reachable = false;
2002 }
2003
2004
2005 SpirVGenerator::TypeKey::TypeKey(BasicTypeDeclaration::Kind kind, bool sign):
2006         type_id(0)
2007 {
2008         switch(kind)
2009         {
2010         case BasicTypeDeclaration::VOID: detail = 'v'; break;
2011         case BasicTypeDeclaration::BOOL: detail = 'b'; break;
2012         case BasicTypeDeclaration::INT: detail = (sign ? 'i' : 'u'); break;
2013         case BasicTypeDeclaration::FLOAT: detail = 'f'; break;
2014         default: throw invalid_argument("TypeKey::TypeKey");
2015         }
2016 }
2017
2018 bool SpirVGenerator::TypeKey::operator<(const TypeKey &other) const
2019 {
2020         if(type_id!=other.type_id)
2021                 return type_id<other.type_id;
2022         return detail<other.detail;
2023 }
2024
2025
2026 bool SpirVGenerator::ConstantKey::operator<(const ConstantKey &other) const
2027 {
2028         if(type_id!=other.type_id)
2029                 return type_id<other.type_id;
2030         return int_value<other.int_value;
2031 }
2032
2033 } // namespace SL
2034 } // namespace GL
2035 } // namespace Msp