]> git.tdb.fi Git - libs/gl.git/commitdiff
Restructure Program to remove linking from the public interface
authorMikko Rasa <tdb@tdb.fi>
Sun, 26 Sep 2021 13:41:03 +0000 (16:41 +0300)
committerMikko Rasa <tdb@tdb.fi>
Sun, 26 Sep 2021 14:56:52 +0000 (17:56 +0300)
It doesn't map to Vulkan at all.  Doing it from add_stages also
simplifies handling some temporary data.

source/core/program.cpp
source/core/program.h
source/resources/resources.cpp

index 82f438cdbf02112ff3362018b7587b066f251fbd..4e09038074613d64f0d5559127e51df305a156c2 100644 (file)
@@ -35,7 +35,6 @@ Program::Program(const Module &mod, const map<string, int> &spec_values)
 {
        init();
        add_stages(mod, spec_values);
-       link();
 }
 
 void Program::init()
@@ -44,8 +43,6 @@ void Program::init()
 
        id = glCreateProgram();
        fill(stage_ids, stage_ids+MAX_STAGES, 0);
-       module = 0;
-       transient = 0;
        linked = false;
 }
 
@@ -62,12 +59,20 @@ void Program::add_stages(const Module &mod, const map<string, int> &spec_values)
        if(has_stages())
                throw invalid_operation("Program::add_stages");
 
+       TransientData transient;
        switch(mod.get_format())
        {
-       case Module::GLSL: return add_glsl_stages(static_cast<const GlslModule &>(mod), spec_values);
-       case Module::SPIR_V: return add_spirv_stages(static_cast<const SpirVModule &>(mod), spec_values);
-       default: throw invalid_argument("Program::add_stages");
+       case Module::GLSL:
+               add_glsl_stages(static_cast<const GlslModule &>(mod), spec_values, transient);
+               break;
+       case Module::SPIR_V:
+               add_spirv_stages(static_cast<const SpirVModule &>(mod), spec_values, transient);
+               break;
+       default:
+               throw invalid_argument("Program::add_stages");
        }
+
+       finalize(mod, transient);
 }
 
 bool Program::has_stages() const
@@ -104,10 +109,8 @@ unsigned Program::add_stage(Stage type)
        return stage_id;
 }
 
-void Program::add_glsl_stages(const GlslModule &mod, const map<string, int> &spec_values)
+void Program::add_glsl_stages(const GlslModule &mod, const map<string, int> &spec_values, TransientData &transient)
 {
-       module = &mod;
-
        SL::Compiler compiler;
        compiler.set_source(mod.get_prepared_source(), "<module>");
        compiler.specialize(spec_values);
@@ -119,6 +122,9 @@ void Program::add_glsl_stages(const GlslModule &mod, const map<string, int> &spe
 #endif
 
        vector<SL::Stage::Type> stages = compiler.get_stages();
+       if(stages.empty())
+               throw invalid_argument("Program::add_glsl_stages");
+
        for(SL::Stage::Type st: stages)
        {
                unsigned stage_id = 0;
@@ -147,16 +153,14 @@ void Program::add_glsl_stages(const GlslModule &mod, const map<string, int> &spe
                                glBindFragDataLocation(id, kvp.second, kvp.first.c_str());
                }
 
-               compile_glsl_stage(stage_id);
+               compile_glsl_stage(mod, stage_id);
        }
 
-       if(!transient)
-               transient = new TransientData;
-       transient->textures = compiler.get_texture_bindings();
-       transient->blocks = compiler.get_uniform_block_bindings();
+       transient.textures = compiler.get_texture_bindings();
+       transient.blocks = compiler.get_uniform_block_bindings();
 }
 
-void Program::compile_glsl_stage(unsigned stage_id)
+void Program::compile_glsl_stage(const GlslModule &mod, unsigned stage_id)
 {
        glCompileShader(stage_id);
        int status = 0;
@@ -167,8 +171,7 @@ void Program::compile_glsl_stage(unsigned stage_id)
        string info_log(info_log_len+1, 0);
        glGetShaderInfoLog(stage_id, info_log_len+1, &info_log_len, &info_log[0]);
        info_log.erase(info_log_len);
-       if(module && module->get_format()==Module::GLSL)
-               info_log = static_cast<const GlslModule *>(module)->get_source_map().translate_errors(info_log);
+       info_log = mod.get_source_map().translate_errors(info_log);
 
        if(!status)
                throw compile_error(info_log);
@@ -178,13 +181,11 @@ void Program::compile_glsl_stage(unsigned stage_id)
 #endif
 }
 
-void Program::add_spirv_stages(const SpirVModule &mod, const map<string, int> &spec_values)
+void Program::add_spirv_stages(const SpirVModule &mod, const map<string, int> &spec_values, TransientData &transient)
 {
        static Require _req(ARB_gl_spirv);
        static Require _req2(ARB_ES2_compatibility);
 
-       module = &mod;
-
        unsigned n_stages = 0;
        unsigned used_stage_ids[MAX_STAGES];
        for(const SpirVModule::EntryPoint &e: mod.get_entry_points())
@@ -201,12 +202,12 @@ void Program::add_spirv_stages(const SpirVModule &mod, const map<string, int> &s
                used_stage_ids[n_stages++] = stage_id;
        }
 
+       if(!n_stages)
+               throw invalid_argument("Program::add_spirv_stages");
+
        const vector<uint32_t> &code = mod.get_code();
        glShaderBinary(n_stages, used_stage_ids, GL_SHADER_BINARY_FORMAT_SPIR_V, &code[0], code.size()*4);
 
-       if(!spec_values.empty() && !transient)
-               transient = new TransientData;
-
        const vector<SpirVModule::Constant> &spec_consts = mod.get_spec_constants();
        vector<unsigned> spec_id_array;
        vector<unsigned> spec_value_array;
@@ -219,7 +220,7 @@ void Program::add_spirv_stages(const SpirVModule &mod, const map<string, int> &s
                {
                        spec_id_array.push_back(c.constant_id);
                        spec_value_array.push_back(i->second);
-                       transient->spec_values[c.constant_id] = i->second;
+                       transient.spec_values[c.constant_id] = i->second;
                }
        }
 
@@ -229,11 +230,8 @@ void Program::add_spirv_stages(const SpirVModule &mod, const map<string, int> &s
                        glSpecializeShader(stage_ids[i], j->name.c_str(), spec_id_array.size(), &spec_id_array[0], &spec_value_array[0]);
 }
 
-void Program::link()
+void Program::finalize(const Module &mod, const TransientData &transient)
 {
-       if(!has_stages())
-               throw invalid_operation("Program::link");
-
        reflect_data = ReflectData();
 
        glLinkProgram(id);
@@ -246,8 +244,8 @@ void Program::link()
        string info_log(info_log_len+1, 0);
        glGetProgramInfoLog(id, info_log_len+1, &info_log_len, &info_log[0]);
        info_log.erase(info_log_len);
-       if(module && module->get_format()==Module::GLSL)
-               info_log = static_cast<const GlslModule *>(module)->get_source_map().translate_errors(info_log);
+       if(mod.get_format()==Module::GLSL)
+               info_log = static_cast<const GlslModule &>(mod).get_source_map().translate_errors(info_log);
 
        if(!linked)
                throw compile_error(info_log);
@@ -256,46 +254,40 @@ void Program::link()
                IO::print("Program link info log:\n%s", info_log);
 #endif
 
-       if(module->get_format()==Module::GLSL)
+       if(mod.get_format()==Module::GLSL)
        {
                query_uniforms();
                query_attributes();
-               if(transient)
+               for(unsigned i=0; i<reflect_data.uniform_blocks.size(); ++i)
                {
-                       for(unsigned i=0; i<reflect_data.uniform_blocks.size(); ++i)
+                       auto j = transient.blocks.find(reflect_data.uniform_blocks[i].name);
+                       if(j!=transient.blocks.end())
                        {
-                               auto j = transient->blocks.find(reflect_data.uniform_blocks[i].name);
-                               if(j!=transient->blocks.end())
-                               {
-                                       glUniformBlockBinding(id, i, j->second);
-                                       reflect_data.uniform_blocks[i].bind_point = j->second;
-                               }
+                               glUniformBlockBinding(id, i, j->second);
+                               reflect_data.uniform_blocks[i].bind_point = j->second;
                        }
+               }
 
-                       if(!ARB_separate_shader_objects)
-                               glUseProgram(id);
-                       for(const auto &kvp: transient->textures)
+               if(!ARB_separate_shader_objects)
+                       glUseProgram(id);
+               for(const auto &kvp: transient.textures)
+               {
+                       int location = get_uniform_location(kvp.first);
+                       if(location>=0)
                        {
-                               int location = get_uniform_location(kvp.first);
-                               if(location>=0)
-                               {
-                                       if(ARB_separate_shader_objects)
-                                               glProgramUniform1i(id, location, kvp.second);
-                                       else
-                                               glUniform1i(location, kvp.second);
-                               }
+                               if(ARB_separate_shader_objects)
+                                       glProgramUniform1i(id, location, kvp.second);
+                               else
+                                       glUniform1i(location, kvp.second);
                        }
                }
        }
-       else if(module->get_format()==Module::SPIR_V)
+       else if(mod.get_format()==Module::SPIR_V)
        {
-               collect_uniforms();
-               collect_attributes();
+               collect_uniforms(static_cast<const SpirVModule &>(mod), transient.spec_values);
+               collect_attributes(static_cast<const SpirVModule &>(mod));
        }
 
-       delete transient;
-       transient = 0;
-
        for(const ReflectData::UniformInfo &u: reflect_data.uniforms)
                require_type(u.type);
        for(const ReflectData::AttributeInfo &a: reflect_data.attributes)
@@ -458,10 +450,8 @@ void Program::query_attributes()
        }
 }
 
-void Program::collect_uniforms()
+void Program::collect_uniforms(const SpirVModule &mod, const map<unsigned, int> &spec_values)
 {
-       const SpirVModule &mod = static_cast<const SpirVModule &>(*module);
-
        // Prepare the default block
        reflect_data.uniform_blocks.push_back(ReflectData::UniformBlockInfo());
        vector<vector<string> > block_uniform_names(1);
@@ -480,7 +470,7 @@ void Program::collect_uniforms()
                        if(!v.name.empty())
                                prefix = v.struct_type->name+".";
                        block_uniform_names.push_back(vector<string>());
-                       collect_block_uniforms(*v.struct_type, prefix, 0, block_uniform_names.back());
+                       collect_block_uniforms(*v.struct_type, prefix, 0, spec_values, block_uniform_names.back());
                }
                else if(v.storage==SpirVModule::UNIFORM_CONSTANT && v.location>=0)
                {
@@ -515,7 +505,7 @@ void Program::collect_uniforms()
        reflect_data.update_layout_hash();
 }
 
-void Program::collect_block_uniforms(const SpirVModule::Structure &strct, const string &prefix, unsigned base_offset, vector<string> &uniform_names)
+void Program::collect_block_uniforms(const SpirVModule::Structure &strct, const string &prefix, unsigned base_offset, const map<unsigned, int> &spec_values, vector<string> &uniform_names)
 {
        for(const SpirVModule::StructMember &m: strct.members)
        {
@@ -526,21 +516,18 @@ void Program::collect_block_uniforms(const SpirVModule::Structure &strct, const
                        if(m.array_size_spec)
                        {
                                array_size = m.array_size_spec->i_value;
-                               if(transient)
-                               {
-                                       auto j = transient->spec_values.find(m.array_size_spec->constant_id);
-                                       if(j!=transient->spec_values.end())
-                                               array_size = j->second;
-                               }
+                               auto j = spec_values.find(m.array_size_spec->constant_id);
+                               if(j!=spec_values.end())
+                                       array_size = j->second;
                        }
 
                        if(array_size)
                        {
                                for(unsigned j=0; j<array_size; ++j, offset+=m.array_stride)
-                                       collect_block_uniforms(*m.struct_type, format("%s%s[%d].", prefix, m.name, j), offset, uniform_names);
+                                       collect_block_uniforms(*m.struct_type, format("%s%s[%d].", prefix, m.name, j), offset, spec_values, uniform_names);
                        }
                        else
-                               collect_block_uniforms(*m.struct_type, prefix+m.name+".", offset, uniform_names);
+                               collect_block_uniforms(*m.struct_type, prefix+m.name+".", offset, spec_values, uniform_names);
                }
                else
                {
@@ -559,10 +546,8 @@ void Program::collect_block_uniforms(const SpirVModule::Structure &strct, const
        }
 }
 
-void Program::collect_attributes()
+void Program::collect_attributes(const SpirVModule &mod)
 {
-       const SpirVModule &mod = static_cast<const SpirVModule &>(*module);
-
        for(const SpirVModule::EntryPoint &e: mod.get_entry_points())
                if(e.stage==SpirVModule::VERTEX && e.name=="main")
                {
@@ -672,12 +657,7 @@ void Program::set_stage_debug_name(unsigned stage_id, Stage type)
 Program::Loader::Loader(Program &p, Collection &c):
        DataFile::CollectionObjectLoader<Program>(p, &c)
 {
-       add("module",          &Loader::module);
-}
-
-void Program::Loader::finish()
-{
-       obj.link();
+       add("module", &Loader::module);
 }
 
 void Program::Loader::module(const string &n)
index b7f46ee4082a74a17ff68aa24328b6c5c01cf9f5..6301edbe2c6e23ded094519e38f30db842ebbb55 100644 (file)
@@ -27,8 +27,6 @@ public:
                Loader(Program &, Collection &);
 
        private:
-               virtual void finish();
-
                void module(const std::string &);
        };
 
@@ -67,8 +65,6 @@ private:
 
        unsigned id;
        unsigned stage_ids[MAX_STAGES];
-       const Module *module;
-       TransientData *transient;
        bool linked;
        ReflectData reflect_data;
        std::string debug_name;
@@ -89,23 +85,19 @@ public:
 private:
        bool has_stages() const;
        unsigned add_stage(Stage);
-       void add_glsl_stages(const GlslModule &, const std::map<std::string, int> &);
-       void compile_glsl_stage(unsigned);
-       void add_spirv_stages(const SpirVModule &, const std::map<std::string, int> &);
+       void add_glsl_stages(const GlslModule &, const std::map<std::string, int> &, TransientData &);
+       void compile_glsl_stage(const GlslModule &, unsigned);
+       void add_spirv_stages(const SpirVModule &, const std::map<std::string, int> &, TransientData &);
 
-public:
-       void link();
-private:
+       void finalize(const Module &, const TransientData &);
        void query_uniforms();
        void query_uniform_blocks(const std::vector<ReflectData::UniformInfo *> &);
        void query_attributes();
-       void collect_uniforms();
-       void collect_block_uniforms(const SpirVModule::Structure &, const std::string &, unsigned, std::vector<std::string> &);
-       void collect_attributes();
+       void collect_uniforms(const SpirVModule &, const std::map<unsigned, int> &);
+       void collect_block_uniforms(const SpirVModule::Structure &, const std::string &, unsigned, const std::map<unsigned, int> &, std::vector<std::string> &);
+       void collect_attributes(const SpirVModule &);
 
 public:
-       bool is_linked() const { return linked; }
-
        ReflectData::LayoutHash get_uniform_layout_hash() const { return reflect_data.layout_hash; }
        const std::vector<ReflectData::UniformBlockInfo> &get_uniform_blocks() const { return reflect_data.uniform_blocks; }
        const ReflectData::UniformBlockInfo &get_uniform_block_info(const std::string &) const;
index 8a7780cd4016521c2b0cb39428270ad0f50c1609..62bdabf60594e796439833b110e512bcd3544b24 100644 (file)
@@ -244,7 +244,6 @@ Program *Resources::create_program(const string &name)
                Module &module = get<Module>(base);
                RefPtr<Program> shprog = new Program;
                shprog->add_stages(module);
-               shprog->link();
                return shprog.release();
        }