]> git.tdb.fi Git - libs/gl.git/blobdiff - source/backends/opengl/program_backend.cpp
Support compute shaders and compute operations
[libs/gl.git] / source / backends / opengl / program_backend.cpp
index 32c6d6e137c36e4789b47ccb2b381a7823396be2..398d35ce7a6c364aba56634cb7aa0be67ad174d6 100644 (file)
@@ -1,5 +1,6 @@
 #include <cstring>
 #include <msp/core/algorithm.h>
+#include <msp/gl/extensions/arb_compute_shader.h>
 #include <msp/gl/extensions/arb_es2_compatibility.h>
 #include <msp/gl/extensions/arb_fragment_shader.h>
 #include <msp/gl/extensions/arb_gl_spirv.h>
@@ -12,7 +13,7 @@
 #include <msp/gl/extensions/khr_debug.h>
 #include <msp/gl/extensions/nv_non_square_matrices.h>
 #include <msp/io/print.h>
-#include "deviceinfo.h"
+#include "device.h"
 #include "error.h"
 #include "program.h"
 #include "program_backend.h"
@@ -46,6 +47,17 @@ OpenGLProgram::OpenGLProgram()
        id = glCreateProgram();
 }
 
+OpenGLProgram::OpenGLProgram(OpenGLProgram &&other):
+       id(other.id),
+       linked(other.linked),
+       uniform_calls(move(other.uniform_calls)),
+       debug_name(move(other.debug_name))
+{
+       move(other.stage_ids, other.stage_ids+MAX_STAGES, stage_ids);
+       other.id = 0;
+       fill(other.stage_ids, other.stage_ids+MAX_STAGES, 0);
+}
+
 OpenGLProgram::~OpenGLProgram()
 {
        for(unsigned i=0; i<MAX_STAGES; ++i)
@@ -70,6 +82,7 @@ unsigned OpenGLProgram::add_stage(Stage type)
        case VERTEX: { static Require _req(ARB_vertex_shader); gl_type = GL_VERTEX_SHADER; } break;
        case GEOMETRY: { static Require _req(ARB_geometry_shader4); gl_type = GL_GEOMETRY_SHADER; } break;
        case FRAGMENT: { static Require _req(ARB_fragment_shader); gl_type = GL_FRAGMENT_SHADER; } break;
+       case COMPUTE: { static Require _req(ARB_compute_shader); gl_type = GL_COMPUTE_SHADER; } break;
        default: throw invalid_argument("OpenGLProgram::add_stage");
        }
 
@@ -88,9 +101,9 @@ unsigned OpenGLProgram::add_stage(Stage type)
        return stage_id;
 }
 
-void OpenGLProgram::add_glsl_stages(const GlslModule &mod, const map<string, int> &spec_values, TransientData &transient)
+void OpenGLProgram::add_glsl_stages(const GlslModule &mod, const map<string, int> &spec_values)
 {
-       SL::Compiler compiler(DeviceInfo::get_global().glsl_features);
+       SL::Compiler compiler(Device::get_current().get_info().glsl_features);
        compiler.set_source(mod.get_prepared_source(), "<module>");
        compiler.specialize(spec_values);
        compiler.compile(SL::Compiler::PROGRAM);
@@ -112,6 +125,7 @@ void OpenGLProgram::add_glsl_stages(const GlslModule &mod, const map<string, int
                case SL::Stage::VERTEX: stage_id = add_stage(VERTEX); break;
                case SL::Stage::GEOMETRY: stage_id = add_stage(GEOMETRY); break;
                case SL::Stage::FRAGMENT: stage_id = add_stage(FRAGMENT); break;
+               case SL::Stage::COMPUTE: stage_id = add_stage(COMPUTE); break;
                default: throw invalid_operation("OpenGLProgram::add_glsl_stages");
                }
 
@@ -135,11 +149,50 @@ void OpenGLProgram::add_glsl_stages(const GlslModule &mod, const map<string, int
                compile_glsl_stage(mod, stage_id);
        }
 
-       transient.textures = compiler.get_texture_bindings();
-       transient.blocks = compiler.get_uniform_block_bindings();
-
        ReflectData &rd = static_cast<Program *>(this)->reflect_data;
        rd.n_clip_distances = compiler.get_n_clip_distances();
+
+       link(mod);
+       query_uniforms();
+       query_attributes();
+       if(is_compute())
+       {
+               int wg_size[3];
+               glGetProgramiv(id, GL_COMPUTE_WORK_GROUP_SIZE, wg_size);
+               rd.compute_wg_size = LinAl::Vector<unsigned, 3>(wg_size[0], wg_size[1], wg_size[2]);
+       }
+
+       const map<string, unsigned> &block_bindings = compiler.get_uniform_block_bindings();
+       if(!block_bindings.empty())
+       {
+               for(unsigned i=0; i<rd.uniform_blocks.size(); ++i)
+               {
+                       auto j = block_bindings.find(rd.uniform_blocks[i].name);
+                       if(j!=block_bindings.end())
+                       {
+                               glUniformBlockBinding(id, i, j->second);
+                               rd.uniform_blocks[i].bind_point = j->second;
+                       }
+               }
+       }
+
+       const map<string, unsigned> &tex_bindings = compiler.get_texture_bindings();
+       if(!tex_bindings.empty())
+       {
+               if(!ARB_separate_shader_objects)
+                       glUseProgram(id);
+               for(const auto &kvp: tex_bindings)
+               {
+                       int location = static_cast<const Program *>(this)->get_uniform_location(kvp.first);
+                       if(location>=0)
+                       {
+                               if(ARB_separate_shader_objects)
+                                       glProgramUniform1i(id, location, kvp.second);
+                               else
+                                       glUniform1i(location, kvp.second);
+                       }
+               }
+       }
 }
 
 void OpenGLProgram::compile_glsl_stage(const GlslModule &mod, unsigned stage_id)
@@ -178,6 +231,7 @@ void OpenGLProgram::add_spirv_stages(const SpirVModule &mod, const map<string, i
                case SpirVModule::VERTEX: stage_id = add_stage(VERTEX); break;
                case SpirVModule::GEOMETRY: stage_id = add_stage(GEOMETRY); break;
                case SpirVModule::FRAGMENT: stage_id = add_stage(FRAGMENT); break;
+               case SpirVModule::COMPUTE: stage_id = add_stage(COMPUTE); break;
                default: throw invalid_operation("OpenGLProgram::add_spirv_stages");
                }
 
@@ -209,9 +263,11 @@ void OpenGLProgram::add_spirv_stages(const SpirVModule &mod, const map<string, i
        for(unsigned i=0; i<MAX_STAGES; ++i)
                if(stage_ids[i])
                        glSpecializeShader(stage_ids[i], j->name.c_str(), spec_id_array.size(), &spec_id_array[0], &spec_value_array[0]);
+
+       link(mod);
 }
 
-void OpenGLProgram::finalize(const Module &mod, TransientData &transient)
+void OpenGLProgram::link(const Module &mod)
 {
        glLinkProgram(id);
        int status = 0;
@@ -232,13 +288,6 @@ void OpenGLProgram::finalize(const Module &mod, TransientData &transient)
        if(!info_log.empty())
                IO::print("Program link info log:\n%s", info_log);
 #endif
-
-       if(mod.get_format()==Module::GLSL)
-       {
-               query_uniforms();
-               query_attributes();
-               apply_bindings(transient);
-       }
 }
 
 void OpenGLProgram::query_uniforms()
@@ -263,7 +312,7 @@ void OpenGLProgram::query_uniforms()
                        if(len>3 && !strcmp(name+len-3, "[0]"))
                                name[len-3] = 0;
 
-                       rd.uniforms.push_back(ReflectData::UniformInfo());
+                       rd.uniforms.emplace_back();
                        ReflectData::UniformInfo &info = rd.uniforms.back();
                        info.name = name;
                        info.tag = name;
@@ -285,7 +334,7 @@ void OpenGLProgram::query_uniforms()
                query_uniform_blocks(uniforms_by_index);
        }
 
-       rd.uniform_blocks.push_back(ReflectData::UniformBlockInfo());
+       rd.uniform_blocks.emplace_back();
        ReflectData::UniformBlockInfo &default_block = rd.uniform_blocks.back();
 
        for(ReflectData::UniformInfo &u: rd.uniforms)
@@ -320,7 +369,7 @@ void OpenGLProgram::query_uniform_blocks(const vector<ReflectData::UniformInfo *
                char name[128];
                int len;
                glGetActiveUniformBlockName(id, i, sizeof(name), &len, name);
-               rd.uniform_blocks.push_back(ReflectData::UniformBlockInfo());
+               rd.uniform_blocks.emplace_back();
                ReflectData::UniformBlockInfo &info = rd.uniform_blocks.back();
                info.name = name;
 
@@ -397,7 +446,7 @@ void OpenGLProgram::query_attributes()
                        if(len>3 && !strcmp(name+len-3, "[0]"))
                                name[len-3] = 0;
 
-                       rd.attributes.push_back(ReflectData::AttributeInfo());
+                       rd.attributes.emplace_back();
                        ReflectData::AttributeInfo &info = rd.attributes.back();
                        info.name = name;
                        info.location = glGetAttribLocation(id, name);
@@ -405,35 +454,8 @@ void OpenGLProgram::query_attributes()
                        info.type = from_gl_type(type);
                }
        }
-}
-
-void OpenGLProgram::apply_bindings(const TransientData &transient)
-{
-       ReflectData &rd = static_cast<Program *>(this)->reflect_data;
 
-       for(unsigned i=0; i<rd.uniform_blocks.size(); ++i)
-       {
-               auto j = transient.blocks.find(rd.uniform_blocks[i].name);
-               if(j!=transient.blocks.end())
-               {
-                       glUniformBlockBinding(id, i, j->second);
-                       rd.uniform_blocks[i].bind_point = j->second;
-               }
-       }
-
-       if(!ARB_separate_shader_objects)
-               glUseProgram(id);
-       for(const auto &kvp: transient.textures)
-       {
-               int location = static_cast<const Program *>(this)->get_uniform_location(kvp.first);
-               if(location>=0)
-               {
-                       if(ARB_separate_shader_objects)
-                               glProgramUniform1i(id, location, kvp.second);
-                       else
-                               glUniform1i(location, kvp.second);
-               }
-       }
+       sort_member(rd.attributes, &ReflectData::AttributeInfo::name);
 }
 
 void OpenGLProgram::finalize_uniforms()
@@ -521,7 +543,7 @@ void OpenGLProgram::set_debug_name(const string &name)
 void OpenGLProgram::set_stage_debug_name(unsigned stage_id, Stage type)
 {
 #ifdef DEBUG
-       static const char *const suffixes[] = { " [VS]", " [GS]", " [FS]" };
+       static const char *const suffixes[] = { " [VS]", " [GS]", " [FS]", " [CS]" };
        string name = debug_name+suffixes[type];
        glObjectLabel(GL_SHADER, stage_id, name.size(), name.c_str());
 #else