]> git.tdb.fi Git - libs/gl.git/blob - source/backends/vulkan/program_backend.cpp
Simplify Program by removing transient data
[libs/gl.git] / source / backends / vulkan / program_backend.cpp
1 #include <cstring>
2 #include <msp/core/algorithm.h>
3 #include "device.h"
4 #include "error.h"
5 #include "program.h"
6 #include "program_backend.h"
7 #include "structurebuilder.h"
8 #include "vulkan.h"
9
10 using namespace std;
11
12 namespace Msp {
13 namespace GL {
14
15 VulkanProgram::VulkanProgram():
16         device(Device::get_current())
17 { }
18
19 VulkanProgram::VulkanProgram(VulkanProgram &&other):
20         device(other.device),
21         n_stages(other.n_stages),
22         stage_flags(other.stage_flags),
23         creation_info(move(other.creation_info)),
24         desc_set_layout_handles(move(other.desc_set_layout_handles)),
25         layout_handle(other.layout_handle)
26 {
27         other.desc_set_layout_handles.clear();
28         other.layout_handle = 0;
29 }
30
31 VulkanProgram::~VulkanProgram()
32 {
33         const VulkanFunctions &vk = device.get_functions();
34
35         if(layout_handle)
36                 vk.DestroyPipelineLayout(layout_handle);
37         for(VkDescriptorSetLayout d: desc_set_layout_handles)
38                 vk.DestroyDescriptorSetLayout(d);
39 }
40
41 bool VulkanProgram::has_stages() const
42 {
43         return n_stages;
44 }
45
46 void VulkanProgram::add_glsl_stages(const GlslModule &, const map<string, int> &)
47 {
48         throw invalid_operation("VulkanProgram::add_glsl_stages");
49 }
50
51 void VulkanProgram::add_spirv_stages(const SpirVModule &mod, const map<string, int> &spec_values)
52 {
53         const vector<SpirVModule::EntryPoint> &entry_points = mod.get_entry_points();
54
55         n_stages = entry_points.size();
56         size_t entry_names_size = 0;
57         for(const SpirVModule::EntryPoint &e: entry_points)
58                 entry_names_size += e.name.size()+1;
59
60         StructureBuilder sb(creation_info, 5);
61         VkPipelineShaderStageCreateInfo *&stage_infos = sb.add<VkPipelineShaderStageCreateInfo>(n_stages);
62         char *&name_table = sb.add<char>(entry_names_size);
63         VkSpecializationInfo *&spec_info = sb.add<VkSpecializationInfo>();
64         VkSpecializationMapEntry *&spec_map = sb.add<VkSpecializationMapEntry>(spec_values.size());
65         int *&spec_data = sb.add<int>(spec_values.size());
66
67         unsigned i = 0;
68         for(const SpirVModule::Constant &c: mod.get_spec_constants())
69         {
70                 auto j = spec_values.find(c.name);
71                 if(j!=spec_values.end())
72                 {
73                         spec_map[i].constantID = c.constant_id;
74                         spec_map[i].offset = i*sizeof(int);
75                         spec_map[i].size = sizeof(int);
76                         spec_data[i] = j->second;
77                         ++i;
78                 }
79         }
80
81         spec_info->mapEntryCount = i;
82         spec_info->pMapEntries = spec_map;
83         spec_info->dataSize = spec_values.size()*sizeof(int);
84         spec_info->pData = spec_data;
85
86         char *name_ptr = name_table;
87         i = 0;
88         for(const SpirVModule::EntryPoint &e: entry_points)
89         {
90                 unsigned stage_bit = get_vulkan_stage(e.stage);
91                 stage_flags |= stage_bit;
92
93                 stage_infos[i].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
94                 stage_infos[i].stage = static_cast<VkShaderStageFlagBits>(stage_bit);
95                 stage_infos[i].module = handle_cast<::VkShaderModule>(mod.handle);
96                 strcpy(name_ptr, e.name.c_str());
97                 stage_infos[i].pName = name_ptr;
98                 name_ptr += e.name.size()+1;
99                 stage_infos[i].pSpecializationInfo = spec_info;
100                 ++i;
101         }
102 }
103
104 void VulkanProgram::finalize_uniforms()
105 {
106         const VulkanFunctions &vk = device.get_functions();
107         const ReflectData &rd = static_cast<const Program *>(this)->reflect_data;
108
109         auto i = find_member(rd.uniform_blocks, static_cast<int>(ReflectData::PUSH_CONSTANT), &ReflectData::UniformBlockInfo::bind_point);
110         const ReflectData::UniformBlockInfo *push_const_block = (i!=rd.uniform_blocks.end() ? &*i : 0);
111
112         desc_set_layout_handles.resize(rd.n_descriptor_sets);
113         for(unsigned j=0; j<rd.n_descriptor_sets; ++j)
114         {
115                 std::vector<VkDescriptorSetLayoutBinding> bindings;
116                 for(const ReflectData::UniformBlockInfo &b: rd.uniform_blocks)
117                         if(b.bind_point>=0 && static_cast<unsigned>(b.bind_point>>20)==j)
118                         {
119                                 bindings.emplace_back();
120                                 VkDescriptorSetLayoutBinding &binding = bindings.back();
121                                 binding.binding = b.bind_point;
122                                 binding.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC;
123                                 binding.descriptorCount = 1;
124                                 binding.stageFlags = VK_SHADER_STAGE_ALL;
125                                 binding.pImmutableSamplers = 0;
126                         }
127
128                 for(const ReflectData::UniformInfo &u: rd.uniforms)
129                         if(u.binding>=0 && static_cast<unsigned>(u.binding>>20)==j && is_image(u.type))
130                         {
131                                 bindings.emplace_back();
132                                 VkDescriptorSetLayoutBinding &binding = bindings.back();
133                                 binding.binding = u.binding;
134                                 binding.descriptorType = VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER;
135                                 binding.descriptorCount = 1;
136                                 binding.stageFlags = VK_SHADER_STAGE_ALL;
137                                 binding.pImmutableSamplers = 0;
138                         }
139
140                 VkDescriptorSetLayoutCreateInfo set_layout_info = { };
141                 set_layout_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
142                 set_layout_info.bindingCount = bindings.size();
143                 set_layout_info.pBindings = bindings.data();
144
145                 vk.CreateDescriptorSetLayout(set_layout_info, desc_set_layout_handles[j]);
146         }
147
148         VkPushConstantRange push_const_range = { };
149         push_const_range.stageFlags = stage_flags;
150         push_const_range.offset = 0;
151         push_const_range.size = (push_const_block ? push_const_block->data_size : 0);
152
153         VkPipelineLayoutCreateInfo layout_info = { };
154         layout_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
155         layout_info.setLayoutCount = rd.n_descriptor_sets;
156         layout_info.pSetLayouts = handle_cast<::VkDescriptorSetLayout *>(desc_set_layout_handles.data());
157         layout_info.pushConstantRangeCount = (push_const_block ? 1 : 0);
158         layout_info.pPushConstantRanges = &push_const_range;
159
160         vk.CreatePipelineLayout(layout_info, layout_handle);
161 }
162
163 } // namespace GL
164 } // namespace Msp