From 1f71ec73edacd6c1bd230abace3cb791eb2ca759 Mon Sep 17 00:00:00 2001 From: Mikko Rasa Date: Sat, 24 Apr 2021 19:18:42 +0300 Subject: [PATCH] Unify handling of constants in SpirVModule --- source/core/module.cpp | 50 +++++++++++++++++++---------------------- source/core/module.h | 18 ++++++++------- source/core/program.cpp | 4 ++-- 3 files changed, 35 insertions(+), 37 deletions(-) diff --git a/source/core/module.cpp b/source/core/module.cpp index 72a6abca..ce540505 100644 --- a/source/core/module.cpp +++ b/source/core/module.cpp @@ -224,8 +224,9 @@ void SpirVModule::reflect() } } - for(map::const_iterator i=reflection.spec_constants.begin(); i!=reflection.spec_constants.end(); ++i) - spec_constants.push_back(i->second); + for(map::const_iterator i=reflection.constants.begin(); i!=reflection.constants.end(); ++i) + if(i->second->constant_id>=0) + spec_constants.push_back(i->second); } @@ -329,12 +330,12 @@ void SpirVModule::Reflection::reflect_code(const vector &code) case OP_TYPE_ARRAY: reflect_array_type(op); break; case OP_TYPE_STRUCT: reflect_struct_type(op); break; case OP_TYPE_POINTER: reflect_pointer_type(op); break; - case OP_CONSTANT_TRUE: constants[*(op+2)] = true; break; - case OP_CONSTANT_FALSE: constants[*(op+2)] = false; break; - case OP_CONSTANT: reflect_constant(op); break; + case OP_CONSTANT_TRUE: + case OP_CONSTANT_FALSE: + case OP_CONSTANT: case OP_SPEC_CONSTANT_TRUE: case OP_SPEC_CONSTANT_FALSE: - case OP_SPEC_CONSTANT: reflect_spec_constant(op); break; + case OP_SPEC_CONSTANT: reflect_constant(op); break; case OP_VARIABLE: reflect_variable(op); break; case OP_DECORATE: reflect_decorate(op); break; case OP_MEMBER_DECORATE: reflect_member_decorate(op); break; @@ -440,11 +441,10 @@ void SpirVModule::Reflection::reflect_array_type(CodeIterator op) const TypeInfo &elem = types[*(op+2)]; type.type = elem.type; type.struct_type = elem.struct_type; - const Variant &size = constants[*(op+3)]; - if(size.check_type()) - type.array_size = size.value(); - else if(size.check_type()) - type.array_size = size.value(); + + const Constant &size = constants[*(op+3)]; + if(size.type==INT || size.type==UNSIGNED_INT) + type.array_size = size.i_value; } void SpirVModule::Reflection::reflect_struct_type(CodeIterator op) @@ -476,23 +476,19 @@ void SpirVModule::Reflection::reflect_pointer_type(CodeIterator op) } void SpirVModule::Reflection::reflect_constant(CodeIterator op) -{ - const TypeInfo &type = types[*(op+1)]; - unsigned id = *(op+2); - if(type.type==INT) - constants[id] = static_cast(*(op+3)); - else if(type.type==UNSIGNED_INT) - constants[id] = static_cast(*(op+3)); - else if(type.type==FLOAT) - constants[id] = *reinterpret_cast(&*(op+3)); -} - -void SpirVModule::Reflection::reflect_spec_constant(CodeIterator op) { unsigned id = *(op+2); - SpecConstant &spec = spec_constants[id]; - spec.name = names[id]; - spec.type = types[*(op+1)].type; + Constant &cnst = constants[id]; + cnst.name = names[id]; + cnst.type = types[*(op+1)].type; + if(*op==OP_CONSTANT_TRUE || *op==OP_SPEC_CONSTANT_TRUE) + cnst.i_value = true; + else if(*op==OP_CONSTANT_FALSE || *op==OP_SPEC_CONSTANT_FALSE) + cnst.i_value = false; + else if(cnst.type==INT || cnst.type==UNSIGNED_INT) + cnst.i_value = *(op+3); + else if(cnst.type==FLOAT) + cnst.f_value = *reinterpret_cast(&*(op+3)); } void SpirVModule::Reflection::reflect_variable(CodeIterator op) @@ -516,7 +512,7 @@ void SpirVModule::Reflection::reflect_decorate(CodeIterator op) switch(decoration) { case DECO_SPEC_ID: - spec_constants[id].constant_id = *op; + constants[id].constant_id = *op; break; case DECO_ARRAY_STRIDE: types[id].array_stride = *op; diff --git a/source/core/module.h b/source/core/module.h index 3492a5de..6ca522a5 100644 --- a/source/core/module.h +++ b/source/core/module.h @@ -126,11 +126,16 @@ public: bool operator==(const Variable &) const; }; - struct SpecConstant + struct Constant { std::string name; - unsigned constant_id; + int constant_id; DataType type; + union + { + int i_value; + float f_value; + }; }; private: @@ -150,12 +155,11 @@ private: typedef std::vector::const_iterator CodeIterator; std::map names; - std::map constants; + std::map constants; std::map types; std::map entry_points; std::map structs; std::map variables; - std::map spec_constants; static UInt32 get_opcode(UInt32); static CodeIterator get_op_end(const CodeIterator &); @@ -177,8 +181,6 @@ private: void reflect_struct_type(CodeIterator); void reflect_pointer_type(CodeIterator); void reflect_constant(CodeIterator); - void reflect_spec_constant_bool(CodeIterator); - void reflect_spec_constant(CodeIterator); void reflect_variable(CodeIterator); void reflect_decorate(CodeIterator); void reflect_member_decorate(CodeIterator); @@ -188,7 +190,7 @@ private: std::vector entry_points; std::vector structs; std::vector variables; - std::vector spec_constants; + std::vector spec_constants; public: SpirVModule() { } @@ -209,7 +211,7 @@ public: const std::vector &get_code() const { return code; } const std::vector &get_entry_points() const { return entry_points; } const std::vector &get_variables() const { return variables; } - const std::vector &get_spec_constants() const { return spec_constants; } + const std::vector &get_spec_constants() const { return spec_constants; } }; } // namespace GL diff --git a/source/core/program.cpp b/source/core/program.cpp index 39184bbd..81da5b96 100644 --- a/source/core/program.cpp +++ b/source/core/program.cpp @@ -214,12 +214,12 @@ void Program::add_spirv_stages(const SpirVModule &mod, const map &s const vector &code = mod.get_code(); glShaderBinary(stage_ids.size(), &stage_ids[0], GL_SHADER_BINARY_FORMAT_SPIR_V, &code[0], code.size()*4); - const vector &spec_consts = mod.get_spec_constants(); + const vector &spec_consts = mod.get_spec_constants(); vector spec_id_array; vector spec_value_array; spec_id_array.reserve(spec_consts.size()); spec_value_array.reserve(spec_consts.size()); - for(vector::const_iterator i=spec_consts.begin(); i!=spec_consts.end(); ++i) + for(vector::const_iterator i=spec_consts.begin(); i!=spec_consts.end(); ++i) { map::const_iterator j = spec_values.find(i->name); if(j!=spec_values.end()) -- 2.43.0