]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/spirv.cpp
Add support for uint types in GLSL
[libs/gl.git] / source / glsl / spirv.cpp
1 #include <msp/core/maputils.h>
2 #include <msp/core/raii.h>
3 #include "reflect.h"
4 #include "spirv.h"
5
6 using namespace std;
7
8 namespace Msp {
9 namespace GL {
10 namespace SL {
11
12 const SpirVGenerator::BuiltinFunctionInfo SpirVGenerator::builtin_functions[] =
13 {
14         { "radians", "f", "GLSL.std.450", GLSL450_RADIANS, { 1 }, 0 },
15         { "degrees", "f", "GLSL.std.450", GLSL450_DEGREES, { 1 }, 0 },
16         { "sin", "f", "GLSL.std.450", GLSL450_SIN, { 1 }, 0 },
17         { "cos", "f", "GLSL.std.450", GLSL450_COS, { 1 }, 0 },
18         { "tan", "f", "GLSL.std.450", GLSL450_TAN, { 1 }, 0 },
19         { "asin", "f", "GLSL.std.450", GLSL450_ASIN, { 1 }, 0 },
20         { "acos", "f", "GLSL.std.450", GLSL450_ACOS, { 1 }, 0 },
21         { "atan", "f", "GLSL.std.450", GLSL450_ATAN, { 1 }, 0 },
22         { "atan", "ff", "GLSL.std.450", GLSL450_ATAN2, { 1, 2 }, 0 },
23         { "sinh", "f", "GLSL.std.450", GLSL450_SINH, { 1 }, 0 },
24         { "cosh", "f", "GLSL.std.450", GLSL450_COSH, { 1 }, 0 },
25         { "tanh", "f", "GLSL.std.450", GLSL450_TANH, { 1 }, 0 },
26         { "asinh", "f", "GLSL.std.450", GLSL450_ASINH, { 1 }, 0 },
27         { "acosh", "f", "GLSL.std.450", GLSL450_ACOSH, { 1 }, 0 },
28         { "atanh", "f", "GLSL.std.450", GLSL450_ATANH, { 1 }, 0 },
29         { "pow", "ff", "GLSL.std.450", GLSL450_POW, { 1, 2 }, 0 },
30         { "exp", "f", "GLSL.std.450", GLSL450_EXP, { 1 }, 0 },
31         { "log", "f", "GLSL.std.450", GLSL450_LOG, { 1 }, 0 },
32         { "exp2", "f", "GLSL.std.450", GLSL450_EXP2, { 1 }, 0 },
33         { "log2", "f", "GLSL.std.450", GLSL450_LOG2, { 1 }, 0 },
34         { "sqrt", "f", "GLSL.std.450", GLSL450_SQRT, { 1 }, 0 },
35         { "inversesqrt", "f", "GLSL.std.450", GLSL450_INVERSE_SQRT, { 1 }, 0 },
36         { "abs", "f", "GLSL.std.450", GLSL450_F_ABS, { 1 }, 0 },
37         { "abs", "i", "GLSL.std.450", GLSL450_S_ABS, { 1 }, 0 },
38         { "sign", "f", "GLSL.std.450", GLSL450_F_SIGN, { 1 }, 0 },
39         { "sign", "i", "GLSL.std.450", GLSL450_S_SIGN, { 1 }, 0 },
40         { "floor", "f", "GLSL.std.450", GLSL450_FLOOR, { 1 }, 0 },
41         { "trunc", "f", "GLSL.std.450", GLSL450_TRUNC, { 1 }, 0 },
42         { "round", "f", "GLSL.std.450", GLSL450_ROUND, { 1 }, 0 },
43         { "roundEven", "f", "GLSL.std.450", GLSL450_ROUND_EVEN, { 1 }, 0 },
44         { "ceil", "f", "GLSL.std.450", GLSL450_CEIL, { 1 }, 0 },
45         { "fract", "f", "GLSL.std.450", GLSL450_FRACT, { 1 }, 0 },
46         { "mod", "f", "", OP_F_MOD, { 1, 2 }, 0 },
47         { "min", "ff", "GLSL.std.450", GLSL450_F_MIN, { 1, 2 }, 0 },
48         { "min", "ii", "GLSL.std.450", GLSL450_S_MIN, { 1, 2 }, 0 },
49         { "min", "uu", "GLSL.std.450", GLSL450_U_MIN, { 1, 2 }, 0 },
50         { "max", "ff", "GLSL.std.450", GLSL450_F_MAX, { 1, 2 }, 0 },
51         { "max", "ii", "GLSL.std.450", GLSL450_S_MAX, { 1, 2 }, 0 },
52         { "max", "uu", "GLSL.std.450", GLSL450_U_MAX, { 1, 2 }, 0 },
53         { "clamp", "fff", "GLSL.std.450", GLSL450_F_CLAMP, { 1, 2, 3 }, 0 },
54         { "clamp", "iii", "GLSL.std.450", GLSL450_S_CLAMP, { 1, 2, 3 }, 0 },
55         { "clamp", "uuu", "GLSL.std.450", GLSL450_U_CLAMP, { 1, 2, 3 }, 0 },
56         { "mix", "fff", "GLSL.std.450", GLSL450_F_MIX, { 1, 2, 3 }, 0 },
57         { "mix", "ffb", "", OP_SELECT, { 3, 2, 1 }, 0 },
58         { "mix", "iib", "", OP_SELECT, { 3, 2, 1 }, 0 },
59         { "mix", "uub", "", OP_SELECT, { 3, 2, 1 }, 0 },
60         { "step", "ff", "GLSL.std.450", GLSL450_F_STEP, { 1, 2 }, 0 },
61         { "smoothstep", "fff", "GLSL.std.450", GLSL450_F_SMOOTH_STEP, { 1, 2, 3 }, 0 },
62         { "isnan", "f", "", OP_IS_NAN, { 1 }, 0 },
63         { "isinf", "f", "", OP_IS_INF, { 1 }, 0 },
64         { "fma", "fff", "GLSL.std.450", GLSL450_F_FMA, { 1, 2, 3 }, 0 },
65         { "length", "f", "GLSL.std.450", GLSL450_LENGTH, { 1 }, 0 },
66         { "distance", "ff", "GLSL.std.450", GLSL450_DISTANCE, { 1, 2 }, 0 },
67         { "dot", "ff", "", OP_DOT, { 1, 2 }, 0 },
68         { "cross", "ff", "GLSL.std.450", GLSL450_CROSS, { 1, 2 }, 0 },
69         { "normalize", "f", "GLSL.std.450", GLSL450_NORMALIZE, { 1 }, 0 },
70         { "faceforward", "fff", "GLSL.std.450", GLSL450_FACE_FORWARD, { 1, 2, 3 }, 0 },
71         { "reflect", "ff", "GLSL.std.450", GLSL450_REFLECT, { 1, 2 }, 0 },
72         { "refract", "fff", "GLSL.std.450", GLSL450_REFRACT, { 1, 2, 3 }, 0 },
73         { "matrixCompMult", "ff", "", 0, { 0 }, &SpirVGenerator::visit_builtin_matrix_comp_mult },
74         { "outerProduct", "ff", "", OP_OUTER_PRODUCT, { 1, 2 }, 0 },
75         { "transpose", "f", "", OP_TRANSPOSE, { 1 }, 0 },
76         { "determinant", "f", "GLSL.std.450", GLSL450_DETERMINANT, { 1 }, 0 },
77         { "inverse", "f", "GLSL.std.450", GLSL450_MATRIX_INVERSE, { 1 }, 0 },
78         { "lessThan", "ff", "", OP_F_ORD_LESS_THAN, { 1, 2 }, 0 },
79         { "lessThan", "ii", "", OP_S_LESS_THAN, { 1, 2 }, 0 },
80         { "lessThan", "uu", "", OP_U_LESS_THAN, { 1, 2 }, 0 },
81         { "lessThanEqual", "ff", "", OP_F_ORD_LESS_THAN_EQUAL, { 1, 2 }, 0 },
82         { "lessThanEqual", "ii", "", OP_S_LESS_THAN_EQUAL, { 1, 2 }, 0 },
83         { "lessThanEqual", "uu", "", OP_U_LESS_THAN_EQUAL, { 1, 2 }, 0 },
84         { "greaterThan", "ff", "", OP_F_ORD_GREATER_THAN, { 1, 2 }, 0 },
85         { "greaterThan", "ii", "", OP_S_GREATER_THAN, { 1, 2 }, 0 },
86         { "greaterThan", "uu", "", OP_U_GREATER_THAN, { 1, 2 }, 0 },
87         { "greaterThanEqual", "ff", "", OP_F_ORD_GREATER_THAN_EQUAL, { 1, 2 }, 0 },
88         { "greaterThanEqual", "ii", "", OP_S_GREATER_THAN_EQUAL, { 1, 2 }, 0 },
89         { "greaterThanEqual", "uu", "", OP_U_GREATER_THAN_EQUAL, { 1, 2 }, 0 },
90         { "equal", "ff", "", OP_F_ORD_EQUAL, { 1, 2 }, 0 },
91         { "equal", "ii", "", OP_I_EQUAL, { 1, 2 }, 0 },
92         { "equal", "uu", "", OP_I_EQUAL, { 1, 2 }, 0 },
93         { "notEqual", "ff", "", OP_F_ORD_NOT_EQUAL, { 1, 2 }, 0 },
94         { "notEqual", "ii", "", OP_I_NOT_EQUAL, { 1, 2 }, 0 },
95         { "notEqual", "uu", "", OP_I_NOT_EQUAL, { 1, 2 }, 0 },
96         { "any", "b", "", OP_ANY, { 1 }, 0 },
97         { "all", "b", "", OP_ALL, { 1 }, 0 },
98         { "not", "b", "", OP_LOGICAL_NOT, { 1 }, 0 },
99         { "bitfieldExtract", "iii", "", OP_BIT_FIELD_S_EXTRACT, { 1, 2, 3 }, 0 },
100         { "bitfieldExtract", "uii", "", OP_BIT_FIELD_U_EXTRACT, { 1, 2, 3 }, 0 },
101         { "bitfieldInsert", "iiii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0 },
102         { "bitfieldInsert", "uuii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0 },
103         { "bitfieldReverse", "i", "", OP_BIT_REVERSE, { 1 }, 0 },
104         { "bitfieldReverse", "u", "", OP_BIT_REVERSE, { 1 }, 0 },
105         { "bitCount", "i", "", OP_BIT_COUNT, { 1 }, 0 },
106         { "findLSB", "i", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0 },
107         { "findLSB", "u", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0 },
108         { "findMSB", "i", "GLSL.std.450", GLSL450_FIND_S_MSB, { 1 }, 0 },
109         { "findMSB", "u", "GLSL.std.450", GLSL450_FIND_U_MSB, { 1 }, 0 },
110         { "textureSize", "", "", OP_IMAGE_QUERY_SIZE_LOD, { 1, 2 }, 0 },
111         { "texture", "", "", 0, { }, &SpirVGenerator::visit_builtin_texture },
112         { "textureLod", "", "", 0, { }, &SpirVGenerator::visit_builtin_texture },
113         { "texelFetch", "", "", 0, { }, &SpirVGenerator::visit_builtin_texel_fetch },
114         { "EmitVertex", "", "", OP_EMIT_VERTEX, { }, 0 },
115         { "EndPrimitive", "", "", OP_END_PRIMITIVE, { }, 0 },
116         { "dFdx", "f", "", OP_DP_DX, { 1 }, 0 },
117         { "dFdy", "f", "", OP_DP_DY, { 1 }, 0 },
118         { "dFdxFine", "f", "", OP_DP_DX_FINE, { 1 }, 0 },
119         { "dFdyFine", "f", "", OP_DP_DY_FINE, { 1 }, 0 },
120         { "dFdxCoarse", "f", "", OP_DP_DX_COARSE, { 1 }, 0 },
121         { "dFdyCoarse", "f", "", OP_DP_DY_COARSE, { 1 }, 0 },
122         { "fwidth", "f", "", OP_FWIDTH, { 1 }, 0 },
123         { "fwidthFine", "f", "", OP_FWIDTH_FINE, { 1 }, 0 },
124         { "fwidthCoarse", "f", "", OP_FWIDTH_COARSE, { 1 }, 0 },
125         { "interpolateAtCentroid", "", "", 0, { }, &SpirVGenerator::visit_builtin_interpolate },
126         { "interpolateAtSample", "", "", 0, { }, &SpirVGenerator::visit_builtin_interpolate },
127         { "interpolateAtOffset", "", "", 0, { }, &SpirVGenerator::visit_builtin_interpolate },
128         { "", "", "", 0, { }, 0 }
129 };
130
131 SpirVGenerator::SpirVGenerator():
132         stage(0),
133         current_function(0),
134         writer(content),
135         next_id(1),
136         r_expression_result_id(0),
137         constant_expression(false),
138         spec_constant(false),
139         reachable(false),
140         composite_access(false),
141         r_composite_base_id(0),
142         r_composite_base(0),
143         assignment_source_id(0),
144         loop_merge_block_id(0),
145         loop_continue_target_id(0)
146 { }
147
148 void SpirVGenerator::apply(Module &module)
149 {
150         use_capability(CAP_SHADER);
151
152         for(list<Stage>::iterator i=module.stages.begin(); i!=module.stages.end(); ++i)
153         {
154                 stage = &*i;
155                 interface_layouts.clear();
156                 i->content.visit(*this);
157         }
158
159         writer.finalize(SPIRV_GENERATOR_MSP, next_id);
160 }
161
162 SpirVGenerator::StorageClass SpirVGenerator::get_interface_storage(const string &iface, bool block)
163 {
164         if(iface=="in")
165                 return STORAGE_INPUT;
166         else if(iface=="out")
167                 return STORAGE_OUTPUT;
168         else if(iface=="uniform")
169                 return (block ? STORAGE_UNIFORM : STORAGE_UNIFORM_CONSTANT);
170         else if(iface.empty())
171                 return STORAGE_PRIVATE;
172         else
173                 throw invalid_argument("SpirVGenerator::get_interface_storage");
174 }
175
176 SpirVGenerator::BuiltinSemantic SpirVGenerator::get_builtin_semantic(const string &name)
177 {
178         if(name=="gl_Position")
179                 return BUILTIN_POSITION;
180         else if(name=="gl_PointSize")
181                 return BUILTIN_POINT_SIZE;
182         else if(name=="gl_ClipDistance")
183                 return BUILTIN_CLIP_DISTANCE;
184         else if(name=="gl_VertexID")
185                 return BUILTIN_VERTEX_ID;
186         else if(name=="gl_InstanceID")
187                 return BUILTIN_INSTANCE_ID;
188         else if(name=="gl_PrimitiveID" || name=="gl_PrimitiveIDIn")
189                 return BUILTIN_PRIMITIVE_ID;
190         else if(name=="gl_InvocationID")
191                 return BUILTIN_INVOCATION_ID;
192         else if(name=="gl_Layer")
193                 return BUILTIN_LAYER;
194         else if(name=="gl_FragCoord")
195                 return BUILTIN_FRAG_COORD;
196         else if(name=="gl_PointCoord")
197                 return BUILTIN_POINT_COORD;
198         else if(name=="gl_FrontFacing")
199                 return BUILTIN_FRONT_FACING;
200         else if(name=="gl_SampleId")
201                 return BUILTIN_SAMPLE_ID;
202         else if(name=="gl_SamplePosition")
203                 return BUILTIN_SAMPLE_POSITION;
204         else if(name=="gl_FragDepth")
205                 return BUILTIN_FRAG_DEPTH;
206         else
207                 throw invalid_argument("SpirVGenerator::get_builtin_semantic");
208 }
209
210 void SpirVGenerator::use_capability(Capability cap)
211 {
212         if(used_capabilities.count(cap))
213                 return;
214
215         used_capabilities.insert(cap);
216         writer.write_op(content.capabilities, OP_CAPABILITY, cap);
217 }
218
219 SpirVGenerator::Id SpirVGenerator::import_extension(const string &name)
220 {
221         Id &ext_id = imported_extension_ids[name];
222         if(!ext_id)
223         {
224                 ext_id = next_id++;
225                 writer.begin_op(content.extensions, OP_EXT_INST_IMPORT);
226                 writer.write(ext_id);
227                 writer.write_string(name);
228                 writer.end_op(OP_EXT_INST_IMPORT);
229         }
230         return ext_id;
231 }
232
233 SpirVGenerator::Id SpirVGenerator::get_id(Node &node) const
234 {
235         return get_item(declared_ids, &node).id;
236 }
237
238 SpirVGenerator::Id SpirVGenerator::allocate_id(Node &node, Id type_id)
239 {
240         Id id = next_id++;
241         insert_unique(declared_ids, &node, Declaration(id, type_id));
242         return id;
243 }
244
245 SpirVGenerator::Id SpirVGenerator::write_constant(Id type_id, Word value, bool spec)
246 {
247         Id const_id = next_id++;
248         if(is_scalar_type(type_id, BasicTypeDeclaration::BOOL))
249         {
250                 Opcode opcode = (value ? (spec ? OP_SPEC_CONSTANT_TRUE : OP_CONSTANT_TRUE) :
251                         (spec ? OP_SPEC_CONSTANT_FALSE : OP_CONSTANT_FALSE));
252                 writer.write_op(content.globals, opcode, type_id, const_id);
253         }
254         else
255         {
256                 Opcode opcode = (spec ? OP_SPEC_CONSTANT : OP_CONSTANT);
257                 writer.write_op(content.globals, opcode, type_id, const_id, value);
258         }
259         return const_id;
260 }
261
262 SpirVGenerator::ConstantKey SpirVGenerator::get_constant_key(Id type_id, const Variant &value)
263 {
264         if(value.check_type<bool>())
265                 return ConstantKey(type_id, value.value<bool>());
266         else if(value.check_type<int>())
267                 return ConstantKey(type_id, value.value<int>());
268         else if(value.check_type<unsigned>())
269                 return ConstantKey(type_id, value.value<unsigned>());
270         else if(value.check_type<float>())
271                 return ConstantKey(type_id, value.value<float>());
272         else
273                 throw invalid_argument("SpirVGenerator::get_constant_key");
274 }
275
276 SpirVGenerator::Id SpirVGenerator::get_constant_id(Id type_id, const Variant &value)
277 {
278         ConstantKey key = get_constant_key(type_id, value);
279         Id &const_id = constant_ids[key];
280         if(!const_id)
281                 const_id = write_constant(type_id, key.int_value, false);
282         return const_id;
283 }
284
285 SpirVGenerator::Id SpirVGenerator::get_vector_constant_id(Id type_id, unsigned size, Id scalar_id)
286 {
287         Id &const_id = constant_ids[get_constant_key(type_id, static_cast<int>(scalar_id))];
288         if(!const_id)
289         {
290                 const_id = next_id++;
291                 writer.begin_op(content.globals, OP_CONSTANT_COMPOSITE, 3+size);
292                 writer.write(type_id);
293                 writer.write(const_id);
294                 for(unsigned i=0; i<size; ++i)
295                         writer.write(scalar_id);
296                 writer.end_op(OP_CONSTANT_COMPOSITE);
297         }
298         return const_id;
299 }
300
301 SpirVGenerator::Id SpirVGenerator::get_standard_type_id(BasicTypeDeclaration::Kind kind, unsigned size, bool sign)
302 {
303         Id base_id = (size>1 ? get_standard_type_id(kind, 1, sign) : 0);
304         Id &type_id = standard_type_ids[base_id ? TypeKey(base_id, size) : TypeKey(kind, sign)];
305         if(!type_id)
306         {
307                 type_id = next_id++;
308                 if(size>1)
309                         writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, base_id, size);
310                 else if(kind==BasicTypeDeclaration::VOID)
311                         writer.write_op(content.globals, OP_TYPE_VOID, type_id);
312                 else if(kind==BasicTypeDeclaration::BOOL)
313                         writer.write_op(content.globals, OP_TYPE_BOOL, type_id);
314                 else if(kind==BasicTypeDeclaration::INT)
315                         writer.write_op(content.globals, OP_TYPE_INT, type_id, 32, sign);
316                 else if(kind==BasicTypeDeclaration::FLOAT)
317                         writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, 32);
318                 else
319                         throw invalid_argument("SpirVGenerator::get_standard_type_id");
320         }
321         return type_id;
322 }
323
324 bool SpirVGenerator::is_scalar_type(Id type_id, BasicTypeDeclaration::Kind kind) const
325 {
326         map<TypeKey, Id>::const_iterator i = standard_type_ids.find(TypeKey(kind, true));
327         return (i!=standard_type_ids.end() && i->second==type_id);
328 }
329
330 SpirVGenerator::Id SpirVGenerator::get_array_type_id(TypeDeclaration &base_type, Id size_id)
331 {
332         Id base_type_id = get_id(base_type);
333         Id &array_type_id = array_type_ids[TypeKey(base_type_id, size_id)];
334         if(!array_type_id)
335         {
336                 array_type_id = next_id++;
337                 if(size_id)
338                         writer.write_op(content.globals, OP_TYPE_ARRAY, array_type_id, base_type_id, size_id);
339                 else
340                         writer.write_op(content.globals, OP_TYPE_RUNTIME_ARRAY, array_type_id, base_type_id);
341
342                 unsigned stride = MemoryRequirementsCalculator().apply(base_type).stride;
343                 writer.write_op_decorate(array_type_id, DECO_ARRAY_STRIDE, stride);
344         }
345
346         return array_type_id;
347 }
348
349 SpirVGenerator::Id SpirVGenerator::get_pointer_type_id(Id type_id, StorageClass storage)
350 {
351         Id &ptr_type_id = pointer_type_ids[TypeKey(type_id, storage)];
352         if(!ptr_type_id)
353         {
354                 ptr_type_id = next_id++;
355                 writer.write_op(content.globals, OP_TYPE_POINTER, ptr_type_id, storage, type_id);
356         }
357         return ptr_type_id;
358 }
359
360 SpirVGenerator::Id SpirVGenerator::get_variable_type_id(const VariableDeclaration &var)
361 {
362         if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var.type_declaration))
363                 if(basic->kind==BasicTypeDeclaration::ARRAY)
364                 {
365                         Id size_id = 0;
366                         if(var.array_size)
367                         {
368                                 SetFlag set_const(constant_expression);
369                                 r_expression_result_id = 0;
370                                 var.array_size->visit(*this);
371                                 size_id = r_expression_result_id;
372                         }
373                         else
374                                 size_id = get_constant_id(get_standard_type_id(BasicTypeDeclaration::INT, 1), 1);
375                         return get_array_type_id(*basic->base_type, size_id);
376                 }
377
378         return get_id(*var.type_declaration);
379 }
380
381 SpirVGenerator::Id SpirVGenerator::get_load_id(VariableDeclaration &var)
382 {
383         Id &load_result_id = variable_load_ids[&var];
384         if(!load_result_id)
385         {
386                 load_result_id = next_id++;
387                 writer.write_op(content.function_body, OP_LOAD, get_variable_type_id(var), load_result_id, get_id(var));
388         }
389         return load_result_id;
390 }
391
392 void SpirVGenerator::prune_loads(Id min_id)
393 {
394         for(map<const VariableDeclaration *, Id>::iterator i=variable_load_ids.begin(); i!=variable_load_ids.end(); )
395         {
396                 if(i->second>=min_id)
397                         variable_load_ids.erase(i++);
398                 else
399                         ++i;
400         }
401 }
402
403 SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, unsigned n_args)
404 {
405         bool has_result = (opcode==OP_FUNCTION_CALL || !is_scalar_type(type_id, BasicTypeDeclaration::VOID));
406         if(!constant_expression)
407         {
408                 if(!current_function)
409                         throw internal_error("non-constant expression outside a function");
410
411                 writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0));
412         }
413         else if(opcode==OP_COMPOSITE_CONSTRUCT)
414                 writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE),
415                         (n_args ? 1+has_result*2+n_args : 0));
416         else if(!spec_constant)
417                 throw internal_error("invalid non-specialization constant expression");
418         else
419                 writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0));
420
421         Id result_id = next_id++;
422         if(has_result)
423         {
424                 writer.write(type_id);
425                 writer.write(result_id);
426         }
427         if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT)
428                 writer.write(opcode);
429
430         return result_id;
431 }
432
433 void SpirVGenerator::end_expression(Opcode opcode)
434 {
435         if(constant_expression)
436                 opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
437         writer.end_op(opcode);
438 }
439
440 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id arg_id)
441 {
442         Id result_id = begin_expression(opcode, type_id, 1);
443         writer.write(arg_id);
444         end_expression(opcode);
445         return result_id;
446 }
447
448 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id left_id, Id right_id)
449 {
450         Id result_id = begin_expression(opcode, type_id, 2);
451         writer.write(left_id);
452         writer.write(right_id);
453         end_expression(opcode);
454         return result_id;
455 }
456
457 void SpirVGenerator::write_deconstruct(Id elem_type_id, Id composite_id, Id *elem_ids, unsigned n_elems)
458 {
459         for(unsigned i=0; i<n_elems; ++i)
460         {
461                 elem_ids[i] = begin_expression(OP_COMPOSITE_EXTRACT, elem_type_id, 2);
462                 writer.write(composite_id);
463                 writer.write(i);
464                 end_expression(OP_COMPOSITE_EXTRACT);
465         }
466 }
467
468 SpirVGenerator::Id SpirVGenerator::write_construct(Id type_id, const Id *elem_ids, unsigned n_elems)
469 {
470         Id result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, type_id, n_elems);
471         for(unsigned i=0; i<n_elems; ++i)
472                 writer.write(elem_ids[i]);
473         end_expression(OP_COMPOSITE_CONSTRUCT);
474
475         return result_id;
476 }
477
478 void SpirVGenerator::visit(Block &block)
479 {
480         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
481                 (*i)->visit(*this);
482 }
483
484 void SpirVGenerator::visit(Literal &literal)
485 {
486         Id type_id = get_id(*literal.type);
487         if(spec_constant)
488                 r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true);
489         else
490                 r_expression_result_id = get_constant_id(type_id, literal.value);
491         r_constant_result = true;
492 }
493
494 void SpirVGenerator::visit(VariableReference &var)
495 {
496         if(constant_expression || var.declaration->constant)
497         {
498                 if(!var.declaration->constant)
499                         throw internal_error("reference to non-constant variable in constant context");
500
501                 r_expression_result_id = get_id(*var.declaration);
502                 r_constant_result = true;
503                 return;
504         }
505         else if(!current_function)
506                 throw internal_error("non-constant context outside a function");
507
508         r_constant_result = false;
509         if(composite_access)
510         {
511                 r_composite_base = var.declaration;
512                 r_expression_result_id = 0;
513         }
514         else if(assignment_source_id)
515         {
516                 writer.write_op(content.function_body, OP_STORE, get_id(*var.declaration), assignment_source_id);
517                 variable_load_ids[var.declaration] = assignment_source_id;
518                 r_expression_result_id = assignment_source_id;
519         }
520         else
521                 r_expression_result_id = get_load_id(*var.declaration);
522 }
523
524 void SpirVGenerator::visit(InterfaceBlockReference &iface)
525 {
526         if(!composite_access || !current_function)
527                 throw internal_error("invalid interface block reference");
528
529         r_composite_base = iface.declaration;
530         r_expression_result_id = 0;
531         r_constant_result = false;
532 }
533
534 void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
535 {
536         Opcode opcode;
537         Id result_type_id = get_id(result_type);
538         Id access_type_id = result_type_id;
539         if(r_composite_base)
540         {
541                 if(constant_expression)
542                         throw internal_error("composite access through pointer in constant context");
543
544                 Id int32_type_id = get_standard_type_id(BasicTypeDeclaration::INT, 1);
545                 for(vector<unsigned>::iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
546                         *i = (*i<0x400000 ? get_constant_id(int32_type_id, static_cast<int>(*i)) : *i&0x3FFFFF);
547
548                 /* Find the storage class of the base and obtain appropriate pointer type
549                 for the result. */
550                 const Declaration &base_decl = get_item(declared_ids, r_composite_base);
551                 map<TypeKey, Id>::const_iterator i = pointer_type_ids.begin();
552                 for(; (i!=pointer_type_ids.end() && i->second!=base_decl.type_id); ++i) ;
553                 if(i==pointer_type_ids.end())
554                         throw internal_error("could not find storage class");
555                 access_type_id = get_pointer_type_id(result_type_id, static_cast<StorageClass>(i->first.detail));
556
557                 opcode = OP_ACCESS_CHAIN;
558         }
559         else if(assignment_source_id)
560                 throw internal_error("assignment to temporary composite");
561         else
562         {
563                 for(vector<unsigned>::iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
564                         for(map<ConstantKey, Id>::iterator j=constant_ids.begin(); (*i>=0x400000 && j!=constant_ids.end()); ++j)
565                                 if(j->second==(*i&0x3FFFFF))
566                                         *i = j->first.int_value;
567
568                 opcode = OP_COMPOSITE_EXTRACT;
569         }
570
571         Id access_id = begin_expression(opcode, access_type_id, 1+r_composite_chain.size());
572         writer.write(r_composite_base_id);
573         for(vector<unsigned>::const_iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
574                 writer.write(*i);
575         end_expression(opcode);
576
577         r_constant_result = false;
578         if(r_composite_base)
579         {
580                 if(assignment_source_id)
581                 {
582                         writer.write_op(content.function_body, OP_STORE, access_id, assignment_source_id);
583                         r_expression_result_id = assignment_source_id;
584                 }
585                 else
586                         r_expression_result_id = write_expression(OP_LOAD, result_type_id, access_id);
587         }
588         else
589                 r_expression_result_id = access_id;
590 }
591
592 void SpirVGenerator::visit_composite(Expression &base_expr, unsigned index, TypeDeclaration &type)
593 {
594         if(!composite_access)
595         {
596                 r_composite_base = 0;
597                 r_composite_base_id = 0;
598                 r_composite_chain.clear();
599         }
600
601         {
602                 SetFlag set_composite(composite_access);
603                 base_expr.visit(*this);
604         }
605
606         if(!r_composite_base_id)
607                 r_composite_base_id = (r_composite_base ? get_id(*r_composite_base) : r_expression_result_id);
608
609         r_composite_chain.push_back(index);
610         if(!composite_access)
611                 generate_composite_access(type);
612         else
613                 r_expression_result_id = 0;
614 }
615
616 void SpirVGenerator::visit_isolated(Expression &expr)
617 {
618         SetForScope<Id> clear_assign(assignment_source_id, 0);
619         SetFlag clear_composite(composite_access, false);
620         SetForScope<Node *> clear_base(r_composite_base, 0);
621         SetForScope<Id> clear_base_id(r_composite_base_id, 0);
622         vector<unsigned> saved_chain;
623         swap(saved_chain, r_composite_chain);
624         expr.visit(*this);
625         swap(saved_chain, r_composite_chain);
626 }
627
628 void SpirVGenerator::visit(MemberAccess &memacc)
629 {
630         visit_composite(*memacc.left, memacc.index, *memacc.type);
631 }
632
633 void SpirVGenerator::visit(Swizzle &swizzle)
634 {
635         if(swizzle.count==1)
636                 visit_composite(*swizzle.left, swizzle.components[0], *swizzle.type);
637         else if(assignment_source_id)
638         {
639                 const BasicTypeDeclaration &basic = dynamic_cast<const BasicTypeDeclaration &>(*swizzle.left->type);
640
641                 unsigned mask = 0;
642                 for(unsigned i=0; i<swizzle.count; ++i)
643                         mask |= 1<<swizzle.components[i];
644
645                 visit_isolated(*swizzle.left);
646
647                 Id combined_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.left->type), 2+basic.size);
648                 writer.write(r_expression_result_id);
649                 writer.write(assignment_source_id);
650                 for(unsigned i=0; i<basic.size; ++i)
651                         writer.write(i+((mask>>i)&1)*basic.size);
652                 end_expression(OP_VECTOR_SHUFFLE);
653
654                 SetForScope<Id> set_assign(assignment_source_id, combined_id);
655                 swizzle.left->visit(*this);
656
657                 r_expression_result_id = combined_id;
658         }
659         else
660         {
661                 swizzle.left->visit(*this);
662                 Id left_id = r_expression_result_id;
663
664                 r_expression_result_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.type), 2+swizzle.count);
665                 writer.write(left_id);
666                 writer.write(left_id);
667                 for(unsigned i=0; i<swizzle.count; ++i)
668                         writer.write(swizzle.components[i]);
669                 end_expression(OP_VECTOR_SHUFFLE);
670         }
671         r_constant_result = false;
672 }
673
674 void SpirVGenerator::visit(UnaryExpression &unary)
675 {
676         unary.expression->visit(*this);
677
678         char oper = unary.oper->token[0];
679         char oper2 = unary.oper->token[1];
680         if(oper=='+' && !oper2)
681                 return;
682
683         BasicTypeDeclaration &basic = dynamic_cast<BasicTypeDeclaration &>(*unary.expression->type);
684         BasicTypeDeclaration &elem = *get_element_type(basic);
685
686         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
687                 /* SPIR-V allows constant operations on floating-point values only for
688                 OpenGL kernels. */
689                 throw internal_error("invalid operands for constant unary expression");
690
691         Id result_type_id = get_id(*unary.type);
692         Opcode opcode = OP_NOP;
693
694         r_constant_result = false;
695         if(oper=='!')
696                 opcode = OP_LOGICAL_NOT;
697         else if(oper=='~')
698                 opcode = OP_NOT;
699         else if(oper=='-' && !oper2)
700         {
701                 opcode = (elem.kind==BasicTypeDeclaration::INT ? OP_S_NEGATE : OP_F_NEGATE);
702
703                 if(basic.kind==BasicTypeDeclaration::MATRIX)
704                 {
705                         Id column_type_id = get_id(*basic.base_type);
706                         unsigned n_columns = basic.size&0xFFFF;
707                         Id column_ids[4];
708                         write_deconstruct(column_type_id, r_expression_result_id, column_ids, n_columns);
709                         for(unsigned i=0; i<n_columns; ++i)
710                                 column_ids[i] = write_expression(opcode, column_type_id, column_ids[i]);
711                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
712                         return;
713                 }
714         }
715         else if((oper=='+' || oper=='-') && oper2==oper)
716         {
717                 if(constant_expression)
718                         throw internal_error("increment or decrement in constant expression");
719
720                 Id one_id = 0;
721                 if(elem.kind==BasicTypeDeclaration::INT)
722                 {
723                         opcode = (oper=='+' ? OP_I_ADD : OP_I_SUB);
724                         one_id = get_constant_id(get_id(elem), 1);
725                 }
726                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
727                 {
728                         opcode = (oper=='+' ? OP_F_ADD : OP_F_SUB);
729                         one_id = get_constant_id(get_id(elem), 1.0f);
730                 }
731                 else
732                         throw internal_error("invalid increment/decrement");
733
734                 if(basic.kind==BasicTypeDeclaration::VECTOR)
735                         one_id = get_vector_constant_id(result_type_id, basic.size, one_id);
736
737                 Id post_value_id = write_expression(opcode, result_type_id, r_expression_result_id, one_id);
738
739                 SetForScope<Id> set_assign(assignment_source_id, post_value_id);
740                 unary.expression->visit(*this);
741
742                 r_expression_result_id = (unary.oper->type==Operator::POSTFIX ? r_expression_result_id : post_value_id);
743                 return;
744         }
745
746         if(opcode==OP_NOP)
747                 throw internal_error("unknown unary operator");
748
749         r_expression_result_id = write_expression(opcode, result_type_id, r_expression_result_id);
750 }
751
752 void SpirVGenerator::visit(BinaryExpression &binary)
753 {
754         char oper = binary.oper->token[0];
755         if(oper=='[')
756         {
757                 visit_isolated(*binary.right);
758                 return visit_composite(*binary.left, 0x400000|r_expression_result_id, *binary.type);
759         }
760
761         if(assignment_source_id)
762                 throw internal_error("invalid binary expression in assignment target");
763
764         BasicTypeDeclaration &basic_left = dynamic_cast<BasicTypeDeclaration &>(*binary.left->type);
765         BasicTypeDeclaration &basic_right = dynamic_cast<BasicTypeDeclaration &>(*binary.right->type);
766         // Expression resolver ensures that element types are the same
767         BasicTypeDeclaration &elem = *get_element_type(basic_left);
768
769         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
770                 /* SPIR-V allows constant operations on floating-point values only for
771                 OpenGL kernels. */
772                 throw internal_error("invalid operands for constant binary expression");
773
774         binary.left->visit(*this);
775         Id left_id = r_expression_result_id;
776         binary.right->visit(*this);
777         Id right_id = r_expression_result_id;
778
779         Id result_type_id = get_id(*binary.type);
780         Opcode opcode = OP_NOP;
781         bool swap_operands = false;
782
783         r_constant_result = false;
784
785         char oper2 = binary.oper->token[1];
786         if((oper=='<' || oper=='>') && oper2!=oper)
787         {
788                 if(basic_left.kind==BasicTypeDeclaration::INT)
789                 {
790                         if(basic_left.sign)
791                                 opcode = (oper=='<' ? (oper2=='=' ? OP_S_LESS_THAN_EQUAL : OP_S_LESS_THAN) :
792                                         (oper2=='=' ? OP_S_GREATER_THAN_EQUAL : OP_S_GREATER_THAN));
793                         else
794                                 opcode = (oper=='<' ? (oper2=='=' ? OP_U_LESS_THAN_EQUAL : OP_U_LESS_THAN) :
795                                         (oper2=='=' ? OP_U_GREATER_THAN_EQUAL : OP_U_GREATER_THAN));
796                 }
797                 else if(basic_left.kind==BasicTypeDeclaration::FLOAT)
798                         opcode = (oper=='<' ? (oper2=='=' ? OP_F_ORD_LESS_THAN_EQUAL : OP_F_ORD_LESS_THAN) :
799                                 (oper2=='=' ? OP_F_ORD_GREATER_THAN_EQUAL : OP_F_ORD_GREATER_THAN));
800         }
801         else if((oper=='=' || oper=='!') && oper2=='=')
802         {
803                 if(elem.kind==BasicTypeDeclaration::BOOL)
804                         opcode = (oper=='=' ? OP_LOGICAL_EQUAL : OP_LOGICAL_NOT_EQUAL);
805                 else if(elem.kind==BasicTypeDeclaration::INT)
806                         opcode = (oper=='=' ? OP_I_EQUAL : OP_I_NOT_EQUAL);
807                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
808                         opcode = (oper=='=' ? OP_F_ORD_EQUAL : OP_F_ORD_NOT_EQUAL);
809
810                 if(opcode!=OP_NOP && basic_left.base_type)
811                 {
812                         /* The SPIR-V equality operations produce component-wise results, but
813                         GLSL operators return a single boolean.  Use the any/all operations to
814                         combine the results. */
815                         Opcode combine_op = (oper=='!' ? OP_ANY : OP_ALL);
816                         unsigned n_elems = basic_left.size&0xFFFF;
817                         Id bool_vec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, n_elems);
818
819                         Id compare_id = 0;
820                         if(basic_left.kind==BasicTypeDeclaration::VECTOR)
821                                 compare_id = write_expression(opcode, bool_vec_type_id, left_id, right_id);
822                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX)
823                         {
824                                 Id column_type_id = get_id(*basic_left.base_type);
825                                 Id column_ids[8];
826                                 write_deconstruct(column_type_id, left_id, column_ids, n_elems);
827                                 write_deconstruct(column_type_id, right_id, column_ids+4, n_elems);
828
829                                 Id column_bvec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, basic_left.size>>16);
830                                 for(unsigned i=0; i<n_elems; ++i)
831                                 {
832                                         compare_id = write_expression(opcode, column_bvec_type_id, column_ids[i], column_ids[4+i]);
833                                         column_ids[i] = write_expression(combine_op, result_type_id, compare_id);;
834                                 }
835
836                                 compare_id = write_construct(bool_vec_type_id, column_ids, n_elems);
837                         }
838
839                         if(compare_id)
840                                 r_expression_result_id = write_expression(combine_op, result_type_id, compare_id);
841                         return;
842                 }
843         }
844         else if(oper2=='&' && elem.kind==BasicTypeDeclaration::BOOL)
845                 opcode = OP_LOGICAL_AND;
846         else if(oper2=='|' && elem.kind==BasicTypeDeclaration::BOOL)
847                 opcode = OP_LOGICAL_OR;
848         else if(oper2=='^' && elem.kind==BasicTypeDeclaration::BOOL)
849                 opcode = OP_LOGICAL_NOT_EQUAL;
850         else if(oper=='&' && elem.kind==BasicTypeDeclaration::INT)
851                 opcode = OP_BITWISE_AND;
852         else if(oper=='|' && elem.kind==BasicTypeDeclaration::INT)
853                 opcode = OP_BITWISE_OR;
854         else if(oper=='^' && elem.kind==BasicTypeDeclaration::INT)
855                 opcode = OP_BITWISE_XOR;
856         else if(oper=='<' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
857                 opcode = OP_SHIFT_LEFT_LOGICAL;
858         else if(oper=='>' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
859                 opcode = OP_SHIFT_RIGHT_ARITHMETIC;
860         else if(oper=='%' && elem.kind==BasicTypeDeclaration::INT)
861                 opcode = (elem.sign ? OP_S_MOD : OP_U_MOD);
862         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
863         {
864                 Opcode elem_op = OP_NOP;
865                 if(elem.kind==BasicTypeDeclaration::INT)
866                 {
867                         if(oper=='/')
868                                 elem_op = (elem.sign ? OP_S_DIV : OP_U_DIV);
869                         else
870                                 elem_op = (oper=='+' ? OP_I_ADD : oper=='-' ? OP_I_SUB : OP_I_MUL);
871                 }
872                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
873                         elem_op = (oper=='+' ? OP_F_ADD : oper=='-' ? OP_F_SUB : oper=='*' ? OP_F_MUL : OP_F_DIV);
874
875                 if(oper=='*' && (basic_left.base_type || basic_right.base_type) && elem.kind==BasicTypeDeclaration::FLOAT)
876                 {
877                         /* Multiplication between floating-point vectors and matrices has
878                         dedicated operations. */
879                         if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
880                                 opcode = OP_MATRIX_TIMES_MATRIX;
881                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
882                         {
883                                 if(basic_left.kind==BasicTypeDeclaration::VECTOR)
884                                         opcode = OP_VECTOR_TIMES_MATRIX;
885                                 else if(basic_right.kind==BasicTypeDeclaration::VECTOR)
886                                         opcode = OP_MATRIX_TIMES_VECTOR;
887                                 else
888                                 {
889                                         opcode = OP_MATRIX_TIMES_SCALAR;
890                                         swap_operands = (basic_right.kind==BasicTypeDeclaration::MATRIX);
891                                 }
892                         }
893                         else if(basic_left.kind==BasicTypeDeclaration::VECTOR && basic_right.kind==BasicTypeDeclaration::VECTOR)
894                                 opcode = elem_op;
895                         else
896                         {
897                                 opcode = OP_VECTOR_TIMES_SCALAR;
898                                 swap_operands = (basic_right.kind==BasicTypeDeclaration::VECTOR);
899                         }
900                 }
901                 else if((basic_left.base_type!=0)!=(basic_right.base_type!=0))
902                 {
903                         /* One operand is scalar and the other is a vector or a matrix.
904                         Expand the scalar to a vector of appropriate size. */
905                         Id &scalar_id = (basic_left.base_type ? right_id : left_id);
906                         BasicTypeDeclaration *vector_type = (basic_left.base_type ? &basic_left : &basic_right);
907                         if(vector_type->kind==BasicTypeDeclaration::MATRIX)
908                                 vector_type = dynamic_cast<BasicTypeDeclaration *>(vector_type->base_type);
909                         Id vector_type_id = get_id(*vector_type);
910
911                         Id expanded_id = begin_expression(OP_COMPOSITE_CONSTRUCT, vector_type_id, vector_type->size);
912                         for(unsigned i=0; i<vector_type->size; ++i)
913                                 writer.write(scalar_id);
914                         end_expression(OP_COMPOSITE_CONSTRUCT);
915
916                         scalar_id = expanded_id;
917
918                         if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
919                         {
920                                 // Apply matrix operation column-wise.
921                                 Id matrix_id = (basic_left.base_type ? left_id : right_id);
922
923                                 Id column_ids[4];
924                                 unsigned n_columns = (basic_left.base_type ? basic_left.size : basic_right.size)&0xFFFF;
925                                 write_deconstruct(vector_type_id, matrix_id, column_ids, n_columns);
926
927                                 for(unsigned i=0; i<n_columns; ++i)
928                                         column_ids[i] = write_expression(elem_op, vector_type_id, column_ids[i], expanded_id);
929
930                                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
931                                 return;
932                         }
933                         else
934                                 opcode = elem_op;
935                 }
936                 else if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
937                 {
938                         if(oper=='*')
939                                 throw internal_error("non-float matrix multiplication");
940
941                         /* Other operations involving matrices need to be performed
942                         column-wise. */
943                         Id column_type_id = get_id(*basic_left.base_type);
944                         Id column_ids[8];
945
946                         unsigned n_columns = basic_left.size&0xFFFF;
947                         write_deconstruct(column_type_id, left_id, column_ids, n_columns);
948                         write_deconstruct(column_type_id, right_id, column_ids+4, n_columns);
949
950                         for(unsigned i=0; i<n_columns; ++i)
951                                 column_ids[i] = write_expression(elem_op, column_type_id, column_ids[i], column_ids[4+i]);
952
953                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
954                         return;
955                 }
956                 else if(basic_left.kind==basic_right.kind)
957                         // Both operands are either scalars or vectors.
958                         opcode = elem_op;
959         }
960
961         if(opcode==OP_NOP)
962                 throw internal_error("unknown binary operator");
963
964         if(swap_operands)
965                 swap(left_id, right_id);
966
967         r_expression_result_id = write_expression(opcode, result_type_id, left_id, right_id);
968 }
969
970 void SpirVGenerator::visit(Assignment &assign)
971 {
972         if(assign.oper->token[0]!='=')
973                 visit(static_cast<BinaryExpression &>(assign));
974         else
975                 assign.right->visit(*this);
976
977         SetForScope<Id> set_assign(assignment_source_id, r_expression_result_id);
978         assign.left->visit(*this);
979         r_constant_result = false;
980 }
981
982 void SpirVGenerator::visit(TernaryExpression &ternary)
983 {
984         if(constant_expression)
985         {
986                 ternary.condition->visit(*this);
987                 Id condition_id = r_expression_result_id;
988                 ternary.true_expr->visit(*this);
989                 Id true_result_id = r_expression_result_id;
990                 ternary.false_expr->visit(*this);
991                 Id false_result_id = r_expression_result_id;
992
993                 r_expression_result_id = begin_expression(OP_SELECT, get_id(*ternary.type), 3);
994                 writer.write(condition_id);
995                 writer.write(true_result_id);
996                 writer.write(false_result_id);
997                 end_expression(OP_SELECT);
998
999                 return;
1000         }
1001
1002         ternary.condition->visit(*this);
1003         Id condition_id = r_expression_result_id;
1004
1005         Id true_label_id = next_id++;
1006         Id false_label_id = next_id++;
1007         Id merge_block_id = next_id++;
1008         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1009         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, condition_id, true_label_id, false_label_id);
1010
1011         writer.write_op_label(true_label_id);
1012         ternary.true_expr->visit(*this);
1013         Id true_result_id = r_expression_result_id;
1014         writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1015
1016         writer.write_op_label(false_label_id);
1017         ternary.false_expr->visit(*this);
1018         Id false_result_id = r_expression_result_id;
1019
1020         writer.write_op_label(merge_block_id);
1021         r_expression_result_id = begin_expression(OP_PHI, get_id(*ternary.type), 4);
1022         writer.write(true_result_id);
1023         writer.write(true_label_id);
1024         writer.write(false_result_id);
1025         writer.write(false_label_id);
1026         end_expression(OP_PHI);
1027
1028         r_constant_result = false;
1029 }
1030
1031 void SpirVGenerator::visit(FunctionCall &call)
1032 {
1033         if(assignment_source_id)
1034                 throw internal_error("assignment to function call");
1035         else if(composite_access)
1036                 return visit_isolated(call);
1037         else if(call.constructor && call.arguments.size()==1 && call.arguments[0]->type==call.type)
1038                 return call.arguments[0]->visit(*this);
1039
1040         vector<Id> argument_ids;
1041         argument_ids.reserve(call.arguments.size());
1042         bool all_args_const = true;
1043         for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
1044         {
1045                 (*i)->visit(*this);
1046                 argument_ids.push_back(r_expression_result_id);
1047                 all_args_const &= r_constant_result;
1048         }
1049
1050         if(constant_expression && (!call.constructor || !all_args_const))
1051                 throw internal_error("function call in constant expression");
1052
1053         Id result_type_id = get_id(*call.type);
1054
1055         if(call.constructor)
1056                 visit_constructor(call, argument_ids, all_args_const);
1057         else if(call.declaration->source==BUILTIN_SOURCE)
1058         {
1059                 string arg_types;
1060                 for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
1061                         if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>((*i)->type))
1062                         {
1063                                 BasicTypeDeclaration &elem_arg = *get_element_type(*basic_arg);
1064                                 switch(elem_arg.kind)
1065                                 {
1066                                 case BasicTypeDeclaration::BOOL: arg_types += 'b'; break;
1067                                 case BasicTypeDeclaration::INT: arg_types += (elem_arg.sign ? 'i' : 'u'); break;
1068                                 case BasicTypeDeclaration::FLOAT: arg_types += 'f'; break;
1069                                 default: arg_types += '?';
1070                                 }
1071                         }
1072
1073                 const BuiltinFunctionInfo *builtin_info;
1074                 for(builtin_info=builtin_functions; builtin_info->function[0]; ++builtin_info)
1075                         if(builtin_info->function==call.name && (!builtin_info->arg_types[0] || builtin_info->arg_types==arg_types))
1076                                 break;
1077
1078                 if(builtin_info->opcode)
1079                 {
1080                         Opcode opcode;
1081                         if(builtin_info->extension[0])
1082                         {
1083                                 opcode = OP_EXT_INST;
1084                                 Id ext_id = import_extension(builtin_info->extension);
1085
1086                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1087                                 writer.write(ext_id);
1088                                 writer.write(builtin_info->opcode);
1089                         }
1090                         else
1091                         {
1092                                 opcode = static_cast<Opcode>(builtin_info->opcode);
1093                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1094                         }
1095
1096                         for(unsigned i=0; i<call.arguments.size(); ++i)
1097                         {
1098                                 if(!builtin_info->arg_order[i] || builtin_info->arg_order[i]>argument_ids.size())
1099                                         throw internal_error("invalid builtin function info");
1100                                 writer.write(argument_ids[builtin_info->arg_order[i]-1]);
1101                         }
1102
1103                         end_expression(opcode);
1104                 }
1105                 else if(builtin_info->handler)
1106                         (this->*(builtin_info->handler))(call, argument_ids);
1107                 else
1108                         throw internal_error("unknown builtin function "+call.name);
1109         }
1110         else
1111         {
1112                 r_expression_result_id = begin_expression(OP_FUNCTION_CALL, result_type_id, 1+call.arguments.size());
1113                 writer.write(get_id(*call.declaration->definition));
1114                 for(vector<Id>::const_iterator i=argument_ids.begin(); i!=argument_ids.end(); ++i)
1115                         writer.write(*i);
1116                 end_expression(OP_FUNCTION_CALL);
1117
1118                 // Any global variables the called function uses might have changed value
1119                 set<Node *> dependencies = DependencyCollector().apply(*call.declaration->definition);
1120                 for(set<Node *>::const_iterator i=dependencies.begin(); i!=dependencies.end(); ++i)
1121                         if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(*i))
1122                                 variable_load_ids.erase(var);
1123         }
1124 }
1125
1126 void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids, bool all_args_const)
1127 {
1128         Id result_type_id = get_id(*call.type);
1129
1130         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(call.type);
1131         if(!basic)
1132         {
1133                 if(dynamic_cast<const StructDeclaration *>(call.type))
1134                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1135                 else
1136                         throw internal_error("unconstructable type "+call.name);
1137                 return;
1138         }
1139
1140         SetFlag set_const(constant_expression, constant_expression || all_args_const);
1141
1142         BasicTypeDeclaration &elem = *get_element_type(*basic);
1143         BasicTypeDeclaration &basic_arg0 = dynamic_cast<BasicTypeDeclaration &>(*call.arguments[0]->type);
1144         BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0);
1145
1146         if(basic->kind==BasicTypeDeclaration::MATRIX)
1147         {
1148                 Id col_type_id = get_id(*basic->base_type);
1149                 unsigned n_columns = basic->size&0xFFFF;
1150                 unsigned n_rows = basic->size>>16;
1151
1152                 Id column_ids[4];
1153                 if(call.arguments.size()==1)
1154                 {
1155                         // Construct diagonal matrix from a single scalar.
1156                         Id zero_id = get_constant_id(get_id(elem), 0.0f);
1157                         for(unsigned i=0; i<n_columns; ++i)
1158                         {
1159                                 column_ids[i] = begin_expression(OP_COMPOSITE_CONSTRUCT, col_type_id, n_rows);;
1160                                 for(unsigned j=0; j<n_rows; ++j)
1161                                         writer.write(j==i ? argument_ids[0] : zero_id);
1162                                 end_expression(OP_COMPOSITE_CONSTRUCT);
1163                         }
1164                 }
1165                 else
1166                         // Construct a matrix from column vectors
1167                         copy(argument_ids.begin(), argument_ids.begin()+n_columns, column_ids);
1168
1169                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
1170         }
1171         else if(basic->kind==BasicTypeDeclaration::VECTOR && (call.arguments.size()>1 || basic_arg0.kind!=BasicTypeDeclaration::VECTOR))
1172         {
1173                 /* There's either a single scalar argument or multiple arguments
1174                 which make up the vector's components. */
1175                 if(call.arguments.size()==1)
1176                 {
1177                         r_expression_result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, result_type_id);
1178                         for(unsigned i=0; i<basic->size; ++i)
1179                                 writer.write(argument_ids[0]);
1180                         end_expression(OP_COMPOSITE_CONSTRUCT);
1181                 }
1182                 else
1183                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1184         }
1185         else if(elem.kind==BasicTypeDeclaration::BOOL)
1186         {
1187                 if(constant_expression)
1188                         throw internal_error("unconverted constant");
1189
1190                 // Conversion to boolean is implemented as comparing against zero.
1191                 Id number_type_id = get_id(elem_arg0);
1192                 Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ?
1193                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1194                 if(basic_arg0.kind==BasicTypeDeclaration::VECTOR)
1195                         zero_id = get_vector_constant_id(get_id(basic_arg0), basic_arg0.size, zero_id);
1196
1197                 Opcode opcode = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ? OP_F_ORD_NOT_EQUAL : OP_I_NOT_EQUAL);
1198                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0], zero_id);
1199         }
1200         else if(elem_arg0.kind==BasicTypeDeclaration::BOOL)
1201         {
1202                 if(constant_expression)
1203                         throw internal_error("unconverted constant");
1204
1205                 /* Conversion from boolean is implemented as selecting from zero
1206                 or one. */
1207                 Id number_type_id = get_id(elem);
1208                 Id zero_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1209                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1210                 Id one_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1211                         get_constant_id(number_type_id, 1.0f) : get_constant_id(number_type_id, 1));
1212                 if(basic->kind==BasicTypeDeclaration::VECTOR)
1213                 {
1214                         zero_id = get_vector_constant_id(get_id(*basic), basic->size, zero_id);
1215                         one_id = get_vector_constant_id(get_id(*basic), basic->size, one_id);
1216                 }
1217
1218                 r_expression_result_id = begin_expression(OP_SELECT, result_type_id, 3);
1219                 writer.write(argument_ids[0]);
1220                 writer.write(zero_id);
1221                 writer.write(one_id);
1222                 end_expression(OP_SELECT);
1223         }
1224         else
1225         {
1226                 if(constant_expression)
1227                         throw internal_error("unconverted constant");
1228
1229                 // Scalar or vector conversion between types of equal size.
1230                 Opcode opcode;
1231                 if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT)
1232                         opcode = (elem.sign ? OP_CONVERT_F_TO_S : OP_CONVERT_F_TO_U);
1233                 else if(elem.kind==BasicTypeDeclaration::FLOAT && elem_arg0.kind==BasicTypeDeclaration::INT)
1234                         opcode = (elem_arg0.sign ? OP_CONVERT_S_TO_F : OP_CONVERT_U_TO_F);
1235                 else if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::INT)
1236                         opcode = OP_BITCAST;
1237                 else
1238                         throw internal_error("invalid conversion");
1239
1240                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0]);
1241         }
1242 }
1243
1244 void SpirVGenerator::visit_builtin_matrix_comp_mult(FunctionCall &call, const vector<Id> &argument_ids)
1245 {
1246         if(argument_ids.size()!=2)
1247                 throw internal_error("invalid matrixCompMult call");
1248
1249         const BasicTypeDeclaration &basic_arg0 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[0]->type);
1250         Id column_type_id = get_id(*basic_arg0.base_type);
1251         Id column_ids[8];
1252
1253         unsigned n_columns = basic_arg0.size&0xFFFF;
1254         write_deconstruct(column_type_id, argument_ids[0], column_ids, n_columns);
1255         write_deconstruct(column_type_id, argument_ids[1], column_ids+4, n_columns);
1256
1257         for(unsigned i=0; i<n_columns; ++i)
1258                 column_ids[i] = write_expression(OP_F_MUL, column_type_id, column_ids[i], column_ids[4+i]);
1259
1260         r_expression_result_id = write_construct(get_id(*call.type), column_ids, n_columns);
1261 }
1262
1263 void SpirVGenerator::visit_builtin_texture(FunctionCall &call, const vector<Id> &argument_ids)
1264 {
1265         if(argument_ids.size()<2)
1266                 throw internal_error("invalid texture sampling call");
1267
1268         bool explicit_lod = (stage->type!=Stage::FRAGMENT || call.name=="textureLod");
1269         Id lod_id = (!explicit_lod ? 0 : call.name=="textureLod" ? argument_ids.back() :
1270                 get_constant_id(get_standard_type_id(BasicTypeDeclaration::FLOAT, 1), 0.0f));
1271
1272         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1273
1274         Opcode opcode;
1275         Id result_type_id = get_id(*call.type);
1276         Id dref_id = 0;
1277         if(image.shadow)
1278         {
1279                 if(argument_ids.size()==2)
1280                 {
1281                         const BasicTypeDeclaration &basic_arg1 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[1]->type);
1282                         dref_id = begin_expression(OP_COMPOSITE_EXTRACT, get_id(*basic_arg1.base_type), 2);
1283                         writer.write(argument_ids.back());
1284                         writer.write(basic_arg1.size-1);
1285                         end_expression(OP_COMPOSITE_EXTRACT);
1286                 }
1287                 else
1288                         dref_id = argument_ids[2];
1289
1290                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD : OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD);
1291                 r_expression_result_id = begin_expression(opcode, result_type_id, 3+explicit_lod*2);
1292         }
1293         else
1294         {
1295                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_EXPLICIT_LOD : OP_IMAGE_SAMPLE_IMPLICIT_LOD);
1296                 r_expression_result_id = begin_expression(opcode, result_type_id, 2+explicit_lod*2);
1297         }
1298
1299         for(unsigned i=0; i<2; ++i)
1300                 writer.write(argument_ids[i]);
1301         if(dref_id)
1302                 writer.write(dref_id);
1303         if(explicit_lod)
1304         {
1305                 writer.write(2);  // Lod
1306                 writer.write(lod_id);
1307         }
1308
1309         end_expression(opcode);
1310 }
1311
1312 void SpirVGenerator::visit_builtin_texel_fetch(FunctionCall &call, const vector<Id> &argument_ids)
1313 {
1314         if(argument_ids.size()!=3)
1315                 throw internal_error("invalid texelFetch call");
1316
1317         r_expression_result_id = begin_expression(OP_IMAGE_FETCH, get_id(*call.type), 4);
1318         for(unsigned i=0; i<2; ++i)
1319                 writer.write(argument_ids[i]);
1320         writer.write(2);  // Lod
1321         writer.write(argument_ids.back());
1322         end_expression(OP_IMAGE_FETCH);
1323 }
1324
1325 void SpirVGenerator::visit_builtin_interpolate(FunctionCall &call, const vector<Id> &argument_ids)
1326 {
1327         if(argument_ids.size()<1)
1328                 throw internal_error("invalid interpolate call");
1329         const VariableReference *var = dynamic_cast<const VariableReference *>(call.arguments[0].get());
1330         if(!var || !var->declaration || var->declaration->interface!="in")
1331                 throw internal_error("invalid interpolate call");
1332
1333         SpirVGlslStd450Opcode opcode;
1334         if(call.name=="interpolateAtCentroid")
1335                 opcode = GLSL450_INTERPOLATE_AT_CENTROID;
1336         else if(call.name=="interpolateAtSample")
1337                 opcode = GLSL450_INTERPOLATE_AT_SAMPLE;
1338         else if(call.name=="interpolateAtOffset")
1339                 opcode = GLSL450_INTERPOLATE_AT_OFFSET;
1340         else
1341                 throw internal_error("invalid interpolate call");
1342
1343         use_capability(CAP_INTERPOLATION_FUNCTION);
1344
1345         Id ext_id = import_extension("GLSL.std.450");
1346         r_expression_result_id = begin_expression(OP_EXT_INST, get_id(*call.type));
1347         writer.write(ext_id);
1348         writer.write(opcode);
1349         writer.write(get_id(*var->declaration));
1350         for(vector<Id>::const_iterator i=argument_ids.begin(); ++i!=argument_ids.end(); )
1351                 writer.write(*i);
1352         end_expression(OP_EXT_INST);
1353 }
1354
1355 void SpirVGenerator::visit(ExpressionStatement &expr)
1356 {
1357         expr.expression->visit(*this);
1358 }
1359
1360 void SpirVGenerator::visit(InterfaceLayout &layout)
1361 {
1362         interface_layouts.push_back(&layout);
1363 }
1364
1365 bool SpirVGenerator::check_duplicate_type(TypeDeclaration &type)
1366 {
1367         for(map<Node *, Declaration>::const_iterator i=declared_ids.begin(); i!=declared_ids.end(); ++i)
1368                 if(TypeDeclaration *type2 = dynamic_cast<TypeDeclaration *>(i->first))
1369                         if(TypeComparer().apply(type, *type2))
1370                         {
1371                                 insert_unique(declared_ids, &type, i->second);
1372                                 return true;
1373                         }
1374
1375         return false;
1376 }
1377
1378 bool SpirVGenerator::check_standard_type(BasicTypeDeclaration &basic)
1379 {
1380         const BasicTypeDeclaration *elem = (basic.kind==BasicTypeDeclaration::VECTOR ?
1381                 dynamic_cast<const BasicTypeDeclaration *>(basic.base_type) : &basic);
1382         if(!elem || elem->base_type)
1383                 return false;
1384         if((elem->kind==BasicTypeDeclaration::INT || elem->kind==BasicTypeDeclaration::FLOAT) && elem->size!=32)
1385                 return false;
1386
1387         Id standard_id = get_standard_type_id(elem->kind, (basic.kind==BasicTypeDeclaration::VECTOR ? basic.size : 1), elem->sign);
1388         insert_unique(declared_ids, &basic, Declaration(standard_id, 0));
1389         writer.write_op_name(standard_id, basic.name);
1390
1391         return true;
1392 }
1393
1394 void SpirVGenerator::visit(BasicTypeDeclaration &basic)
1395 {
1396         if(check_standard_type(basic))
1397                 return;
1398         if(check_duplicate_type(basic))
1399                 return;
1400         // Alias types shouldn't exist at this point and arrays are handled elsewhere
1401         if(basic.kind==BasicTypeDeclaration::ALIAS || basic.kind==BasicTypeDeclaration::ARRAY)
1402                 return;
1403
1404         Id type_id = allocate_id(basic, 0);
1405         writer.write_op_name(type_id, basic.name);
1406
1407         switch(basic.kind)
1408         {
1409         case BasicTypeDeclaration::INT:
1410                 writer.write_op(content.globals, OP_TYPE_INT, type_id, basic.size, basic.sign);
1411                 break;
1412         case BasicTypeDeclaration::FLOAT:
1413                 writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, basic.size);
1414                 break;
1415         case BasicTypeDeclaration::VECTOR:
1416                 writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, get_id(*basic.base_type), basic.size);
1417                 break;
1418         case BasicTypeDeclaration::MATRIX:
1419                 writer.write_op(content.globals, OP_TYPE_MATRIX, type_id, get_id(*basic.base_type), basic.size&0xFFFF);
1420                 break;
1421         default:
1422                 throw internal_error("unknown basic type");
1423         }
1424 }
1425
1426 void SpirVGenerator::visit(ImageTypeDeclaration &image)
1427 {
1428         if(check_duplicate_type(image))
1429                 return;
1430
1431         Id type_id = allocate_id(image, 0);
1432
1433         Id image_id = (image.sampled ? next_id++ : type_id);
1434         writer.begin_op(content.globals, OP_TYPE_IMAGE, 9);
1435         writer.write(image_id);
1436         writer.write(get_id(*image.base_type));
1437         writer.write(image.dimensions-1);
1438         writer.write(image.shadow);
1439         writer.write(image.array);
1440         writer.write(false);  // Multisample
1441         writer.write(image.sampled ? 1 : 2);
1442         writer.write(0);  // Format (unknown)
1443         writer.end_op(OP_TYPE_IMAGE);
1444
1445         if(image.sampled)
1446         {
1447                 writer.write_op_name(type_id, image.name);
1448                 writer.write_op(content.globals, OP_TYPE_SAMPLED_IMAGE, type_id, image_id);
1449         }
1450
1451         if(image.dimensions==ImageTypeDeclaration::ONE)
1452                 use_capability(image.sampled ? CAP_SAMPLED_1D : CAP_IMAGE_1D);
1453         else if(image.dimensions==ImageTypeDeclaration::CUBE && image.array)
1454                 use_capability(image.sampled ? CAP_SAMPLED_CUBE_ARRAY : CAP_IMAGE_CUBE_ARRAY);
1455 }
1456
1457 void SpirVGenerator::visit(StructDeclaration &strct)
1458 {
1459         if(check_duplicate_type(strct))
1460                 return;
1461
1462         Id type_id = allocate_id(strct, 0);
1463         writer.write_op_name(type_id, strct.name);
1464
1465         if(strct.interface_block)
1466                 writer.write_op_decorate(type_id, DECO_BLOCK);
1467
1468         bool builtin = (strct.interface_block && !strct.interface_block->block_name.compare(0, 3, "gl_"));
1469         vector<Id> member_type_ids;
1470         member_type_ids.reserve(strct.members.body.size());
1471         for(NodeList<Statement>::const_iterator i=strct.members.body.begin(); i!=strct.members.body.end(); ++i)
1472         {
1473                 const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(i->get());
1474                 if(!var)
1475                         continue;
1476
1477                 unsigned index = member_type_ids.size();
1478                 member_type_ids.push_back(get_variable_type_id(*var));
1479
1480                 writer.write_op_member_name(type_id, index, var->name);
1481
1482                 if(builtin)
1483                 {
1484                         BuiltinSemantic semantic = get_builtin_semantic(var->name);
1485                         writer.write_op_member_decorate(type_id, index, DECO_BUILTIN, semantic);
1486                 }
1487                 else
1488                 {
1489                         if(var->layout)
1490                         {
1491                                 const vector<Layout::Qualifier> &qualifiers = var->layout->qualifiers;
1492                                 for(vector<Layout::Qualifier>::const_iterator j=qualifiers.begin(); j!=qualifiers.end(); ++j)
1493                                 {
1494                                         if(j->name=="offset")
1495                                                 writer.write_op_member_decorate(type_id, index, DECO_OFFSET, j->value);
1496                                         else if(j->name=="column_major")
1497                                                 writer.write_op_member_decorate(type_id, index, DECO_COL_MAJOR);
1498                                         else if(j->name=="row_major")
1499                                                 writer.write_op_member_decorate(type_id, index, DECO_ROW_MAJOR);
1500                                 }
1501                         }
1502
1503                         const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var->type_declaration);
1504                         while(basic && basic->kind==BasicTypeDeclaration::ARRAY)
1505                                 basic = dynamic_cast<const BasicTypeDeclaration *>(basic->base_type);
1506                         if(basic && basic->kind==BasicTypeDeclaration::MATRIX)
1507                         {
1508                                 unsigned stride = MemoryRequirementsCalculator().apply(*basic->base_type).stride;
1509                                 writer.write_op_member_decorate(type_id, index, DECO_MATRIX_STRIDE, stride);
1510                         }
1511                 }
1512         }
1513
1514         writer.begin_op(content.globals, OP_TYPE_STRUCT);
1515         writer.write(type_id);
1516         for(vector<Id>::const_iterator i=member_type_ids.begin(); i!=member_type_ids.end(); ++i)
1517                 writer.write(*i);
1518         writer.end_op(OP_TYPE_STRUCT);
1519 }
1520
1521 void SpirVGenerator::visit(VariableDeclaration &var)
1522 {
1523         const vector<Layout::Qualifier> *layout_ql = (var.layout ? &var.layout->qualifiers : 0);
1524
1525         int spec_id = -1;
1526         if(layout_ql)
1527         {
1528                 for(vector<Layout::Qualifier>::const_iterator i=layout_ql->begin(); (spec_id<0 && i!=layout_ql->end()); ++i)
1529                         if(i->name=="constant_id")
1530                                 spec_id = i->value;
1531         }
1532
1533         Id type_id = get_variable_type_id(var);
1534         Id var_id;
1535
1536         if(var.constant)
1537         {
1538                 if(!var.init_expression)
1539                         throw internal_error("const variable without initializer");
1540
1541                 SetFlag set_const(constant_expression);
1542                 SetFlag set_spec(spec_constant, spec_id>=0);
1543                 r_expression_result_id = 0;
1544                 var.init_expression->visit(*this);
1545                 var_id = r_expression_result_id;
1546                 insert_unique(declared_ids, &var, Declaration(var_id, type_id));
1547                 writer.write_op_decorate(var_id, DECO_SPEC_ID, spec_id);
1548
1549                 /* It's unclear what should be done if a specialization constant is
1550                 initialized with anything other than a literal.  GLSL doesn't seem to
1551                 prohibit that but SPIR-V says OpSpecConstantOp can't be updated via
1552                 specialization. */
1553         }
1554         else
1555         {
1556                 StorageClass storage = (current_function ? STORAGE_FUNCTION : get_interface_storage(var.interface, false));
1557                 Id ptr_type_id = get_pointer_type_id(type_id, storage);
1558                 if(var.interface=="uniform")
1559                 {
1560                         Id &uni_id = declared_uniform_ids["v"+var.name];
1561                         if(uni_id)
1562                         {
1563                                 insert_unique(declared_ids, &var, Declaration(uni_id, ptr_type_id));
1564                                 return;
1565                         }
1566
1567                         uni_id = var_id = allocate_id(var, ptr_type_id);
1568                 }
1569                 else
1570                         var_id = allocate_id(var, (var.constant ? type_id : ptr_type_id));
1571
1572                 Id init_id = 0;
1573                 if(var.init_expression)
1574                 {
1575                         SetFlag set_const(constant_expression, !current_function);
1576                         r_expression_result_id = 0;
1577                         var.init_expression->visit(*this);
1578                         init_id = r_expression_result_id;
1579                 }
1580
1581                 vector<Word> &target = (current_function ? content.locals : content.globals);
1582                 writer.begin_op(target, OP_VARIABLE, 4+(init_id && !current_function));
1583                 writer.write(ptr_type_id);
1584                 writer.write(var_id);
1585                 writer.write(storage);
1586                 if(init_id && !current_function)
1587                         writer.write(init_id);
1588                 writer.end_op(OP_VARIABLE);
1589
1590                 if(layout_ql)
1591                 {
1592                         for(vector<Layout::Qualifier>::const_iterator i=layout_ql->begin(); i!=layout_ql->end(); ++i)
1593                         {
1594                                 if(i->name=="location")
1595                                         writer.write_op_decorate(var_id, DECO_LOCATION, i->value);
1596                                 else if(i->name=="set")
1597                                         writer.write_op_decorate(var_id, DECO_DESCRIPTOR_SET, i->value);
1598                                 else if(i->name=="binding")
1599                                         writer.write_op_decorate(var_id, DECO_BINDING, i->value);
1600                         }
1601                 }
1602
1603                 if(init_id && current_function)
1604                         writer.write_op(content.function_body, OP_STORE, var_id, init_id);
1605         }
1606
1607         writer.write_op_name(var_id, var.name);
1608 }
1609
1610 void SpirVGenerator::visit(InterfaceBlock &iface)
1611 {
1612         StorageClass storage = get_interface_storage(iface.interface, true);
1613         Id type_id;
1614         if(iface.array)
1615                 type_id = get_array_type_id(*iface.struct_declaration, 0);
1616         else
1617                 type_id = get_id(*iface.struct_declaration);
1618         Id ptr_type_id = get_pointer_type_id(type_id, storage);
1619
1620         Id block_id;
1621         if(iface.interface=="uniform")
1622         {
1623                 Id &uni_id = declared_uniform_ids["b"+iface.block_name];
1624                 if(uni_id)
1625                 {
1626                         insert_unique(declared_ids, &iface, Declaration(uni_id, ptr_type_id));
1627                         return;
1628                 }
1629
1630                 uni_id = block_id = allocate_id(iface, ptr_type_id);
1631         }
1632         else
1633                 block_id = allocate_id(iface, ptr_type_id);
1634         writer.write_op_name(block_id, iface.instance_name);
1635
1636         writer.write_op(content.globals, OP_VARIABLE, ptr_type_id, block_id, storage);
1637
1638         if(iface.layout)
1639         {
1640                 const vector<Layout::Qualifier> &qualifiers = iface.layout->qualifiers;
1641                 for(vector<Layout::Qualifier>::const_iterator i=qualifiers.begin(); i!=qualifiers.end(); ++i)
1642                         if(i->name=="binding")
1643                                 writer.write_op_decorate(block_id, DECO_BINDING, i->value);
1644         }
1645 }
1646
1647 void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id)
1648 {
1649         writer.begin_op(content.entry_points, OP_ENTRY_POINT);
1650         switch(stage->type)
1651         {
1652         case Stage::VERTEX: writer.write(0); break;
1653         case Stage::GEOMETRY: writer.write(3); break;
1654         case Stage::FRAGMENT: writer.write(4); break;
1655         default: throw internal_error("unknown stage");
1656         }
1657         writer.write(func_id);
1658         writer.write_string(func.name);
1659
1660         set<Node *> dependencies = DependencyCollector().apply(func);
1661         for(set<Node *>::const_iterator i=dependencies.begin(); i!=dependencies.end(); ++i)
1662         {
1663                 if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(*i))
1664                 {
1665                         if(!var->interface.empty())
1666                                 writer.write(get_id(**i));
1667                 }
1668                 else if(dynamic_cast<InterfaceBlock *>(*i))
1669                         writer.write(get_id(**i));
1670         }
1671
1672         writer.end_op(OP_ENTRY_POINT);
1673
1674         if(stage->type==Stage::FRAGMENT)
1675                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_ORIGIN_LOWER_LEFT);
1676         else if(stage->type==Stage::GEOMETRY)
1677                 use_capability(CAP_GEOMETRY);
1678
1679         for(vector<const InterfaceLayout *>::const_iterator i=interface_layouts.begin(); i!=interface_layouts.end(); ++i)
1680         {
1681                 const vector<Layout::Qualifier> &qualifiers = (*i)->layout.qualifiers;
1682                 for(vector<Layout::Qualifier>::const_iterator j=qualifiers.begin(); j!=qualifiers.end(); ++j)
1683                 {
1684                         if(j->name=="point")
1685                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id,
1686                                         ((*i)->interface=="in" ? EXEC_INPUT_POINTS : EXEC_OUTPUT_POINTS));
1687                         else if(j->name=="lines")
1688                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES);
1689                         else if(j->name=="lines_adjacency")
1690                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES_ADJACENCY);
1691                         else if(j->name=="triangles")
1692                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_TRIANGLES);
1693                         else if(j->name=="triangles_adjacency")
1694                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_TRIANGLES_ADJACENCY);
1695                         else if(j->name=="line_strip")
1696                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_LINE_STRIP);
1697                         else if(j->name=="triangle_strip")
1698                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_TRIANGLE_STRIP);
1699                         else if(j->name=="max_vertices")
1700                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_VERTICES, j->value);
1701                 }
1702         }
1703 }
1704
1705 void SpirVGenerator::visit(FunctionDeclaration &func)
1706 {
1707         if(func.source==BUILTIN_SOURCE || func.definition!=&func)
1708                 return;
1709
1710         Id return_type_id = get_id(*func.return_type_declaration);
1711         vector<unsigned> param_type_ids;
1712         param_type_ids.reserve(func.parameters.size());
1713         for(NodeArray<VariableDeclaration>::const_iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
1714                 param_type_ids.push_back(get_variable_type_id(**i));
1715
1716         string sig_with_return = func.return_type+func.signature;
1717         Id &type_id = function_type_ids[sig_with_return];
1718         if(!type_id)
1719         {
1720                 type_id = next_id++;
1721                 writer.begin_op(content.globals, OP_TYPE_FUNCTION);
1722                 writer.write(type_id);
1723                 writer.write(return_type_id);
1724                 for(vector<unsigned>::const_iterator i=param_type_ids.begin(); i!=param_type_ids.end(); ++i)
1725                         writer.write(*i);
1726                 writer.end_op(OP_TYPE_FUNCTION);
1727
1728                 writer.write_op_name(type_id, sig_with_return);
1729         }
1730
1731         Id func_id = allocate_id(func, type_id);
1732         writer.write_op_name(func_id, func.name+func.signature);
1733
1734         if(func.name=="main")
1735                 visit_entry_point(func, func_id);
1736
1737         writer.begin_op(content.functions, OP_FUNCTION, 5);
1738         writer.write(return_type_id);
1739         writer.write(func_id);
1740         writer.write(0);  // Function control flags (none)
1741         writer.write(type_id);
1742         writer.end_op(OP_FUNCTION);
1743
1744         for(unsigned i=0; i<func.parameters.size(); ++i)
1745         {
1746                 Id param_id = allocate_id(*func.parameters[i], param_type_ids[i]);
1747                 writer.write_op(content.functions, OP_FUNCTION_PARAMETER, param_type_ids[i], param_id);
1748                 // TODO This is probably incorrect if the parameter is assigned to.
1749                 variable_load_ids[func.parameters[i].get()] = param_id;
1750         }
1751
1752         writer.begin_function_body(next_id++);
1753         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
1754         func.body.visit(*this);
1755
1756         if(writer.has_current_block())
1757         {
1758                 if(!reachable)
1759                         writer.write_op(content.function_body, OP_UNREACHABLE);
1760                 else
1761                 {
1762                         const BasicTypeDeclaration *basic_return = dynamic_cast<const BasicTypeDeclaration *>(func.return_type_declaration);
1763                         if(basic_return && basic_return->kind==BasicTypeDeclaration::VOID)
1764                                 writer.write_op(content.function_body, OP_RETURN);
1765                         else
1766                                 throw internal_error("missing return in non-void function");
1767                 }
1768         }
1769         writer.end_function_body();
1770         variable_load_ids.clear();
1771 }
1772
1773 void SpirVGenerator::visit(Conditional &cond)
1774 {
1775         cond.condition->visit(*this);
1776
1777         Id true_label_id = next_id++;
1778         Id merge_block_id = next_id++;
1779         Id false_label_id = (cond.else_body.body.empty() ? merge_block_id : next_id++);
1780         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1781         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, true_label_id, false_label_id);
1782
1783         writer.write_op_label(true_label_id);
1784         cond.body.visit(*this);
1785         if(writer.has_current_block())
1786                 writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1787
1788         bool reachable_if_true = reachable;
1789
1790         reachable = true;
1791         if(!cond.else_body.body.empty())
1792         {
1793                 writer.write_op_label(false_label_id);
1794                 cond.else_body.visit(*this);
1795                 reachable |= reachable_if_true;
1796         }
1797
1798         writer.write_op_label(merge_block_id);
1799         prune_loads(true_label_id);
1800 }
1801
1802 void SpirVGenerator::visit(Iteration &iter)
1803 {
1804         if(iter.init_statement)
1805                 iter.init_statement->visit(*this);
1806
1807         Id header_id = next_id++;
1808         Id continue_id = next_id++;
1809         Id merge_block_id = next_id++;
1810
1811         SetForScope<Id> set_merge(loop_merge_block_id, merge_block_id);
1812         SetForScope<Id> set_continue(loop_continue_target_id, continue_id);
1813
1814         writer.write_op_label(header_id);
1815         writer.write_op(content.function_body, OP_LOOP_MERGE, merge_block_id, continue_id, 0);  // Loop control (none)
1816
1817         Id body_id = next_id++;
1818         if(iter.condition)
1819         {
1820                 writer.write_op_label(next_id++);
1821                 iter.condition->visit(*this);
1822                 writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, body_id, merge_block_id);
1823         }
1824
1825         writer.write_op_label(body_id);
1826         iter.body.visit(*this);
1827
1828         writer.write_op_label(continue_id);
1829         if(iter.loop_expression)
1830                 iter.loop_expression->visit(*this);
1831         writer.write_op(content.function_body, OP_BRANCH, header_id);
1832
1833         writer.write_op_label(merge_block_id);
1834         prune_loads(header_id);
1835         reachable = true;
1836 }
1837
1838 void SpirVGenerator::visit(Return &ret)
1839 {
1840         if(ret.expression)
1841         {
1842                 ret.expression->visit(*this);
1843                 writer.write_op(content.function_body, OP_RETURN_VALUE, r_expression_result_id);
1844         }
1845         else
1846                 writer.write_op(content.function_body, OP_RETURN);
1847         reachable = false;
1848 }
1849
1850 void SpirVGenerator::visit(Jump &jump)
1851 {
1852         if(jump.keyword=="discard")
1853                 writer.write_op(content.function_body, OP_KILL);
1854         else if(jump.keyword=="break")
1855                 writer.write_op(content.function_body, OP_BRANCH, loop_merge_block_id);
1856         else if(jump.keyword=="continue")
1857                 writer.write_op(content.function_body, OP_BRANCH, loop_continue_target_id);
1858         else
1859                 throw internal_error("unknown jump");
1860         reachable = false;
1861 }
1862
1863
1864 SpirVGenerator::TypeKey::TypeKey(BasicTypeDeclaration::Kind kind, bool sign):
1865         type_id(0)
1866 {
1867         switch(kind)
1868         {
1869         case BasicTypeDeclaration::VOID: detail = 'v'; break;
1870         case BasicTypeDeclaration::BOOL: detail = 'b'; break;
1871         case BasicTypeDeclaration::INT: detail = (sign ? 'i' : 'u'); break;
1872         case BasicTypeDeclaration::FLOAT: detail = 'f'; break;
1873         default: throw invalid_argument("TypeKey::TypeKey");
1874         }
1875 }
1876
1877 bool SpirVGenerator::TypeKey::operator<(const TypeKey &other) const
1878 {
1879         if(type_id!=other.type_id)
1880                 return type_id<other.type_id;
1881         return detail<other.detail;
1882 }
1883
1884
1885 bool SpirVGenerator::ConstantKey::operator<(const ConstantKey &other) const
1886 {
1887         if(type_id!=other.type_id)
1888                 return type_id<other.type_id;
1889         return int_value<other.int_value;
1890 }
1891
1892 } // namespace SL
1893 } // namespace GL
1894 } // namespace Msp