]> git.tdb.fi Git - libs/gl.git/commitdiff
Reflect the control flow graph from SPIR-V and check variable accesses
authorMikko Rasa <tdb@tdb.fi>
Wed, 17 Nov 2021 13:25:15 +0000 (15:25 +0200)
committerMikko Rasa <tdb@tdb.fi>
Wed, 17 Nov 2021 13:29:34 +0000 (15:29 +0200)
Specialization constants may cause some declared variables to become
unused, and those shouldn't appear in Vulkan descriptor set layouts.

source/core/module.cpp
source/core/module.h
source/core/program.cpp
source/core/program.h

index 116c6241a184d321b6408122713873b16657412e..08b3a128d4e7c49dac863b7047e4d568713f08c6 100644 (file)
@@ -32,8 +32,14 @@ enum SpirVConstants
        OP_SPEC_CONSTANT_FALSE = 49,
        OP_SPEC_CONSTANT = 50,
        OP_VARIABLE = 59,
+       OP_LOAD = 61,
+       OP_STORE = 62,
+       OP_ACCESS_CHAIN = 65,
        OP_DECORATE = 71,
        OP_MEMBER_DECORATE = 72,
+       OP_LABEL = 248,
+       OP_BRANCH = 249,
+       OP_BRANCH_CONDITIONAL = 250,
 
        DECO_SPEC_ID = 1,
        DECO_ARRAY_STRIDE = 6,
@@ -210,6 +216,32 @@ void SpirVModule::reflect()
                        v = (i!=var_indices.end() ? &variables[i->second] : 0);
                }
        }
+
+       map<const InstructionBlock *, unsigned> block_indices;
+       blocks.reserve(reflection.blocks.size());
+       for(const auto &kvp: reflection.blocks)
+       {
+               block_indices[&kvp.second] = blocks.size();
+               blocks.push_back(kvp.second);
+       }
+
+       for(InstructionBlock &b: blocks)
+       {
+               auto i = spec_indices.find(b.condition);
+               b.condition = (i!=spec_indices.end() ? &spec_constants[i->second] : 0);
+
+               for(const Variable *&v: b.accessed_variables)
+               {
+                       auto j = var_indices.find(v);
+                       v = (j!=var_indices.end() ? &variables[j->second] : 0);
+               }
+
+               for(const InstructionBlock *&s: b.successors)
+               {
+                       auto j = block_indices.find(s);
+                       s = (j!=block_indices.end() ? &blocks[j->second] : 0);
+               }
+       }
 }
 
 
@@ -288,8 +320,14 @@ void SpirVModule::Reflection::reflect_code(const vector<uint32_t> &code)
                case OP_SPEC_CONSTANT_FALSE:
                case OP_SPEC_CONSTANT: reflect_constant(op); break;
                case OP_VARIABLE: reflect_variable(op); break;
+               case OP_LOAD:
+               case OP_STORE: reflect_access(op); break;
+               case OP_ACCESS_CHAIN: reflect_access_chain(op); break;
                case OP_DECORATE: reflect_decorate(op); break;
                case OP_MEMBER_DECORATE: reflect_member_decorate(op); break;
+               case OP_LABEL: reflect_label(op); break;
+               case OP_BRANCH: reflect_branch(op); break;
+               case OP_BRANCH_CONDITIONAL: reflect_branch_conditional(op); break;
                }
 
                op += word_count;
@@ -455,6 +493,26 @@ void SpirVModule::Reflection::reflect_variable(CodeIterator op)
        var.array_size = type.array_size;
 }
 
+void SpirVModule::Reflection::reflect_access(CodeIterator op)
+{
+       if(current_block)
+       {
+               unsigned id = (get_opcode(*op)==OP_LOAD ? *(op+3) : *(op+1));
+               auto i = access_chain_bases.find(id);
+               if(i!=access_chain_bases.end())
+                       id = i->second;
+               Variable &var = variables[id];
+               auto j = find(current_block->accessed_variables, &var);
+               if(j==current_block->accessed_variables.end())
+                       current_block->accessed_variables.push_back(&var);
+       }
+}
+
+void SpirVModule::Reflection::reflect_access_chain(CodeIterator op)
+{
+       access_chain_bases[*(op+2)] = *(op+3);
+}
+
 void SpirVModule::Reflection::reflect_decorate(CodeIterator op)
 {
        unsigned id = *(op+1);
@@ -508,5 +566,38 @@ void SpirVModule::Reflection::reflect_member_decorate(CodeIterator op)
        }
 }
 
+void SpirVModule::Reflection::reflect_label(CodeIterator op)
+{
+       current_block = &blocks[*(op+1)];
+}
+
+void SpirVModule::Reflection::reflect_branch(CodeIterator op)
+{
+       InstructionBlock &block = blocks[*(op+1)];
+       block.condition = &true_condition;
+       current_block->successors.push_back(&block);
+}
+
+void SpirVModule::Reflection::reflect_branch_conditional(CodeIterator op)
+{
+       InstructionBlock &true_block = blocks[*(op+2)];
+       InstructionBlock &false_block = blocks[*(op+3)];
+
+       auto i = constants.find(*(op+1));
+       if(i!=constants.end() && i->second.constant_id)
+       {
+               if(!true_block.condition)
+                       true_block.condition = &i->second;
+               if(!false_block.condition)
+               {
+                       false_block.condition = &i->second;
+                       false_block.negate_condition = true;
+               }
+       }
+
+       current_block->successors.push_back(&true_block);
+       current_block->successors.push_back(&false_block);
+}
+
 } // namespace GL
 } // namespace Msp
index c669af0506ab64d26b25bf504ec28533e7e1cf92..15c07fb0366bd79b0e2ee737f4c21e42b3ed58f8 100644 (file)
@@ -177,6 +177,14 @@ public:
                };
        };
 
+       struct InstructionBlock
+       {
+               const Constant *condition = 0;
+               bool negate_condition = false;
+               std::vector<const Variable *> accessed_variables;
+               std::vector<const InstructionBlock *> successors;
+       };
+
 private:
        struct TypeInfo
        {
@@ -197,6 +205,10 @@ private:
                std::map<unsigned, EntryPoint> entry_points;
                std::map<unsigned, Structure> structs;
                std::map<unsigned, Variable> variables;
+               std::map<unsigned, InstructionBlock> blocks;
+               std::map<unsigned, unsigned> access_chain_bases;
+               Constant true_condition;
+               InstructionBlock *current_block = 0;
 
                static std::uint32_t get_opcode(std::uint32_t);
                static CodeIterator get_op_end(const CodeIterator &);
@@ -219,8 +231,13 @@ private:
                void reflect_pointer_type(CodeIterator);
                void reflect_constant(CodeIterator);
                void reflect_variable(CodeIterator);
+               void reflect_access(CodeIterator);
+               void reflect_access_chain(CodeIterator);
                void reflect_decorate(CodeIterator);
                void reflect_member_decorate(CodeIterator);
+               void reflect_label(CodeIterator);
+               void reflect_branch(CodeIterator);
+               void reflect_branch_conditional(CodeIterator);
        };
 
        std::vector<std::uint32_t> code;
@@ -228,6 +245,7 @@ private:
        std::vector<Structure> structs;
        std::vector<Variable> variables;
        std::vector<Constant> spec_constants;
+       std::vector<InstructionBlock> blocks;
 
 public:
        virtual Format get_format() const { return SPIR_V; }
@@ -243,6 +261,7 @@ public:
        const std::vector<EntryPoint> &get_entry_points() const { return entry_points; }
        const std::vector<Variable> &get_variables() const { return variables; }
        const std::vector<Constant> &get_spec_constants() const { return spec_constants; }
+       const std::vector<InstructionBlock> &get_blocks() const { return blocks; }
 };
 
 } // namespace GL
index acdb20e88a5488063c3778a598af10f411b1b541..fdbfa6d4be08004f38be801c3a46c69ae0745b04 100644 (file)
@@ -36,9 +36,11 @@ void Program::add_stages(const Module &mod, const map<string, int> &spec_values)
 
        if(mod.get_format()==Module::SPIR_V)
        {
-               collect_uniforms(static_cast<const SpirVModule &>(mod));
-               collect_attributes(static_cast<const SpirVModule &>(mod));
-               collect_builtins(static_cast<const SpirVModule &>(mod));
+               const SpirVModule &spirv_mod = static_cast<const SpirVModule &>(mod);
+               vector<uint8_t> used_variables = collect_used_variables(spirv_mod, spec_values);
+               collect_uniforms(spirv_mod, used_variables);
+               collect_attributes(spirv_mod, used_variables);
+               collect_builtins(spirv_mod);
        }
 
        finalize_uniforms();
@@ -49,15 +51,79 @@ void Program::add_stages(const Module &mod, const map<string, int> &spec_values)
                require_type(a.type);
 }
 
-void Program::collect_uniforms(const SpirVModule &mod)
+vector<uint8_t> Program::collect_used_variables(const SpirVModule &mod, const map<string, int> &spec_values)
+{
+       std::map<unsigned, int> spec_values_by_id;
+       for(const SpirVModule::Constant &c: mod.get_spec_constants())
+       {
+               auto i = spec_values.find(c.name);
+               if(i!=spec_values.end())
+                       spec_values_by_id[c.constant_id] = i->second;
+       }
+
+       const vector<SpirVModule::InstructionBlock> &blocks = mod.get_blocks();
+       vector<uint8_t> visited(blocks.size(), 4);
+       for(unsigned i=0; i<blocks.size(); ++i)
+       {
+               const SpirVModule::InstructionBlock &b = blocks[i];
+
+               bool cond = true;
+               if(b.condition)
+               {
+                       cond = b.condition->i_value;
+                       auto j = spec_values_by_id.find(b.condition->constant_id);
+                       if(j!=spec_values_by_id.end())
+                               cond = j->second;
+                       if(b.negate_condition)
+                               cond = !cond;
+               }
+
+               visited[i] |= cond*2;
+               for(const SpirVModule::InstructionBlock *s: b.successors)
+                       visited[s-blocks.data()] &= 3;
+       }
+
+       for(unsigned i=0; i<blocks.size(); ++i)
+               if(visited[i]&4)
+                       collect_visited_blocks(blocks, i, visited);
+
+       const vector<SpirVModule::Variable> &variables = mod.get_variables();
+       vector<uint8_t> used(variables.size());
+       for(unsigned i=0; i<blocks.size(); ++i)
+               if(visited[i]&1)
+               {
+                       for(const SpirVModule::Variable *v: blocks[i].accessed_variables)
+                               used[v-variables.data()] = 1;
+               }
+
+       return used;
+}
+
+void Program::collect_visited_blocks(const vector<SpirVModule::InstructionBlock> &blocks, unsigned i, vector<uint8_t> &visited)
+{
+       visited[i] |= 1;
+       for(const SpirVModule::InstructionBlock *s: blocks[i].successors)
+       {
+               unsigned j = s-blocks.data();
+               if((visited[j]&3)==2)
+                       collect_visited_blocks(blocks, j, visited);
+       }
+}
+
+void Program::collect_uniforms(const SpirVModule &mod, const vector<uint8_t> &used_variables)
 {
        // Prepare the default block
        reflect_data.uniform_blocks.push_back(ReflectData::UniformBlockInfo());
        vector<vector<string> > block_uniform_names(1);
 
+       const vector<SpirVModule::Variable> &variables = mod.get_variables();
        unsigned n_descriptor_sets = 0;
-       for(const SpirVModule::Variable &v: mod.get_variables())
+       for(unsigned i=0; i<variables.size(); ++i)
        {
+               if(!used_variables[i])
+                       continue;
+
+               const SpirVModule::Variable &v = variables[i];
                if((v.storage==SpirVModule::UNIFORM || v.storage==SpirVModule::PUSH_CONSTANT) && v.struct_type)
                {
                        reflect_data.uniform_blocks.push_back(ReflectData::UniformBlockInfo());
@@ -155,13 +221,14 @@ void Program::collect_block_uniforms(const SpirVModule::Structure &strct, const
        }
 }
 
-void Program::collect_attributes(const SpirVModule &mod)
+void Program::collect_attributes(const SpirVModule &mod, const vector<uint8_t> &used_variables)
 {
+       const vector<SpirVModule::Variable> &variables = mod.get_variables();
        for(const SpirVModule::EntryPoint &e: mod.get_entry_points())
                if(e.stage==SpirVModule::VERTEX && e.name=="main")
                {
                        for(const SpirVModule::Variable *v: e.globals)
-                               if(v->storage==SpirVModule::INPUT)
+                               if(v->storage==SpirVModule::INPUT && used_variables[v-variables.data()])
                                {
                                        reflect_data.attributes.push_back(ReflectData::AttributeInfo());
                                        ReflectData::AttributeInfo &info = reflect_data.attributes.back();
index 9a4eb27a2dcb94f7fbfadf0102a290e9460002a1..7f6ad54e3a6c9a14c3cfb4224d1d0d83ce07c69a 100644 (file)
@@ -61,9 +61,11 @@ public:
 
        void add_stages(const Module &, const std::map<std::string, int> & = std::map<std::string, int>());
 private:
-       void collect_uniforms(const SpirVModule &);
+       static std::vector<std::uint8_t> collect_used_variables(const SpirVModule &, const std::map<std::string, int> &);
+       static void collect_visited_blocks(const std::vector<SpirVModule::InstructionBlock> &, unsigned, std::vector<std::uint8_t> &);
+       void collect_uniforms(const SpirVModule &, const std::vector<std::uint8_t> &);
        void collect_block_uniforms(const SpirVModule::Structure &, const std::string &, unsigned, std::vector<std::string> &);
-       void collect_attributes(const SpirVModule &);
+       void collect_attributes(const SpirVModule &, const std::vector<std::uint8_t> &);
        void collect_builtins(const SpirVModule &);
        void collect_builtins(const SpirVModule::Structure &);