]> git.tdb.fi Git - libs/gl.git/blob - source/backends/vulkan/program_backend.cpp
Support compute shaders and compute operations
[libs/gl.git] / source / backends / vulkan / program_backend.cpp
1 #include <cstring>
2 #include <msp/core/algorithm.h>
3 #include "device.h"
4 #include "error.h"
5 #include "program.h"
6 #include "program_backend.h"
7 #include "structurebuilder.h"
8 #include "vulkan.h"
9
10 using namespace std;
11
12 namespace Msp {
13 namespace GL {
14
15 VulkanProgram::VulkanProgram():
16         device(Device::get_current())
17 { }
18
19 VulkanProgram::VulkanProgram(VulkanProgram &&other):
20         device(other.device),
21         n_stages(other.n_stages),
22         stage_flags(other.stage_flags),
23         creation_info(move(other.creation_info)),
24         desc_set_layout_handles(move(other.desc_set_layout_handles)),
25         layout_handle(other.layout_handle),
26         debug_name(move(other.debug_name))
27 {
28         other.desc_set_layout_handles.clear();
29         other.layout_handle = 0;
30 }
31
32 VulkanProgram::~VulkanProgram()
33 {
34         const VulkanFunctions &vk = device.get_functions();
35
36         if(layout_handle)
37                 vk.DestroyPipelineLayout(layout_handle);
38         for(VkDescriptorSetLayout d: desc_set_layout_handles)
39                 vk.DestroyDescriptorSetLayout(d);
40 }
41
42 bool VulkanProgram::has_stages() const
43 {
44         return n_stages;
45 }
46
47 void VulkanProgram::add_glsl_stages(const GlslModule &, const map<string, int> &)
48 {
49         throw invalid_operation("VulkanProgram::add_glsl_stages");
50 }
51
52 void VulkanProgram::add_spirv_stages(const SpirVModule &mod, const map<string, int> &spec_values)
53 {
54         const vector<SpirVModule::EntryPoint> &entry_points = mod.get_entry_points();
55
56         n_stages = entry_points.size();
57         size_t entry_names_size = 0;
58         for(const SpirVModule::EntryPoint &e: entry_points)
59                 entry_names_size += e.name.size()+1;
60
61         StructureBuilder sb(creation_info, 5);
62         VkPipelineShaderStageCreateInfo *const &stage_infos = sb.add<VkPipelineShaderStageCreateInfo>(n_stages);
63         char *const &name_table = sb.add<char>(entry_names_size);
64         VkSpecializationInfo *const &spec_info = sb.add<VkSpecializationInfo>();
65         VkSpecializationMapEntry *const &spec_map = sb.add<VkSpecializationMapEntry>(spec_values.size());
66         int *const &spec_data = sb.add<int>(spec_values.size());
67
68         unsigned i = 0;
69         for(const SpirVModule::Constant &c: mod.get_spec_constants())
70         {
71                 auto j = spec_values.find(c.name);
72                 if(j!=spec_values.end())
73                 {
74                         spec_map[i].constantID = c.constant_id;
75                         spec_map[i].offset = i*sizeof(int);
76                         spec_map[i].size = sizeof(int);
77                         spec_data[i] = j->second;
78                         ++i;
79                 }
80         }
81
82         spec_info->mapEntryCount = i;
83         spec_info->pMapEntries = spec_map;
84         spec_info->dataSize = spec_values.size()*sizeof(int);
85         spec_info->pData = spec_data;
86
87         char *name_ptr = name_table;
88         i = 0;
89         for(const SpirVModule::EntryPoint &e: entry_points)
90         {
91                 unsigned stage_bit = get_vulkan_stage(e.stage);
92                 stage_flags |= stage_bit;
93
94                 stage_infos[i].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
95                 stage_infos[i].stage = static_cast<VkShaderStageFlagBits>(stage_bit);
96                 stage_infos[i].module = handle_cast<::VkShaderModule>(mod.handle);
97                 strcpy(name_ptr, e.name.c_str());
98                 stage_infos[i].pName = name_ptr;
99                 name_ptr += e.name.size()+1;
100                 stage_infos[i].pSpecializationInfo = spec_info;
101                 ++i;
102         }
103
104 #if DEBUG
105         if(!debug_name.empty())
106                 if(SpirVModule *spirv = static_cast<Program *>(this)->specialized_spirv)
107                         spirv->set_debug_name(debug_name);
108 #endif
109 }
110
111 void VulkanProgram::finalize_uniforms()
112 {
113         const VulkanFunctions &vk = device.get_functions();
114         const ReflectData &rd = static_cast<const Program *>(this)->reflect_data;
115
116         auto i = find_member(rd.uniform_blocks, static_cast<int>(ReflectData::PUSH_CONSTANT), &ReflectData::UniformBlockInfo::bind_point);
117         const ReflectData::UniformBlockInfo *push_const_block = (i!=rd.uniform_blocks.end() ? &*i : 0);
118
119         desc_set_layout_handles.resize(rd.n_descriptor_sets);
120         for(unsigned j=0; j<rd.n_descriptor_sets; ++j)
121         {
122                 std::vector<VkDescriptorSetLayoutBinding> bindings;
123                 for(const ReflectData::UniformBlockInfo &b: rd.uniform_blocks)
124                         if(b.bind_point>=0 && static_cast<unsigned>(b.bind_point>>20)==j)
125                         {
126                                 bindings.emplace_back();
127                                 VkDescriptorSetLayoutBinding &binding = bindings.back();
128                                 binding.binding = b.bind_point&0xFFFFF;
129                                 binding.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
130                                 binding.descriptorCount = 1;
131                                 binding.stageFlags = stage_flags;
132                                 binding.pImmutableSamplers = 0;
133                         }
134
135                 for(const ReflectData::UniformInfo &u: rd.uniforms)
136                         if(u.binding>=0 && static_cast<unsigned>(u.binding>>20)==j && is_image(u.type))
137                         {
138                                 bindings.emplace_back();
139                                 VkDescriptorSetLayoutBinding &binding = bindings.back();
140                                 binding.binding = u.binding&0xFFFFF;
141                                 if(is_sampled_image(u.type))
142                                         binding.descriptorType = VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER;
143                                 else
144                                         binding.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
145                                 binding.descriptorCount = 1;
146                                 binding.stageFlags = stage_flags;
147                                 binding.pImmutableSamplers = 0;
148                         }
149
150                 VkDescriptorSetLayoutCreateInfo set_layout_info = { };
151                 set_layout_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
152                 set_layout_info.bindingCount = bindings.size();
153                 set_layout_info.pBindings = bindings.data();
154
155                 vk.CreateDescriptorSetLayout(set_layout_info, desc_set_layout_handles[j]);
156         }
157
158         VkPushConstantRange push_const_range = { };
159         push_const_range.stageFlags = stage_flags;
160         push_const_range.offset = 0;
161         push_const_range.size = (push_const_block ? push_const_block->data_size : 0);
162
163         VkPipelineLayoutCreateInfo layout_info = { };
164         layout_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
165         layout_info.setLayoutCount = rd.n_descriptor_sets;
166         layout_info.pSetLayouts = handle_cast<::VkDescriptorSetLayout *>(desc_set_layout_handles.data());
167         layout_info.pushConstantRangeCount = (push_const_block ? 1 : 0);
168         layout_info.pPushConstantRanges = &push_const_range;
169
170         vk.CreatePipelineLayout(layout_info, layout_handle);
171
172 #if DEBUG
173         if(!debug_name.empty())
174                 set_vulkan_object_name();
175 #endif
176 }
177
178 bool VulkanProgram::is_compute() const
179 {
180         return stage_flags&VK_SHADER_STAGE_COMPUTE_BIT;
181 }
182
183 void VulkanProgram::set_debug_name(const string &name)
184 {
185 #ifdef DEBUG
186         debug_name = name;
187         set_vulkan_object_name();
188         if(SpirVModule *spirv = static_cast<Program *>(this)->specialized_spirv)
189                 spirv->set_debug_name(debug_name);
190 #else
191         (void)name;
192 #endif
193 }
194
195 void VulkanProgram::set_vulkan_object_name() const
196 {
197 #ifdef DEBUG
198         const VulkanFunctions &vk = device.get_functions();
199
200         string layout_name = debug_name+" [layout]";
201
202         VkDebugUtilsObjectNameInfoEXT name_info = { };
203         name_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_OBJECT_NAME_INFO_EXT;
204         name_info.objectType = VK_OBJECT_TYPE_PIPELINE_LAYOUT;
205         name_info.objectHandle = reinterpret_cast<uint64_t>(layout_handle);
206         name_info.pObjectName = layout_name.c_str();
207         vk.SetDebugUtilsObjectName(name_info);
208
209         name_info.objectType = VK_OBJECT_TYPE_DESCRIPTOR_SET_LAYOUT;
210         for(unsigned i=0; i<desc_set_layout_handles.size(); ++i)
211         {
212                 layout_name = format("%s [layout:%d]", debug_name, i);
213         name_info.objectHandle = reinterpret_cast<uint64_t>(desc_set_layout_handles[i]);
214                 name_info.pObjectName = layout_name.c_str();
215                 vk.SetDebugUtilsObjectName(name_info);
216         }
217 #endif
218 }
219
220 } // namespace GL
221 } // namespace Msp