]> git.tdb.fi Git - libs/gl.git/blob - source/glsl/spirv.cpp
7d6edf656a4097607e70b1a68208f3c523e9d212
[libs/gl.git] / source / glsl / spirv.cpp
1 #include <msp/core/maputils.h>
2 #include <msp/core/raii.h>
3 #include "reflect.h"
4 #include "spirv.h"
5
6 using namespace std;
7
8 namespace Msp {
9 namespace GL {
10 namespace SL {
11
12 const SpirVGenerator::BuiltinFunctionInfo SpirVGenerator::builtin_functions[] =
13 {
14         { "radians", "f", "GLSL.std.450", GLSL450_RADIANS, { 1 }, 0 },
15         { "degrees", "f", "GLSL.std.450", GLSL450_DEGREES, { 1 }, 0 },
16         { "sin", "f", "GLSL.std.450", GLSL450_SIN, { 1 }, 0 },
17         { "cos", "f", "GLSL.std.450", GLSL450_COS, { 1 }, 0 },
18         { "tan", "f", "GLSL.std.450", GLSL450_TAN, { 1 }, 0 },
19         { "asin", "f", "GLSL.std.450", GLSL450_ASIN, { 1 }, 0 },
20         { "acos", "f", "GLSL.std.450", GLSL450_ACOS, { 1 }, 0 },
21         { "atan", "f", "GLSL.std.450", GLSL450_ATAN, { 1 }, 0 },
22         { "atan", "ff", "GLSL.std.450", GLSL450_ATAN2, { 1, 2 }, 0 },
23         { "sinh", "f", "GLSL.std.450", GLSL450_SINH, { 1 }, 0 },
24         { "cosh", "f", "GLSL.std.450", GLSL450_COSH, { 1 }, 0 },
25         { "tanh", "f", "GLSL.std.450", GLSL450_TANH, { 1 }, 0 },
26         { "asinh", "f", "GLSL.std.450", GLSL450_ASINH, { 1 }, 0 },
27         { "acosh", "f", "GLSL.std.450", GLSL450_ACOSH, { 1 }, 0 },
28         { "atanh", "f", "GLSL.std.450", GLSL450_ATANH, { 1 }, 0 },
29         { "pow", "ff", "GLSL.std.450", GLSL450_POW, { 1, 2 }, 0 },
30         { "exp", "f", "GLSL.std.450", GLSL450_EXP, { 1 }, 0 },
31         { "log", "f", "GLSL.std.450", GLSL450_LOG, { 1 }, 0 },
32         { "exp2", "f", "GLSL.std.450", GLSL450_EXP2, { 1 }, 0 },
33         { "log2", "f", "GLSL.std.450", GLSL450_LOG2, { 1 }, 0 },
34         { "sqrt", "f", "GLSL.std.450", GLSL450_SQRT, { 1 }, 0 },
35         { "inversesqrt", "f", "GLSL.std.450", GLSL450_INVERSE_SQRT, { 1 }, 0 },
36         { "abs", "f", "GLSL.std.450", GLSL450_F_ABS, { 1 }, 0 },
37         { "abs", "i", "GLSL.std.450", GLSL450_S_ABS, { 1 }, 0 },
38         { "sign", "f", "GLSL.std.450", GLSL450_F_SIGN, { 1 }, 0 },
39         { "sign", "i", "GLSL.std.450", GLSL450_S_SIGN, { 1 }, 0 },
40         { "floor", "f", "GLSL.std.450", GLSL450_FLOOR, { 1 }, 0 },
41         { "trunc", "f", "GLSL.std.450", GLSL450_TRUNC, { 1 }, 0 },
42         { "round", "f", "GLSL.std.450", GLSL450_ROUND, { 1 }, 0 },
43         { "roundEven", "f", "GLSL.std.450", GLSL450_ROUND_EVEN, { 1 }, 0 },
44         { "ceil", "f", "GLSL.std.450", GLSL450_CEIL, { 1 }, 0 },
45         { "fract", "f", "GLSL.std.450", GLSL450_FRACT, { 1 }, 0 },
46         { "mod", "f", "", OP_F_MOD, { 1, 2 }, 0 },
47         { "min", "ff", "GLSL.std.450", GLSL450_F_MIN, { 1, 2 }, 0 },
48         { "min", "ii", "GLSL.std.450", GLSL450_S_MIN, { 1, 2 }, 0 },
49         { "min", "uu", "GLSL.std.450", GLSL450_U_MIN, { 1, 2 }, 0 },
50         { "max", "ff", "GLSL.std.450", GLSL450_F_MAX, { 1, 2 }, 0 },
51         { "max", "ii", "GLSL.std.450", GLSL450_S_MAX, { 1, 2 }, 0 },
52         { "max", "uu", "GLSL.std.450", GLSL450_U_MAX, { 1, 2 }, 0 },
53         { "clamp", "fff", "GLSL.std.450", GLSL450_F_CLAMP, { 1, 2, 3 }, 0 },
54         { "clamp", "iii", "GLSL.std.450", GLSL450_S_CLAMP, { 1, 2, 3 }, 0 },
55         { "clamp", "uuu", "GLSL.std.450", GLSL450_U_CLAMP, { 1, 2, 3 }, 0 },
56         { "mix", "fff", "GLSL.std.450", GLSL450_F_MIX, { 1, 2, 3 }, 0 },
57         { "mix", "ffb", "", OP_SELECT, { 3, 2, 1 }, 0 },
58         { "mix", "iib", "", OP_SELECT, { 3, 2, 1 }, 0 },
59         { "mix", "uub", "", OP_SELECT, { 3, 2, 1 }, 0 },
60         { "step", "ff", "GLSL.std.450", GLSL450_F_STEP, { 1, 2 }, 0 },
61         { "smoothstep", "fff", "GLSL.std.450", GLSL450_F_SMOOTH_STEP, { 1, 2, 3 }, 0 },
62         { "isnan", "f", "", OP_IS_NAN, { 1 }, 0 },
63         { "isinf", "f", "", OP_IS_INF, { 1 }, 0 },
64         { "fma", "fff", "GLSL.std.450", GLSL450_F_FMA, { 1, 2, 3 }, 0 },
65         { "length", "f", "GLSL.std.450", GLSL450_LENGTH, { 1 }, 0 },
66         { "distance", "ff", "GLSL.std.450", GLSL450_DISTANCE, { 1, 2 }, 0 },
67         { "dot", "ff", "", OP_DOT, { 1, 2 }, 0 },
68         { "cross", "ff", "GLSL.std.450", GLSL450_CROSS, { 1, 2 }, 0 },
69         { "normalize", "f", "GLSL.std.450", GLSL450_NORMALIZE, { 1 }, 0 },
70         { "faceforward", "fff", "GLSL.std.450", GLSL450_FACE_FORWARD, { 1, 2, 3 }, 0 },
71         { "reflect", "ff", "GLSL.std.450", GLSL450_REFLECT, { 1, 2 }, 0 },
72         { "refract", "fff", "GLSL.std.450", GLSL450_REFRACT, { 1, 2, 3 }, 0 },
73         { "matrixCompMult", "ff", "", 0, { 0 }, &SpirVGenerator::visit_builtin_matrix_comp_mult },
74         { "outerProduct", "ff", "", OP_OUTER_PRODUCT, { 1, 2 }, 0 },
75         { "transpose", "f", "", OP_TRANSPOSE, { 1 }, 0 },
76         { "determinant", "f", "GLSL.std.450", GLSL450_DETERMINANT, { 1 }, 0 },
77         { "inverse", "f", "GLSL.std.450", GLSL450_MATRIX_INVERSE, { 1 }, 0 },
78         { "lessThan", "ff", "", OP_F_ORD_LESS_THAN, { 1, 2 }, 0 },
79         { "lessThan", "ii", "", OP_S_LESS_THAN, { 1, 2 }, 0 },
80         { "lessThan", "uu", "", OP_U_LESS_THAN, { 1, 2 }, 0 },
81         { "lessThanEqual", "ff", "", OP_F_ORD_LESS_THAN_EQUAL, { 1, 2 }, 0 },
82         { "lessThanEqual", "ii", "", OP_S_LESS_THAN_EQUAL, { 1, 2 }, 0 },
83         { "lessThanEqual", "uu", "", OP_U_LESS_THAN_EQUAL, { 1, 2 }, 0 },
84         { "greaterThan", "ff", "", OP_F_ORD_GREATER_THAN, { 1, 2 }, 0 },
85         { "greaterThan", "ii", "", OP_S_GREATER_THAN, { 1, 2 }, 0 },
86         { "greaterThan", "uu", "", OP_U_GREATER_THAN, { 1, 2 }, 0 },
87         { "greaterThanEqual", "ff", "", OP_F_ORD_GREATER_THAN_EQUAL, { 1, 2 }, 0 },
88         { "greaterThanEqual", "ii", "", OP_S_GREATER_THAN_EQUAL, { 1, 2 }, 0 },
89         { "greaterThanEqual", "uu", "", OP_U_GREATER_THAN_EQUAL, { 1, 2 }, 0 },
90         { "equal", "ff", "", OP_F_ORD_EQUAL, { 1, 2 }, 0 },
91         { "equal", "ii", "", OP_I_EQUAL, { 1, 2 }, 0 },
92         { "equal", "uu", "", OP_I_EQUAL, { 1, 2 }, 0 },
93         { "notEqual", "ff", "", OP_F_ORD_NOT_EQUAL, { 1, 2 }, 0 },
94         { "notEqual", "ii", "", OP_I_NOT_EQUAL, { 1, 2 }, 0 },
95         { "notEqual", "uu", "", OP_I_NOT_EQUAL, { 1, 2 }, 0 },
96         { "any", "b", "", OP_ANY, { 1 }, 0 },
97         { "all", "b", "", OP_ALL, { 1 }, 0 },
98         { "not", "b", "", OP_LOGICAL_NOT, { 1 }, 0 },
99         { "bitfieldExtract", "iii", "", OP_BIT_FIELD_S_EXTRACT, { 1, 2, 3 }, 0 },
100         { "bitfieldExtract", "uii", "", OP_BIT_FIELD_U_EXTRACT, { 1, 2, 3 }, 0 },
101         { "bitfieldInsert", "iiii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0 },
102         { "bitfieldInsert", "uuii", "", OP_BIT_FIELD_INSERT, { 1, 2, 3, 4 }, 0 },
103         { "bitfieldReverse", "i", "", OP_BIT_REVERSE, { 1 }, 0 },
104         { "bitfieldReverse", "u", "", OP_BIT_REVERSE, { 1 }, 0 },
105         { "bitCount", "i", "", OP_BIT_COUNT, { 1 }, 0 },
106         { "findLSB", "i", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0 },
107         { "findLSB", "u", "GLSL.std.450", GLSL450_FIND_I_LSB, { 1 }, 0 },
108         { "findMSB", "i", "GLSL.std.450", GLSL450_FIND_S_MSB, { 1 }, 0 },
109         { "findMSB", "u", "GLSL.std.450", GLSL450_FIND_U_MSB, { 1 }, 0 },
110         { "textureSize", "", "", OP_IMAGE_QUERY_SIZE_LOD, { 1, 2 }, 0 },
111         { "texture", "", "", 0, { }, &SpirVGenerator::visit_builtin_texture },
112         { "textureLod", "", "", 0, { }, &SpirVGenerator::visit_builtin_texture },
113         { "texelFetch", "", "", 0, { }, &SpirVGenerator::visit_builtin_texel_fetch },
114         { "EmitVertex", "", "", OP_EMIT_VERTEX, { }, 0 },
115         { "EndPrimitive", "", "", OP_END_PRIMITIVE, { }, 0 },
116         { "dFdx", "f", "", OP_DP_DX, { 1 }, 0 },
117         { "dFdy", "f", "", OP_DP_DY, { 1 }, 0 },
118         { "dFdxFine", "f", "", OP_DP_DX_FINE, { 1 }, 0 },
119         { "dFdyFine", "f", "", OP_DP_DY_FINE, { 1 }, 0 },
120         { "dFdxCoarse", "f", "", OP_DP_DX_COARSE, { 1 }, 0 },
121         { "dFdyCoarse", "f", "", OP_DP_DY_COARSE, { 1 }, 0 },
122         { "fwidth", "f", "", OP_FWIDTH, { 1 }, 0 },
123         { "fwidthFine", "f", "", OP_FWIDTH_FINE, { 1 }, 0 },
124         { "fwidthCoarse", "f", "", OP_FWIDTH_COARSE, { 1 }, 0 },
125         { "interpolateAtCentroid", "", "", 0, { }, &SpirVGenerator::visit_builtin_interpolate },
126         { "interpolateAtSample", "", "", 0, { }, &SpirVGenerator::visit_builtin_interpolate },
127         { "interpolateAtOffset", "", "", 0, { }, &SpirVGenerator::visit_builtin_interpolate },
128         { "", "", "", 0, { }, 0 }
129 };
130
131 SpirVGenerator::SpirVGenerator():
132         stage(0),
133         current_function(0),
134         writer(content),
135         next_id(1),
136         r_expression_result_id(0),
137         constant_expression(false),
138         spec_constant(false),
139         reachable(false),
140         composite_access(false),
141         r_composite_base_id(0),
142         r_composite_base(0),
143         assignment_source_id(0),
144         loop_merge_block_id(0),
145         loop_continue_target_id(0)
146 { }
147
148 void SpirVGenerator::apply(Module &module)
149 {
150         use_capability(CAP_SHADER);
151
152         for(list<Stage>::iterator i=module.stages.begin(); i!=module.stages.end(); ++i)
153         {
154                 stage = &*i;
155                 interface_layouts.clear();
156                 i->content.visit(*this);
157         }
158
159         writer.finalize(SPIRV_GENERATOR_MSP, next_id);
160 }
161
162 SpirVGenerator::StorageClass SpirVGenerator::get_interface_storage(const string &iface, bool block)
163 {
164         if(iface=="in")
165                 return STORAGE_INPUT;
166         else if(iface=="out")
167                 return STORAGE_OUTPUT;
168         else if(iface=="uniform")
169                 return (block ? STORAGE_UNIFORM : STORAGE_UNIFORM_CONSTANT);
170         else if(iface.empty())
171                 return STORAGE_PRIVATE;
172         else
173                 throw invalid_argument("SpirVGenerator::get_interface_storage");
174 }
175
176 SpirVGenerator::BuiltinSemantic SpirVGenerator::get_builtin_semantic(const string &name)
177 {
178         if(name=="gl_Position")
179                 return BUILTIN_POSITION;
180         else if(name=="gl_PointSize")
181                 return BUILTIN_POINT_SIZE;
182         else if(name=="gl_ClipDistance")
183                 return BUILTIN_CLIP_DISTANCE;
184         else if(name=="gl_VertexID")
185                 return BUILTIN_VERTEX_ID;
186         else if(name=="gl_InstanceID")
187                 return BUILTIN_INSTANCE_ID;
188         else if(name=="gl_PrimitiveID" || name=="gl_PrimitiveIDIn")
189                 return BUILTIN_PRIMITIVE_ID;
190         else if(name=="gl_InvocationID")
191                 return BUILTIN_INVOCATION_ID;
192         else if(name=="gl_Layer")
193                 return BUILTIN_LAYER;
194         else if(name=="gl_FragCoord")
195                 return BUILTIN_FRAG_COORD;
196         else if(name=="gl_PointCoord")
197                 return BUILTIN_POINT_COORD;
198         else if(name=="gl_FrontFacing")
199                 return BUILTIN_FRONT_FACING;
200         else if(name=="gl_SampleId")
201                 return BUILTIN_SAMPLE_ID;
202         else if(name=="gl_SamplePosition")
203                 return BUILTIN_SAMPLE_POSITION;
204         else if(name=="gl_FragDepth")
205                 return BUILTIN_FRAG_DEPTH;
206         else
207                 throw invalid_argument("SpirVGenerator::get_builtin_semantic");
208 }
209
210 void SpirVGenerator::use_capability(Capability cap)
211 {
212         if(used_capabilities.count(cap))
213                 return;
214
215         used_capabilities.insert(cap);
216         writer.write_op(content.capabilities, OP_CAPABILITY, cap);
217 }
218
219 SpirVGenerator::Id SpirVGenerator::import_extension(const string &name)
220 {
221         Id &ext_id = imported_extension_ids[name];
222         if(!ext_id)
223         {
224                 ext_id = next_id++;
225                 writer.begin_op(content.extensions, OP_EXT_INST_IMPORT);
226                 writer.write(ext_id);
227                 writer.write_string(name);
228                 writer.end_op(OP_EXT_INST_IMPORT);
229         }
230         return ext_id;
231 }
232
233 SpirVGenerator::Id SpirVGenerator::get_id(Node &node) const
234 {
235         return get_item(declared_ids, &node).id;
236 }
237
238 SpirVGenerator::Id SpirVGenerator::allocate_id(Node &node, Id type_id)
239 {
240         Id id = next_id++;
241         insert_unique(declared_ids, &node, Declaration(id, type_id));
242         return id;
243 }
244
245 SpirVGenerator::Id SpirVGenerator::write_constant(Id type_id, Word value, bool spec)
246 {
247         Id const_id = next_id++;
248         if(is_scalar_type(type_id, BasicTypeDeclaration::BOOL))
249         {
250                 Opcode opcode = (value ? (spec ? OP_SPEC_CONSTANT_TRUE : OP_CONSTANT_TRUE) :
251                         (spec ? OP_SPEC_CONSTANT_FALSE : OP_CONSTANT_FALSE));
252                 writer.write_op(content.globals, opcode, type_id, const_id);
253         }
254         else
255         {
256                 Opcode opcode = (spec ? OP_SPEC_CONSTANT : OP_CONSTANT);
257                 writer.write_op(content.globals, opcode, type_id, const_id, value);
258         }
259         return const_id;
260 }
261
262 SpirVGenerator::ConstantKey SpirVGenerator::get_constant_key(Id type_id, const Variant &value)
263 {
264         if(value.check_type<bool>())
265                 return ConstantKey(type_id, value.value<bool>());
266         else if(value.check_type<int>())
267                 return ConstantKey(type_id, value.value<int>());
268         else if(value.check_type<unsigned>())
269                 return ConstantKey(type_id, value.value<unsigned>());
270         else if(value.check_type<float>())
271                 return ConstantKey(type_id, value.value<float>());
272         else
273                 throw invalid_argument("SpirVGenerator::get_constant_key");
274 }
275
276 SpirVGenerator::Id SpirVGenerator::get_constant_id(Id type_id, const Variant &value)
277 {
278         ConstantKey key = get_constant_key(type_id, value);
279         Id &const_id = constant_ids[key];
280         if(!const_id)
281                 const_id = write_constant(type_id, key.int_value, false);
282         return const_id;
283 }
284
285 SpirVGenerator::Id SpirVGenerator::get_vector_constant_id(Id type_id, unsigned size, Id scalar_id)
286 {
287         Id &const_id = constant_ids[get_constant_key(type_id, static_cast<int>(scalar_id))];
288         if(!const_id)
289         {
290                 const_id = next_id++;
291                 writer.begin_op(content.globals, OP_CONSTANT_COMPOSITE, 3+size);
292                 writer.write(type_id);
293                 writer.write(const_id);
294                 for(unsigned i=0; i<size; ++i)
295                         writer.write(scalar_id);
296                 writer.end_op(OP_CONSTANT_COMPOSITE);
297         }
298         return const_id;
299 }
300
301 SpirVGenerator::Id SpirVGenerator::get_standard_type_id(BasicTypeDeclaration::Kind kind, unsigned size, bool sign)
302 {
303         Id base_id = (size>1 ? get_standard_type_id(kind, 1, sign) : 0);
304         Id &type_id = standard_type_ids[base_id ? TypeKey(base_id, size) : TypeKey(kind, sign)];
305         if(!type_id)
306         {
307                 type_id = next_id++;
308                 if(size>1)
309                         writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, base_id, size);
310                 else if(kind==BasicTypeDeclaration::VOID)
311                         writer.write_op(content.globals, OP_TYPE_VOID, type_id);
312                 else if(kind==BasicTypeDeclaration::BOOL)
313                         writer.write_op(content.globals, OP_TYPE_BOOL, type_id);
314                 else if(kind==BasicTypeDeclaration::INT)
315                         writer.write_op(content.globals, OP_TYPE_INT, type_id, 32, sign);
316                 else if(kind==BasicTypeDeclaration::FLOAT)
317                         writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, 32);
318                 else
319                         throw invalid_argument("SpirVGenerator::get_standard_type_id");
320         }
321         return type_id;
322 }
323
324 bool SpirVGenerator::is_scalar_type(Id type_id, BasicTypeDeclaration::Kind kind) const
325 {
326         map<TypeKey, Id>::const_iterator i = standard_type_ids.find(TypeKey(kind, true));
327         return (i!=standard_type_ids.end() && i->second==type_id);
328 }
329
330 SpirVGenerator::Id SpirVGenerator::get_array_type_id(TypeDeclaration &base_type, Id size_id)
331 {
332         Id base_type_id = get_id(base_type);
333         Id &array_type_id = array_type_ids[TypeKey(base_type_id, size_id)];
334         if(!array_type_id)
335         {
336                 array_type_id = next_id++;
337                 if(size_id)
338                         writer.write_op(content.globals, OP_TYPE_ARRAY, array_type_id, base_type_id, size_id);
339                 else
340                         writer.write_op(content.globals, OP_TYPE_RUNTIME_ARRAY, array_type_id, base_type_id);
341
342                 unsigned stride = MemoryRequirementsCalculator().apply(base_type).stride;
343                 writer.write_op_decorate(array_type_id, DECO_ARRAY_STRIDE, stride);
344         }
345
346         return array_type_id;
347 }
348
349 SpirVGenerator::Id SpirVGenerator::get_pointer_type_id(Id type_id, StorageClass storage)
350 {
351         Id &ptr_type_id = pointer_type_ids[TypeKey(type_id, storage)];
352         if(!ptr_type_id)
353         {
354                 ptr_type_id = next_id++;
355                 writer.write_op(content.globals, OP_TYPE_POINTER, ptr_type_id, storage, type_id);
356         }
357         return ptr_type_id;
358 }
359
360 SpirVGenerator::Id SpirVGenerator::get_variable_type_id(const VariableDeclaration &var)
361 {
362         if(const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var.type_declaration))
363                 if(basic->kind==BasicTypeDeclaration::ARRAY)
364                 {
365                         Id size_id = 0;
366                         if(var.array_size)
367                         {
368                                 SetFlag set_const(constant_expression);
369                                 r_expression_result_id = 0;
370                                 var.array_size->visit(*this);
371                                 size_id = r_expression_result_id;
372                         }
373                         else
374                                 size_id = get_constant_id(get_standard_type_id(BasicTypeDeclaration::INT, 1), 1);
375                         return get_array_type_id(*basic->base_type, size_id);
376                 }
377
378         return get_id(*var.type_declaration);
379 }
380
381 SpirVGenerator::Id SpirVGenerator::get_load_id(VariableDeclaration &var)
382 {
383         Id &load_result_id = variable_load_ids[&var];
384         if(!load_result_id)
385         {
386                 load_result_id = next_id++;
387                 writer.write_op(content.function_body, OP_LOAD, get_variable_type_id(var), load_result_id, get_id(var));
388         }
389         return load_result_id;
390 }
391
392 void SpirVGenerator::prune_loads(Id min_id)
393 {
394         for(map<const VariableDeclaration *, Id>::iterator i=variable_load_ids.begin(); i!=variable_load_ids.end(); )
395         {
396                 if(i->second>=min_id)
397                         variable_load_ids.erase(i++);
398                 else
399                         ++i;
400         }
401 }
402
403 SpirVGenerator::Id SpirVGenerator::begin_expression(Opcode opcode, Id type_id, unsigned n_args)
404 {
405         bool has_result = (opcode==OP_FUNCTION_CALL || !is_scalar_type(type_id, BasicTypeDeclaration::VOID));
406         if(!constant_expression)
407         {
408                 if(!current_function)
409                         throw internal_error("non-constant expression outside a function");
410
411                 writer.begin_op(content.function_body, opcode, (n_args ? 1+has_result*2+n_args : 0));
412         }
413         else if(opcode==OP_COMPOSITE_CONSTRUCT)
414                 writer.begin_op(content.globals, (spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE),
415                         (n_args ? 1+has_result*2+n_args : 0));
416         else if(!spec_constant)
417                 throw internal_error("invalid non-specialization constant expression");
418         else
419                 writer.begin_op(content.globals, OP_SPEC_CONSTANT_OP, (n_args ? 2+has_result*2+n_args : 0));
420
421         Id result_id = next_id++;
422         if(has_result)
423         {
424                 writer.write(type_id);
425                 writer.write(result_id);
426         }
427         if(spec_constant && opcode!=OP_COMPOSITE_CONSTRUCT)
428                 writer.write(opcode);
429
430         return result_id;
431 }
432
433 void SpirVGenerator::end_expression(Opcode opcode)
434 {
435         if(constant_expression)
436                 opcode = (opcode==OP_COMPOSITE_CONSTRUCT ? spec_constant ? OP_SPEC_CONSTANT_COMPOSITE : OP_CONSTANT_COMPOSITE : OP_SPEC_CONSTANT_OP);
437         writer.end_op(opcode);
438 }
439
440 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id arg_id)
441 {
442         Id result_id = begin_expression(opcode, type_id, 1);
443         writer.write(arg_id);
444         end_expression(opcode);
445         return result_id;
446 }
447
448 SpirVGenerator::Id SpirVGenerator::write_expression(Opcode opcode, Id type_id, Id left_id, Id right_id)
449 {
450         Id result_id = begin_expression(opcode, type_id, 2);
451         writer.write(left_id);
452         writer.write(right_id);
453         end_expression(opcode);
454         return result_id;
455 }
456
457 void SpirVGenerator::write_deconstruct(Id elem_type_id, Id composite_id, Id *elem_ids, unsigned n_elems)
458 {
459         for(unsigned i=0; i<n_elems; ++i)
460         {
461                 elem_ids[i] = begin_expression(OP_COMPOSITE_EXTRACT, elem_type_id, 2);
462                 writer.write(composite_id);
463                 writer.write(i);
464                 end_expression(OP_COMPOSITE_EXTRACT);
465         }
466 }
467
468 SpirVGenerator::Id SpirVGenerator::write_construct(Id type_id, const Id *elem_ids, unsigned n_elems)
469 {
470         Id result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, type_id, n_elems);
471         for(unsigned i=0; i<n_elems; ++i)
472                 writer.write(elem_ids[i]);
473         end_expression(OP_COMPOSITE_CONSTRUCT);
474
475         return result_id;
476 }
477
478 void SpirVGenerator::visit(Block &block)
479 {
480         for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
481                 (*i)->visit(*this);
482 }
483
484 void SpirVGenerator::visit(Literal &literal)
485 {
486         Id type_id = get_id(*literal.type);
487         if(spec_constant)
488                 r_expression_result_id = write_constant(type_id, get_constant_key(type_id, literal.value).int_value, true);
489         else
490                 r_expression_result_id = get_constant_id(type_id, literal.value);
491         r_constant_result = true;
492 }
493
494 void SpirVGenerator::visit(VariableReference &var)
495 {
496         if(constant_expression || var.declaration->constant)
497         {
498                 if(!var.declaration->constant)
499                         throw internal_error("reference to non-constant variable in constant context");
500
501                 r_expression_result_id = get_id(*var.declaration);
502                 r_constant_result = true;
503                 return;
504         }
505         else if(!current_function)
506                 throw internal_error("non-constant context outside a function");
507
508         r_constant_result = false;
509         if(composite_access)
510         {
511                 r_composite_base = var.declaration;
512                 r_expression_result_id = 0;
513         }
514         else if(assignment_source_id)
515         {
516                 writer.write_op(content.function_body, OP_STORE, get_id(*var.declaration), assignment_source_id);
517                 variable_load_ids[var.declaration] = assignment_source_id;
518                 r_expression_result_id = assignment_source_id;
519         }
520         else
521                 r_expression_result_id = get_load_id(*var.declaration);
522 }
523
524 void SpirVGenerator::visit(InterfaceBlockReference &iface)
525 {
526         if(!composite_access || !current_function)
527                 throw internal_error("invalid interface block reference");
528
529         r_composite_base = iface.declaration;
530         r_expression_result_id = 0;
531         r_constant_result = false;
532 }
533
534 void SpirVGenerator::generate_composite_access(TypeDeclaration &result_type)
535 {
536         Opcode opcode;
537         Id result_type_id = get_id(result_type);
538         Id access_type_id = result_type_id;
539         if(r_composite_base)
540         {
541                 if(constant_expression)
542                         throw internal_error("composite access through pointer in constant context");
543
544                 Id int32_type_id = get_standard_type_id(BasicTypeDeclaration::INT, 1);
545                 for(vector<unsigned>::iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
546                         *i = (*i<0x400000 ? get_constant_id(int32_type_id, static_cast<int>(*i)) : *i&0x3FFFFF);
547
548                 /* Find the storage class of the base and obtain appropriate pointer type
549                 for the result. */
550                 const Declaration &base_decl = get_item(declared_ids, r_composite_base);
551                 map<TypeKey, Id>::const_iterator i = pointer_type_ids.begin();
552                 for(; (i!=pointer_type_ids.end() && i->second!=base_decl.type_id); ++i) ;
553                 if(i==pointer_type_ids.end())
554                         throw internal_error("could not find storage class");
555                 access_type_id = get_pointer_type_id(result_type_id, static_cast<StorageClass>(i->first.detail));
556
557                 opcode = OP_ACCESS_CHAIN;
558         }
559         else if(assignment_source_id)
560                 throw internal_error("assignment to temporary composite");
561         else
562         {
563                 for(vector<unsigned>::iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
564                         for(map<ConstantKey, Id>::iterator j=constant_ids.begin(); (*i>=0x400000 && j!=constant_ids.end()); ++j)
565                                 if(j->second==(*i&0x3FFFFF))
566                                         *i = j->first.int_value;
567
568                 opcode = OP_COMPOSITE_EXTRACT;
569         }
570
571         Id access_id = begin_expression(opcode, access_type_id, 1+r_composite_chain.size());
572         writer.write(r_composite_base_id);
573         for(vector<unsigned>::const_iterator i=r_composite_chain.begin(); i!=r_composite_chain.end(); ++i)
574                 writer.write(*i);
575         end_expression(opcode);
576
577         r_constant_result = false;
578         if(r_composite_base)
579         {
580                 if(assignment_source_id)
581                 {
582                         writer.write_op(content.function_body, OP_STORE, access_id, assignment_source_id);
583                         r_expression_result_id = assignment_source_id;
584                 }
585                 else
586                         r_expression_result_id = write_expression(OP_LOAD, result_type_id, access_id);
587         }
588         else
589                 r_expression_result_id = access_id;
590 }
591
592 void SpirVGenerator::visit_composite(Expression &base_expr, unsigned index, TypeDeclaration &type)
593 {
594         if(!composite_access)
595         {
596                 r_composite_base = 0;
597                 r_composite_base_id = 0;
598                 r_composite_chain.clear();
599         }
600
601         {
602                 SetFlag set_composite(composite_access);
603                 base_expr.visit(*this);
604         }
605
606         if(!r_composite_base_id)
607                 r_composite_base_id = (r_composite_base ? get_id(*r_composite_base) : r_expression_result_id);
608
609         r_composite_chain.push_back(index);
610         if(!composite_access)
611                 generate_composite_access(type);
612         else
613                 r_expression_result_id = 0;
614 }
615
616 void SpirVGenerator::visit_isolated(Expression &expr)
617 {
618         SetForScope<Id> clear_assign(assignment_source_id, 0);
619         SetFlag clear_composite(composite_access, false);
620         SetForScope<Node *> clear_base(r_composite_base, 0);
621         SetForScope<Id> clear_base_id(r_composite_base_id, 0);
622         vector<unsigned> saved_chain;
623         swap(saved_chain, r_composite_chain);
624         expr.visit(*this);
625         swap(saved_chain, r_composite_chain);
626 }
627
628 void SpirVGenerator::visit(MemberAccess &memacc)
629 {
630         visit_composite(*memacc.left, memacc.index, *memacc.type);
631 }
632
633 void SpirVGenerator::visit(Swizzle &swizzle)
634 {
635         if(swizzle.count==1)
636                 visit_composite(*swizzle.left, swizzle.components[0], *swizzle.type);
637         else if(assignment_source_id)
638         {
639                 const BasicTypeDeclaration &basic = dynamic_cast<const BasicTypeDeclaration &>(*swizzle.left->type);
640
641                 unsigned mask = 0;
642                 for(unsigned i=0; i<swizzle.count; ++i)
643                         mask |= 1<<swizzle.components[i];
644
645                 visit_isolated(*swizzle.left);
646
647                 Id combined_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.left->type), 2+basic.size);
648                 writer.write(r_expression_result_id);
649                 writer.write(assignment_source_id);
650                 for(unsigned i=0; i<basic.size; ++i)
651                         writer.write(i+((mask>>i)&1)*basic.size);
652                 end_expression(OP_VECTOR_SHUFFLE);
653
654                 SetForScope<Id> set_assign(assignment_source_id, combined_id);
655                 swizzle.left->visit(*this);
656
657                 r_expression_result_id = combined_id;
658         }
659         else
660         {
661                 swizzle.left->visit(*this);
662                 Id left_id = r_expression_result_id;
663
664                 r_expression_result_id = begin_expression(OP_VECTOR_SHUFFLE, get_id(*swizzle.type), 2+swizzle.count);
665                 writer.write(left_id);
666                 writer.write(left_id);
667                 for(unsigned i=0; i<swizzle.count; ++i)
668                         writer.write(swizzle.components[i]);
669                 end_expression(OP_VECTOR_SHUFFLE);
670         }
671         r_constant_result = false;
672 }
673
674 void SpirVGenerator::visit(UnaryExpression &unary)
675 {
676         unary.expression->visit(*this);
677
678         char oper = unary.oper->token[0];
679         char oper2 = unary.oper->token[1];
680         if(oper=='+' && !oper2)
681                 return;
682
683         BasicTypeDeclaration &basic = dynamic_cast<BasicTypeDeclaration &>(*unary.expression->type);
684         BasicTypeDeclaration &elem = *get_element_type(basic);
685
686         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
687                 /* SPIR-V allows constant operations on floating-point values only for
688                 OpenGL kernels. */
689                 throw internal_error("invalid operands for constant unary expression");
690
691         Id result_type_id = get_id(*unary.type);
692         Opcode opcode = OP_NOP;
693
694         r_constant_result = false;
695         if(oper=='!')
696                 opcode = OP_LOGICAL_NOT;
697         else if(oper=='~')
698                 opcode = OP_NOT;
699         else if(oper=='-' && !oper2)
700         {
701                 opcode = (elem.kind==BasicTypeDeclaration::INT ? OP_S_NEGATE : OP_F_NEGATE);
702
703                 if(basic.kind==BasicTypeDeclaration::MATRIX)
704                 {
705                         Id column_type_id = get_id(*basic.base_type);
706                         unsigned n_columns = basic.size&0xFFFF;
707                         Id column_ids[4];
708                         write_deconstruct(column_type_id, r_expression_result_id, column_ids, n_columns);
709                         for(unsigned i=0; i<n_columns; ++i)
710                                 column_ids[i] = write_expression(opcode, column_type_id, column_ids[i]);
711                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
712                         return;
713                 }
714         }
715         else if((oper=='+' || oper=='-') && oper2==oper)
716         {
717                 if(constant_expression)
718                         throw internal_error("increment or decrement in constant expression");
719
720                 Id one_id = 0;
721                 if(elem.kind==BasicTypeDeclaration::INT)
722                 {
723                         opcode = (oper=='+' ? OP_I_ADD : OP_I_SUB);
724                         one_id = get_constant_id(get_id(elem), 1);
725                 }
726                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
727                 {
728                         opcode = (oper=='+' ? OP_F_ADD : OP_F_SUB);
729                         one_id = get_constant_id(get_id(elem), 1.0f);
730                 }
731                 else
732                         throw internal_error("invalid increment/decrement");
733
734                 if(basic.kind==BasicTypeDeclaration::VECTOR)
735                         one_id = get_vector_constant_id(result_type_id, basic.size, one_id);
736
737                 Id post_value_id = write_expression(opcode, result_type_id, r_expression_result_id, one_id);
738
739                 SetForScope<Id> set_assign(assignment_source_id, post_value_id);
740                 unary.expression->visit(*this);
741
742                 r_expression_result_id = (unary.oper->type==Operator::POSTFIX ? r_expression_result_id : post_value_id);
743                 return;
744         }
745
746         if(opcode==OP_NOP)
747                 throw internal_error("unknown unary operator");
748
749         r_expression_result_id = write_expression(opcode, result_type_id, r_expression_result_id);
750 }
751
752 void SpirVGenerator::visit(BinaryExpression &binary)
753 {
754         char oper = binary.oper->token[0];
755         if(oper=='[')
756         {
757                 visit_isolated(*binary.right);
758                 return visit_composite(*binary.left, 0x400000|r_expression_result_id, *binary.type);
759         }
760
761         if(assignment_source_id)
762                 throw internal_error("invalid binary expression in assignment target");
763
764         BasicTypeDeclaration &basic_left = dynamic_cast<BasicTypeDeclaration &>(*binary.left->type);
765         BasicTypeDeclaration &basic_right = dynamic_cast<BasicTypeDeclaration &>(*binary.right->type);
766         // Expression resolver ensures that element types are the same
767         BasicTypeDeclaration &elem = *get_element_type(basic_left);
768
769         if(constant_expression && elem.kind!=BasicTypeDeclaration::BOOL && elem.kind!=BasicTypeDeclaration::INT)
770                 /* SPIR-V allows constant operations on floating-point values only for
771                 OpenGL kernels. */
772                 throw internal_error("invalid operands for constant binary expression");
773
774         binary.left->visit(*this);
775         Id left_id = r_expression_result_id;
776         binary.right->visit(*this);
777         Id right_id = r_expression_result_id;
778
779         Id result_type_id = get_id(*binary.type);
780         Opcode opcode = OP_NOP;
781         bool swap_operands = false;
782
783         r_constant_result = false;
784
785         char oper2 = binary.oper->token[1];
786         if((oper=='<' || oper=='>') && oper2!=oper)
787         {
788                 if(basic_left.kind==BasicTypeDeclaration::INT)
789                 {
790                         if(basic_left.sign)
791                                 opcode = (oper=='<' ? (oper2=='=' ? OP_S_LESS_THAN_EQUAL : OP_S_LESS_THAN) :
792                                         (oper2=='=' ? OP_S_GREATER_THAN_EQUAL : OP_S_GREATER_THAN));
793                         else
794                                 opcode = (oper=='<' ? (oper2=='=' ? OP_U_LESS_THAN_EQUAL : OP_U_LESS_THAN) :
795                                         (oper2=='=' ? OP_U_GREATER_THAN_EQUAL : OP_U_GREATER_THAN));
796                 }
797                 else if(basic_left.kind==BasicTypeDeclaration::FLOAT)
798                         opcode = (oper=='<' ? (oper2=='=' ? OP_F_ORD_LESS_THAN_EQUAL : OP_F_ORD_LESS_THAN) :
799                                 (oper2=='=' ? OP_F_ORD_GREATER_THAN_EQUAL : OP_F_ORD_GREATER_THAN));
800         }
801         else if((oper=='=' || oper=='!') && oper2=='=')
802         {
803                 if(elem.kind==BasicTypeDeclaration::BOOL)
804                         opcode = (oper=='=' ? OP_LOGICAL_EQUAL : OP_LOGICAL_NOT_EQUAL);
805                 else if(elem.kind==BasicTypeDeclaration::INT)
806                         opcode = (oper=='=' ? OP_I_EQUAL : OP_I_NOT_EQUAL);
807                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
808                         opcode = (oper=='=' ? OP_F_ORD_EQUAL : OP_F_ORD_NOT_EQUAL);
809
810                 if(opcode!=OP_NOP && basic_left.base_type)
811                 {
812                         /* The SPIR-V equality operations produce component-wise results, but
813                         GLSL operators return a single boolean.  Use the any/all operations to
814                         combine the results. */
815                         Opcode combine_op = (oper=='!' ? OP_ANY : OP_ALL);
816                         unsigned n_elems = basic_left.size&0xFFFF;
817                         Id bool_vec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, n_elems);
818
819                         Id compare_id = 0;
820                         if(basic_left.kind==BasicTypeDeclaration::VECTOR)
821                                 compare_id = write_expression(opcode, bool_vec_type_id, left_id, right_id);
822                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX)
823                         {
824                                 Id column_type_id = get_id(*basic_left.base_type);
825                                 Id column_ids[8];
826                                 write_deconstruct(column_type_id, left_id, column_ids, n_elems);
827                                 write_deconstruct(column_type_id, right_id, column_ids+4, n_elems);
828
829                                 Id column_bvec_type_id = get_standard_type_id(BasicTypeDeclaration::BOOL, basic_left.size>>16);
830                                 for(unsigned i=0; i<n_elems; ++i)
831                                 {
832                                         compare_id = write_expression(opcode, column_bvec_type_id, column_ids[i], column_ids[4+i]);
833                                         column_ids[i] = write_expression(combine_op, result_type_id, compare_id);;
834                                 }
835
836                                 compare_id = write_construct(bool_vec_type_id, column_ids, n_elems);
837                         }
838
839                         if(compare_id)
840                                 r_expression_result_id = write_expression(combine_op, result_type_id, compare_id);
841                         return;
842                 }
843         }
844         else if(oper2=='&' && elem.kind==BasicTypeDeclaration::BOOL)
845                 opcode = OP_LOGICAL_AND;
846         else if(oper2=='|' && elem.kind==BasicTypeDeclaration::BOOL)
847                 opcode = OP_LOGICAL_OR;
848         else if(oper2=='^' && elem.kind==BasicTypeDeclaration::BOOL)
849                 opcode = OP_LOGICAL_NOT_EQUAL;
850         else if(oper=='&' && elem.kind==BasicTypeDeclaration::INT)
851                 opcode = OP_BITWISE_AND;
852         else if(oper=='|' && elem.kind==BasicTypeDeclaration::INT)
853                 opcode = OP_BITWISE_OR;
854         else if(oper=='^' && elem.kind==BasicTypeDeclaration::INT)
855                 opcode = OP_BITWISE_XOR;
856         else if(oper=='<' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
857                 opcode = OP_SHIFT_LEFT_LOGICAL;
858         else if(oper=='>' && oper2==oper && elem.kind==BasicTypeDeclaration::INT)
859                 opcode = OP_SHIFT_RIGHT_ARITHMETIC;
860         else if(oper=='%' && elem.kind==BasicTypeDeclaration::INT)
861                 opcode = (elem.sign ? OP_S_MOD : OP_U_MOD);
862         else if(oper=='+' || oper=='-' || oper=='*' || oper=='/')
863         {
864                 Opcode elem_op = OP_NOP;
865                 if(elem.kind==BasicTypeDeclaration::INT)
866                 {
867                         if(oper=='/')
868                                 elem_op = (elem.sign ? OP_S_DIV : OP_U_DIV);
869                         else
870                                 elem_op = (oper=='+' ? OP_I_ADD : oper=='-' ? OP_I_SUB : OP_I_MUL);
871                 }
872                 else if(elem.kind==BasicTypeDeclaration::FLOAT)
873                         elem_op = (oper=='+' ? OP_F_ADD : oper=='-' ? OP_F_SUB : oper=='*' ? OP_F_MUL : OP_F_DIV);
874
875                 if(oper=='*' && (basic_left.base_type || basic_right.base_type) && elem.kind==BasicTypeDeclaration::FLOAT)
876                 {
877                         /* Multiplication between floating-point vectors and matrices has
878                         dedicated operations. */
879                         if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
880                                 opcode = OP_MATRIX_TIMES_MATRIX;
881                         else if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
882                         {
883                                 if(basic_left.kind==BasicTypeDeclaration::VECTOR)
884                                         opcode = OP_VECTOR_TIMES_MATRIX;
885                                 else if(basic_right.kind==BasicTypeDeclaration::VECTOR)
886                                         opcode = OP_MATRIX_TIMES_VECTOR;
887                                 else
888                                 {
889                                         opcode = OP_MATRIX_TIMES_SCALAR;
890                                         swap_operands = (basic_right.kind==BasicTypeDeclaration::MATRIX);
891                                 }
892                         }
893                         else if(basic_left.kind==BasicTypeDeclaration::VECTOR && basic_right.kind==BasicTypeDeclaration::VECTOR)
894                                 opcode = elem_op;
895                         else
896                         {
897                                 opcode = OP_VECTOR_TIMES_SCALAR;
898                                 swap_operands = (basic_right.kind==BasicTypeDeclaration::VECTOR);
899                         }
900                 }
901                 else if((basic_left.base_type!=0)!=(basic_right.base_type!=0))
902                 {
903                         /* One operand is scalar and the other is a vector or a matrix.
904                         Expand the scalar to a vector of appropriate size. */
905                         Id &scalar_id = (basic_left.base_type ? right_id : left_id);
906                         BasicTypeDeclaration *vector_type = (basic_left.base_type ? &basic_left : &basic_right);
907                         if(vector_type->kind==BasicTypeDeclaration::MATRIX)
908                                 vector_type = dynamic_cast<BasicTypeDeclaration *>(vector_type->base_type);
909                         Id vector_type_id = get_id(*vector_type);
910
911                         Id expanded_id = begin_expression(OP_COMPOSITE_CONSTRUCT, vector_type_id, vector_type->size);
912                         for(unsigned i=0; i<vector_type->size; ++i)
913                                 writer.write(scalar_id);
914                         end_expression(OP_COMPOSITE_CONSTRUCT);
915
916                         scalar_id = expanded_id;
917
918                         if(basic_left.kind==BasicTypeDeclaration::MATRIX || basic_right.kind==BasicTypeDeclaration::MATRIX)
919                         {
920                                 // Apply matrix operation column-wise.
921                                 Id matrix_id = (basic_left.base_type ? left_id : right_id);
922
923                                 Id column_ids[4];
924                                 unsigned n_columns = (basic_left.base_type ? basic_left.size : basic_right.size)&0xFFFF;
925                                 write_deconstruct(vector_type_id, matrix_id, column_ids, n_columns);
926
927                                 for(unsigned i=0; i<n_columns; ++i)
928                                         column_ids[i] = write_expression(elem_op, vector_type_id, column_ids[i], expanded_id);
929
930                                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
931                                 return;
932                         }
933                         else
934                                 opcode = elem_op;
935                 }
936                 else if(basic_left.kind==BasicTypeDeclaration::MATRIX && basic_right.kind==BasicTypeDeclaration::MATRIX)
937                 {
938                         if(oper=='*')
939                                 throw internal_error("non-float matrix multiplication");
940
941                         /* Other operations involving matrices need to be performed
942                         column-wise. */
943                         Id column_type_id = get_id(*basic_left.base_type);
944                         Id column_ids[8];
945
946                         unsigned n_columns = basic_left.size&0xFFFF;
947                         write_deconstruct(column_type_id, left_id, column_ids, n_columns);
948                         write_deconstruct(column_type_id, right_id, column_ids+4, n_columns);
949
950                         for(unsigned i=0; i<n_columns; ++i)
951                                 column_ids[i] = write_expression(elem_op, column_type_id, column_ids[i], column_ids[4+i]);
952
953                         r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
954                         return;
955                 }
956                 else if(basic_left.kind==basic_right.kind)
957                         // Both operands are either scalars or vectors.
958                         opcode = elem_op;
959         }
960
961         if(opcode==OP_NOP)
962                 throw internal_error("unknown binary operator");
963
964         if(swap_operands)
965                 swap(left_id, right_id);
966
967         r_expression_result_id = write_expression(opcode, result_type_id, left_id, right_id);
968 }
969
970 void SpirVGenerator::visit(Assignment &assign)
971 {
972         if(assign.oper->token[0]!='=')
973                 visit(static_cast<BinaryExpression &>(assign));
974         else
975                 assign.right->visit(*this);
976
977         SetForScope<Id> set_assign(assignment_source_id, r_expression_result_id);
978         assign.left->visit(*this);
979         r_constant_result = false;
980 }
981
982 void SpirVGenerator::visit(TernaryExpression &ternary)
983 {
984         if(constant_expression)
985         {
986                 ternary.condition->visit(*this);
987                 Id condition_id = r_expression_result_id;
988                 ternary.true_expr->visit(*this);
989                 Id true_result_id = r_expression_result_id;
990                 ternary.false_expr->visit(*this);
991                 Id false_result_id = r_expression_result_id;
992
993                 r_expression_result_id = begin_expression(OP_SELECT, get_id(*ternary.type), 3);
994                 writer.write(condition_id);
995                 writer.write(true_result_id);
996                 writer.write(false_result_id);
997                 end_expression(OP_SELECT);
998
999                 return;
1000         }
1001
1002         ternary.condition->visit(*this);
1003         Id condition_id = r_expression_result_id;
1004
1005         Id true_label_id = next_id++;
1006         Id false_label_id = next_id++;
1007         Id merge_block_id = next_id++;
1008         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1009         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, condition_id, true_label_id, false_label_id);
1010
1011         writer.write_op_label(true_label_id);
1012         ternary.true_expr->visit(*this);
1013         Id true_result_id = r_expression_result_id;
1014         writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1015
1016         writer.write_op_label(false_label_id);
1017         ternary.false_expr->visit(*this);
1018         Id false_result_id = r_expression_result_id;
1019
1020         writer.write_op_label(merge_block_id);
1021         r_expression_result_id = begin_expression(OP_PHI, get_id(*ternary.type), 4);
1022         writer.write(true_result_id);
1023         writer.write(true_label_id);
1024         writer.write(false_result_id);
1025         writer.write(false_label_id);
1026         end_expression(OP_PHI);
1027
1028         r_constant_result = false;
1029 }
1030
1031 void SpirVGenerator::visit(FunctionCall &call)
1032 {
1033         if(assignment_source_id)
1034                 throw internal_error("assignment to function call");
1035         else if(composite_access)
1036                 return visit_isolated(call);
1037         else if(call.constructor && call.arguments.size()==1 && call.arguments[0]->type==call.type)
1038                 return call.arguments[0]->visit(*this);
1039
1040         vector<Id> argument_ids;
1041         argument_ids.reserve(call.arguments.size());
1042         bool all_args_const = true;
1043         for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
1044         {
1045                 (*i)->visit(*this);
1046                 argument_ids.push_back(r_expression_result_id);
1047                 all_args_const &= r_constant_result;
1048         }
1049
1050         if(constant_expression && (!call.constructor || !all_args_const))
1051                 throw internal_error("function call in constant expression");
1052
1053         Id result_type_id = get_id(*call.type);
1054         r_constant_result = false;
1055
1056         if(call.constructor)
1057                 visit_constructor(call, argument_ids, all_args_const);
1058         else if(call.declaration->source==BUILTIN_SOURCE)
1059         {
1060                 string arg_types;
1061                 for(NodeArray<Expression>::const_iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
1062                         if(BasicTypeDeclaration *basic_arg = dynamic_cast<BasicTypeDeclaration *>((*i)->type))
1063                         {
1064                                 BasicTypeDeclaration &elem_arg = *get_element_type(*basic_arg);
1065                                 switch(elem_arg.kind)
1066                                 {
1067                                 case BasicTypeDeclaration::BOOL: arg_types += 'b'; break;
1068                                 case BasicTypeDeclaration::INT: arg_types += (elem_arg.sign ? 'i' : 'u'); break;
1069                                 case BasicTypeDeclaration::FLOAT: arg_types += 'f'; break;
1070                                 default: arg_types += '?';
1071                                 }
1072                         }
1073
1074                 const BuiltinFunctionInfo *builtin_info;
1075                 for(builtin_info=builtin_functions; builtin_info->function[0]; ++builtin_info)
1076                         if(builtin_info->function==call.name && (!builtin_info->arg_types[0] || builtin_info->arg_types==arg_types))
1077                                 break;
1078
1079                 if(builtin_info->opcode)
1080                 {
1081                         Opcode opcode;
1082                         if(builtin_info->extension[0])
1083                         {
1084                                 opcode = OP_EXT_INST;
1085                                 Id ext_id = import_extension(builtin_info->extension);
1086
1087                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1088                                 writer.write(ext_id);
1089                                 writer.write(builtin_info->opcode);
1090                         }
1091                         else
1092                         {
1093                                 opcode = static_cast<Opcode>(builtin_info->opcode);
1094                                 r_expression_result_id = begin_expression(opcode, result_type_id);
1095                         }
1096
1097                         for(unsigned i=0; i<call.arguments.size(); ++i)
1098                         {
1099                                 if(!builtin_info->arg_order[i] || builtin_info->arg_order[i]>argument_ids.size())
1100                                         throw internal_error("invalid builtin function info");
1101                                 writer.write(argument_ids[builtin_info->arg_order[i]-1]);
1102                         }
1103
1104                         end_expression(opcode);
1105                 }
1106                 else if(builtin_info->handler)
1107                         (this->*(builtin_info->handler))(call, argument_ids);
1108                 else
1109                         throw internal_error("unknown builtin function "+call.name);
1110         }
1111         else
1112         {
1113                 r_expression_result_id = begin_expression(OP_FUNCTION_CALL, result_type_id, 1+call.arguments.size());
1114                 writer.write(get_id(*call.declaration->definition));
1115                 for(vector<Id>::const_iterator i=argument_ids.begin(); i!=argument_ids.end(); ++i)
1116                         writer.write(*i);
1117                 end_expression(OP_FUNCTION_CALL);
1118
1119                 // Any global variables the called function uses might have changed value
1120                 set<Node *> dependencies = DependencyCollector().apply(*call.declaration->definition);
1121                 for(set<Node *>::const_iterator i=dependencies.begin(); i!=dependencies.end(); ++i)
1122                         if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(*i))
1123                                 variable_load_ids.erase(var);
1124         }
1125 }
1126
1127 void SpirVGenerator::visit_constructor(FunctionCall &call, const vector<Id> &argument_ids, bool all_args_const)
1128 {
1129         Id result_type_id = get_id(*call.type);
1130
1131         BasicTypeDeclaration *basic = dynamic_cast<BasicTypeDeclaration *>(call.type);
1132         if(!basic)
1133         {
1134                 if(dynamic_cast<const StructDeclaration *>(call.type))
1135                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1136                 else
1137                         throw internal_error("unconstructable type "+call.name);
1138                 return;
1139         }
1140
1141         SetFlag set_const(constant_expression, constant_expression || all_args_const);
1142
1143         BasicTypeDeclaration &elem = *get_element_type(*basic);
1144         BasicTypeDeclaration &basic_arg0 = dynamic_cast<BasicTypeDeclaration &>(*call.arguments[0]->type);
1145         BasicTypeDeclaration &elem_arg0 = *get_element_type(basic_arg0);
1146
1147         if(basic->kind==BasicTypeDeclaration::MATRIX)
1148         {
1149                 Id col_type_id = get_id(*basic->base_type);
1150                 unsigned n_columns = basic->size&0xFFFF;
1151                 unsigned n_rows = basic->size>>16;
1152
1153                 Id column_ids[4];
1154                 if(call.arguments.size()==1)
1155                 {
1156                         // Construct diagonal matrix from a single scalar.
1157                         Id zero_id = get_constant_id(get_id(elem), 0.0f);
1158                         for(unsigned i=0; i<n_columns; ++i)
1159                         {
1160                                 column_ids[i] = begin_expression(OP_COMPOSITE_CONSTRUCT, col_type_id, n_rows);;
1161                                 for(unsigned j=0; j<n_rows; ++j)
1162                                         writer.write(j==i ? argument_ids[0] : zero_id);
1163                                 end_expression(OP_COMPOSITE_CONSTRUCT);
1164                         }
1165                 }
1166                 else
1167                         // Construct a matrix from column vectors
1168                         copy(argument_ids.begin(), argument_ids.begin()+n_columns, column_ids);
1169
1170                 r_expression_result_id = write_construct(result_type_id, column_ids, n_columns);
1171         }
1172         else if(basic->kind==BasicTypeDeclaration::VECTOR && (call.arguments.size()>1 || basic_arg0.kind!=BasicTypeDeclaration::VECTOR))
1173         {
1174                 /* There's either a single scalar argument or multiple arguments
1175                 which make up the vector's components. */
1176                 if(call.arguments.size()==1)
1177                 {
1178                         r_expression_result_id = begin_expression(OP_COMPOSITE_CONSTRUCT, result_type_id);
1179                         for(unsigned i=0; i<basic->size; ++i)
1180                                 writer.write(argument_ids[0]);
1181                         end_expression(OP_COMPOSITE_CONSTRUCT);
1182                 }
1183                 else
1184                         r_expression_result_id = write_construct(result_type_id, &argument_ids[0], argument_ids.size());
1185         }
1186         else if(elem.kind==BasicTypeDeclaration::BOOL)
1187         {
1188                 if(constant_expression)
1189                         throw internal_error("unconverted constant");
1190
1191                 // Conversion to boolean is implemented as comparing against zero.
1192                 Id number_type_id = get_id(elem_arg0);
1193                 Id zero_id = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ?
1194                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1195                 if(basic_arg0.kind==BasicTypeDeclaration::VECTOR)
1196                         zero_id = get_vector_constant_id(get_id(basic_arg0), basic_arg0.size, zero_id);
1197
1198                 Opcode opcode = (elem_arg0.kind==BasicTypeDeclaration::FLOAT ? OP_F_ORD_NOT_EQUAL : OP_I_NOT_EQUAL);
1199                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0], zero_id);
1200         }
1201         else if(elem_arg0.kind==BasicTypeDeclaration::BOOL)
1202         {
1203                 if(constant_expression)
1204                         throw internal_error("unconverted constant");
1205
1206                 /* Conversion from boolean is implemented as selecting from zero
1207                 or one. */
1208                 Id number_type_id = get_id(elem);
1209                 Id zero_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1210                         get_constant_id(number_type_id, 0.0f) : get_constant_id(number_type_id, 0));
1211                 Id one_id = (elem.kind==BasicTypeDeclaration::FLOAT ?
1212                         get_constant_id(number_type_id, 1.0f) : get_constant_id(number_type_id, 1));
1213                 if(basic->kind==BasicTypeDeclaration::VECTOR)
1214                 {
1215                         zero_id = get_vector_constant_id(get_id(*basic), basic->size, zero_id);
1216                         one_id = get_vector_constant_id(get_id(*basic), basic->size, one_id);
1217                 }
1218
1219                 r_expression_result_id = begin_expression(OP_SELECT, result_type_id, 3);
1220                 writer.write(argument_ids[0]);
1221                 writer.write(zero_id);
1222                 writer.write(one_id);
1223                 end_expression(OP_SELECT);
1224         }
1225         else
1226         {
1227                 if(constant_expression)
1228                         throw internal_error("unconverted constant");
1229
1230                 // Scalar or vector conversion between types of equal size.
1231                 Opcode opcode;
1232                 if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::FLOAT)
1233                         opcode = (elem.sign ? OP_CONVERT_F_TO_S : OP_CONVERT_F_TO_U);
1234                 else if(elem.kind==BasicTypeDeclaration::FLOAT && elem_arg0.kind==BasicTypeDeclaration::INT)
1235                         opcode = (elem_arg0.sign ? OP_CONVERT_S_TO_F : OP_CONVERT_U_TO_F);
1236                 else if(elem.kind==BasicTypeDeclaration::INT && elem_arg0.kind==BasicTypeDeclaration::INT)
1237                         opcode = OP_BITCAST;
1238                 else
1239                         throw internal_error("invalid conversion");
1240
1241                 r_expression_result_id = write_expression(opcode, result_type_id, argument_ids[0]);
1242         }
1243 }
1244
1245 void SpirVGenerator::visit_builtin_matrix_comp_mult(FunctionCall &call, const vector<Id> &argument_ids)
1246 {
1247         if(argument_ids.size()!=2)
1248                 throw internal_error("invalid matrixCompMult call");
1249
1250         const BasicTypeDeclaration &basic_arg0 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[0]->type);
1251         Id column_type_id = get_id(*basic_arg0.base_type);
1252         Id column_ids[8];
1253
1254         unsigned n_columns = basic_arg0.size&0xFFFF;
1255         write_deconstruct(column_type_id, argument_ids[0], column_ids, n_columns);
1256         write_deconstruct(column_type_id, argument_ids[1], column_ids+4, n_columns);
1257
1258         for(unsigned i=0; i<n_columns; ++i)
1259                 column_ids[i] = write_expression(OP_F_MUL, column_type_id, column_ids[i], column_ids[4+i]);
1260
1261         r_expression_result_id = write_construct(get_id(*call.type), column_ids, n_columns);
1262 }
1263
1264 void SpirVGenerator::visit_builtin_texture(FunctionCall &call, const vector<Id> &argument_ids)
1265 {
1266         if(argument_ids.size()<2)
1267                 throw internal_error("invalid texture sampling call");
1268
1269         bool explicit_lod = (stage->type!=Stage::FRAGMENT || call.name=="textureLod");
1270         Id lod_id = (!explicit_lod ? 0 : call.name=="textureLod" ? argument_ids.back() :
1271                 get_constant_id(get_standard_type_id(BasicTypeDeclaration::FLOAT, 1), 0.0f));
1272
1273         const ImageTypeDeclaration &image = dynamic_cast<const ImageTypeDeclaration &>(*call.arguments[0]->type);
1274
1275         Opcode opcode;
1276         Id result_type_id = get_id(*call.type);
1277         Id dref_id = 0;
1278         if(image.shadow)
1279         {
1280                 if(argument_ids.size()==2)
1281                 {
1282                         const BasicTypeDeclaration &basic_arg1 = dynamic_cast<const BasicTypeDeclaration &>(*call.arguments[1]->type);
1283                         dref_id = begin_expression(OP_COMPOSITE_EXTRACT, get_id(*basic_arg1.base_type), 2);
1284                         writer.write(argument_ids.back());
1285                         writer.write(basic_arg1.size-1);
1286                         end_expression(OP_COMPOSITE_EXTRACT);
1287                 }
1288                 else
1289                         dref_id = argument_ids[2];
1290
1291                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD : OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD);
1292                 r_expression_result_id = begin_expression(opcode, result_type_id, 3+explicit_lod*2);
1293         }
1294         else
1295         {
1296                 opcode = (explicit_lod ? OP_IMAGE_SAMPLE_EXPLICIT_LOD : OP_IMAGE_SAMPLE_IMPLICIT_LOD);
1297                 r_expression_result_id = begin_expression(opcode, result_type_id, 2+explicit_lod*2);
1298         }
1299
1300         for(unsigned i=0; i<2; ++i)
1301                 writer.write(argument_ids[i]);
1302         if(dref_id)
1303                 writer.write(dref_id);
1304         if(explicit_lod)
1305         {
1306                 writer.write(2);  // Lod
1307                 writer.write(lod_id);
1308         }
1309
1310         end_expression(opcode);
1311 }
1312
1313 void SpirVGenerator::visit_builtin_texel_fetch(FunctionCall &call, const vector<Id> &argument_ids)
1314 {
1315         if(argument_ids.size()!=3)
1316                 throw internal_error("invalid texelFetch call");
1317
1318         r_expression_result_id = begin_expression(OP_IMAGE_FETCH, get_id(*call.type), 4);
1319         for(unsigned i=0; i<2; ++i)
1320                 writer.write(argument_ids[i]);
1321         writer.write(2);  // Lod
1322         writer.write(argument_ids.back());
1323         end_expression(OP_IMAGE_FETCH);
1324 }
1325
1326 void SpirVGenerator::visit_builtin_interpolate(FunctionCall &call, const vector<Id> &argument_ids)
1327 {
1328         if(argument_ids.size()<1)
1329                 throw internal_error("invalid interpolate call");
1330         const VariableReference *var = dynamic_cast<const VariableReference *>(call.arguments[0].get());
1331         if(!var || !var->declaration || var->declaration->interface!="in")
1332                 throw internal_error("invalid interpolate call");
1333
1334         SpirVGlslStd450Opcode opcode;
1335         if(call.name=="interpolateAtCentroid")
1336                 opcode = GLSL450_INTERPOLATE_AT_CENTROID;
1337         else if(call.name=="interpolateAtSample")
1338                 opcode = GLSL450_INTERPOLATE_AT_SAMPLE;
1339         else if(call.name=="interpolateAtOffset")
1340                 opcode = GLSL450_INTERPOLATE_AT_OFFSET;
1341         else
1342                 throw internal_error("invalid interpolate call");
1343
1344         use_capability(CAP_INTERPOLATION_FUNCTION);
1345
1346         Id ext_id = import_extension("GLSL.std.450");
1347         r_expression_result_id = begin_expression(OP_EXT_INST, get_id(*call.type));
1348         writer.write(ext_id);
1349         writer.write(opcode);
1350         writer.write(get_id(*var->declaration));
1351         for(vector<Id>::const_iterator i=argument_ids.begin(); ++i!=argument_ids.end(); )
1352                 writer.write(*i);
1353         end_expression(OP_EXT_INST);
1354 }
1355
1356 void SpirVGenerator::visit(ExpressionStatement &expr)
1357 {
1358         expr.expression->visit(*this);
1359 }
1360
1361 void SpirVGenerator::visit(InterfaceLayout &layout)
1362 {
1363         interface_layouts.push_back(&layout);
1364 }
1365
1366 bool SpirVGenerator::check_duplicate_type(TypeDeclaration &type)
1367 {
1368         for(map<Node *, Declaration>::const_iterator i=declared_ids.begin(); i!=declared_ids.end(); ++i)
1369                 if(TypeDeclaration *type2 = dynamic_cast<TypeDeclaration *>(i->first))
1370                         if(TypeComparer().apply(type, *type2))
1371                         {
1372                                 insert_unique(declared_ids, &type, i->second);
1373                                 return true;
1374                         }
1375
1376         return false;
1377 }
1378
1379 bool SpirVGenerator::check_standard_type(BasicTypeDeclaration &basic)
1380 {
1381         const BasicTypeDeclaration *elem = (basic.kind==BasicTypeDeclaration::VECTOR ?
1382                 dynamic_cast<const BasicTypeDeclaration *>(basic.base_type) : &basic);
1383         if(!elem || elem->base_type)
1384                 return false;
1385         if((elem->kind==BasicTypeDeclaration::INT || elem->kind==BasicTypeDeclaration::FLOAT) && elem->size!=32)
1386                 return false;
1387
1388         Id standard_id = get_standard_type_id(elem->kind, (basic.kind==BasicTypeDeclaration::VECTOR ? basic.size : 1), elem->sign);
1389         insert_unique(declared_ids, &basic, Declaration(standard_id, 0));
1390         writer.write_op_name(standard_id, basic.name);
1391
1392         return true;
1393 }
1394
1395 void SpirVGenerator::visit(BasicTypeDeclaration &basic)
1396 {
1397         if(check_standard_type(basic))
1398                 return;
1399         if(check_duplicate_type(basic))
1400                 return;
1401         // Alias types shouldn't exist at this point and arrays are handled elsewhere
1402         if(basic.kind==BasicTypeDeclaration::ALIAS || basic.kind==BasicTypeDeclaration::ARRAY)
1403                 return;
1404
1405         Id type_id = allocate_id(basic, 0);
1406         writer.write_op_name(type_id, basic.name);
1407
1408         switch(basic.kind)
1409         {
1410         case BasicTypeDeclaration::INT:
1411                 writer.write_op(content.globals, OP_TYPE_INT, type_id, basic.size, basic.sign);
1412                 break;
1413         case BasicTypeDeclaration::FLOAT:
1414                 writer.write_op(content.globals, OP_TYPE_FLOAT, type_id, basic.size);
1415                 break;
1416         case BasicTypeDeclaration::VECTOR:
1417                 writer.write_op(content.globals, OP_TYPE_VECTOR, type_id, get_id(*basic.base_type), basic.size);
1418                 break;
1419         case BasicTypeDeclaration::MATRIX:
1420                 writer.write_op(content.globals, OP_TYPE_MATRIX, type_id, get_id(*basic.base_type), basic.size&0xFFFF);
1421                 break;
1422         default:
1423                 throw internal_error("unknown basic type");
1424         }
1425 }
1426
1427 void SpirVGenerator::visit(ImageTypeDeclaration &image)
1428 {
1429         if(check_duplicate_type(image))
1430                 return;
1431
1432         Id type_id = allocate_id(image, 0);
1433
1434         Id image_id = (image.sampled ? next_id++ : type_id);
1435         writer.begin_op(content.globals, OP_TYPE_IMAGE, 9);
1436         writer.write(image_id);
1437         writer.write(get_id(*image.base_type));
1438         writer.write(image.dimensions-1);
1439         writer.write(image.shadow);
1440         writer.write(image.array);
1441         writer.write(false);  // Multisample
1442         writer.write(image.sampled ? 1 : 2);
1443         writer.write(0);  // Format (unknown)
1444         writer.end_op(OP_TYPE_IMAGE);
1445
1446         if(image.sampled)
1447         {
1448                 writer.write_op_name(type_id, image.name);
1449                 writer.write_op(content.globals, OP_TYPE_SAMPLED_IMAGE, type_id, image_id);
1450         }
1451
1452         if(image.dimensions==ImageTypeDeclaration::ONE)
1453                 use_capability(image.sampled ? CAP_SAMPLED_1D : CAP_IMAGE_1D);
1454         else if(image.dimensions==ImageTypeDeclaration::CUBE && image.array)
1455                 use_capability(image.sampled ? CAP_SAMPLED_CUBE_ARRAY : CAP_IMAGE_CUBE_ARRAY);
1456 }
1457
1458 void SpirVGenerator::visit(StructDeclaration &strct)
1459 {
1460         if(check_duplicate_type(strct))
1461                 return;
1462
1463         Id type_id = allocate_id(strct, 0);
1464         writer.write_op_name(type_id, strct.name);
1465
1466         if(strct.interface_block)
1467                 writer.write_op_decorate(type_id, DECO_BLOCK);
1468
1469         bool builtin = (strct.interface_block && !strct.interface_block->block_name.compare(0, 3, "gl_"));
1470         vector<Id> member_type_ids;
1471         member_type_ids.reserve(strct.members.body.size());
1472         for(NodeList<Statement>::const_iterator i=strct.members.body.begin(); i!=strct.members.body.end(); ++i)
1473         {
1474                 const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(i->get());
1475                 if(!var)
1476                         continue;
1477
1478                 unsigned index = member_type_ids.size();
1479                 member_type_ids.push_back(get_variable_type_id(*var));
1480
1481                 writer.write_op_member_name(type_id, index, var->name);
1482
1483                 if(builtin)
1484                 {
1485                         BuiltinSemantic semantic = get_builtin_semantic(var->name);
1486                         writer.write_op_member_decorate(type_id, index, DECO_BUILTIN, semantic);
1487                 }
1488                 else
1489                 {
1490                         if(var->layout)
1491                         {
1492                                 const vector<Layout::Qualifier> &qualifiers = var->layout->qualifiers;
1493                                 for(vector<Layout::Qualifier>::const_iterator j=qualifiers.begin(); j!=qualifiers.end(); ++j)
1494                                 {
1495                                         if(j->name=="offset")
1496                                                 writer.write_op_member_decorate(type_id, index, DECO_OFFSET, j->value);
1497                                         else if(j->name=="column_major")
1498                                                 writer.write_op_member_decorate(type_id, index, DECO_COL_MAJOR);
1499                                         else if(j->name=="row_major")
1500                                                 writer.write_op_member_decorate(type_id, index, DECO_ROW_MAJOR);
1501                                 }
1502                         }
1503
1504                         const BasicTypeDeclaration *basic = dynamic_cast<const BasicTypeDeclaration *>(var->type_declaration);
1505                         while(basic && basic->kind==BasicTypeDeclaration::ARRAY)
1506                                 basic = dynamic_cast<const BasicTypeDeclaration *>(basic->base_type);
1507                         if(basic && basic->kind==BasicTypeDeclaration::MATRIX)
1508                         {
1509                                 unsigned stride = MemoryRequirementsCalculator().apply(*basic->base_type).stride;
1510                                 writer.write_op_member_decorate(type_id, index, DECO_MATRIX_STRIDE, stride);
1511                         }
1512                 }
1513         }
1514
1515         writer.begin_op(content.globals, OP_TYPE_STRUCT);
1516         writer.write(type_id);
1517         for(vector<Id>::const_iterator i=member_type_ids.begin(); i!=member_type_ids.end(); ++i)
1518                 writer.write(*i);
1519         writer.end_op(OP_TYPE_STRUCT);
1520 }
1521
1522 void SpirVGenerator::visit(VariableDeclaration &var)
1523 {
1524         const vector<Layout::Qualifier> *layout_ql = (var.layout ? &var.layout->qualifiers : 0);
1525
1526         int spec_id = -1;
1527         if(layout_ql)
1528         {
1529                 for(vector<Layout::Qualifier>::const_iterator i=layout_ql->begin(); (spec_id<0 && i!=layout_ql->end()); ++i)
1530                         if(i->name=="constant_id")
1531                                 spec_id = i->value;
1532         }
1533
1534         Id type_id = get_variable_type_id(var);
1535         Id var_id;
1536
1537         if(var.constant)
1538         {
1539                 if(!var.init_expression)
1540                         throw internal_error("const variable without initializer");
1541
1542                 SetFlag set_const(constant_expression);
1543                 SetFlag set_spec(spec_constant, spec_id>=0);
1544                 r_expression_result_id = 0;
1545                 var.init_expression->visit(*this);
1546                 var_id = r_expression_result_id;
1547                 insert_unique(declared_ids, &var, Declaration(var_id, type_id));
1548                 writer.write_op_decorate(var_id, DECO_SPEC_ID, spec_id);
1549
1550                 /* It's unclear what should be done if a specialization constant is
1551                 initialized with anything other than a literal.  GLSL doesn't seem to
1552                 prohibit that but SPIR-V says OpSpecConstantOp can't be updated via
1553                 specialization. */
1554         }
1555         else
1556         {
1557                 StorageClass storage = (current_function ? STORAGE_FUNCTION : get_interface_storage(var.interface, false));
1558                 Id ptr_type_id = get_pointer_type_id(type_id, storage);
1559                 if(var.interface=="uniform")
1560                 {
1561                         Id &uni_id = declared_uniform_ids["v"+var.name];
1562                         if(uni_id)
1563                         {
1564                                 insert_unique(declared_ids, &var, Declaration(uni_id, ptr_type_id));
1565                                 return;
1566                         }
1567
1568                         uni_id = var_id = allocate_id(var, ptr_type_id);
1569                 }
1570                 else
1571                         var_id = allocate_id(var, (var.constant ? type_id : ptr_type_id));
1572
1573                 Id init_id = 0;
1574                 if(var.init_expression)
1575                 {
1576                         SetFlag set_const(constant_expression, !current_function);
1577                         r_expression_result_id = 0;
1578                         r_constant_result = false;
1579                         var.init_expression->visit(*this);
1580                         init_id = r_expression_result_id;
1581                 }
1582
1583                 vector<Word> &target = (current_function ? content.locals : content.globals);
1584                 writer.begin_op(target, OP_VARIABLE, 4+(init_id && !current_function));
1585                 writer.write(ptr_type_id);
1586                 writer.write(var_id);
1587                 writer.write(storage);
1588                 if(init_id && !current_function)
1589                         writer.write(init_id);
1590                 writer.end_op(OP_VARIABLE);
1591
1592                 if(layout_ql)
1593                 {
1594                         for(vector<Layout::Qualifier>::const_iterator i=layout_ql->begin(); i!=layout_ql->end(); ++i)
1595                         {
1596                                 if(i->name=="location")
1597                                         writer.write_op_decorate(var_id, DECO_LOCATION, i->value);
1598                                 else if(i->name=="set")
1599                                         writer.write_op_decorate(var_id, DECO_DESCRIPTOR_SET, i->value);
1600                                 else if(i->name=="binding")
1601                                         writer.write_op_decorate(var_id, DECO_BINDING, i->value);
1602                         }
1603                 }
1604
1605                 if(init_id && current_function)
1606                         writer.write_op(content.function_body, OP_STORE, var_id, init_id);
1607         }
1608
1609         writer.write_op_name(var_id, var.name);
1610 }
1611
1612 void SpirVGenerator::visit(InterfaceBlock &iface)
1613 {
1614         StorageClass storage = get_interface_storage(iface.interface, true);
1615         Id type_id;
1616         if(iface.array)
1617                 type_id = get_array_type_id(*iface.struct_declaration, 0);
1618         else
1619                 type_id = get_id(*iface.struct_declaration);
1620         Id ptr_type_id = get_pointer_type_id(type_id, storage);
1621
1622         Id block_id;
1623         if(iface.interface=="uniform")
1624         {
1625                 Id &uni_id = declared_uniform_ids["b"+iface.block_name];
1626                 if(uni_id)
1627                 {
1628                         insert_unique(declared_ids, &iface, Declaration(uni_id, ptr_type_id));
1629                         return;
1630                 }
1631
1632                 uni_id = block_id = allocate_id(iface, ptr_type_id);
1633         }
1634         else
1635                 block_id = allocate_id(iface, ptr_type_id);
1636         writer.write_op_name(block_id, iface.instance_name);
1637
1638         writer.write_op(content.globals, OP_VARIABLE, ptr_type_id, block_id, storage);
1639
1640         if(iface.layout)
1641         {
1642                 const vector<Layout::Qualifier> &qualifiers = iface.layout->qualifiers;
1643                 for(vector<Layout::Qualifier>::const_iterator i=qualifiers.begin(); i!=qualifiers.end(); ++i)
1644                         if(i->name=="binding")
1645                                 writer.write_op_decorate(block_id, DECO_BINDING, i->value);
1646         }
1647 }
1648
1649 void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id)
1650 {
1651         writer.begin_op(content.entry_points, OP_ENTRY_POINT);
1652         switch(stage->type)
1653         {
1654         case Stage::VERTEX: writer.write(0); break;
1655         case Stage::GEOMETRY: writer.write(3); break;
1656         case Stage::FRAGMENT: writer.write(4); break;
1657         default: throw internal_error("unknown stage");
1658         }
1659         writer.write(func_id);
1660         writer.write_string(func.name);
1661
1662         set<Node *> dependencies = DependencyCollector().apply(func);
1663         for(set<Node *>::const_iterator i=dependencies.begin(); i!=dependencies.end(); ++i)
1664         {
1665                 if(const VariableDeclaration *var = dynamic_cast<const VariableDeclaration *>(*i))
1666                 {
1667                         if(!var->interface.empty())
1668                                 writer.write(get_id(**i));
1669                 }
1670                 else if(dynamic_cast<InterfaceBlock *>(*i))
1671                         writer.write(get_id(**i));
1672         }
1673
1674         writer.end_op(OP_ENTRY_POINT);
1675
1676         if(stage->type==Stage::FRAGMENT)
1677                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_ORIGIN_LOWER_LEFT);
1678         else if(stage->type==Stage::GEOMETRY)
1679                 use_capability(CAP_GEOMETRY);
1680
1681         for(vector<const InterfaceLayout *>::const_iterator i=interface_layouts.begin(); i!=interface_layouts.end(); ++i)
1682         {
1683                 const vector<Layout::Qualifier> &qualifiers = (*i)->layout.qualifiers;
1684                 for(vector<Layout::Qualifier>::const_iterator j=qualifiers.begin(); j!=qualifiers.end(); ++j)
1685                 {
1686                         if(j->name=="point")
1687                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id,
1688                                         ((*i)->interface=="in" ? EXEC_INPUT_POINTS : EXEC_OUTPUT_POINTS));
1689                         else if(j->name=="lines")
1690                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES);
1691                         else if(j->name=="lines_adjacency")
1692                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_LINES_ADJACENCY);
1693                         else if(j->name=="triangles")
1694                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_TRIANGLES);
1695                         else if(j->name=="triangles_adjacency")
1696                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INPUT_TRIANGLES_ADJACENCY);
1697                         else if(j->name=="line_strip")
1698                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_LINE_STRIP);
1699                         else if(j->name=="triangle_strip")
1700                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_TRIANGLE_STRIP);
1701                         else if(j->name=="max_vertices")
1702                                 writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_VERTICES, j->value);
1703                 }
1704         }
1705 }
1706
1707 void SpirVGenerator::visit(FunctionDeclaration &func)
1708 {
1709         if(func.source==BUILTIN_SOURCE || func.definition!=&func)
1710                 return;
1711
1712         Id return_type_id = get_id(*func.return_type_declaration);
1713         vector<unsigned> param_type_ids;
1714         param_type_ids.reserve(func.parameters.size());
1715         for(NodeArray<VariableDeclaration>::const_iterator i=func.parameters.begin(); i!=func.parameters.end(); ++i)
1716                 param_type_ids.push_back(get_variable_type_id(**i));
1717
1718         string sig_with_return = func.return_type+func.signature;
1719         Id &type_id = function_type_ids[sig_with_return];
1720         if(!type_id)
1721         {
1722                 type_id = next_id++;
1723                 writer.begin_op(content.globals, OP_TYPE_FUNCTION);
1724                 writer.write(type_id);
1725                 writer.write(return_type_id);
1726                 for(vector<unsigned>::const_iterator i=param_type_ids.begin(); i!=param_type_ids.end(); ++i)
1727                         writer.write(*i);
1728                 writer.end_op(OP_TYPE_FUNCTION);
1729
1730                 writer.write_op_name(type_id, sig_with_return);
1731         }
1732
1733         Id func_id = allocate_id(func, type_id);
1734         writer.write_op_name(func_id, func.name+func.signature);
1735
1736         if(func.name=="main")
1737                 visit_entry_point(func, func_id);
1738
1739         writer.begin_op(content.functions, OP_FUNCTION, 5);
1740         writer.write(return_type_id);
1741         writer.write(func_id);
1742         writer.write(0);  // Function control flags (none)
1743         writer.write(type_id);
1744         writer.end_op(OP_FUNCTION);
1745
1746         for(unsigned i=0; i<func.parameters.size(); ++i)
1747         {
1748                 Id param_id = allocate_id(*func.parameters[i], param_type_ids[i]);
1749                 writer.write_op(content.functions, OP_FUNCTION_PARAMETER, param_type_ids[i], param_id);
1750                 // TODO This is probably incorrect if the parameter is assigned to.
1751                 variable_load_ids[func.parameters[i].get()] = param_id;
1752         }
1753
1754         writer.begin_function_body(next_id++);
1755         SetForScope<FunctionDeclaration *> set_func(current_function, &func);
1756         func.body.visit(*this);
1757
1758         if(writer.has_current_block())
1759         {
1760                 if(!reachable)
1761                         writer.write_op(content.function_body, OP_UNREACHABLE);
1762                 else
1763                 {
1764                         const BasicTypeDeclaration *basic_return = dynamic_cast<const BasicTypeDeclaration *>(func.return_type_declaration);
1765                         if(basic_return && basic_return->kind==BasicTypeDeclaration::VOID)
1766                                 writer.write_op(content.function_body, OP_RETURN);
1767                         else
1768                                 throw internal_error("missing return in non-void function");
1769                 }
1770         }
1771         writer.end_function_body();
1772         variable_load_ids.clear();
1773 }
1774
1775 void SpirVGenerator::visit(Conditional &cond)
1776 {
1777         cond.condition->visit(*this);
1778
1779         Id true_label_id = next_id++;
1780         Id merge_block_id = next_id++;
1781         Id false_label_id = (cond.else_body.body.empty() ? merge_block_id : next_id++);
1782         writer.write_op(content.function_body, OP_SELECTION_MERGE, merge_block_id, 0);  // Selection control (none)
1783         writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, true_label_id, false_label_id);
1784
1785         writer.write_op_label(true_label_id);
1786         cond.body.visit(*this);
1787         if(writer.has_current_block())
1788                 writer.write_op(content.function_body, OP_BRANCH, merge_block_id);
1789
1790         bool reachable_if_true = reachable;
1791
1792         reachable = true;
1793         if(!cond.else_body.body.empty())
1794         {
1795                 writer.write_op_label(false_label_id);
1796                 cond.else_body.visit(*this);
1797                 reachable |= reachable_if_true;
1798         }
1799
1800         writer.write_op_label(merge_block_id);
1801         prune_loads(true_label_id);
1802 }
1803
1804 void SpirVGenerator::visit(Iteration &iter)
1805 {
1806         if(iter.init_statement)
1807                 iter.init_statement->visit(*this);
1808
1809         Id header_id = next_id++;
1810         Id continue_id = next_id++;
1811         Id merge_block_id = next_id++;
1812
1813         SetForScope<Id> set_merge(loop_merge_block_id, merge_block_id);
1814         SetForScope<Id> set_continue(loop_continue_target_id, continue_id);
1815
1816         writer.write_op_label(header_id);
1817         writer.write_op(content.function_body, OP_LOOP_MERGE, merge_block_id, continue_id, 0);  // Loop control (none)
1818
1819         Id body_id = next_id++;
1820         if(iter.condition)
1821         {
1822                 writer.write_op_label(next_id++);
1823                 iter.condition->visit(*this);
1824                 writer.write_op(content.function_body, OP_BRANCH_CONDITIONAL, r_expression_result_id, body_id, merge_block_id);
1825         }
1826
1827         writer.write_op_label(body_id);
1828         iter.body.visit(*this);
1829
1830         writer.write_op_label(continue_id);
1831         if(iter.loop_expression)
1832                 iter.loop_expression->visit(*this);
1833         writer.write_op(content.function_body, OP_BRANCH, header_id);
1834
1835         writer.write_op_label(merge_block_id);
1836         prune_loads(header_id);
1837         reachable = true;
1838 }
1839
1840 void SpirVGenerator::visit(Return &ret)
1841 {
1842         if(ret.expression)
1843         {
1844                 ret.expression->visit(*this);
1845                 writer.write_op(content.function_body, OP_RETURN_VALUE, r_expression_result_id);
1846         }
1847         else
1848                 writer.write_op(content.function_body, OP_RETURN);
1849         reachable = false;
1850 }
1851
1852 void SpirVGenerator::visit(Jump &jump)
1853 {
1854         if(jump.keyword=="discard")
1855                 writer.write_op(content.function_body, OP_KILL);
1856         else if(jump.keyword=="break")
1857                 writer.write_op(content.function_body, OP_BRANCH, loop_merge_block_id);
1858         else if(jump.keyword=="continue")
1859                 writer.write_op(content.function_body, OP_BRANCH, loop_continue_target_id);
1860         else
1861                 throw internal_error("unknown jump");
1862         reachable = false;
1863 }
1864
1865
1866 SpirVGenerator::TypeKey::TypeKey(BasicTypeDeclaration::Kind kind, bool sign):
1867         type_id(0)
1868 {
1869         switch(kind)
1870         {
1871         case BasicTypeDeclaration::VOID: detail = 'v'; break;
1872         case BasicTypeDeclaration::BOOL: detail = 'b'; break;
1873         case BasicTypeDeclaration::INT: detail = (sign ? 'i' : 'u'); break;
1874         case BasicTypeDeclaration::FLOAT: detail = 'f'; break;
1875         default: throw invalid_argument("TypeKey::TypeKey");
1876         }
1877 }
1878
1879 bool SpirVGenerator::TypeKey::operator<(const TypeKey &other) const
1880 {
1881         if(type_id!=other.type_id)
1882                 return type_id<other.type_id;
1883         return detail<other.detail;
1884 }
1885
1886
1887 bool SpirVGenerator::ConstantKey::operator<(const ConstantKey &other) const
1888 {
1889         if(type_id!=other.type_id)
1890                 return type_id<other.type_id;
1891         return int_value<other.int_value;
1892 }
1893
1894 } // namespace SL
1895 } // namespace GL
1896 } // namespace Msp