]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/spirv.cpp
72f73bd56bdf9c4ff1d3738342986b6fbf7a315f
[libs/gl.git] / source / glsl / spirv.cpp
1 #include <msp/core/maputils.h>
2 #include <msp/core/raii.h>
3 #include "reflect.h"
4 #include "spirv.h"
5
6 using namespace std;
7
8 namespace Msp {
9 namespace GL {
10 namespace SL {
11
12 const SpirVGenerator::BuiltinFunctionInfo SpirVGenerator::builtin_functions[] =
13 {
14         { "radians", "f", "GLSL.std.450", GLSL450_RADIANS, { 1 }, 0, 0 },
15         { "degrees", "f", "GLSL.std.450", GLSL450_DEGREES, { 1 }, 0, 0 },
16         { "sin", "f", "GLSL.std.450", GLSL450_SIN, { 1 }, 0, 0 },
17         { "cos", "f", "GLSL.std.450", GLSL450_COS, { 1 }, 0, 0 },
18         { "tan", "f", "GLSL.std.450", GLSL450_TAN, { 1 }, 0, 0 },
19         { "asin", "f", "GLSL.std.450", GLSL450_ASIN, { 1 }, 0, 0 },
20         { "acos", "f", "GLSL.std.450", GLSL450_ACOS, { 1 }, 0, 0 },
21         { "atan", "f", "GLSL.std.450", GLSL450_ATAN, { 1 }, 0, 0 },
22         { "atan", "ff", "GLSL.std.450", GLSL450_ATAN2, { 1, 2 }, 0, 0 },
23         { "sinh", "f", "GLSL.std.450", GLSL450_SINH, { 1 }, 0, 0 },
24         { "cosh", "f", "GLSL.std.450", GLSL450_COSH, { 1 }, 0, 0 },
25         { "tanh", "f", "GLSL.std.450", GLSL450_TANH, { 1 }, 0, 0 },
26         { "asinh", "f", "GLSL.std.450", GLSL450_ASINH, { 1 }, 0, 0 },
27         { "acosh", "f", "GLSL.std.450", GLSL450_ACOSH, { 1 }, 0, 0 },
28         { "atanh", "f", "GLSL.std.450", GLSL450_ATANH, { 1 }, 0, 0 },
29         { "pow", "ff", "GLSL.std.450", GLSL450_POW, { 1, 2 }, 0, 0 },
30         { "exp", "f", "GLSL.std.450", GLSL450_EXP, { 1 }, 0, 0 },
31         { "log", "f", "GLSL.std.450", GLSL450_LOG, { 1 }, 0, 0 },
32         { "exp2", "f", "GLSL.std.450", GLSL450_EXP2, { 1 }, 0, 0 },
33         { "log2", "f", "GLSL.std.450", GLSL450_LOG2, { 1 }, 0, 0 },
34         { "sqrt", "f", "GLSL.std.450", GLSL450_SQRT, { 1 }, 0, 0 },
35         { "inversesqrt", "f", "GLSL.std.450", GLSL450_INVERSE_SQRT, { 1 }, 0, 0 },
36         { "abs", "f", "GLSL.std.450", GLSL450_F_ABS, { 1 }, 0, 0 },
37         { "abs", "i", "GLSL.std.450", GLSL450_S_ABS, { 1 }, 0, 0 },
38         { "sign", "f", "GLSL.std.450", GLSL450_F_SIGN, { 1 }, 0, 0 },
39         { "sign", "i", "GLSL.std.450", GLSL450_S_SIGN, { 1 }, 0, 0 },
40         { "floor", "f", "GLSL.std.450", GLSL450_FLOOR, { 1 }, 0, 0 },
41         { "trunc", "f", "GLSL.std.450", GLSL450_TRUNC, { 1 }, 0, 0 },
42         { "round", "f", "GLSL.std.450", GLSL450_ROUND, { 1 }, 0, 0 },
43         { "roundEven", "f", "GLSL.std.450", GLSL450_ROUND_EVEN, { 1 }, 0, 0 },
44         { "ceil", "f", "GLSL.std.450", GLSL450_CEIL, { 1 }, 0, 0 },
45         { "fract", "f", "GLSL.std.450", GLSL450_FRACT, { 1 }, 0, 0 },
46         { "mod", "f", "", OP_F_MOD, { 1, 2 }, 0, 0 },
47         { "min", "ff", "GLSL.std.450", GLSL450_F_MIN, { 1, 2 }, 0, 0 },
48         { "min", "ii", "GLSL.std.450", GLSL450_S_MIN, { 1, 2 }, 0, 0 },
49         { "min", "uu", "GLSL.std.450", GLSL450_U_MIN, { 1, 2 }, 0, 0 },
50         { "max", "ff", "GLSL.std.450", GLSL450_F_MAX, { 1, 2 }, 0, 0 },
51         { "max", "ii", "GLSL.std.450", GLSL450_S_MAX, { 1, 2 }, 0, 0 },
52         { "max", "uu", "GLSL.std.450", GLSL450_U_MAX, { 1, 2 }, 0, 0 },
53         { "clamp", "fff", "GLSL.std.450", GLSL450_F_CLAMP, { 1, 2, 3 }, 0, 0 },
54         { "clamp", "iii", "GLSL.std.450", GLSL450_S_CLAMP, { 1, 2, 3 }, 0, 0 },
55         { "clamp", "uuu", "GLSL.std.450", GLSL450_U_CLAMP, { 1, 2, 3 }, 0, 0 },
56         { "mix", "fff", "GLSL.std.450", GLSL450_F_MIX, { 1, 2, 3 }, 0, 0 },
57         { "mix", "ffb", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
58         { "mix", "iib", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
59         { "mix", "uub", "", OP_SELECT, { 3, 2, 1 }, 0, 0 },
60         { "step", "ff", "GLSL.std.450", GLSL450_F_STEP, { 1, 2 }, 0, 0 },
61         { "smoothstep", "fff", "GLSL.std.450", GLSL450_F_SMOOTH_STEP, { 1, 2, 3 }, 0, 0 },
62         { "isnan", "f", "", OP_IS_NAN, { 1 }, 0, 0 },
63         { "isinf", "f", "", OP_IS_INF, { 1 }, 0, 0 },
64         { "fma", "fff", "GLSL.std.450", GLSL450_F_FMA, { 1, 2, 3 }, 0, 0 },
65         { "length", "f", "GLSL.std.450", GLSL450_LENGTH, { 1 }, 0, 0 },
66         { "distance", "ff", "GLSL.std.450", GLSL450_DISTANCE, { 1, 2 }, 0, 0 },
67         { "dot", "ff", "", OP_DOT, { 1, 2 }, 0, 0 },
68         { "cross", "ff", "GLSL.std.450", GLSL450_CROSS, { 1, 2 }, 0, 0 },
69         { "normalize", "f", "GLSL.std.450", GLSL450_NORMALIZE, { 1 }, 0, 0 },
70         { "faceforward", "fff", "GLSL.std.450", GLSL450_FACE_FORWARD, { 1, 2, 3 }, 0, 0 },
71         { "reflect", "ff", "GLSL.std.450", GLSL450_REFLECT, { 1, 2 }, 0, 0 },
72         { "refract", "fff", "GLSL.std.450", GLSL450_REFRACT, { 1, 2, 3 }, 0, 0 },
73         { "matrixCompMult", "ff", "", 0, { 0 }, 0, &SpirVGenerator::visit_builtin_matrix_comp_mult },
74         { "outerProduct", "ff", "", OP_OUTER_PRODUCT, { 1, 2 }, 0, 0 },
75         { "transpose", "f", "", OP_TRANSPOSE, { 1 }, 0, 0 },
76         { "determinant", "f", "GLSL.std.450", GLSL450_DETERMINANT, { 1 }, 0, 0 },
77         { "inverse", "f", "GLSL.std.450", GLSL450_MATRIX_INVERSE, { 1 }, 0, 0 },
78         { "lessThan", "ff", "", OP_F_ORD_LESS_THAN, { 1, 2 }, 0, 0 },
79         { "lessThan", "ii", "", OP_S_LESS_THAN, { 1, 2 }, 0, 0 },
80         { "lessThan", "uu", "", OP_U_LESS_THAN, { 1, 2 }, 0, 0 },
81         { "lessThanEqual", "ff", "", OP_F_ORD_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
82         { "lessThanEqual", "ii", "", OP_S_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
83         { "lessThanEqual", "uu", "", OP_U_LESS_THAN_EQUAL, { 1, 2 }, 0, 0 },
84         { "greaterThan", "ff", "", OP_F_ORD_GREATER_THAN, { 1, 2 }, 0, 0 },
85         { "greaterThan", "ii", "", OP_S_GREATER_THAN, { 1, 2 }, 0, 0 },
86         { "greaterThan", "uu", "", OP_U_GREATER_THAN, { 1, 2 }, 0, 0 },
87         { "greaterThanEqual", "ff", "", OP_F_ORD_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
88         { "greaterThanEqual", "ii", "", OP_S_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
89         { "greaterThanEqual", "uu", "", OP_U_GREATER_THAN_EQUAL, { 1, 2 }, 0, 0 },
90         { "equal", "ff", "", OP_F_ORD_EQUAL, { 1, 2 }, 0, 0 },
91         { "equal", "ii", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
92         { "equal", "uu", "", OP_I_EQUAL, { 1, 2 }, 0, 0 },
93         { "notEqual", "ff", "", OP_F_ORD_NOT_EQUAL, { 1, 2 }, 0, 0 },
94         { "notEqual", "ii", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
95         { "notEqual", "uu", "", OP_I_NOT_EQUAL, { 1, 2 }, 0, 0 },
96         { "any", "b", "", OP_ANY, { 1 }, 0, 0 },
97         { "all", "b", "", OP_ALL, { 1 }, 0, 0 },
98         { "not", "b", "", OP_LOGICAL_NOT, { 1 }, 0, 0 },
99         { "bitfieldExtract", "iii", "", OP_BIT_FIELD_S_EXTRACT, { 1, 2, 3 }, 0, 0 },
100         { "bitfieldExtract", "uii", "", OP_BIT_FIELD_U_EXTRACT, { 1, 2, 3 }, 0, 0 },
101         { "bitfieldInsert", "iiii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
102         { "bitfieldInsert", "uuii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0, 0 },
103         { "bitfieldReverse", "i", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
104         { "bitfieldReverse", "u", "", OP_BIT_REVERSE, { 1 }, 0, 0 },
105         { "bitCount", "i", "", OP_BIT_COUNT, { 1 }, 0, 0 },
106         { "findLSB", "i", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
107         { "findLSB", "u", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0, 0 },
108         { "findMSB", "i", "GLSL.std.450", GLSL450_FIND_S_MSB, { 1 }, 0, 0 },
109         { "findMSB", "u", "GLSL.std.450", GLSL450_FIND_U_MSB, { 1 }, 0, 0 },
110         { "textureSize", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
111         { "textureQueryLod", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
112         { "textureQueryLevels", "", "", 0, { }, CAP_IMAGE_QUERY, &SpirVGenerator::visit_builtin_texture_query },
113         { "texture", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
114         { "textureLod", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texture },
115         { "texelFetch", "", "", 0, { }, 0, &SpirVGenerator::visit_builtin_texel_fetch },
116         { "EmitVertex", "", "", OP_EMIT_VERTEX, { }, 0, 0 },
117         { "EndPrimitive", "", "", OP_END_PRIMITIVE, { }, 0, 0 },
118         { "dFdx", "f", "", OP_DP_DX, { 1 }, 0, 0 },
119         { "dFdy", "f", "", OP_DP_DY, { 1 }, 0, 0 },
120         { "dFdxFine", "f", "", OP_DP_DX_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
121         { "dFdyFine", "f", "", OP_DP_DY_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
122         { "dFdxCoarse", "f", "", OP_DP_DX_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
123         { "dFdyCoarse", "f", "", OP_DP_DY_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
124         { "fwidth", "f", "", OP_FWIDTH, { 1 }, 0, 0 },
125         { "fwidthFine", "f", "", OP_FWIDTH_FINE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
126         { "fwidthCoarse", "f", "", OP_FWIDTH_COARSE, { 1 }, CAP_DERIVATIVE_CONTROL, 0 },
127         { "interpolateAtCentroid", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
128         { "interpolateAtSample", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
129         { "interpolateAtOffset", "", "", 0, { }, CAP_INTERPOLATION_FUNCTION, &SpirVGenerator::visit_builtin_interpolate },
130         { "", "", "", 0, { }, 0, 0 }
131 };
132
133 SpirVGenerator::SpirVGenerator():
134         stage(0),
135         current_function(0),
136         writer(content),
137         next_id(1),
138         r_expression_result_id(0),
139         constant_expression(false),
140         spec_constant(false),
141         reachable(false),
142         composite_access(false),
143         r_composite_base_id(0),
144         r_composite_base(0),
145         assignment_source_id(0),
146         loop_merge_block_id(0),
147         loop_continue_target_id(0)
148 { }
149
150 void SpirVGenerator::apply(Module &module)
151 {
152         use_capability(CAP_SHADER);
153
154         for(list<Stage>::iterator i=module.stages.begin(); i!=module.stages.end(); ++i)
155         {
156                 stage = &*i;
157                 interface_layouts.clear();
158                 i->content.visit(*this);
159         }
160
161         writer.finalize(SPIRV_GENERATOR_MSP, next_id);
162 }
163
164 SpirVGenerator::StorageClass SpirVGenerator::get_interface_storage(const string &iface, bool block)
165 {
166         if(iface=="in")
167                 return STORAGE_INPUT;
168         else if(iface=="out")
169                 return STORAGE_OUTPUT;
170         else if(iface=="uniform")
171                 return (block ? STORAGE_UNIFORM : STORAGE_UNIFORM_CONSTANT);
172         else if(iface.empty())
173                 return STORAGE_PRIVATE;
174         else
175                 throw invalid_argument("SpirVGenerator::get_interface_storage");
176 }
177
178 SpirVGenerator::BuiltinSemantic SpirVGenerator::get_builtin_semantic(const string &name)
179 {
180         if(name=="gl_Position")
181                 return BUILTIN_POSITION;
182         else if(name=="gl_PointSize")
183                 return BUILTIN_POINT_SIZE;
184         else if(name=="gl_ClipDistance")
185                 return BUILTIN_CLIP_DISTANCE;
186         else if(name=="gl_VertexID")
187                 return BUILTIN_VERTEX_ID;
188         else if(name=="gl_InstanceID")
189                 return BUILTIN_INSTANCE_ID;
190         else if(name=="gl_PrimitiveID" || name=="gl_PrimitiveIDIn")
191                 return BUILTIN_PRIMITIVE_ID;
192         else if(name=="gl_InvocationID")
193                 return BUILTIN_INVOCATION_ID;
194         else if(name=="gl_Layer")
195                 return BUILTIN_LAYER;
196         else if(name=="gl_FragCoord")
197                 return BUILTIN_FRAG_COORD;
198         else if(name=="gl_PointCoord")
199                 return BUILTIN_POINT_COORD;
200         else if(name=="gl_FrontFacing")
201                 return BUILTIN_FRONT_FACING;
202         else if(name=="gl_SampleId")
203                 return BUILTIN_SAMPLE_ID;
204         else if(name=="gl_SamplePosition")
205                 return BUILTIN_SAMPLE_POSITION;
206         else if(name=="gl_FragDepth")
207                 return BUILTIN_FRAG_DEPTH;
208         else
209                 throw invalid_argument("SpirVGenerator::get_builtin_semantic");
210 }
211
212 void SpirVGenerator::use_capability(Capability cap)
213 {
214         if(used_capabilities.count(cap))
215                 return;
216
217         used_capabilities.insert(cap);
218         writer.write_op(content.capabilities, OP_CAPABILITY, cap);
219 }
220
221 SpirVGenerator::Id SpirVGenerator::import_extension(const string &name)
222 {
223         Id &ext_id = imported_extension_ids[name];
224         if(!ext_id)
225         {
226                 ext_id = next_id++;
227                 writer.begin_op(content.extensions, OP_EXT_INST_IMPORT);
228                 writer.write(ext_id);
229                 writer.write_string(name);
230                 writer.end_op(OP_EXT_INST_IMPORT);
231         }
232         return ext_id;
233 }
234
235 SpirVGenerator::Id SpirVGenerator::get_id(Node &node) const
236 {
237         return get_item(declared_ids, &node).id;
238 }
239
240 SpirVGenerator::Id SpirVGenerator::allocate_id(Node &node, Id type_id)
241 {
242         map<Node *, Declaration>::iterator i = declared_ids.find(&node);
243         if(i!=declared_ids.end())
244         {
245                 if(i->second.type_id)
246                         throw key_error(&node);
247                 i->second.type_id = type_id;
248                 return i->second.id;
249         }
250
251         Id id = next_id++;
252         declared_ids.insert(make_pair(&node, Declaration(id, type_id)));
253         return id;
254 }
255
256 SpirVGenerator::Id SpirVGenerator::allocate_forward_id(Node &node)
257 {
258         map<Node *, Declaration>::iterator i = declared_ids.find(&node);
259         if(i!=declared_ids.end())
260                 return i->second.id;
261
262         Id id = next_id++;
263         declared_ids.insert(make_pair(&node, Declaration(id, 0)));
264         return id;
265 }
266
267 SpirVGenerator::Id SpirVGenerator::write_constant(Id type_id, Word value, bool spec)
268 {
269         Id const_id = next_id++;
270         if(is_scalar_type(type_id, BasicTypeDeclaration::BOOL))
271         {
272                 Opcode opcode = (value ? (spec ? OP_SPEC_CONSTANT_TRUE : OP_CONSTANT_TRUE) :
273                         (spec ? OP_SPEC_CONSTANT_FALSE : OP_CONSTANT_FALSE));
274                 writer.write_op(content.globals, opcode, type_id, const_id);
275         }
276         else
277         {
278                 Opcode opcode = (spec ? OP_SPEC_CONSTANT : OP_CONSTANT);
279                 writer.write_op(content.globals, opcode, type_id, const_id, value);
280         }
281         return const_id;
282 }
283
284 SpirVGenerator::ConstantKey SpirVGenerator::get_constant_key(Id type_id, const Variant &value)
285 {
286         if(value.check_type<bool>())
287                 return ConstantKey(type_id, value.value<bool>());
288         else if(value.check_type<int>())
289                 return ConstantKey(type_id, value.value<int>());
290         else if(value.check_type<unsigned>())
291                 return ConstantKey(type_id, value.value<unsigned>());
292         else if(value.check_type<float>())
293                 return ConstantKey(type_id, value.value<float>());
294         else
295                 throw invalid_argument("SpirVGenerator::get_constant_key");
296 }
297
298 SpirVGenerator::Id SpirVGenerator::get_constant_id(Id type_id, const Variant &value)
299 {
300         ConstantKey key = get_constant_key(type_id, value);
301         Id &const_id = constant_ids[key];
302         if(!const_id)
303                 const_id = write_constant(type_id, key.int_value, false);
304         return const_id;
305 }
306
307 SpirVGenerator::Id SpirVGenerator::get_vector_constant_id(Id type_id, unsigned size, Id scalar_id)
308 {
309         Id &const_id = constant_ids[get_constant_key(type_id, static_cast<int>(scalar_id))];
310         if(!const_id)
311         {
312                 const_id = next_id++;
313                 writer.begin_op(content.globals, OP_CONSTANT_COMPOSITE, 3+size);
314                 writer.write(type_id);
315                 writer.write(const_id);
316                 for(unsigned i=0; i<size; ++i)
317                         writer.write(scalar_id);
318                 writer.end_op(OP_CONSTANT_COMPOSITE);
319         }
320         return const_id;
321 }
322
323 SpirVGenerator::Id SpirVGenerator::get_standard_type_id(BasicTypeDeclaration::Kind kind, unsigned size, bool sign)
324 {
325         Id base_id = (size>1 ? get_standard_type_id(kind, 1, sign) : 0);
326         Id &type_id = standard_type_ids[base_id ? TypeKey(base_id, size) : TypeKey(kind, sign)];
327         if(!type_id)
328         {
329                 type_id = next_id++;
330                 if(size>1)
331                         writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, base_id, size);
332                 else if(kind==BasicTypeDeclaration::VOID)
333                         writer.write_op(content.globals, OP_TYPE_VOID, type_id);
334                 else if(kind==BasicTypeDeclaration::BOOL)
335                         writer.write_op(content.globals, OP_TYPE_BOOL, type_id);
336                 else if(kind==BasicTypeDeclaration::INT)
337                         writer.write_op(content.globals, OP_TYPE_INT, type_id, 32, sign);
338                 else if(kind==BasicTypeDeclaration::FLOAT)
339                         writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, 32);
340                 else
341                         throw invalid_argument("SpirVGenerator::get_standard_type_id");
342         }
343         return type_id;
344 }
345
346 bool SpirVGenerator::is_scalar_type(Id type_id, BasicTypeDeclaration::Kind kind) const
347 {
348         map<TypeKey, Id>::const_iterator i = standard_type_ids.find(TypeKey(kind, true));
349         return (i!=standard_type_ids.end() && i->second==type_id);
350 }
351
352 SpirVGenerator::Id SpirVGenerator::get_array_type_id(TypeDeclaration &base_type, Id size_id)
353 {
354         Id base_type_id = get_id(base_type);
355         Id &array_type_id = array_type_ids[TypeKey(base_type_id, size_id)];
356         if(!array_type_id)
357         {
358                 array_type_id = next_id++;
359                 if(size_id)
360                         writer.write_op(content.globals, OP_TYPE_ARRAY, array_type_id, base_type_id, size_id);
361                 else
362                         writer.write_op(content.globals, OP_TYPE_RUNTIME_ARRAY, array_type_id, base_type_id);
363
364                 unsigned stride = MemoryRequirementsCalculator().apply(base_type).stride;
365                 writer.write_op_decorate(array_type_id, DECO_ARRAY_STRIDE, stride);
366         }
367
368         return array_type_id;
369 }
370
371 SpirVGenerator::Id SpirVGenerator::get_pointer_type_id(Id type_id, StorageClass storage)
372 {
373         Id &ptr_type_id = pointer_type_ids[TypeKey(type_id, storage)];
374         if(!ptr_type_id)
375         {
376                 ptr_type_id = next_id++;
377                 writer.write_op(content.globals, OP_TYPE_POINTER, ptr_type_id, storage, type_id);
378         }
379         return ptr_type_id;
380 }
381
382 SpirVGenerator::Id SpirVGenerator::get_variable_type_id(const VariableDeclaration &var)
383 {
384         if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var.type_declaration))
385                 if(basic->kind==BasicTypeDeclaration::ARRAY)
386                 {
387                         Id size_id = 0;
388                         if(var.array_size)
389                         {
390                                 SetFlag set_const(constant_expression);
391                                 r_expression_result_id = 0;
392                                 var.array_size->visit(*this);
393                                 size_id = r_expression_result_id;
394                         }
395                         else
396                                 size_id = get_constant_id(get_standard_type_id(BasicTypeDeclaration::INT, 1), 1);
397                         return get_array_type_id(*basic->base_type, size_id);
398                 }
399
400         return get_id(*var.type_declaration);
401 }
402
403 SpirVGenerator::Id SpirVGenerator::get_load_id(VariableDeclaration &var)
404 {
405         Id &load_result_id = variable_load_ids[&var];
406         if(!load_result_id)
407         {
408                 load_result_id = next_id++;
409                 writer.write_op(content.function_body, OP_LOAD, get_variable_type_id(var), load_result_id, get_id(var));
410         }
411         return load_result_id;
412 }
413
414 void SpirVGenerator::prune_loads(Id min_id)
415 {
416         for(map<const VariableDeclaration *, Id>::iterator i=variable_load_ids.begin(); i!=variable_load_ids.end(); )
417         {
418                 if(i->second>=min_id)
419                         variable_load_ids.erase(i++);
420                 else
421                         ++i;
422         }
423 }
424
425 SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, unsigned n_args)
426 {
427         bool has_result = (opcode==OP_FUNCTION_CALL || !is_scalar_type(type_id, BasicTypeDeclaration::VOID));
428         if(!constant_expression)
429         {
430                 if(!current_function)
431                         throw internal_error("non-constant expression outside a function");
432
433                 writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0));
434         }
435         else if(opcode==OP_COMPOSITE_CONSTRUCT)
436                 writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE),
437                         (n_args ? 1+has_result*2+n_args : 0));
438         else if(!spec_constant)
439                 throw internal_error("invalid non-specialization constant expression");
440         else
441                 writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0));
442
443         Id result_id = next_id++;
444         if(has_result)
445         {
446                 writer.write(type_id);
447                 writer.write(result_id);
448         }
449         if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT)
450                 writer.write(opcode);
451
452         return result_id;
453 }
454
455 void SpirVGenerator::end_expression(Opcode opcode)
456 {
457         if(constant_expression)
458                 opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
459         writer.end_op(opcode);
460 }
461
462 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id arg_id)
463 {
464         Id result_id = begin_expression(opcode, type_id, 1);
465         writer.write(arg_id);
466         end_expression(opcode);
467         return result_id;
468 }
469
470 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id left_id, Id right_id)
471 {
472         Id result_id = begin_expression(opcode, type_id, 2);
473         writer.write(left_id);
474         writer.write(right_id);
475         end_expression(opcode);
476         return result_id;
477 }
478
479 void SpirVGenerator::write_deconstruct(Id elem_type_id, Id composite_id, Id *elem_ids, unsigned n_elems)
480 {
481         for(unsigned i=0; i<n_elems; ++i)
482         {
483                 elem_ids[i] = begin_expression(OP_COMPOSITE_EXTRACT, elem_type_id, 2);
484                 writer.write(composite_id);
485                 writer.write(i);
486                 end_expression(OP_COMPOSITE_EXTRACT);
487         }
488 }
489
490 SpirVGenerator::Id SpirVGenerator::write_construct(Id type_id, const Id *elem_ids, unsigned n_elems)
491 {
492         Id result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, type_id, n_elems);
493         for(unsigned i=0; i<n_elems; ++i)
494                 writer.write(elem_ids[i]);
495         end_expression(OP_COMPOSITE_CONSTRUCT);
496
497         return result_id;
498 }
499
500 void SpirVGenerator::visit(Block &block)
501 {
502         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
503                 (*i)->visit(*this);
504 }
505
506 void SpirVGenerator::visit(Literal &literal)
507 {
508         Id type_id = get_id(*literal.type);
509         if(spec_constant)
510                 r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true);
511         else
512                 r_expression_result_id = get_constant_id(type_id, literal.value);
513         r_constant_result = true;
514 }
515
516 void SpirVGenerator::visit(VariableReference &var)
517 {
518         if(constant_expression || var.declaration->constant)
519         {
520                 if(!var.declaration->constant)
521                         throw internal_error("reference to non-constant variable in constant context");
522
523                 r_expression_result_id = get_id(*var.declaration);
524                 r_constant_result = true;
525                 return;
526         }
527         else if(!current_function)
528                 throw internal_error("non-constant context outside a function");
529
530         r_constant_result = false;
531         if(composite_access)
532         {
533                 r_composite_base = var.declaration;
534                 r_expression_result_id = 0;
535         }
536         else if(assignment_source_id)
537         {
538                 writer.write_op(content.function_body, OP_STORE, get_id(*var.declaration), assignment_source_id);
539                 variable_load_ids[var.declaration] = assignment_source_id;
540                 r_expression_result_id = assignment_source_id;
541         }
542         else
543                 r_expression_result_id = get_load_id(*var.declaration);
544 }
545
546 void SpirVGenerator::visit(InterfaceBlockReference &iface)
547 {
548         if(!composite_access || !current_function)
549                 throw internal_error("invalid interface block reference");
550
551         r_composite_base = iface.declaration;
552         r_expression_result_id = 0;
553         r_constant_result = false;
554 }
555
556 void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
557 {
558         Opcode opcode;
559         Id result_type_id = get_id(result_type);
560         Id access_type_id = result_type_id;
561         if(r_composite_base)
562         {
563                 if(constant_expression)
564                         throw internal_error("composite access through pointer in constant context");
565
566                 Id int32_type_id = get_standard_type_id(BasicTypeDeclaration::INT, 1);
567                 for(vector<unsigned>::iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
568                         *i = (*i<0x400000 ? get_constant_id(int32_type_id, static_cast<int>(*i)) : *i&0x3FFFFF);
569
570                 /* Find the storage class of the base and obtain appropriate pointer type
571                 for the result. */
572                 const Declaration &base_decl = get_item(declared_ids, r_composite_base);
573                 map<TypeKey, Id>::const_iterator i = pointer_type_ids.begin();
574                 for(; (i!=pointer_type_ids.end() && i->second!=base_decl.type_id); ++i) ;
575                 if(i==pointer_type_ids.end())
576                         throw internal_error("could not find storage class");
577                 access_type_id = get_pointer_type_id(result_type_id, static_cast<StorageClass>(i->first.detail));
578
579                 opcode = OP_ACCESS_CHAIN;
580         }
581         else if(assignment_source_id)
582                 throw internal_error("assignment to temporary composite");
583         else
584         {
585                 for(vector<unsigned>::iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
586                         for(map<ConstantKey, Id>::iterator j=constant_ids.begin(); (*i>=0x400000 && j!=constant_ids.end()); ++j)
587                                 if(j->second==(*i&0x3FFFFF))
588                                         *i = j->first.int_value;
589
590                 opcode = OP_COMPOSITE_EXTRACT;
591         }
592
593         Id access_id = begin_expression(opcode, access_type_id, 1+r_composite_chain.size());
594         writer.write(r_composite_base_id);
595         for(vector<unsigned>::const_iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
596                 writer.write(*i);
597         end_expression(opcode);
598
599         r_constant_result = false;
600         if(r_composite_base)
601         {
602                 if(assignment_source_id)
603                 {
604                         writer.write_op(content.function_body, OP_STORE, access_id, assignment_source_id);
605                         r_expression_result_id = assignment_source_id;
606                 }
607                 else
608                         r_expression_result_id = write_expression(OP_LOAD, result_type_id, access_id);
609         }
610         else
611                 r_expression_result_id = access_id;
612 }
613
614 void SpirVGenerator::visit_composite(Expression &base_expr, unsigned index, TypeDeclaration &type)
615 {
616         if(!composite_access)
617         {
618                 r_composite_base = 0;
619                 r_composite_base_id = 0;
620                 r_composite_chain.clear();
621         }
622
623         {
624                 SetFlag set_composite(composite_access);
625                 base_expr.visit(*this);
626         }
627
628         if(!r_composite_base_id)
629                 r_composite_base_id = (r_composite_base ? get_id(*r_composite_base) : r_expression_result_id);
630
631         r_composite_chain.push_back(index);
632         if(!composite_access)
633                 generate_composite_access(type);
634         else
635                 r_expression_result_id = 0;
636 }
637
638 void SpirVGenerator::visit_isolated(Expression &expr)
639 {
640         SetForScope<Id> clear_assign(assignment_source_id, 0);
641         SetFlag clear_composite(composite_access, false);
642         SetForScope<Node *> clear_base(r_composite_base, 0);
643         SetForScope<Id> clear_base_id(r_composite_base_id, 0);
644         vector<unsigned> saved_chain;
645         swap(saved_chain, r_composite_chain);
646         expr.visit(*this);
647         swap(saved_chain, r_composite_chain);
648 }
649
650 void SpirVGenerator::visit(MemberAccess &memacc)
651 {
652         visit_composite(*memacc.left, memacc.index, *memacc.type);
653 }
654
655 void SpirVGenerator::visit(Swizzle &swizzle)
656 {
657         if(swizzle.count==1)
658                 visit_composite(*swizzle.left, swizzle.components[0], *swizzle.type);
659         else if(assignment_source_id)
660         {
661                 const BasicTypeDeclaration &basic = dynamic_cast<const BasicTypeDeclaration &>(*swizzle.left->type);
662
663                 unsigned mask = 0;
664                 for(unsigned i=0; i<swizzle.count; ++i)
665                         mask |= 1<<swizzle.components[i];
666
667                 visit_isolated(*swizzle.left);
668
669                 Id combined_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.left->type), 2+basic.size);
670                 writer.write(r_expression_result_id);
671                 writer.write(assignment_source_id);
672                 for(unsigned i=0; i<basic.size; ++i)
673                         writer.write(i+((mask>>i)&1)*basic.size);
674                 end_expression(OP_VECTOR_SHUFFLE);
675
676                 SetForScope<Id> set_assign(assignment_source_id, combined_id);
677                 swizzle.left->visit(*this);
678
679                 r_expression_result_id = combined_id;
680         }
681         else
682         {
683                 swizzle.left->visit(*this);
684                 Id left_id = r_expression_result_id;
685
686                 r_expression_result_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.type), 2+swizzle.count);
687                 writer.write(left_id);
688                 writer.write(left_id);
689                 for(unsigned i=0; i<swizzle.count; ++i)
690                         writer.write(swizzle.components[i]);
691                 end_expression(OP_VECTOR_SHUFFLE);
692         }
693         r_constant_result = false;
694 }
695
696 void SpirVGenerator::visit(UnaryExpression &unary)
697 {
698         unary.expression->visit(*this);
699
700         char oper = unary.oper->token[0];
701         char oper2 = unary.oper->token[1];
702         if(oper=='+' && !oper2)
703                 return;
704
705         BasicTypeDeclaration &basic = dynamic_cast<BasicTypeDeclaration &>(*unary.expression->type);
706         BasicTypeDeclaration &elem = *get_element_type(basic);
707
708         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
709                 /* SPIR-V allows constant operations on floating-point values only for
710                 OpenGL kernels. */
711                 throw internal_error("invalid operands for constant unary expression");
712
713         Id result_type_id = get_id(*unary.type);
714         Opcode opcode = OP_NOP;
715
716         r_constant_result = false;
717         if(oper=='!')
718                 opcode = OP_LOGICAL_NOT;
719         else if(oper=='~')
720                 opcode = OP_NOT;
721         else if(oper=='-' && !oper2)
722         {
723                 opcode = (elem.kind==BasicTypeDeclaration::INT ? OP_S_NEGATE : OP_F_NEGATE);
724
725                 if(basic.kind==BasicTypeDeclaration::MATRIX)
726                 {
727                         Id column_type_id = get_id(*basic.base_type);
728                         unsigned n_columns = basic.size&0xFFFF;
729                         Id column_ids[4];
730                         write_deconstruct(column_type_id, r_expression_result_id, column_ids, n_columns);
731                         for(unsigned i=0; i<n_columns; ++i)
732                                 column_ids[i] = write_expression(opcode, column_type_id, column_ids[i]);
733                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
734                         return;
735                 }
736         }
737         else if((oper=='+' || oper=='-') && oper2==oper)
738         {
739                 if(constant_expression)
740                         throw internal_error("increment or decrement in constant expression");
741
742                 Id one_id = 0;
743                 if(elem.kind==BasicTypeDeclaration::INT)
744                 {
745                         opcode = (oper=='+' ? OP_I_ADD : OP_I_SUB);
746                         one_id = get_constant_id(get_id(elem), 1);
747                 }
748                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
749                 {
750                         opcode = (oper=='+' ? OP_F_ADD : OP_F_SUB);
751                         one_id = get_constant_id(get_id(elem), 1.0f);
752                 }
753                 else
754                         throw internal_error("invalid increment/decrement");
755
756                 if(basic.kind==BasicTypeDeclaration::VECTOR)
757                         one_id = get_vector_constant_id(result_type_id, basic.size, one_id);
758
759                 Id post_value_id = write_expression(opcode, result_type_id, r_expression_result_id, one_id);
760
761                 SetForScope<Id> set_assign(assignment_source_id, post_value_id);
762                 unary.expression->visit(*this);
763
764                 r_expression_result_id = (unary.oper->type==Operator::POSTFIX ? r_expression_result_id : post_value_id);
765                 return;
766         }
767
768         if(opcode==OP_NOP)
769                 throw internal_error("unknown unary operator");
770
771         r_expression_result_id = write_expression(opcode, result_type_id, r_expression_result_id);
772 }
773
774 void SpirVGenerator::visit(BinaryExpression &binary)
775 {
776         char oper = binary.oper->token[0];
777         if(oper=='[')
778         {
779                 visit_isolated(*binary.right);
780                 return visit_composite(*binary.left, 0x400000|r_expression_result_id, *binary.type);
781         }
782
783         if(assignment_source_id)
784                 throw internal_error("invalid binary expression in assignment target");
785
786         BasicTypeDeclaration &basic_left = dynamic_cast<BasicTypeDeclaration &>(*binary.left->type);
787         BasicTypeDeclaration &basic_right = dynamic_cast<BasicTypeDeclaration &>(*binary.right->type);
788         // Expression resolver ensures that element types are the same
789         BasicTypeDeclaration &elem = *get_element_type(basic_left);
790
791         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
792                 /* SPIR-V allows constant operations on floating-point values only for
793                 OpenGL kernels. */
794                 throw internal_error("invalid operands for constant binary expression");
795
796         binary.left->visit(*this);
797         Id left_id = r_expression_result_id;
798         binary.right->visit(*this);
799         Id right_id = r_expression_result_id;
800
801         Id result_type_id = get_id(*binary.type);
802         Opcode opcode = OP_NOP;
803         bool swap_operands = false;
804
805         r_constant_result = false;
806
807         char oper2 = binary.oper->token[1];
808         if((oper=='<' || oper=='>') && oper2!=oper)
809         {
810                 if(basic_left.kind==BasicTypeDeclaration::INT)
811                 {
812                         if(basic_left.sign)
813                                 opcode = (oper=='<' ? (oper2=='=' ? OP_S_LESS_THAN_EQUAL : OP_S_LESS_THAN) :
814                                         (oper2=='=' ? OP_S_GREATER_THAN_EQUAL : OP_S_GREATER_THAN));
815                         else
816                                 opcode = (oper=='<' ? (oper2=='=' ? OP_U_LESS_THAN_EQUAL : OP_U_LESS_THAN) :
817                                         (oper2=='=' ? OP_U_GREATER_THAN_EQUAL : OP_U_GREATER_THAN));
818                 }
819                 else if(basic_left.kind==BasicTypeDeclaration::FLOAT)
820                         opcode = (oper=='<' ? (oper2=='=' ? OP_F_ORD_LESS_THAN_EQUAL : OP_F_ORD_LESS_THAN) :
821                                 (oper2=='=' ? OP_F_ORD_GREATER_THAN_EQUAL : OP_F_ORD_GREATER_THAN));
822         }
823         else if((oper=='=' || oper=='!') && oper2=='=')
824         {
825                 if(elem.kind==BasicTypeDeclaration::BOOL)
826                         opcode = (oper=='=' ? OP_LOGICAL_EQUAL : OP_LOGICAL_NOT_EQUAL);
827                 else if(elem.kind==BasicTypeDeclaration::INT)
828                         opcode = (oper=='=' ? OP_I_EQUAL : OP_I_NOT_EQUAL);
829                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
830                         opcode = (oper=='=' ? OP_F_ORD_EQUAL : OP_F_ORD_NOT_EQUAL);
831
832                 if(opcode!=OP_NOP && basic_left.base_type)
833                 {
834                         /* The SPIR-V equality operations produce component-wise results, but
835                         GLSL operators return a single boolean.  Use the any/all operations to
836                         combine the results. */
837                         Opcode combine_op = (oper=='!' ? OP_ANY : OP_ALL);
838                         unsigned n_elems = basic_left.size&0xFFFF;
839                         Id bool_vec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, n_elems);
840
841                         Id compare_id = 0;
842                         if(basic_left.kind==BasicTypeDeclaration::VECTOR)
843                                 compare_id = write_expression(opcode, bool_vec_type_id, left_id, right_id);
844                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX)
845                         {
846                                 Id column_type_id = get_id(*basic_left.base_type);
847                                 Id column_ids[8];
848                                 write_deconstruct(column_type_id, left_id, column_ids, n_elems);
849                                 write_deconstruct(column_type_id, right_id, column_ids+4, n_elems);
850
851                                 Id column_bvec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, basic_left.size>>16);
852                                 for(unsigned i=0; i<n_elems; ++i)
853                                 {
854                                         compare_id = write_expression(opcode, column_bvec_type_id, column_ids[i], column_ids[4+i]);
855                                         column_ids[i] = write_expression(combine_op, result_type_id, compare_id);;
856                                 }
857
858                                 compare_id = write_construct(bool_vec_type_id, column_ids, n_elems);
859                         }
860
861                         if(compare_id)
862                                 r_expression_result_id = write_expression(combine_op, result_type_id, compare_id);
863                         return;
864                 }
865         }
866         else if(oper2=='&' && elem.kind==BasicTypeDeclaration::BOOL)
867                 opcode = OP_LOGICAL_AND;
868         else if(oper2=='|' && elem.kind==BasicTypeDeclaration::BOOL)
869                 opcode = OP_LOGICAL_OR;
870         else if(oper2=='^' && elem.kind==BasicTypeDeclaration::BOOL)
871                 opcode = OP_LOGICAL_NOT_EQUAL;
872         else if(oper=='&' && elem.kind==BasicTypeDeclaration::INT)
873                 opcode = OP_BITWISE_AND;
874         else if(oper=='|' && elem.kind==BasicTypeDeclaration::INT)
875                 opcode = OP_BITWISE_OR;
876         else if(oper=='^' && elem.kind==BasicTypeDeclaration::INT)
877                 opcode = OP_BITWISE_XOR;
878         else if(oper=='<' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
879                 opcode = OP_SHIFT_LEFT_LOGICAL;
880         else if(oper=='>' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
881                 opcode = OP_SHIFT_RIGHT_ARITHMETIC;
882         else if(oper=='%' && elem.kind==BasicTypeDeclaration::INT)
883                 opcode = (elem.sign ? OP_S_MOD : OP_U_MOD);
884         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
885         {
886                 Opcode elem_op = OP_NOP;
887                 if(elem.kind==BasicTypeDeclaration::INT)
888                 {
889                         if(oper=='/')
890                                 elem_op = (elem.sign ? OP_S_DIV : OP_U_DIV);
891                         else
892                                 elem_op = (oper=='+' ? OP_I_ADD : oper=='-' ? OP_I_SUB : OP_I_MUL);
893                 }
894                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
895                         elem_op = (oper=='+' ? OP_F_ADD : oper=='-' ? OP_F_SUB : oper=='*' ? OP_F_MUL : OP_F_DIV);
896
897                 if(oper=='*' && (basic_left.base_type || basic_right.base_type) && elem.kind==BasicTypeDeclaration::FLOAT)
898                 {
899                         /* Multiplication between floating-point vectors and matrices has
900                         dedicated operations. */
901                         if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
902                                 opcode = OP_MATRIX_TIMES_MATRIX;
903                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
904                         {
905                                 if(basic_left.kind==BasicTypeDeclaration::VECTOR)
906                                         opcode = OP_VECTOR_TIMES_MATRIX;
907                                 else if(basic_right.kind==BasicTypeDeclaration::VECTOR)
908                                         opcode = OP_MATRIX_TIMES_VECTOR;
909                                 else
910                                 {
911                                         opcode = OP_MATRIX_TIMES_SCALAR;
912                                         swap_operands = (basic_right.kind==BasicTypeDeclaration::MATRIX);
913                                 }
914                         }
915                         else if(basic_left.kind==BasicTypeDeclaration::VECTOR && basic_right.kind==BasicTypeDeclaration::VECTOR)
916                                 opcode = elem_op;
917                         else
918                         {
919                                 opcode = OP_VECTOR_TIMES_SCALAR;
920                                 swap_operands = (basic_right.kind==BasicTypeDeclaration::VECTOR);
921                         }
922                 }
923                 else if((basic_left.base_type!=0)!=(basic_right.base_type!=0))
924                 {
925                         /* One operand is scalar and the other is a vector or a matrix.
926                         Expand the scalar to a vector of appropriate size. */
927                         Id &scalar_id = (basic_left.base_type ? right_id : left_id);
928                         BasicTypeDeclaration *vector_type = (basic_left.base_type ? &basic_left : &basic_right);
929                         if(vector_type->kind==BasicTypeDeclaration::MATRIX)
930                                 vector_type = dynamic_cast<BasicTypeDeclaration *>(vector_type->base_type);
931                         Id vector_type_id = get_id(*vector_type);
932
933                         Id expanded_id = begin_expression(OP_COMPOSITE_CONSTRUCT, vector_type_id, vector_type->size);
934                         for(unsigned i=0; i<vector_type->size; ++i)
935                                 writer.write(scalar_id);
936                         end_expression(OP_COMPOSITE_CONSTRUCT);
937
938                         scalar_id = expanded_id;
939
940                         if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
941                         {
942                                 // Apply matrix operation column-wise.
943                                 Id matrix_id = (basic_left.base_type ? left_id : right_id);
944
945                                 Id column_ids[4];
946                                 unsigned n_columns = (basic_left.base_type ? basic_left.size : basic_right.size)&0xFFFF;
947                                 write_deconstruct(vector_type_id, matrix_id, column_ids, n_columns);
948
949                                 for(unsigned i=0; i<n_columns; ++i)
950                                         column_ids[i] = write_expression(elem_op, vector_type_id, column_ids[i], expanded_id);
951
952                                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
953                                 return;
954                         }
955                         else
956                                 opcode = elem_op;
957                 }
958                 else if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
959                 {
960                         if(oper=='*')
961                                 throw internal_error("non-float matrix multiplication");
962
963                         /* Other operations involving matrices need to be performed
964                         column-wise. */
965                         Id column_type_id = get_id(*basic_left.base_type);
966                         Id column_ids[8];
967
968                         unsigned n_columns = basic_left.size&0xFFFF;
969                         write_deconstruct(column_type_id, left_id, column_ids, n_columns);
970                         write_deconstruct(column_type_id, right_id, column_ids+4, n_columns);
971
972                         for(unsigned i=0; i<n_columns; ++i)
973                                 column_ids[i] = write_expression(elem_op, column_type_id, column_ids[i], column_ids[4+i]);
974
975                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
976                         return;
977                 }
978                 else if(basic_left.kind==basic_right.kind)
979                         // Both operands are either scalars or vectors.
980                         opcode = elem_op;
981         }
982
983         if(opcode==OP_NOP)
984                 throw internal_error("unknown binary operator");
985
986         if(swap_operands)
987                 swap(left_id, right_id);
988
989         r_expression_result_id = write_expression(opcode, result_type_id, left_id, right_id);
990 }
991
992 void SpirVGenerator::visit(Assignment &assign)
993 {
994         if(assign.oper->token[0]!='=')
995                 visit(static_cast<BinaryExpression &>(assign));
996         else
997                 assign.right->visit(*this);
998
999         SetForScope<Id> set_assign(assignment_source_id, r_expression_result_id);
1000         assign.left->visit(*this);
1001         r_constant_result = false;
1002 }
1003
1004 void SpirVGenerator::visit(TernaryExpression &ternary)
1005 {
1006         if(constant_expression)
1007         {
1008                 ternary.condition->visit(*this);
1009                 Id condition_id = r_expression_result_id;
1010                 ternary.true_expr->visit(*this);
1011                 Id true_result_id = r_expression_result_id;
1012                 ternary.false_expr->visit(*this);
1013                 Id false_result_id = r_expression_result_id;
1014
1015                 r_expression_result_id = begin_expression(OP_SELECT, get_id(*ternary.type), 3);
1016                 writer.write(condition_id);
1017                 writer.write(true_result_id);
1018                 writer.write(false_result_id);
1019                 end_expression(OP_SELECT);
1020
1021                 return;
1022         }
1023
1024         ternary.condition->visit(*this);
1025         Id condition_id = r_expression_result_id;
1026
1027         Id true_label_id = next_id++;
1028         Id false_label_id = next_id++;
1029         Id merge_block_id = next_id++;
1030         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1031         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, condition_id, true_label_id, false_label_id);
1032
1033         writer.write_op_label(true_label_id);
1034         ternary.true_expr->visit(*this);
1035         Id true_result_id = r_expression_result_id;
1036         writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1037
1038         writer.write_op_label(false_label_id);
1039         ternary.false_expr->visit(*this);
1040         Id false_result_id = r_expression_result_id;
1041
1042         writer.write_op_label(merge_block_id);
1043         r_expression_result_id = begin_expression(OP_PHI, get_id(*ternary.type), 4);
1044         writer.write(true_result_id);
1045         writer.write(true_label_id);
1046         writer.write(false_result_id);
1047         writer.write(false_label_id);
1048         end_expression(OP_PHI);
1049
1050         r_constant_result = false;
1051 }
1052
1053 void SpirVGenerator::visit(FunctionCall &call)
1054 {
1055         if(assignment_source_id)
1056                 throw internal_error("assignment to function call");
1057         else if(composite_access)
1058                 return visit_isolated(call);
1059         else if(call.constructor && call.arguments.size()==1 && call.arguments[0]->type==call.type)
1060                 return call.arguments[0]->visit(*this);
1061
1062         vector<Id> argument_ids;
1063         argument_ids.reserve(call.arguments.size());
1064         bool all_args_const = true;
1065         for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
1066         {
1067                 (*i)->visit(*this);
1068                 argument_ids.push_back(r_expression_result_id);
1069                 all_args_const &= r_constant_result;
1070         }
1071
1072         if(constant_expression && (!call.constructor || !all_args_const))
1073                 throw internal_error("function call in constant expression");
1074
1075         Id result_type_id = get_id(*call.type);
1076         r_constant_result = false;
1077
1078         if(call.constructor)
1079                 visit_constructor(call, argument_ids, all_args_const);
1080         else if(call.declaration->source==BUILTIN_SOURCE)
1081         {
1082                 string arg_types;
1083                 for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
1084                         if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>((*i)->type))
1085                         {
1086                                 BasicTypeDeclaration &elem_arg = *get_element_type(*basic_arg);
1087                                 switch(elem_arg.kind)
1088                                 {
1089                                 case BasicTypeDeclaration::BOOL: arg_types += 'b'; break;
1090                                 case BasicTypeDeclaration::INT: arg_types += (elem_arg.sign ? 'i' : 'u'); break;
1091                                 case BasicTypeDeclaration::FLOAT: arg_types += 'f'; break;
1092                                 default: arg_types += '?';
1093                                 }
1094                         }
1095
1096                 const BuiltinFunctionInfo *builtin_info;
1097                 for(builtin_info=builtin_functions; builtin_info->function[0]; ++builtin_info)
1098                         if(builtin_info->function==call.name && (!builtin_info->arg_types[0] || builtin_info->arg_types==arg_types))
1099                                 break;
1100
1101                 if(builtin_info->capability)
1102                         use_capability(static_cast<Capability>(builtin_info->capability));
1103
1104                 if(builtin_info->opcode)
1105                 {
1106                         Opcode opcode;
1107                         if(builtin_info->extension[0])
1108                         {
1109                                 opcode = OP_EXT_INST;
1110                                 Id ext_id = import_extension(builtin_info->extension);
1111
1112                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1113                                 writer.write(ext_id);
1114                                 writer.write(builtin_info->opcode);
1115                         }
1116                         else
1117                         {
1118                                 opcode = static_cast<Opcode>(builtin_info->opcode);
1119                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1120                         }
1121
1122                         for(unsigned i=0; i<call.arguments.size(); ++i)
1123                         {
1124                                 if(!builtin_info->arg_order[i] || builtin_info->arg_order[i]>argument_ids.size())
1125                                         throw internal_error("invalid builtin function info");
1126                                 writer.write(argument_ids[builtin_info->arg_order[i]-1]);
1127                         }
1128
1129                         end_expression(opcode);
1130                 }
1131                 else if(builtin_info->handler)
1132                         (this->*(builtin_info->handler))(call, argument_ids);
1133                 else
1134                         throw internal_error("unknown builtin function "+call.name);
1135         }
1136         else
1137         {
1138                 r_expression_result_id = begin_expression(OP_FUNCTION_CALL, result_type_id, 1+call.arguments.size());
1139                 writer.write(get_id(*call.declaration->definition));
1140                 for(vector<Id>::const_iterator i=argument_ids.begin(); i!=argument_ids.end(); ++i)
1141                         writer.write(*i);
1142                 end_expression(OP_FUNCTION_CALL);
1143
1144                 // Any global variables the called function uses might have changed value
1145                 set<Node *> dependencies = DependencyCollector().apply(*call.declaration->definition);
1146                 for(set<Node *>::const_iterator i=dependencies.begin(); i!=dependencies.end(); ++i)
1147                         if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(*i))
1148                                 variable_load_ids.erase(var);
1149         }
1150 }
1151
1152 void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids, bool all_args_const)
1153 {
1154         Id result_type_id = get_id(*call.type);
1155
1156         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(call.type);
1157         if(!basic)
1158         {
1159                 if(dynamic_cast<const StructDeclaration *>(call.type))
1160                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1161                 else
1162                         throw internal_error("unconstructable type "+call.name);
1163                 return;
1164         }
1165
1166         SetFlag set_const(constant_expression, constant_expression || all_args_const);
1167
1168         BasicTypeDeclaration &elem = *get_element_type(*basic);
1169         BasicTypeDeclaration &basic_arg0 = dynamic_cast<BasicTypeDeclaration &>(*call.arguments[0]->type);
1170         BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0);
1171
1172         if(basic->kind==BasicTypeDeclaration::MATRIX)
1173         {
1174                 Id col_type_id = get_id(*basic->base_type);
1175                 unsigned n_columns = basic->size&0xFFFF;
1176                 unsigned n_rows = basic->size>>16;
1177
1178                 Id column_ids[4];
1179                 if(call.arguments.size()==1)
1180                 {
1181                         // Construct diagonal matrix from a single scalar.
1182                         Id zero_id = get_constant_id(get_id(elem), 0.0f);
1183                         for(unsigned i=0; i<n_columns; ++i)
1184                         {
1185                                 column_ids[i] = begin_expression(OP_COMPOSITE_CONSTRUCT, col_type_id, n_rows);
1186                                 for(unsigned j=0; j<n_rows; ++j)
1187                                         writer.write(j==i ? argument_ids[0] : zero_id);
1188                                 end_expression(OP_COMPOSITE_CONSTRUCT);
1189                         }
1190                 }
1191                 else
1192                         // Construct a matrix from column vectors
1193                         copy(argument_ids.begin(), argument_ids.begin()+n_columns, column_ids);
1194
1195                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
1196         }
1197         else if(basic->kind==BasicTypeDeclaration::VECTOR && (call.arguments.size()>1 || basic_arg0.kind!=BasicTypeDeclaration::VECTOR))
1198         {
1199                 /* There's either a single scalar argument or multiple arguments
1200                 which make up the vector's components. */
1201                 if(call.arguments.size()==1)
1202                 {
1203                         r_expression_result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, result_type_id);
1204                         for(unsigned i=0; i<basic->size; ++i)
1205                                 writer.write(argument_ids[0]);
1206                         end_expression(OP_COMPOSITE_CONSTRUCT);
1207                 }
1208                 else
1209                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1210         }
1211         else if(elem.kind==BasicTypeDeclaration::BOOL)
1212         {
1213                 if(constant_expression)
1214                         throw internal_error("unconverted constant");
1215
1216                 // Conversion to boolean is implemented as comparing against zero.
1217                 Id number_type_id = get_id(elem_arg0);
1218                 Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ?
1219                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1220                 if(basic_arg0.kind==BasicTypeDeclaration::VECTOR)
1221                         zero_id = get_vector_constant_id(get_id(basic_arg0), basic_arg0.size, zero_id);
1222
1223                 Opcode opcode = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ? OP_F_ORD_NOT_EQUAL : OP_I_NOT_EQUAL);
1224                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0], zero_id);
1225         }
1226         else if(elem_arg0.kind==BasicTypeDeclaration::BOOL)
1227         {
1228                 if(constant_expression)
1229                         throw internal_error("unconverted constant");
1230
1231                 /* Conversion from boolean is implemented as selecting from zero
1232                 or one. */
1233                 Id number_type_id = get_id(elem);
1234                 Id zero_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1235                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1236                 Id one_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1237                         get_constant_id(number_type_id, 1.0f) : get_constant_id(number_type_id, 1));
1238                 if(basic->kind==BasicTypeDeclaration::VECTOR)
1239                 {
1240                         zero_id = get_vector_constant_id(get_id(*basic), basic->size, zero_id);
1241                         one_id = get_vector_constant_id(get_id(*basic), basic->size, one_id);
1242                 }
1243
1244                 r_expression_result_id = begin_expression(OP_SELECT, result_type_id, 3);
1245                 writer.write(argument_ids[0]);
1246                 writer.write(zero_id);
1247                 writer.write(one_id);
1248                 end_expression(OP_SELECT);
1249         }
1250         else
1251         {
1252                 if(constant_expression)
1253                         throw internal_error("unconverted constant");
1254
1255                 // Scalar or vector conversion between types of equal size.
1256                 Opcode opcode;
1257                 if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT)
1258                         opcode = (elem.sign ? OP_CONVERT_F_TO_S : OP_CONVERT_F_TO_U);
1259                 else if(elem.kind==BasicTypeDeclaration::FLOAT && elem_arg0.kind==BasicTypeDeclaration::INT)
1260                         opcode = (elem_arg0.sign ? OP_CONVERT_S_TO_F : OP_CONVERT_U_TO_F);
1261                 else if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::INT)
1262                         opcode = OP_BITCAST;
1263                 else
1264                         throw internal_error("invalid conversion");
1265
1266                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0]);
1267         }
1268 }
1269
1270 void SpirVGenerator::visit_builtin_matrix_comp_mult(FunctionCall &call, const vector<Id> &argument_ids)
1271 {
1272         if(argument_ids.size()!=2)
1273                 throw internal_error("invalid matrixCompMult call");
1274
1275         const BasicTypeDeclaration &basic_arg0 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[0]->type);
1276         Id column_type_id = get_id(*basic_arg0.base_type);
1277         Id column_ids[8];
1278
1279         unsigned n_columns = basic_arg0.size&0xFFFF;
1280         write_deconstruct(column_type_id, argument_ids[0], column_ids, n_columns);
1281         write_deconstruct(column_type_id, argument_ids[1], column_ids+4, n_columns);
1282
1283         for(unsigned i=0; i<n_columns; ++i)
1284                 column_ids[i] = write_expression(OP_F_MUL, column_type_id, column_ids[i], column_ids[4+i]);
1285
1286         r_expression_result_id = write_construct(get_id(*call.type), column_ids, n_columns);
1287 }
1288
1289 void SpirVGenerator::visit_builtin_texture_query(FunctionCall &call, const vector<Id> &argument_ids)
1290 {
1291         if(argument_ids.size()<1)
1292                 throw internal_error("invalid texture query call");
1293
1294         Opcode opcode;
1295         if(call.name=="textureSize")
1296                 opcode = OP_IMAGE_QUERY_SIZE_LOD;
1297         else if(call.name=="textureQueryLod")
1298                 opcode = OP_IMAGE_QUERY_LOD;
1299         else if(call.name=="textureQueryLevels")
1300                 opcode = OP_IMAGE_QUERY_LEVELS;
1301         else
1302                 throw internal_error("invalid texture query call");
1303
1304         ImageTypeDeclaration &image_arg0 = dynamic_cast<ImageTypeDeclaration &>(*call.arguments[0]->type);
1305
1306         Id image_id;
1307         if(image_arg0.sampled)
1308         {
1309                 Id image_type_id = get_item(image_type_ids, get_id(image_arg0));
1310                 image_id = write_expression(OP_IMAGE, image_type_id, argument_ids[0]);
1311         }
1312         else
1313                 image_id = argument_ids[0];
1314
1315         Id result_type_id = get_id(*call.type);
1316         r_expression_result_id = begin_expression(opcode, result_type_id, argument_ids.size());
1317         writer.write(image_id);
1318         for(unsigned i=1; i<argument_ids.size(); ++i)
1319                 writer.write(argument_ids[i]);
1320         end_expression(opcode);
1321 }
1322
1323 void SpirVGenerator::visit_builtin_texture(FunctionCall &call, const vector<Id> &argument_ids)
1324 {
1325         if(argument_ids.size()<2)
1326                 throw internal_error("invalid texture sampling call");
1327
1328         bool explicit_lod = (stage->type!=Stage::FRAGMENT || call.name=="textureLod");
1329         Id lod_id = (!explicit_lod ? 0 : call.name=="textureLod" ? argument_ids.back() :
1330                 get_constant_id(get_standard_type_id(BasicTypeDeclaration::FLOAT, 1), 0.0f));
1331
1332         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1333
1334         Opcode opcode;
1335         Id result_type_id = get_id(*call.type);
1336         Id dref_id = 0;
1337         if(image.shadow)
1338         {
1339                 if(argument_ids.size()==2)
1340                 {
1341                         const BasicTypeDeclaration &basic_arg1 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[1]->type);
1342                         dref_id = begin_expression(OP_COMPOSITE_EXTRACT, get_id(*basic_arg1.base_type), 2);
1343                         writer.write(argument_ids.back());
1344                         writer.write(basic_arg1.size-1);
1345                         end_expression(OP_COMPOSITE_EXTRACT);
1346                 }
1347                 else
1348                         dref_id = argument_ids[2];
1349
1350                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD : OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD);
1351                 r_expression_result_id = begin_expression(opcode, result_type_id, 3+explicit_lod*2);
1352         }
1353         else
1354         {
1355                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_EXPLICIT_LOD : OP_IMAGE_SAMPLE_IMPLICIT_LOD);
1356                 r_expression_result_id = begin_expression(opcode, result_type_id, 2+explicit_lod*2);
1357         }
1358
1359         for(unsigned i=0; i<2; ++i)
1360                 writer.write(argument_ids[i]);
1361         if(dref_id)
1362                 writer.write(dref_id);
1363         if(explicit_lod)
1364         {
1365                 writer.write(2);  // Lod
1366                 writer.write(lod_id);
1367         }
1368
1369         end_expression(opcode);
1370 }
1371
1372 void SpirVGenerator::visit_builtin_texel_fetch(FunctionCall &call, const vector<Id> &argument_ids)
1373 {
1374         if(argument_ids.size()!=3)
1375                 throw internal_error("invalid texelFetch call");
1376
1377         r_expression_result_id = begin_expression(OP_IMAGE_FETCH, get_id(*call.type), 4);
1378         for(unsigned i=0; i<2; ++i)
1379                 writer.write(argument_ids[i]);
1380         writer.write(2);  // Lod
1381         writer.write(argument_ids.back());
1382         end_expression(OP_IMAGE_FETCH);
1383 }
1384
1385 void SpirVGenerator::visit_builtin_interpolate(FunctionCall &call, const vector<Id> &argument_ids)
1386 {
1387         if(argument_ids.size()<1)
1388                 throw internal_error("invalid interpolate call");
1389         const VariableReference *var = dynamic_cast<const VariableReference *>(call.arguments[0].get());
1390         if(!var || !var->declaration || var->declaration->interface!="in")
1391                 throw internal_error("invalid interpolate call");
1392
1393         SpirVGlslStd450Opcode opcode;
1394         if(call.name=="interpolateAtCentroid")
1395                 opcode = GLSL450_INTERPOLATE_AT_CENTROID;
1396         else if(call.name=="interpolateAtSample")
1397                 opcode = GLSL450_INTERPOLATE_AT_SAMPLE;
1398         else if(call.name=="interpolateAtOffset")
1399                 opcode = GLSL450_INTERPOLATE_AT_OFFSET;
1400         else
1401                 throw internal_error("invalid interpolate call");
1402
1403         Id ext_id = import_extension("GLSL.std.450");
1404         r_expression_result_id = begin_expression(OP_EXT_INST, get_id(*call.type));
1405         writer.write(ext_id);
1406         writer.write(opcode);
1407         writer.write(get_id(*var->declaration));
1408         for(vector<Id>::const_iterator i=argument_ids.begin(); ++i!=argument_ids.end(); )
1409                 writer.write(*i);
1410         end_expression(OP_EXT_INST);
1411 }
1412
1413 void SpirVGenerator::visit(ExpressionStatement &expr)
1414 {
1415         expr.expression->visit(*this);
1416 }
1417
1418 void SpirVGenerator::visit(InterfaceLayout &layout)
1419 {
1420         interface_layouts.push_back(&layout);
1421 }
1422
1423 bool SpirVGenerator::check_duplicate_type(TypeDeclaration &type)
1424 {
1425         for(map<Node *, Declaration>::const_iterator i=declared_ids.begin(); i!=declared_ids.end(); ++i)
1426                 if(TypeDeclaration *type2 = dynamic_cast<TypeDeclaration *>(i->first))
1427                         if(TypeComparer().apply(type, *type2))
1428                         {
1429                                 insert_unique(declared_ids, &type, i->second);
1430                                 return true;
1431                         }
1432
1433         return false;
1434 }
1435
1436 bool SpirVGenerator::check_standard_type(BasicTypeDeclaration &basic)
1437 {
1438         const BasicTypeDeclaration *elem = (basic.kind==BasicTypeDeclaration::VECTOR ?
1439                 dynamic_cast<const BasicTypeDeclaration *>(basic.base_type) : &basic);
1440         if(!elem || elem->base_type)
1441                 return false;
1442         if((elem->kind==BasicTypeDeclaration::INT || elem->kind==BasicTypeDeclaration::FLOAT) && elem->size!=32)
1443                 return false;
1444
1445         Id standard_id = get_standard_type_id(elem->kind, (basic.kind==BasicTypeDeclaration::VECTOR ? basic.size : 1), elem->sign);
1446         insert_unique(declared_ids, &basic, Declaration(standard_id, 0));
1447         writer.write_op_name(standard_id, basic.name);
1448
1449         return true;
1450 }
1451
1452 void SpirVGenerator::visit(BasicTypeDeclaration &basic)
1453 {
1454         if(check_standard_type(basic))
1455                 return;
1456         if(check_duplicate_type(basic))
1457                 return;
1458         // Alias types shouldn't exist at this point and arrays are handled elsewhere
1459         if(basic.kind==BasicTypeDeclaration::ALIAS || basic.kind==BasicTypeDeclaration::ARRAY)
1460                 return;
1461
1462         Id type_id = allocate_id(basic, 0);
1463         writer.write_op_name(type_id, basic.name);
1464
1465         switch(basic.kind)
1466         {
1467         case BasicTypeDeclaration::INT:
1468                 writer.write_op(content.globals, OP_TYPE_INT, type_id, basic.size, basic.sign);
1469                 break;
1470         case BasicTypeDeclaration::FLOAT:
1471                 writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, basic.size);
1472                 break;
1473         case BasicTypeDeclaration::VECTOR:
1474                 writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, get_id(*basic.base_type), basic.size);
1475                 break;
1476         case BasicTypeDeclaration::MATRIX:
1477                 writer.write_op(content.globals, OP_TYPE_MATRIX, type_id, get_id(*basic.base_type), basic.size&0xFFFF);
1478                 break;
1479         default:
1480                 throw internal_error("unknown basic type");
1481         }
1482 }
1483
1484 void SpirVGenerator::visit(ImageTypeDeclaration &image)
1485 {
1486         if(check_duplicate_type(image))
1487                 return;
1488
1489         Id type_id = allocate_id(image, 0);
1490
1491         Id image_id = (image.sampled ? next_id++ : type_id);
1492         writer.begin_op(content.globals, OP_TYPE_IMAGE, 9);
1493         writer.write(image_id);
1494         writer.write(get_id(*image.base_type));
1495         writer.write(image.dimensions-1);
1496         writer.write(image.shadow);
1497         writer.write(image.array);
1498         writer.write(false);  // Multisample
1499         writer.write(image.sampled ? 1 : 2);
1500         writer.write(0);  // Format (unknown)
1501         writer.end_op(OP_TYPE_IMAGE);
1502
1503         if(image.sampled)
1504         {
1505                 writer.write_op_name(type_id, image.name);
1506                 writer.write_op(content.globals, OP_TYPE_SAMPLED_IMAGE, type_id, image_id);
1507                 insert_unique(image_type_ids, type_id, image_id);
1508         }
1509
1510         if(image.dimensions==ImageTypeDeclaration::ONE)
1511                 use_capability(image.sampled ? CAP_SAMPLED_1D : CAP_IMAGE_1D);
1512         else if(image.dimensions==ImageTypeDeclaration::CUBE && image.array)
1513                 use_capability(image.sampled ? CAP_SAMPLED_CUBE_ARRAY : CAP_IMAGE_CUBE_ARRAY);
1514 }
1515
1516 void SpirVGenerator::visit(StructDeclaration &strct)
1517 {
1518         if(check_duplicate_type(strct))
1519                 return;
1520
1521         Id type_id = allocate_id(strct, 0);
1522         writer.write_op_name(type_id, strct.name);
1523
1524         if(strct.interface_block)
1525                 writer.write_op_decorate(type_id, DECO_BLOCK);
1526
1527         bool builtin = (strct.interface_block && !strct.interface_block->block_name.compare(0, 3, "gl_"));
1528         vector<Id> member_type_ids;
1529         member_type_ids.reserve(strct.members.body.size());
1530         for(NodeList<Statement>::const_iterator i=strct.members.body.begin(); i!=strct.members.body.end(); ++i)
1531         {
1532                 const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(i->get());
1533                 if(!var)
1534                         continue;
1535
1536                 unsigned index = member_type_ids.size();
1537                 member_type_ids.push_back(get_variable_type_id(*var));
1538
1539                 writer.write_op_member_name(type_id, index, var->name);
1540
1541                 if(builtin)
1542                 {
1543                         BuiltinSemantic semantic = get_builtin_semantic(var->name);
1544                         writer.write_op_member_decorate(type_id, index, DECO_BUILTIN, semantic);
1545                 }
1546                 else
1547                 {
1548                         if(var->layout)
1549                         {
1550                                 const vector<Layout::Qualifier> &qualifiers = var->layout->qualifiers;
1551                                 for(vector<Layout::Qualifier>::const_iterator j=qualifiers.begin(); j!=qualifiers.end(); ++j)
1552                                 {
1553                                         if(j->name=="offset")
1554                                                 writer.write_op_member_decorate(type_id, index, DECO_OFFSET, j->value);
1555                                         else if(j->name=="column_major")
1556                                                 writer.write_op_member_decorate(type_id, index, DECO_COL_MAJOR);
1557                                         else if(j->name=="row_major")
1558                                                 writer.write_op_member_decorate(type_id, index, DECO_ROW_MAJOR);
1559                                 }
1560                         }
1561
1562                         const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var->type_declaration);
1563                         while(basic && basic->kind==BasicTypeDeclaration::ARRAY)
1564                                 basic = dynamic_cast<const BasicTypeDeclaration *>(basic->base_type);
1565                         if(basic && basic->kind==BasicTypeDeclaration::MATRIX)
1566                         {
1567                                 unsigned stride = MemoryRequirementsCalculator().apply(*basic->base_type).stride;
1568                                 writer.write_op_member_decorate(type_id, index, DECO_MATRIX_STRIDE, stride);
1569                         }
1570                 }
1571         }
1572
1573         writer.begin_op(content.globals, OP_TYPE_STRUCT);
1574         writer.write(type_id);
1575         for(vector<Id>::const_iterator i=member_type_ids.begin(); i!=member_type_ids.end(); ++i)
1576                 writer.write(*i);
1577         writer.end_op(OP_TYPE_STRUCT);
1578 }
1579
1580 void SpirVGenerator::visit(VariableDeclaration &var)
1581 {
1582         const vector<Layout::Qualifier> *layout_ql = (var.layout ? &var.layout->qualifiers : 0);
1583
1584         int spec_id = -1;
1585         if(layout_ql)
1586         {
1587                 for(vector<Layout::Qualifier>::const_iterator i=layout_ql->begin(); (spec_id<0 && i!=layout_ql->end()); ++i)
1588                         if(i->name=="constant_id")
1589                                 spec_id = i->value;
1590         }
1591
1592         Id type_id = get_variable_type_id(var);
1593         Id var_id;
1594
1595         if(var.constant)
1596         {
1597                 if(!var.init_expression)
1598                         throw internal_error("const variable without initializer");
1599
1600                 SetFlag set_const(constant_expression);
1601                 SetFlag set_spec(spec_constant, spec_id>=0);
1602                 r_expression_result_id = 0;
1603                 var.init_expression->visit(*this);
1604                 var_id = r_expression_result_id;
1605                 insert_unique(declared_ids, &var, Declaration(var_id, type_id));
1606                 writer.write_op_decorate(var_id, DECO_SPEC_ID, spec_id);
1607
1608                 /* It's unclear what should be done if a specialization constant is
1609                 initialized with anything other than a literal.  GLSL doesn't seem to
1610                 prohibit that but SPIR-V says OpSpecConstantOp can't be updated via
1611                 specialization. */
1612         }
1613         else
1614         {
1615                 StorageClass storage = (current_function ? STORAGE_FUNCTION : get_interface_storage(var.interface, false));
1616                 Id ptr_type_id = get_pointer_type_id(type_id, storage);
1617                 if(var.interface=="uniform")
1618                 {
1619                         Id &uni_id = declared_uniform_ids["v"+var.name];
1620                         if(uni_id)
1621                         {
1622                                 insert_unique(declared_ids, &var, Declaration(uni_id, ptr_type_id));
1623                                 return;
1624                         }
1625
1626                         uni_id = var_id = allocate_id(var, ptr_type_id);
1627                 }
1628                 else
1629                         var_id = allocate_id(var, (var.constant ? type_id : ptr_type_id));
1630
1631                 Id init_id = 0;
1632                 if(var.init_expression)
1633                 {
1634                         SetFlag set_const(constant_expression, !current_function);
1635                         r_expression_result_id = 0;
1636                         r_constant_result = false;
1637                         var.init_expression->visit(*this);
1638                         init_id = r_expression_result_id;
1639                 }
1640
1641                 vector<Word> &target = (current_function ? content.locals : content.globals);
1642                 writer.begin_op(target, OP_VARIABLE, 4+(init_id && !current_function));
1643                 writer.write(ptr_type_id);
1644                 writer.write(var_id);
1645                 writer.write(storage);
1646                 if(init_id && !current_function)
1647                         writer.write(init_id);
1648                 writer.end_op(OP_VARIABLE);
1649
1650                 if(layout_ql)
1651                 {
1652                         for(vector<Layout::Qualifier>::const_iterator i=layout_ql->begin(); i!=layout_ql->end(); ++i)
1653                         {
1654                                 if(i->name=="location")
1655                                         writer.write_op_decorate(var_id, DECO_LOCATION, i->value);
1656                                 else if(i->name=="set")
1657                                         writer.write_op_decorate(var_id, DECO_DESCRIPTOR_SET, i->value);
1658                                 else if(i->name=="binding")
1659                                         writer.write_op_decorate(var_id, DECO_BINDING, i->value);
1660                         }
1661                 }
1662
1663                 if(init_id && current_function)
1664                 {
1665                         writer.write_op(content.function_body, OP_STORE, var_id, init_id);
1666                         variable_load_ids[&var] = init_id;
1667                 }
1668         }
1669
1670         writer.write_op_name(var_id, var.name);
1671 }
1672
1673 void SpirVGenerator::visit(InterfaceBlock &iface)
1674 {
1675         StorageClass storage = get_interface_storage(iface.interface, true);
1676         Id type_id;
1677         if(iface.array)
1678                 type_id = get_array_type_id(*iface.struct_declaration, 0);
1679         else
1680                 type_id = get_id(*iface.struct_declaration);
1681         Id ptr_type_id = get_pointer_type_id(type_id, storage);
1682
1683         Id block_id;
1684         if(iface.interface=="uniform")
1685         {
1686                 Id &uni_id = declared_uniform_ids["b"+iface.block_name];
1687                 if(uni_id)
1688                 {
1689                         insert_unique(declared_ids, &iface, Declaration(uni_id, ptr_type_id));
1690                         return;
1691                 }
1692
1693                 uni_id = block_id = allocate_id(iface, ptr_type_id);
1694         }
1695         else
1696                 block_id = allocate_id(iface, ptr_type_id);
1697         writer.write_op_name(block_id, iface.instance_name);
1698
1699         writer.write_op(content.globals, OP_VARIABLE, ptr_type_id, block_id, storage);
1700
1701         if(iface.layout)
1702         {
1703                 const vector<Layout::Qualifier> &qualifiers = iface.layout->qualifiers;
1704                 for(vector<Layout::Qualifier>::const_iterator i=qualifiers.begin(); i!=qualifiers.end(); ++i)
1705                         if(i->name=="binding")
1706                                 writer.write_op_decorate(block_id, DECO_BINDING, i->value);
1707         }
1708 }
1709
1710 void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id)
1711 {
1712         writer.begin_op(content.entry_points, OP_ENTRY_POINT);
1713         switch(stage->type)
1714         {
1715         case Stage::VERTEX: writer.write(0); break;
1716         case Stage::GEOMETRY: writer.write(3); break;
1717         case Stage::FRAGMENT: writer.write(4); break;
1718         default: throw internal_error("unknown stage");
1719         }
1720         writer.write(func_id);
1721         writer.write_string(func.name);
1722
1723         set<Node *> dependencies = DependencyCollector().apply(func);
1724         for(set<Node *>::const_iterator i=dependencies.begin(); i!=dependencies.end(); ++i)
1725         {
1726                 if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(*i))
1727                 {
1728                         if(!var->interface.empty())
1729                                 writer.write(get_id(**i));
1730                 }
1731                 else if(dynamic_cast<InterfaceBlock *>(*i))
1732                         writer.write(get_id(**i));
1733         }
1734
1735         writer.end_op(OP_ENTRY_POINT);
1736
1737         if(stage->type==Stage::FRAGMENT)
1738                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_ORIGIN_LOWER_LEFT);
1739         else if(stage->type==Stage::GEOMETRY)
1740                 use_capability(CAP_GEOMETRY);
1741
1742         for(vector<const InterfaceLayout *>::const_iterator i=interface_layouts.begin(); i!=interface_layouts.end(); ++i)
1743         {
1744                 const vector<Layout::Qualifier> &qualifiers = (*i)->layout.qualifiers;
1745                 for(vector<Layout::Qualifier>::const_iterator j=qualifiers.begin(); j!=qualifiers.end(); ++j)
1746                 {
1747                         if(j->name=="point")
1748                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id,
1749                                         ((*i)->interface=="in" ? EXEC_INPUT_POINTS : EXEC_OUTPUT_POINTS));
1750                         else if(j->name=="lines")
1751                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES);
1752                         else if(j->name=="lines_adjacency")
1753                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES_ADJACENCY);
1754                         else if(j->name=="triangles")
1755                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_TRIANGLES);
1756                         else if(j->name=="triangles_adjacency")
1757                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_TRIANGLES_ADJACENCY);
1758                         else if(j->name=="line_strip")
1759                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_LINE_STRIP);
1760                         else if(j->name=="triangle_strip")
1761                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_TRIANGLE_STRIP);
1762                         else if(j->name=="max_vertices")
1763                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_VERTICES, j->value);
1764                 }
1765         }
1766 }
1767
1768 void SpirVGenerator::visit(FunctionDeclaration &func)
1769 {
1770         if(func.source==BUILTIN_SOURCE)
1771                 return;
1772         else if(func.definition!=&func)
1773         {
1774                 if(func.definition)
1775                         allocate_forward_id(*func.definition);
1776                 return;
1777         }
1778
1779         Id return_type_id = get_id(*func.return_type_declaration);
1780         vector<unsigned> param_type_ids;
1781         param_type_ids.reserve(func.parameters.size());
1782         for(NodeArray<VariableDeclaration>::const_iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
1783                 param_type_ids.push_back(get_variable_type_id(**i));
1784
1785         string sig_with_return = func.return_type+func.signature;
1786         Id &type_id = function_type_ids[sig_with_return];
1787         if(!type_id)
1788         {
1789                 type_id = next_id++;
1790                 writer.begin_op(content.globals, OP_TYPE_FUNCTION);
1791                 writer.write(type_id);
1792                 writer.write(return_type_id);
1793                 for(vector<unsigned>::const_iterator i=param_type_ids.begin(); i!=param_type_ids.end(); ++i)
1794                         writer.write(*i);
1795                 writer.end_op(OP_TYPE_FUNCTION);
1796
1797                 writer.write_op_name(type_id, sig_with_return);
1798         }
1799
1800         Id func_id = allocate_id(func, type_id);
1801         writer.write_op_name(func_id, func.name+func.signature);
1802
1803         if(func.name=="main")
1804                 visit_entry_point(func, func_id);
1805
1806         writer.begin_op(content.functions, OP_FUNCTION, 5);
1807         writer.write(return_type_id);
1808         writer.write(func_id);
1809         writer.write(0);  // Function control flags (none)
1810         writer.write(type_id);
1811         writer.end_op(OP_FUNCTION);
1812
1813         for(unsigned i=0; i<func.parameters.size(); ++i)
1814         {
1815                 Id param_id = allocate_id(*func.parameters[i], param_type_ids[i]);
1816                 writer.write_op(content.functions, OP_FUNCTION_PARAMETER, param_type_ids[i], param_id);
1817                 // TODO This is probably incorrect if the parameter is assigned to.
1818                 variable_load_ids[func.parameters[i].get()] = param_id;
1819         }
1820
1821         writer.begin_function_body(next_id++);
1822         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
1823         func.body.visit(*this);
1824
1825         if(writer.has_current_block())
1826         {
1827                 if(!reachable)
1828                         writer.write_op(content.function_body, OP_UNREACHABLE);
1829                 else
1830                 {
1831                         const BasicTypeDeclaration *basic_return = dynamic_cast<const BasicTypeDeclaration *>(func.return_type_declaration);
1832                         if(basic_return && basic_return->kind==BasicTypeDeclaration::VOID)
1833                                 writer.write_op(content.function_body, OP_RETURN);
1834                         else
1835                                 throw internal_error("missing return in non-void function");
1836                 }
1837         }
1838         writer.end_function_body();
1839         variable_load_ids.clear();
1840 }
1841
1842 void SpirVGenerator::visit(Conditional &cond)
1843 {
1844         cond.condition->visit(*this);
1845
1846         Id true_label_id = next_id++;
1847         Id merge_block_id = next_id++;
1848         Id false_label_id = (cond.else_body.body.empty() ? merge_block_id : next_id++);
1849         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1850         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, true_label_id, false_label_id);
1851
1852         writer.write_op_label(true_label_id);
1853         cond.body.visit(*this);
1854         if(writer.has_current_block())
1855                 writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1856
1857         bool reachable_if_true = reachable;
1858
1859         reachable = true;
1860         if(!cond.else_body.body.empty())
1861         {
1862                 writer.write_op_label(false_label_id);
1863                 cond.else_body.visit(*this);
1864                 reachable |= reachable_if_true;
1865         }
1866
1867         writer.write_op_label(merge_block_id);
1868         prune_loads(true_label_id);
1869 }
1870
1871 void SpirVGenerator::visit(Iteration &iter)
1872 {
1873         if(iter.init_statement)
1874                 iter.init_statement->visit(*this);
1875
1876         variable_load_ids.clear();
1877
1878         Id header_id = next_id++;
1879         Id continue_id = next_id++;
1880         Id merge_block_id = next_id++;
1881
1882         SetForScope<Id> set_merge(loop_merge_block_id, merge_block_id);
1883         SetForScope<Id> set_continue(loop_continue_target_id, continue_id);
1884
1885         writer.write_op_label(header_id);
1886         writer.write_op(content.function_body, OP_LOOP_MERGE, merge_block_id, continue_id, 0);  // Loop control (none)
1887
1888         Id body_id = next_id++;
1889         if(iter.condition)
1890         {
1891                 writer.write_op_label(next_id++);
1892                 iter.condition->visit(*this);
1893                 writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, body_id, merge_block_id);
1894         }
1895
1896         writer.write_op_label(body_id);
1897         iter.body.visit(*this);
1898
1899         writer.write_op_label(continue_id);
1900         if(iter.loop_expression)
1901                 iter.loop_expression->visit(*this);
1902         writer.write_op(content.function_body, OP_BRANCH, header_id);
1903
1904         writer.write_op_label(merge_block_id);
1905         prune_loads(header_id);
1906         reachable = true;
1907 }
1908
1909 void SpirVGenerator::visit(Return &ret)
1910 {
1911         if(ret.expression)
1912         {
1913                 ret.expression->visit(*this);
1914                 writer.write_op(content.function_body, OP_RETURN_VALUE, r_expression_result_id);
1915         }
1916         else
1917                 writer.write_op(content.function_body, OP_RETURN);
1918         reachable = false;
1919 }
1920
1921 void SpirVGenerator::visit(Jump &jump)
1922 {
1923         if(jump.keyword=="discard")
1924                 writer.write_op(content.function_body, OP_KILL);
1925         else if(jump.keyword=="break")
1926                 writer.write_op(content.function_body, OP_BRANCH, loop_merge_block_id);
1927         else if(jump.keyword=="continue")
1928                 writer.write_op(content.function_body, OP_BRANCH, loop_continue_target_id);
1929         else
1930                 throw internal_error("unknown jump");
1931         reachable = false;
1932 }
1933
1934
1935 SpirVGenerator::TypeKey::TypeKey(BasicTypeDeclaration::Kind kind, bool sign):
1936         type_id(0)
1937 {
1938         switch(kind)
1939         {
1940         case BasicTypeDeclaration::VOID: detail = 'v'; break;
1941         case BasicTypeDeclaration::BOOL: detail = 'b'; break;
1942         case BasicTypeDeclaration::INT: detail = (sign ? 'i' : 'u'); break;
1943         case BasicTypeDeclaration::FLOAT: detail = 'f'; break;
1944         default: throw invalid_argument("TypeKey::TypeKey");
1945         }
1946 }
1947
1948 bool SpirVGenerator::TypeKey::operator<(const TypeKey &other) const
1949 {
1950         if(type_id!=other.type_id)
1951                 return type_id<other.type_id;
1952         return detail<other.detail;
1953 }
1954
1955
1956 bool SpirVGenerator::ConstantKey::operator<(const ConstantKey &other) const
1957 {
1958         if(type_id!=other.type_id)
1959                 return type_id<other.type_id;
1960         return int_value<other.int_value;
1961 }
1962
1963 } // namespace SL
1964 } // namespace GL
1965 } // namespace Msp