]> git.tdb.fi Git - libs/gl.git/commitdiff
Support compute shaders and compute operations
authorMikko Rasa <tdb@tdb.fi>
Fri, 15 Apr 2022 21:17:05 +0000 (00:17 +0300)
committerMikko Rasa <tdb@tdb.fi>
Fri, 15 Apr 2022 21:17:05 +0000 (00:17 +0300)
23 files changed:
extensions/arb_compute_shader.glext [new file with mode: 0644]
source/backends/opengl/commands_backend.cpp
source/backends/opengl/commands_backend.h
source/backends/opengl/program_backend.cpp
source/backends/opengl/program_backend.h
source/backends/vulkan/commands_backend.cpp
source/backends/vulkan/commands_backend.h
source/backends/vulkan/module_backend.cpp
source/backends/vulkan/pipelinecache.cpp
source/backends/vulkan/pipelinestate_backend.cpp
source/backends/vulkan/pipelinestate_backend.h
source/backends/vulkan/program_backend.cpp
source/backends/vulkan/program_backend.h
source/backends/vulkan/vulkan.cpp
source/backends/vulkan/vulkan.h
source/core/commands.h
source/core/module.cpp
source/core/module.h
source/core/program.cpp
source/core/program.h
source/core/reflectdata.h
source/render/renderer.cpp
source/render/renderer.h

diff --git a/extensions/arb_compute_shader.glext b/extensions/arb_compute_shader.glext
new file mode 100644 (file)
index 0000000..b8683f4
--- /dev/null
@@ -0,0 +1 @@
+extension ARB_compute_shader
index 7d599153e74e10ecec77bd603af376783b021286..76469ab395be70ef4da6c8273eacf9604af91159 100644 (file)
@@ -1,4 +1,5 @@
 #include <algorithm>
+#include <msp/gl/extensions/arb_compute_shader.h>
 #include <msp/gl/extensions/arb_direct_state_access.h>
 #include <msp/gl/extensions/arb_draw_instanced.h>
 #include <msp/gl/extensions/arb_occlusion_query.h>
@@ -78,6 +79,17 @@ void OpenGLCommands::draw_instanced(const Batch &batch, unsigned count)
        glDrawElementsInstanced(batch.gl_prim_type, batch.size(), batch.gl_index_type, data_ptr, count);
 }
 
+void OpenGLCommands::dispatch(unsigned count_x, unsigned count_y, unsigned count_z)
+{
+       if(!pipeline_state)
+               throw invalid_operation("OpenGLCommands::dispatch_compute");
+
+       static Require req(ARB_compute_shader);
+
+       pipeline_state->apply();
+       glDispatchCompute(count_x, count_y, count_z);
+}
+
 void OpenGLCommands::resolve_multisample(Framebuffer &target)
 {
        const Framebuffer *source = (pipeline_state ? pipeline_state->get_framebuffer() : 0);
index f4976f3f7b48bdf38b837ce0fb8f79c67c34aa29..e6a3161ad105f464e2d825baab1360d1ad0a7eb6 100644 (file)
@@ -24,6 +24,7 @@ protected:
        void clear(const ClearValue *);
        void draw(const Batch &);
        void draw_instanced(const Batch &, unsigned);
+       void dispatch(unsigned, unsigned, unsigned);
        void resolve_multisample(Framebuffer &);
 
        void begin_query(const QueryPool &, unsigned);
index f90315c0dd3e5373a6b3cb363f079e981087564c..398d35ce7a6c364aba56634cb7aa0be67ad174d6 100644 (file)
@@ -1,5 +1,6 @@
 #include <cstring>
 #include <msp/core/algorithm.h>
+#include <msp/gl/extensions/arb_compute_shader.h>
 #include <msp/gl/extensions/arb_es2_compatibility.h>
 #include <msp/gl/extensions/arb_fragment_shader.h>
 #include <msp/gl/extensions/arb_gl_spirv.h>
@@ -81,6 +82,7 @@ unsigned OpenGLProgram::add_stage(Stage type)
        case VERTEX: { static Require _req(ARB_vertex_shader); gl_type = GL_VERTEX_SHADER; } break;
        case GEOMETRY: { static Require _req(ARB_geometry_shader4); gl_type = GL_GEOMETRY_SHADER; } break;
        case FRAGMENT: { static Require _req(ARB_fragment_shader); gl_type = GL_FRAGMENT_SHADER; } break;
+       case COMPUTE: { static Require _req(ARB_compute_shader); gl_type = GL_COMPUTE_SHADER; } break;
        default: throw invalid_argument("OpenGLProgram::add_stage");
        }
 
@@ -123,6 +125,7 @@ void OpenGLProgram::add_glsl_stages(const GlslModule &mod, const map<string, int
                case SL::Stage::VERTEX: stage_id = add_stage(VERTEX); break;
                case SL::Stage::GEOMETRY: stage_id = add_stage(GEOMETRY); break;
                case SL::Stage::FRAGMENT: stage_id = add_stage(FRAGMENT); break;
+               case SL::Stage::COMPUTE: stage_id = add_stage(COMPUTE); break;
                default: throw invalid_operation("OpenGLProgram::add_glsl_stages");
                }
 
@@ -152,6 +155,12 @@ void OpenGLProgram::add_glsl_stages(const GlslModule &mod, const map<string, int
        link(mod);
        query_uniforms();
        query_attributes();
+       if(is_compute())
+       {
+               int wg_size[3];
+               glGetProgramiv(id, GL_COMPUTE_WORK_GROUP_SIZE, wg_size);
+               rd.compute_wg_size = LinAl::Vector<unsigned, 3>(wg_size[0], wg_size[1], wg_size[2]);
+       }
 
        const map<string, unsigned> &block_bindings = compiler.get_uniform_block_bindings();
        if(!block_bindings.empty())
@@ -222,6 +231,7 @@ void OpenGLProgram::add_spirv_stages(const SpirVModule &mod, const map<string, i
                case SpirVModule::VERTEX: stage_id = add_stage(VERTEX); break;
                case SpirVModule::GEOMETRY: stage_id = add_stage(GEOMETRY); break;
                case SpirVModule::FRAGMENT: stage_id = add_stage(FRAGMENT); break;
+               case SpirVModule::COMPUTE: stage_id = add_stage(COMPUTE); break;
                default: throw invalid_operation("OpenGLProgram::add_spirv_stages");
                }
 
@@ -533,7 +543,7 @@ void OpenGLProgram::set_debug_name(const string &name)
 void OpenGLProgram::set_stage_debug_name(unsigned stage_id, Stage type)
 {
 #ifdef DEBUG
-       static const char *const suffixes[] = { " [VS]", " [GS]", " [FS]" };
+       static const char *const suffixes[] = { " [VS]", " [GS]", " [FS]", " [CS]" };
        string name = debug_name+suffixes[type];
        glObjectLabel(GL_SHADER, stage_id, name.size(), name.c_str());
 #else
index 14650708e32abe77ee8f3ac3e851212e4b2ce895..1a548ed69116d66d96bdd28e5194abb1cd441dfd 100644 (file)
@@ -20,6 +20,7 @@ protected:
                VERTEX,
                GEOMETRY,
                FRAGMENT,
+               COMPUTE,
                MAX_STAGES
        };
 
@@ -56,6 +57,8 @@ protected:
        void query_attributes();
        void finalize_uniforms();
 
+       bool is_compute() const { return stage_ids[COMPUTE]; }
+
        void set_debug_name(const std::string &);
        void set_stage_debug_name(unsigned, Stage);
 };
index ff9e6cf62da0ae0b851e11f572dd416cabddd968..bca1a678940d75b1625a59ac03a62c719efd86e6 100644 (file)
@@ -233,6 +233,23 @@ void VulkanCommands::draw_instanced(const Batch &batch, unsigned count)
        vkCmd.DrawIndexed(batch.size(), count, first_index, 0, 0);
 }
 
+void VulkanCommands::dispatch(unsigned count_x, unsigned count_y, unsigned count_z)
+{
+       if(!pipeline_state)
+               throw invalid_operation("VulkanCommands::draw_instanced");
+
+       if(framebuffer)
+               end_render_pass();
+
+       VulkanCommandRecorder vkCmd(device.get_functions(), primary_buffer);
+
+       pipeline_state->refresh();
+       pipeline_state->synchronize_resources();
+       device.get_synchronizer().barrier(vkCmd);
+       pipeline_state->apply(vkCmd, 0, frame_index, false);
+       vkCmd.Dispatch(count_x, count_y, count_z);
+}
+
 void VulkanCommands::resolve_multisample(Framebuffer &)
 {
        throw logic_error("VulkanCommands::resolve_multisample is unimplemented");
index 7e1554413bb10acebd50e840a540ed2c7dbd2777..7c98b15f0d061bbe7ea43bce73e8ae9559c1996a 100644 (file)
@@ -69,6 +69,7 @@ protected:
        void clear(const ClearValue *);
        void draw(const Batch &);
        void draw_instanced(const Batch &, unsigned);
+       void dispatch(unsigned, unsigned, unsigned);
        void resolve_multisample(Framebuffer &);
 
        void begin_query(const QueryPool &, unsigned);
index 7ec198e0ec3f52de1696523e83a92eddeb129cb1..492de9a6311ac9ac5d00207724fba93b1c47d32e 100644 (file)
@@ -77,6 +77,7 @@ unsigned get_vulkan_stage(unsigned stage)
        case SpirVModule::VERTEX: return VK_SHADER_STAGE_VERTEX_BIT;
        case SpirVModule::GEOMETRY: return VK_SHADER_STAGE_GEOMETRY_BIT;
        case SpirVModule::FRAGMENT: return VK_SHADER_STAGE_FRAGMENT_BIT;
+       case SpirVModule::COMPUTE: return VK_SHADER_STAGE_COMPUTE_BIT;
        default: throw invalid_argument("get_vulkan_stage");
        }
 }
index 951d33f3bea801fe72a2516dd984edbd1e50a669..0ec5fcc78fbd12f7e8a2a693ae07fd6bbadf64b5 100644 (file)
@@ -55,10 +55,19 @@ VkPipeline PipelineCache::get_pipeline(const PipelineState &ps)
 
        vector<char> buffer;
        ps.fill_creation_info(buffer);
-       const VkGraphicsPipelineCreateInfo *creation_info = reinterpret_cast<const VkGraphicsPipelineCreateInfo *>(buffer.data());
 
+       VkStructureType type = *reinterpret_cast<const VkStructureType *>(buffer.data());
        VkPipeline pipeline;
-       vk.CreateGraphicsPipelines(0, 1, creation_info, &pipeline);
+       if(type==VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO)
+       {
+               const VkComputePipelineCreateInfo *creation_info = reinterpret_cast<const VkComputePipelineCreateInfo *>(buffer.data());
+               vk.CreateComputePipelines(0, 1, creation_info, &pipeline);
+       }
+       else
+       {
+               const VkGraphicsPipelineCreateInfo *creation_info = reinterpret_cast<const VkGraphicsPipelineCreateInfo *>(buffer.data());
+               vk.CreateGraphicsPipelines(0, 1, creation_info, &pipeline);
+       }
 
        pipelines.insert(make_pair(key, pipeline));
 
index 5f9e7eb9185a4e812c1826edf0e071677e2b20d7..7af36eedf82ceadced50b3f576b103ceeff584f8 100644 (file)
@@ -48,8 +48,11 @@ void VulkanPipelineState::update() const
                push_const_compat = hash_update<32>(push_const_compat, self.shprog->get_push_constants_size());
        }
 
-       constexpr unsigned pipeline_mask = PipelineState::SHPROG|PipelineState::VERTEX_SETUP|PipelineState::FACE_CULL|
+       constexpr unsigned graphics_mask = PipelineState::VERTEX_SETUP|PipelineState::FACE_CULL|
                PipelineState::DEPTH_TEST|PipelineState::STENCIL_TEST|PipelineState::BLEND|PipelineState::PRIMITIVE_TYPE;
+       unsigned pipeline_mask = PipelineState::SHPROG;
+       if(!self.shprog->is_compute())
+               pipeline_mask |= graphics_mask;
        if(changes&pipeline_mask)
        {
                handle = device.get_pipeline_cache().get_pipeline(self);
@@ -98,50 +101,63 @@ void VulkanPipelineState::update() const
 uint64_t VulkanPipelineState::compute_hash() const
 {
        const PipelineState &self = *static_cast<const PipelineState *>(this);
-       const FrameFormat &format = self.framebuffer->get_format();
 
        uint64_t result = hash<64>(self.shprog);
-       result = hash_update<64>(result, self.vertex_setup->compute_hash());
-       result = hash_round<64>(result, self.primitive_type);
 
-       if(self.front_face!=NON_MANIFOLD && self.face_cull!=NO_CULL)
+       if(!self.shprog->is_compute())
        {
-               result = hash_round<64>(result, self.front_face);
-               result = hash_round<64>(result, self.face_cull);
-       }
+               const FrameFormat &format = self.framebuffer->get_format();
 
-       result = hash_round<64>(result, format.get_samples());
+               result = hash_update<64>(result, self.vertex_setup->compute_hash());
+               result = hash_round<64>(result, self.primitive_type);
 
-       if(self.depth_test.enabled)
-       {
-               result = hash_round<64>(result, self.depth_test.compare);
-               result = hash_update<64>(result, self.depth_test.write);
-       }
+               if(self.front_face!=NON_MANIFOLD && self.face_cull!=NO_CULL)
+               {
+                       result = hash_round<64>(result, self.front_face);
+                       result = hash_round<64>(result, self.face_cull);
+               }
 
-       if(self.stencil_test.enabled)
-       {
-               result = hash_round<64>(result, self.stencil_test.compare);
-               result = hash_round<64>(result, self.stencil_test.stencil_fail_op);
-               result = hash_round<64>(result, self.stencil_test.depth_fail_op);
-               result = hash_round<64>(result, self.stencil_test.depth_pass_op);
-               result = hash_update<64>(result, self.stencil_test.reference);
-       }
+               result = hash_round<64>(result, format.get_samples());
 
-       if(self.blend.enabled)
-       {
-               result = hash_round<64>(result, self.blend.equation);
-               result = hash_round<64>(result, self.blend.src_factor);
-               result = hash_round<64>(result, self.blend.dst_factor);
-               result = hash_round<64>(result, self.blend.write_mask);
-       }
+               if(self.depth_test.enabled)
+               {
+                       result = hash_round<64>(result, self.depth_test.compare);
+                       result = hash_update<64>(result, self.depth_test.write);
+               }
 
-       for(FrameAttachment a: format)
-               result = hash_update<64>(result, a);
+               if(self.stencil_test.enabled)
+               {
+                       result = hash_round<64>(result, self.stencil_test.compare);
+                       result = hash_round<64>(result, self.stencil_test.stencil_fail_op);
+                       result = hash_round<64>(result, self.stencil_test.depth_fail_op);
+                       result = hash_round<64>(result, self.stencil_test.depth_pass_op);
+                       result = hash_update<64>(result, self.stencil_test.reference);
+               }
+
+               if(self.blend.enabled)
+               {
+                       result = hash_round<64>(result, self.blend.equation);
+                       result = hash_round<64>(result, self.blend.src_factor);
+                       result = hash_round<64>(result, self.blend.dst_factor);
+                       result = hash_round<64>(result, self.blend.write_mask);
+               }
+
+               for(FrameAttachment a: format)
+                       result = hash_update<64>(result, a);
+       }
 
        return result;
 }
 
 void VulkanPipelineState::fill_creation_info(vector<char> &buffer) const
+{
+       if(static_cast<const PipelineState *>(this)->shprog->is_compute())
+               fill_compute_creation_info(buffer);
+       else
+               fill_graphics_creation_info(buffer);
+}
+
+void VulkanPipelineState::fill_graphics_creation_info(vector<char> &buffer) const
 {
        const PipelineState &self = *static_cast<const PipelineState *>(this);
 
@@ -266,6 +282,22 @@ void VulkanPipelineState::fill_creation_info(vector<char> &buffer) const
                pipeline_info->pVertexInputState = reinterpret_cast<const VkPipelineVertexInputStateCreateInfo *>(self.vertex_setup->creation_info.data());
 }
 
+void VulkanPipelineState::fill_compute_creation_info(vector<char> &buffer) const
+{
+       const PipelineState &self = *static_cast<const PipelineState *>(this);
+
+       StructureBuilder sb(buffer, 1);
+       VkComputePipelineCreateInfo *const &pipeline_info = sb.add<VkComputePipelineCreateInfo>();
+
+       pipeline_info->sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
+
+       if(self.shprog)
+       {
+               pipeline_info->stage = *reinterpret_cast<const VkPipelineShaderStageCreateInfo *>(self.shprog->creation_info.data());
+               pipeline_info->layout = handle_cast<::VkPipelineLayout>(self.shprog->layout_handle);
+       }
+}
+
 uint64_t VulkanPipelineState::compute_descriptor_set_hash(unsigned index) const
 {
        const PipelineState &self = *static_cast<const PipelineState *>(this);
@@ -446,10 +478,11 @@ void VulkanPipelineState::apply(const VulkanCommandRecorder &vkCmd, const Vulkan
                        unapplied |= PipelineState::SCISSOR;
        }
 
+       VkPipelineBindPoint bind_point = (self.shprog->is_compute() ? VK_PIPELINE_BIND_POINT_COMPUTE : VK_PIPELINE_BIND_POINT_GRAPHICS);
        if(unapplied&PipelineState::SHPROG)
-               vkCmd.BindPipeline(VK_PIPELINE_BIND_POINT_GRAPHICS, handle);
+               vkCmd.BindPipeline(bind_point, handle);
 
-       if(unapplied&PipelineState::VERTEX_SETUP)
+       if(!self.shprog->is_compute() && (unapplied&PipelineState::VERTEX_SETUP))
                if(const VertexSetup *vs = self.vertex_setup)
                {
                        vkCmd.BindVertexBuffers(0, vs->n_bindings, vs->buffers, vs->offsets);
@@ -476,11 +509,11 @@ void VulkanPipelineState::apply(const VulkanCommandRecorder &vkCmd, const Vulkan
                        descriptor_set_handles.push_back(device.get_descriptor_pool().get_descriptor_set(
                                self.descriptor_set_slots[i], self, i, frame));
 
-               vkCmd.BindDescriptorSets(VK_PIPELINE_BIND_POINT_GRAPHICS, self.shprog->layout_handle,
+               vkCmd.BindDescriptorSets(bind_point, self.shprog->layout_handle,
                        first_changed_desc_set, descriptor_set_handles.size(), descriptor_set_handles.data(), 0, 0);
        }
 
-       if(unapplied&(PipelineState::VIEWPORT|PipelineState::SCISSOR))
+       if(!self.shprog->is_compute() && (unapplied&(PipelineState::VIEWPORT|PipelineState::SCISSOR)))
        {
                Rect fb_rect = self.framebuffer->get_rect();
 
index 37c96a4e2d394f967f0f75d0f0b4f802bea6bab2..175c291383b825bcd2e348dbebb00f129ede3da6 100644 (file)
@@ -34,6 +34,8 @@ protected:
        void refresh() const { if(changes) update(); }
        std::uint64_t compute_hash() const;
        void fill_creation_info(std::vector<char> &) const;
+       void fill_graphics_creation_info(std::vector<char> &) const;
+       void fill_compute_creation_info(std::vector<char> &) const;
        std::uint64_t compute_descriptor_set_hash(unsigned) const;
        bool is_descriptor_set_dynamic(unsigned) const;
        VkDescriptorSetLayout get_descriptor_set_layout(unsigned) const;
index a962a78e61a5f3a9d1ccbd08310b6b15faea4f7d..993a3e38ef965d11b838fb9b41653cb2fe5b3b29 100644 (file)
@@ -175,6 +175,11 @@ void VulkanProgram::finalize_uniforms()
 #endif
 }
 
+bool VulkanProgram::is_compute() const
+{
+       return stage_flags&VK_SHADER_STAGE_COMPUTE_BIT;
+}
+
 void VulkanProgram::set_debug_name(const string &name)
 {
 #ifdef DEBUG
index 4f9d9c4730103e8188a8c2e1c3422bfe0458056f..0e48b14bcb9aa5371e0fcdf0b077006d0262fa01 100644 (file)
@@ -35,6 +35,8 @@ protected:
 
        void finalize_uniforms();
 
+       bool is_compute() const;
+
        void set_debug_name(const std::string &);
        void set_vulkan_object_name() const;
 };
index c80546953cc26030661739f50e159bf69d6eb1c7..31349e8fad7a76e60af472e14835bb1bcd526f56 100644 (file)
@@ -41,6 +41,7 @@ VulkanFunctions::VulkanFunctions(const Graphics::VulkanContext &c):
        vkCreateShaderModule(context.get_function<PFN_vkCreateShaderModule>("vkCreateShaderModule")),
        vkDestroyShaderModule(context.get_function<PFN_vkDestroyShaderModule>("vkDestroyShaderModule")),
        // 10
+       vkCreateComputePipelines(context.get_function<PFN_vkCreateComputePipelines>("vkCreateComputePipelines")),
        vkCreateGraphicsPipelines(context.get_function<PFN_vkCreateGraphicsPipelines>("vkCreateGraphicsPipelines")),
        vkDestroyPipeline(context.get_function<PFN_vkDestroyPipeline>("vkDestroyPipeline")),
        vkCmdBindPipeline(context.get_function<PFN_vkCmdBindPipeline>("vkCmdBindPipeline")),
@@ -88,6 +89,8 @@ VulkanFunctions::VulkanFunctions(const Graphics::VulkanContext &c):
        vkCmdSetViewport(context.get_function<PFN_vkCmdSetViewport>("vkCmdSetViewport")),
        // 26
        vkCmdSetScissor(context.get_function<PFN_vkCmdSetScissor>("vkCmdSetScissor")),
+       // 28
+       vkCmdDispatch(context.get_function<PFN_vkCmdDispatch>("vkCmdDispatch")),
        // 30
        vkGetPhysicalDeviceSurfaceCapabilities(context.get_function<PFN_vkGetPhysicalDeviceSurfaceCapabilitiesKHR>("vkGetPhysicalDeviceSurfaceCapabilitiesKHR")),
        vkGetPhysicalDeviceSurfaceFormats(context.get_function<PFN_vkGetPhysicalDeviceSurfaceFormatsKHR>("vkGetPhysicalDeviceSurfaceFormatsKHR")),
index d77e2cc424e852852adb7ac31e75eeb7b6d3569c..2b42550c24657ad2b493c8497f99ee22f6d83d76 100644 (file)
@@ -124,6 +124,7 @@ private:
        PFN_vkCmdEndRenderPass vkCmdEndRenderPass = 0;  // 8.4
        PFN_vkCreateShaderModule vkCreateShaderModule = 0;  // 9.1
        PFN_vkDestroyShaderModule vkDestroyShaderModule = 0;  // 9.1
+       PFN_vkCreateComputePipelines vkCreateComputePipelines = 0;  // 10.1
        PFN_vkCreateGraphicsPipelines vkCreateGraphicsPipelines = 0;  // 10.2
        PFN_vkDestroyPipeline vkDestroyPipeline = 0;  // 10.4
        PFN_vkCmdBindPipeline vkCmdBindPipeline = 0;  // 10.10
@@ -162,6 +163,7 @@ private:
        PFN_vkCmdBindVertexBuffers vkCmdBindVertexBuffers = 0;  // 21.2
        PFN_vkCmdSetViewport vkCmdSetViewport = 0;  // 24.5
        PFN_vkCmdSetScissor vkCmdSetScissor = 0;  // 26.1
+       PFN_vkCmdDispatch vkCmdDispatch = 0;  // 28
        PFN_vkGetPhysicalDeviceSurfaceCapabilitiesKHR vkGetPhysicalDeviceSurfaceCapabilities = 0;  // 30.5.1
        PFN_vkGetPhysicalDeviceSurfaceFormatsKHR vkGetPhysicalDeviceSurfaceFormats = 0;  // 30.5.2
        PFN_vkGetPhysicalDeviceSurfacePresentModesKHR vkGetPhysicalDeviceSurfacePresentModes = 0;  // 30.5.3
@@ -260,6 +262,9 @@ public:
        { vkDestroyShaderModule(device, handle_cast<::VkShaderModule>(shaderModule), 0); }
 
        // Chapter 10: Pipelines
+       Result CreateComputePipelines(VkPipelineCache pipelineCache, std::uint32_t createInfoCount, const VkComputePipelineCreateInfo *pCreateInfos, VkPipeline *pPipelines) const
+       { return { vkCreateComputePipelines(device, handle_cast<::VkPipelineCache>(pipelineCache), createInfoCount, pCreateInfos, 0, handle_cast<::VkPipeline *>(pPipelines)), "vkCreateComputePipelines" }; }
+
        Result CreateGraphicsPipelines(VkPipelineCache pipelineCache, std::uint32_t createInfoCount, const VkGraphicsPipelineCreateInfo *pCreateInfos, VkPipeline *pPipelines) const
        { return { vkCreateGraphicsPipelines(device, handle_cast<::VkPipelineCache>(pipelineCache), createInfoCount, pCreateInfos, 0, handle_cast<::VkPipeline *>(pPipelines)), "vkCreateGraphicsPipelines" }; }
 
@@ -383,6 +388,10 @@ public:
        void CmdSetScissor(VkCommandBuffer commandBuffer, std::uint32_t firstScissor, std::uint32_t scissorCount, const VkRect2D *pScissors) const
        { vkCmdSetScissor(handle_cast<::VkCommandBuffer>(commandBuffer), firstScissor, scissorCount, pScissors); }
 
+       // Chapter 28: Dispatching Commands
+       void CmdDispatch(VkCommandBuffer commandBuffer, std::uint32_t groupCountX, std::uint32_t groupCountY, std::uint32_t groupCountZ) const
+       { vkCmdDispatch(handle_cast<::VkCommandBuffer>(commandBuffer), groupCountX, groupCountY, groupCountZ); }
+
        // Chapter 30: Window System Integration (WSI)
        Result GetPhysicalDeviceSurfaceCapabilities(VkSurface surface, VkSurfaceCapabilitiesKHR &rSurfaceCapabilities) const
        { return { vkGetPhysicalDeviceSurfaceCapabilities(physicalDevice, handle_cast<::VkSurfaceKHR>(surface), &rSurfaceCapabilities), "vkGetPhysicalDeviceSurfaceCapabilities" }; }
@@ -470,6 +479,9 @@ public:
 
        void SetScissor(std::uint32_t firstScissor, std::uint32_t scissorCount, const VkRect2D *pScissors) const
        { vk.CmdSetScissor(commandBuffer, firstScissor, scissorCount, pScissors); }
+
+       void Dispatch(std::uint32_t groupCountX, std::uint32_t groupCountY, std::uint32_t groupCountZ) const
+       { vk.CmdDispatch(commandBuffer, groupCountX, groupCountY, groupCountZ); }
 };
 
 } // namespace GL
index e7dce6a162690e6300437f9377cd6bb167e9cac1..793b258901b0b706482cc3cf72a5ed2c1caa0448 100644 (file)
@@ -22,6 +22,7 @@ public:
        using CommandsBackend::clear;
        using CommandsBackend::draw;
        using CommandsBackend::draw_instanced;
+       using CommandsBackend::dispatch;
        using CommandsBackend::resolve_multisample;
 
        using CommandsBackend::begin_query;
index 0870c4782c9f90d620b8ecda500d985f8bd363ab..8c69cd7e256673bb1887c61a9f82da31539f2cfb 100644 (file)
@@ -48,6 +48,8 @@ enum SpirVConstants
        OP_RETURN_VALUE = 254,
        OP_UNREACHABLE = 255,
 
+       EXEC_LOCAL_SIZE = 17,
+
        DECO_SPEC_ID = 1,
        DECO_ARRAY_STRIDE = 6,
        DECO_MATRIX_STRIDE = 7,
@@ -513,6 +515,7 @@ void SpirVModule::Reflection::reflect_code(const vector<uint32_t> &code)
                case OP_NAME: reflect_name(op); break;
                case OP_MEMBER_NAME: reflect_member_name(op); break;
                case OP_ENTRY_POINT: reflect_entry_point(op); break;
+               case OP_EXECUTION_MODE: reflect_execution_mode(op); break;
                case OP_TYPE_VOID: reflect_void_type(op); break;
                case OP_TYPE_BOOL: reflect_bool_type(op); break;
                case OP_TYPE_INT: reflect_int_type(op); break;
@@ -579,6 +582,18 @@ void SpirVModule::Reflection::reflect_entry_point(CodeIterator op)
                entry.globals.push_back(&variables[*op]);
 }
 
+void SpirVModule::Reflection::reflect_execution_mode(CodeIterator op)
+{
+       EntryPoint &entry = entry_points[*(op+1)];
+       unsigned mode = *(op+2);
+       if(mode==EXEC_LOCAL_SIZE)
+       {
+               entry.compute_local_size.x = *(op+3);
+               entry.compute_local_size.y = *(op+4);
+               entry.compute_local_size.z = *(op+5);
+       }
+}
+
 void SpirVModule::Reflection::reflect_void_type(CodeIterator op)
 {
        types[*(op+1)].type = VOID;
index 26895ee0ba2fe9355e397026b5a068c3f7985fe4..480a0669ffbb6275324a7ab2b3c5bb660fdba2eb 100644 (file)
@@ -104,7 +104,8 @@ public:
        {
                VERTEX = 0,
                GEOMETRY = 3,
-               FRAGMENT = 4
+               FRAGMENT = 4,
+               COMPUTE = 5
        };
 
        enum StorageClass
@@ -135,6 +136,7 @@ public:
                unsigned id = 0;
                Stage stage = VERTEX;
                std::vector<const Variable *> globals;
+               LinAl::Vector<unsigned, 3> compute_local_size;
        };
 
        struct StructMember
@@ -228,6 +230,7 @@ private:
                void reflect_name(CodeIterator);
                void reflect_member_name(CodeIterator);
                void reflect_entry_point(CodeIterator);
+               void reflect_execution_mode(CodeIterator);
                void reflect_void_type(CodeIterator);
                void reflect_bool_type(CodeIterator);
                void reflect_int_type(CodeIterator);
index 8ac13bbd0ce21454cb5af1b269dacc0fc6770205..91f215651b57df64b834315c13aa89b86d57d546 100644 (file)
@@ -56,6 +56,10 @@ void Program::add_stages(const Module &mod, const map<string, int> &spec_values)
                collect_uniforms(spirv_mod);
                collect_attributes(spirv_mod);
                collect_builtins(spirv_mod);
+
+               for(const SpirVModule::EntryPoint &e: spirv_mod.get_entry_points())
+                       if(e.stage==SpirVModule::COMPUTE)
+                               reflect_data.compute_wg_size = e.compute_local_size;
        }
 
        finalize_uniforms();
index ca5437f3a6970e6d626bf36520ebc828d0697ab1..c652f468b00cd3855a1b3729a358895dac998c80 100644 (file)
@@ -72,6 +72,8 @@ private:
        void collect_builtins(const SpirVModule::Structure &);
 
 public:
+       using ProgramBackend::is_compute;
+
        ReflectData::LayoutHash get_uniform_layout_hash() const { return reflect_data.layout_hash; }
        unsigned get_n_descriptor_sets() const { return reflect_data.n_descriptor_sets; }
        unsigned get_push_constants_size() const { return reflect_data.push_constants_size; }
@@ -92,6 +94,7 @@ public:
        const ReflectData::AttributeInfo &get_attribute_info(const std::string &) const;
        int get_attribute_location(const std::string &) const;
        unsigned get_n_clip_distances() const { return reflect_data.n_clip_distances; }
+       const LinAl::Vector<unsigned, 3> &get_compute_workgroup_size() const { return reflect_data.compute_wg_size; }
 
        using ProgramBackend::set_debug_name;
 };
index 1e82a61bfadef9ef4b75910f40c44d4ee24afbcc..981f45f000f7b02d926af39d64dd19efaf990508 100644 (file)
@@ -70,6 +70,7 @@ struct ReflectData
        unsigned n_descriptor_sets = 0;
        unsigned push_constants_size = 0;
        std::vector<int> used_bindings;
+       LinAl::Vector<unsigned, 3> compute_wg_size;
 
        void update_layout_hash();
        void update_used_bindings();
index 828efd70ab0695dfc8773dd478061e0eb826aebf..ce0b3447b2b02a43df1029db2b358cd4f8aa2cb8 100644 (file)
@@ -296,6 +296,14 @@ void Renderer::draw_instanced(const Batch &batch, unsigned count)
        commands.draw_instanced(batch, count);
 }
 
+void Renderer::dispatch(unsigned count_x, unsigned count_y, unsigned count_z)
+{
+       apply_state();
+       PipelineState &ps = get_pipeline_state();
+       commands.use_pipeline(&ps);
+       commands.dispatch(count_x, count_y, count_z);
+}
+
 void Renderer::resolve_multisample(Framebuffer &target)
 {
        const State &state = get_state();
index fb460cf07824cfea911c651d897c2aa88bce053b..3341d9aa061e038f3acb59c5515cc6f776c9d5d7 100644 (file)
@@ -231,6 +231,9 @@ public:
        /** Draws multiple instances of a batch of primitives.  A shader must be active. */
        void draw_instanced(const Batch &, unsigned);
 
+       /** Dispatches a compute operation. */
+       void dispatch(unsigned, unsigned = 1, unsigned = 1);
+
        /** Resolves multisample attachments from the active framebuffer into
        target. */
        void resolve_multisample(Framebuffer &target);