From cebf1330ef6773b7b4496dc279ec02a7ca4351bb Mon Sep 17 00:00:00 2001 From: Mikko Rasa Date: Sat, 16 Apr 2022 00:17:05 +0300 Subject: [PATCH] Support compute shaders and compute operations --- extensions/arb_compute_shader.glext | 1 + source/backends/opengl/commands_backend.cpp | 12 ++ source/backends/opengl/commands_backend.h | 1 + source/backends/opengl/program_backend.cpp | 12 +- source/backends/opengl/program_backend.h | 3 + source/backends/vulkan/commands_backend.cpp | 17 +++ source/backends/vulkan/commands_backend.h | 1 + source/backends/vulkan/module_backend.cpp | 1 + source/backends/vulkan/pipelinecache.cpp | 13 ++- .../backends/vulkan/pipelinestate_backend.cpp | 103 ++++++++++++------ .../backends/vulkan/pipelinestate_backend.h | 2 + source/backends/vulkan/program_backend.cpp | 5 + source/backends/vulkan/program_backend.h | 2 + source/backends/vulkan/vulkan.cpp | 3 + source/backends/vulkan/vulkan.h | 12 ++ source/core/commands.h | 1 + source/core/module.cpp | 15 +++ source/core/module.h | 5 +- source/core/program.cpp | 4 + source/core/program.h | 3 + source/core/reflectdata.h | 1 + source/render/renderer.cpp | 8 ++ source/render/renderer.h | 3 + 23 files changed, 189 insertions(+), 39 deletions(-) create mode 100644 extensions/arb_compute_shader.glext diff --git a/extensions/arb_compute_shader.glext b/extensions/arb_compute_shader.glext new file mode 100644 index 00000000..b8683f43 --- /dev/null +++ b/extensions/arb_compute_shader.glext @@ -0,0 +1 @@ +extension ARB_compute_shader diff --git a/source/backends/opengl/commands_backend.cpp b/source/backends/opengl/commands_backend.cpp index 7d599153..76469ab3 100644 --- a/source/backends/opengl/commands_backend.cpp +++ b/source/backends/opengl/commands_backend.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -78,6 +79,17 @@ void OpenGLCommands::draw_instanced(const Batch &batch, unsigned count) glDrawElementsInstanced(batch.gl_prim_type, batch.size(), batch.gl_index_type, data_ptr, count); } +void OpenGLCommands::dispatch(unsigned count_x, unsigned count_y, unsigned count_z) +{ + if(!pipeline_state) + throw invalid_operation("OpenGLCommands::dispatch_compute"); + + static Require req(ARB_compute_shader); + + pipeline_state->apply(); + glDispatchCompute(count_x, count_y, count_z); +} + void OpenGLCommands::resolve_multisample(Framebuffer &target) { const Framebuffer *source = (pipeline_state ? pipeline_state->get_framebuffer() : 0); diff --git a/source/backends/opengl/commands_backend.h b/source/backends/opengl/commands_backend.h index f4976f3f..e6a3161a 100644 --- a/source/backends/opengl/commands_backend.h +++ b/source/backends/opengl/commands_backend.h @@ -24,6 +24,7 @@ protected: void clear(const ClearValue *); void draw(const Batch &); void draw_instanced(const Batch &, unsigned); + void dispatch(unsigned, unsigned, unsigned); void resolve_multisample(Framebuffer &); void begin_query(const QueryPool &, unsigned); diff --git a/source/backends/opengl/program_backend.cpp b/source/backends/opengl/program_backend.cpp index f90315c0..398d35ce 100644 --- a/source/backends/opengl/program_backend.cpp +++ b/source/backends/opengl/program_backend.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -81,6 +82,7 @@ unsigned OpenGLProgram::add_stage(Stage type) case VERTEX: { static Require _req(ARB_vertex_shader); gl_type = GL_VERTEX_SHADER; } break; case GEOMETRY: { static Require _req(ARB_geometry_shader4); gl_type = GL_GEOMETRY_SHADER; } break; case FRAGMENT: { static Require _req(ARB_fragment_shader); gl_type = GL_FRAGMENT_SHADER; } break; + case COMPUTE: { static Require _req(ARB_compute_shader); gl_type = GL_COMPUTE_SHADER; } break; default: throw invalid_argument("OpenGLProgram::add_stage"); } @@ -123,6 +125,7 @@ void OpenGLProgram::add_glsl_stages(const GlslModule &mod, const map(wg_size[0], wg_size[1], wg_size[2]); + } const map &block_bindings = compiler.get_uniform_block_bindings(); if(!block_bindings.empty()) @@ -222,6 +231,7 @@ void OpenGLProgram::add_spirv_stages(const SpirVModule &mod, const maprefresh(); + pipeline_state->synchronize_resources(); + device.get_synchronizer().barrier(vkCmd); + pipeline_state->apply(vkCmd, 0, frame_index, false); + vkCmd.Dispatch(count_x, count_y, count_z); +} + void VulkanCommands::resolve_multisample(Framebuffer &) { throw logic_error("VulkanCommands::resolve_multisample is unimplemented"); diff --git a/source/backends/vulkan/commands_backend.h b/source/backends/vulkan/commands_backend.h index 7e155441..7c98b15f 100644 --- a/source/backends/vulkan/commands_backend.h +++ b/source/backends/vulkan/commands_backend.h @@ -69,6 +69,7 @@ protected: void clear(const ClearValue *); void draw(const Batch &); void draw_instanced(const Batch &, unsigned); + void dispatch(unsigned, unsigned, unsigned); void resolve_multisample(Framebuffer &); void begin_query(const QueryPool &, unsigned); diff --git a/source/backends/vulkan/module_backend.cpp b/source/backends/vulkan/module_backend.cpp index 7ec198e0..492de9a6 100644 --- a/source/backends/vulkan/module_backend.cpp +++ b/source/backends/vulkan/module_backend.cpp @@ -77,6 +77,7 @@ unsigned get_vulkan_stage(unsigned stage) case SpirVModule::VERTEX: return VK_SHADER_STAGE_VERTEX_BIT; case SpirVModule::GEOMETRY: return VK_SHADER_STAGE_GEOMETRY_BIT; case SpirVModule::FRAGMENT: return VK_SHADER_STAGE_FRAGMENT_BIT; + case SpirVModule::COMPUTE: return VK_SHADER_STAGE_COMPUTE_BIT; default: throw invalid_argument("get_vulkan_stage"); } } diff --git a/source/backends/vulkan/pipelinecache.cpp b/source/backends/vulkan/pipelinecache.cpp index 951d33f3..0ec5fcc7 100644 --- a/source/backends/vulkan/pipelinecache.cpp +++ b/source/backends/vulkan/pipelinecache.cpp @@ -55,10 +55,19 @@ VkPipeline PipelineCache::get_pipeline(const PipelineState &ps) vector buffer; ps.fill_creation_info(buffer); - const VkGraphicsPipelineCreateInfo *creation_info = reinterpret_cast(buffer.data()); + VkStructureType type = *reinterpret_cast(buffer.data()); VkPipeline pipeline; - vk.CreateGraphicsPipelines(0, 1, creation_info, &pipeline); + if(type==VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO) + { + const VkComputePipelineCreateInfo *creation_info = reinterpret_cast(buffer.data()); + vk.CreateComputePipelines(0, 1, creation_info, &pipeline); + } + else + { + const VkGraphicsPipelineCreateInfo *creation_info = reinterpret_cast(buffer.data()); + vk.CreateGraphicsPipelines(0, 1, creation_info, &pipeline); + } pipelines.insert(make_pair(key, pipeline)); diff --git a/source/backends/vulkan/pipelinestate_backend.cpp b/source/backends/vulkan/pipelinestate_backend.cpp index 5f9e7eb9..7af36eed 100644 --- a/source/backends/vulkan/pipelinestate_backend.cpp +++ b/source/backends/vulkan/pipelinestate_backend.cpp @@ -48,8 +48,11 @@ void VulkanPipelineState::update() const push_const_compat = hash_update<32>(push_const_compat, self.shprog->get_push_constants_size()); } - constexpr unsigned pipeline_mask = PipelineState::SHPROG|PipelineState::VERTEX_SETUP|PipelineState::FACE_CULL| + constexpr unsigned graphics_mask = PipelineState::VERTEX_SETUP|PipelineState::FACE_CULL| PipelineState::DEPTH_TEST|PipelineState::STENCIL_TEST|PipelineState::BLEND|PipelineState::PRIMITIVE_TYPE; + unsigned pipeline_mask = PipelineState::SHPROG; + if(!self.shprog->is_compute()) + pipeline_mask |= graphics_mask; if(changes&pipeline_mask) { handle = device.get_pipeline_cache().get_pipeline(self); @@ -98,50 +101,63 @@ void VulkanPipelineState::update() const uint64_t VulkanPipelineState::compute_hash() const { const PipelineState &self = *static_cast(this); - const FrameFormat &format = self.framebuffer->get_format(); uint64_t result = hash<64>(self.shprog); - result = hash_update<64>(result, self.vertex_setup->compute_hash()); - result = hash_round<64>(result, self.primitive_type); - if(self.front_face!=NON_MANIFOLD && self.face_cull!=NO_CULL) + if(!self.shprog->is_compute()) { - result = hash_round<64>(result, self.front_face); - result = hash_round<64>(result, self.face_cull); - } + const FrameFormat &format = self.framebuffer->get_format(); - result = hash_round<64>(result, format.get_samples()); + result = hash_update<64>(result, self.vertex_setup->compute_hash()); + result = hash_round<64>(result, self.primitive_type); - if(self.depth_test.enabled) - { - result = hash_round<64>(result, self.depth_test.compare); - result = hash_update<64>(result, self.depth_test.write); - } + if(self.front_face!=NON_MANIFOLD && self.face_cull!=NO_CULL) + { + result = hash_round<64>(result, self.front_face); + result = hash_round<64>(result, self.face_cull); + } - if(self.stencil_test.enabled) - { - result = hash_round<64>(result, self.stencil_test.compare); - result = hash_round<64>(result, self.stencil_test.stencil_fail_op); - result = hash_round<64>(result, self.stencil_test.depth_fail_op); - result = hash_round<64>(result, self.stencil_test.depth_pass_op); - result = hash_update<64>(result, self.stencil_test.reference); - } + result = hash_round<64>(result, format.get_samples()); - if(self.blend.enabled) - { - result = hash_round<64>(result, self.blend.equation); - result = hash_round<64>(result, self.blend.src_factor); - result = hash_round<64>(result, self.blend.dst_factor); - result = hash_round<64>(result, self.blend.write_mask); - } + if(self.depth_test.enabled) + { + result = hash_round<64>(result, self.depth_test.compare); + result = hash_update<64>(result, self.depth_test.write); + } - for(FrameAttachment a: format) - result = hash_update<64>(result, a); + if(self.stencil_test.enabled) + { + result = hash_round<64>(result, self.stencil_test.compare); + result = hash_round<64>(result, self.stencil_test.stencil_fail_op); + result = hash_round<64>(result, self.stencil_test.depth_fail_op); + result = hash_round<64>(result, self.stencil_test.depth_pass_op); + result = hash_update<64>(result, self.stencil_test.reference); + } + + if(self.blend.enabled) + { + result = hash_round<64>(result, self.blend.equation); + result = hash_round<64>(result, self.blend.src_factor); + result = hash_round<64>(result, self.blend.dst_factor); + result = hash_round<64>(result, self.blend.write_mask); + } + + for(FrameAttachment a: format) + result = hash_update<64>(result, a); + } return result; } void VulkanPipelineState::fill_creation_info(vector &buffer) const +{ + if(static_cast(this)->shprog->is_compute()) + fill_compute_creation_info(buffer); + else + fill_graphics_creation_info(buffer); +} + +void VulkanPipelineState::fill_graphics_creation_info(vector &buffer) const { const PipelineState &self = *static_cast(this); @@ -266,6 +282,22 @@ void VulkanPipelineState::fill_creation_info(vector &buffer) const pipeline_info->pVertexInputState = reinterpret_cast(self.vertex_setup->creation_info.data()); } +void VulkanPipelineState::fill_compute_creation_info(vector &buffer) const +{ + const PipelineState &self = *static_cast(this); + + StructureBuilder sb(buffer, 1); + VkComputePipelineCreateInfo *const &pipeline_info = sb.add(); + + pipeline_info->sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + + if(self.shprog) + { + pipeline_info->stage = *reinterpret_cast(self.shprog->creation_info.data()); + pipeline_info->layout = handle_cast<::VkPipelineLayout>(self.shprog->layout_handle); + } +} + uint64_t VulkanPipelineState::compute_descriptor_set_hash(unsigned index) const { const PipelineState &self = *static_cast(this); @@ -446,10 +478,11 @@ void VulkanPipelineState::apply(const VulkanCommandRecorder &vkCmd, const Vulkan unapplied |= PipelineState::SCISSOR; } + VkPipelineBindPoint bind_point = (self.shprog->is_compute() ? VK_PIPELINE_BIND_POINT_COMPUTE : VK_PIPELINE_BIND_POINT_GRAPHICS); if(unapplied&PipelineState::SHPROG) - vkCmd.BindPipeline(VK_PIPELINE_BIND_POINT_GRAPHICS, handle); + vkCmd.BindPipeline(bind_point, handle); - if(unapplied&PipelineState::VERTEX_SETUP) + if(!self.shprog->is_compute() && (unapplied&PipelineState::VERTEX_SETUP)) if(const VertexSetup *vs = self.vertex_setup) { vkCmd.BindVertexBuffers(0, vs->n_bindings, vs->buffers, vs->offsets); @@ -476,11 +509,11 @@ void VulkanPipelineState::apply(const VulkanCommandRecorder &vkCmd, const Vulkan descriptor_set_handles.push_back(device.get_descriptor_pool().get_descriptor_set( self.descriptor_set_slots[i], self, i, frame)); - vkCmd.BindDescriptorSets(VK_PIPELINE_BIND_POINT_GRAPHICS, self.shprog->layout_handle, + vkCmd.BindDescriptorSets(bind_point, self.shprog->layout_handle, first_changed_desc_set, descriptor_set_handles.size(), descriptor_set_handles.data(), 0, 0); } - if(unapplied&(PipelineState::VIEWPORT|PipelineState::SCISSOR)) + if(!self.shprog->is_compute() && (unapplied&(PipelineState::VIEWPORT|PipelineState::SCISSOR))) { Rect fb_rect = self.framebuffer->get_rect(); diff --git a/source/backends/vulkan/pipelinestate_backend.h b/source/backends/vulkan/pipelinestate_backend.h index 37c96a4e..175c2913 100644 --- a/source/backends/vulkan/pipelinestate_backend.h +++ b/source/backends/vulkan/pipelinestate_backend.h @@ -34,6 +34,8 @@ protected: void refresh() const { if(changes) update(); } std::uint64_t compute_hash() const; void fill_creation_info(std::vector &) const; + void fill_graphics_creation_info(std::vector &) const; + void fill_compute_creation_info(std::vector &) const; std::uint64_t compute_descriptor_set_hash(unsigned) const; bool is_descriptor_set_dynamic(unsigned) const; VkDescriptorSetLayout get_descriptor_set_layout(unsigned) const; diff --git a/source/backends/vulkan/program_backend.cpp b/source/backends/vulkan/program_backend.cpp index a962a78e..993a3e38 100644 --- a/source/backends/vulkan/program_backend.cpp +++ b/source/backends/vulkan/program_backend.cpp @@ -175,6 +175,11 @@ void VulkanProgram::finalize_uniforms() #endif } +bool VulkanProgram::is_compute() const +{ + return stage_flags&VK_SHADER_STAGE_COMPUTE_BIT; +} + void VulkanProgram::set_debug_name(const string &name) { #ifdef DEBUG diff --git a/source/backends/vulkan/program_backend.h b/source/backends/vulkan/program_backend.h index 4f9d9c47..0e48b14b 100644 --- a/source/backends/vulkan/program_backend.h +++ b/source/backends/vulkan/program_backend.h @@ -35,6 +35,8 @@ protected: void finalize_uniforms(); + bool is_compute() const; + void set_debug_name(const std::string &); void set_vulkan_object_name() const; }; diff --git a/source/backends/vulkan/vulkan.cpp b/source/backends/vulkan/vulkan.cpp index c8054695..31349e8f 100644 --- a/source/backends/vulkan/vulkan.cpp +++ b/source/backends/vulkan/vulkan.cpp @@ -41,6 +41,7 @@ VulkanFunctions::VulkanFunctions(const Graphics::VulkanContext &c): vkCreateShaderModule(context.get_function("vkCreateShaderModule")), vkDestroyShaderModule(context.get_function("vkDestroyShaderModule")), // 10 + vkCreateComputePipelines(context.get_function("vkCreateComputePipelines")), vkCreateGraphicsPipelines(context.get_function("vkCreateGraphicsPipelines")), vkDestroyPipeline(context.get_function("vkDestroyPipeline")), vkCmdBindPipeline(context.get_function("vkCmdBindPipeline")), @@ -88,6 +89,8 @@ VulkanFunctions::VulkanFunctions(const Graphics::VulkanContext &c): vkCmdSetViewport(context.get_function("vkCmdSetViewport")), // 26 vkCmdSetScissor(context.get_function("vkCmdSetScissor")), + // 28 + vkCmdDispatch(context.get_function("vkCmdDispatch")), // 30 vkGetPhysicalDeviceSurfaceCapabilities(context.get_function("vkGetPhysicalDeviceSurfaceCapabilitiesKHR")), vkGetPhysicalDeviceSurfaceFormats(context.get_function("vkGetPhysicalDeviceSurfaceFormatsKHR")), diff --git a/source/backends/vulkan/vulkan.h b/source/backends/vulkan/vulkan.h index d77e2cc4..2b42550c 100644 --- a/source/backends/vulkan/vulkan.h +++ b/source/backends/vulkan/vulkan.h @@ -124,6 +124,7 @@ private: PFN_vkCmdEndRenderPass vkCmdEndRenderPass = 0; // 8.4 PFN_vkCreateShaderModule vkCreateShaderModule = 0; // 9.1 PFN_vkDestroyShaderModule vkDestroyShaderModule = 0; // 9.1 + PFN_vkCreateComputePipelines vkCreateComputePipelines = 0; // 10.1 PFN_vkCreateGraphicsPipelines vkCreateGraphicsPipelines = 0; // 10.2 PFN_vkDestroyPipeline vkDestroyPipeline = 0; // 10.4 PFN_vkCmdBindPipeline vkCmdBindPipeline = 0; // 10.10 @@ -162,6 +163,7 @@ private: PFN_vkCmdBindVertexBuffers vkCmdBindVertexBuffers = 0; // 21.2 PFN_vkCmdSetViewport vkCmdSetViewport = 0; // 24.5 PFN_vkCmdSetScissor vkCmdSetScissor = 0; // 26.1 + PFN_vkCmdDispatch vkCmdDispatch = 0; // 28 PFN_vkGetPhysicalDeviceSurfaceCapabilitiesKHR vkGetPhysicalDeviceSurfaceCapabilities = 0; // 30.5.1 PFN_vkGetPhysicalDeviceSurfaceFormatsKHR vkGetPhysicalDeviceSurfaceFormats = 0; // 30.5.2 PFN_vkGetPhysicalDeviceSurfacePresentModesKHR vkGetPhysicalDeviceSurfacePresentModes = 0; // 30.5.3 @@ -260,6 +262,9 @@ public: { vkDestroyShaderModule(device, handle_cast<::VkShaderModule>(shaderModule), 0); } // Chapter 10: Pipelines + Result CreateComputePipelines(VkPipelineCache pipelineCache, std::uint32_t createInfoCount, const VkComputePipelineCreateInfo *pCreateInfos, VkPipeline *pPipelines) const + { return { vkCreateComputePipelines(device, handle_cast<::VkPipelineCache>(pipelineCache), createInfoCount, pCreateInfos, 0, handle_cast<::VkPipeline *>(pPipelines)), "vkCreateComputePipelines" }; } + Result CreateGraphicsPipelines(VkPipelineCache pipelineCache, std::uint32_t createInfoCount, const VkGraphicsPipelineCreateInfo *pCreateInfos, VkPipeline *pPipelines) const { return { vkCreateGraphicsPipelines(device, handle_cast<::VkPipelineCache>(pipelineCache), createInfoCount, pCreateInfos, 0, handle_cast<::VkPipeline *>(pPipelines)), "vkCreateGraphicsPipelines" }; } @@ -383,6 +388,10 @@ public: void CmdSetScissor(VkCommandBuffer commandBuffer, std::uint32_t firstScissor, std::uint32_t scissorCount, const VkRect2D *pScissors) const { vkCmdSetScissor(handle_cast<::VkCommandBuffer>(commandBuffer), firstScissor, scissorCount, pScissors); } + // Chapter 28: Dispatching Commands + void CmdDispatch(VkCommandBuffer commandBuffer, std::uint32_t groupCountX, std::uint32_t groupCountY, std::uint32_t groupCountZ) const + { vkCmdDispatch(handle_cast<::VkCommandBuffer>(commandBuffer), groupCountX, groupCountY, groupCountZ); } + // Chapter 30: Window System Integration (WSI) Result GetPhysicalDeviceSurfaceCapabilities(VkSurface surface, VkSurfaceCapabilitiesKHR &rSurfaceCapabilities) const { return { vkGetPhysicalDeviceSurfaceCapabilities(physicalDevice, handle_cast<::VkSurfaceKHR>(surface), &rSurfaceCapabilities), "vkGetPhysicalDeviceSurfaceCapabilities" }; } @@ -470,6 +479,9 @@ public: void SetScissor(std::uint32_t firstScissor, std::uint32_t scissorCount, const VkRect2D *pScissors) const { vk.CmdSetScissor(commandBuffer, firstScissor, scissorCount, pScissors); } + + void Dispatch(std::uint32_t groupCountX, std::uint32_t groupCountY, std::uint32_t groupCountZ) const + { vk.CmdDispatch(commandBuffer, groupCountX, groupCountY, groupCountZ); } }; } // namespace GL diff --git a/source/core/commands.h b/source/core/commands.h index e7dce6a1..793b2589 100644 --- a/source/core/commands.h +++ b/source/core/commands.h @@ -22,6 +22,7 @@ public: using CommandsBackend::clear; using CommandsBackend::draw; using CommandsBackend::draw_instanced; + using CommandsBackend::dispatch; using CommandsBackend::resolve_multisample; using CommandsBackend::begin_query; diff --git a/source/core/module.cpp b/source/core/module.cpp index 0870c478..8c69cd7e 100644 --- a/source/core/module.cpp +++ b/source/core/module.cpp @@ -48,6 +48,8 @@ enum SpirVConstants OP_RETURN_VALUE = 254, OP_UNREACHABLE = 255, + EXEC_LOCAL_SIZE = 17, + DECO_SPEC_ID = 1, DECO_ARRAY_STRIDE = 6, DECO_MATRIX_STRIDE = 7, @@ -513,6 +515,7 @@ void SpirVModule::Reflection::reflect_code(const vector &code) case OP_NAME: reflect_name(op); break; case OP_MEMBER_NAME: reflect_member_name(op); break; case OP_ENTRY_POINT: reflect_entry_point(op); break; + case OP_EXECUTION_MODE: reflect_execution_mode(op); break; case OP_TYPE_VOID: reflect_void_type(op); break; case OP_TYPE_BOOL: reflect_bool_type(op); break; case OP_TYPE_INT: reflect_int_type(op); break; @@ -579,6 +582,18 @@ void SpirVModule::Reflection::reflect_entry_point(CodeIterator op) entry.globals.push_back(&variables[*op]); } +void SpirVModule::Reflection::reflect_execution_mode(CodeIterator op) +{ + EntryPoint &entry = entry_points[*(op+1)]; + unsigned mode = *(op+2); + if(mode==EXEC_LOCAL_SIZE) + { + entry.compute_local_size.x = *(op+3); + entry.compute_local_size.y = *(op+4); + entry.compute_local_size.z = *(op+5); + } +} + void SpirVModule::Reflection::reflect_void_type(CodeIterator op) { types[*(op+1)].type = VOID; diff --git a/source/core/module.h b/source/core/module.h index 26895ee0..480a0669 100644 --- a/source/core/module.h +++ b/source/core/module.h @@ -104,7 +104,8 @@ public: { VERTEX = 0, GEOMETRY = 3, - FRAGMENT = 4 + FRAGMENT = 4, + COMPUTE = 5 }; enum StorageClass @@ -135,6 +136,7 @@ public: unsigned id = 0; Stage stage = VERTEX; std::vector globals; + LinAl::Vector compute_local_size; }; struct StructMember @@ -228,6 +230,7 @@ private: void reflect_name(CodeIterator); void reflect_member_name(CodeIterator); void reflect_entry_point(CodeIterator); + void reflect_execution_mode(CodeIterator); void reflect_void_type(CodeIterator); void reflect_bool_type(CodeIterator); void reflect_int_type(CodeIterator); diff --git a/source/core/program.cpp b/source/core/program.cpp index 8ac13bbd..91f21565 100644 --- a/source/core/program.cpp +++ b/source/core/program.cpp @@ -56,6 +56,10 @@ void Program::add_stages(const Module &mod, const map &spec_values) collect_uniforms(spirv_mod); collect_attributes(spirv_mod); collect_builtins(spirv_mod); + + for(const SpirVModule::EntryPoint &e: spirv_mod.get_entry_points()) + if(e.stage==SpirVModule::COMPUTE) + reflect_data.compute_wg_size = e.compute_local_size; } finalize_uniforms(); diff --git a/source/core/program.h b/source/core/program.h index ca5437f3..c652f468 100644 --- a/source/core/program.h +++ b/source/core/program.h @@ -72,6 +72,8 @@ private: void collect_builtins(const SpirVModule::Structure &); public: + using ProgramBackend::is_compute; + ReflectData::LayoutHash get_uniform_layout_hash() const { return reflect_data.layout_hash; } unsigned get_n_descriptor_sets() const { return reflect_data.n_descriptor_sets; } unsigned get_push_constants_size() const { return reflect_data.push_constants_size; } @@ -92,6 +94,7 @@ public: const ReflectData::AttributeInfo &get_attribute_info(const std::string &) const; int get_attribute_location(const std::string &) const; unsigned get_n_clip_distances() const { return reflect_data.n_clip_distances; } + const LinAl::Vector &get_compute_workgroup_size() const { return reflect_data.compute_wg_size; } using ProgramBackend::set_debug_name; }; diff --git a/source/core/reflectdata.h b/source/core/reflectdata.h index 1e82a61b..981f45f0 100644 --- a/source/core/reflectdata.h +++ b/source/core/reflectdata.h @@ -70,6 +70,7 @@ struct ReflectData unsigned n_descriptor_sets = 0; unsigned push_constants_size = 0; std::vector used_bindings; + LinAl::Vector compute_wg_size; void update_layout_hash(); void update_used_bindings(); diff --git a/source/render/renderer.cpp b/source/render/renderer.cpp index 828efd70..ce0b3447 100644 --- a/source/render/renderer.cpp +++ b/source/render/renderer.cpp @@ -296,6 +296,14 @@ void Renderer::draw_instanced(const Batch &batch, unsigned count) commands.draw_instanced(batch, count); } +void Renderer::dispatch(unsigned count_x, unsigned count_y, unsigned count_z) +{ + apply_state(); + PipelineState &ps = get_pipeline_state(); + commands.use_pipeline(&ps); + commands.dispatch(count_x, count_y, count_z); +} + void Renderer::resolve_multisample(Framebuffer &target) { const State &state = get_state(); diff --git a/source/render/renderer.h b/source/render/renderer.h index fb460cf0..3341d9aa 100644 --- a/source/render/renderer.h +++ b/source/render/renderer.h @@ -231,6 +231,9 @@ public: /** Draws multiple instances of a batch of primitives. A shader must be active. */ void draw_instanced(const Batch &, unsigned); + /** Dispatches a compute operation. */ + void dispatch(unsigned, unsigned = 1, unsigned = 1); + /** Resolves multisample attachments from the active framebuffer into target. */ void resolve_multisample(Framebuffer &target); -- 2.43.0