]> git.tdb.fi Git - libs/gl.git/blob - source/core/program.cpp
Determine the default uniform block size regardless of module format
[libs/gl.git] / source / core / program.cpp
1 #include <msp/core/algorithm.h>
2 #include "error.h"
3 #include "program.h"
4
5 using namespace std;
6
7 namespace Msp {
8 namespace GL {
9
10 Program::Program(const Module &mod, const map<string, int> &spec_values)
11 {
12         add_stages(mod, spec_values);
13 }
14
15 void Program::add_stages(const Module &mod, const map<string, int> &spec_values)
16 {
17         if(has_stages())
18                 throw invalid_operation("Program::add_stages");
19
20         TransientData transient;
21         switch(mod.get_format())
22         {
23         case Module::GLSL:
24                 add_glsl_stages(static_cast<const GlslModule &>(mod), spec_values, transient);
25                 break;
26         case Module::SPIR_V:
27                 add_spirv_stages(static_cast<const SpirVModule &>(mod), spec_values, transient);
28                 break;
29         default:
30                 throw invalid_argument("Program::add_stages");
31         }
32
33         reflect_data = ReflectData();
34
35         finalize(mod);
36
37         if(mod.get_format()==Module::GLSL)
38         {
39                 query_uniforms();
40                 query_attributes();
41                 apply_bindings(transient);
42         }
43         else if(mod.get_format()==Module::SPIR_V)
44         {
45                 collect_uniforms(static_cast<const SpirVModule &>(mod), transient.spec_values);
46                 collect_attributes(static_cast<const SpirVModule &>(mod));
47         }
48
49         for(ReflectData::UniformBlockInfo &b: reflect_data.uniform_blocks)
50                 if(!b.data_size && !b.uniforms.empty())
51                 {
52                         const ReflectData::UniformInfo &uni = *b.uniforms.back();
53                         b.data_size = uni.location*16+uni.array_size*get_type_size(uni.type);
54                 }
55
56         for(const ReflectData::UniformInfo &u: reflect_data.uniforms)
57                 require_type(u.type);
58         for(const ReflectData::AttributeInfo &a: reflect_data.attributes)
59                 require_type(a.type);
60 }
61
62 void Program::collect_uniforms(const SpirVModule &mod, const map<unsigned, int> &spec_values)
63 {
64         // Prepare the default block
65         reflect_data.uniform_blocks.push_back(ReflectData::UniformBlockInfo());
66         vector<vector<string> > block_uniform_names(1);
67
68         for(const SpirVModule::Variable &v: mod.get_variables())
69         {
70                 if(v.storage==SpirVModule::UNIFORM && v.struct_type)
71                 {
72                         reflect_data.uniform_blocks.push_back(ReflectData::UniformBlockInfo());
73                         ReflectData::UniformBlockInfo &info = reflect_data.uniform_blocks.back();
74                         info.name = v.struct_type->name;
75                         info.bind_point = v.binding;
76                         info.data_size = v.struct_type->size;
77
78                         string prefix;
79                         if(!v.name.empty())
80                                 prefix = v.struct_type->name+".";
81                         block_uniform_names.push_back(vector<string>());
82                         collect_block_uniforms(*v.struct_type, prefix, 0, spec_values, block_uniform_names.back());
83                 }
84                 else if(v.storage==SpirVModule::UNIFORM_CONSTANT && v.location>=0)
85                 {
86                         block_uniform_names[0].push_back(v.name);
87                         reflect_data.uniforms.push_back(ReflectData::UniformInfo());
88                         ReflectData::UniformInfo &info = reflect_data.uniforms.back();
89                         info.name = v.name;
90                         info.tag = v.name;
91                         info.location = v.location;
92                         info.binding = v.binding;
93                         info.array_size = v.array_size;
94                         info.type = v.type;
95                 }
96         }
97
98         sort_member(reflect_data.uniforms, &ReflectData::UniformInfo::tag);
99
100         for(unsigned i=0; i<reflect_data.uniform_blocks.size(); ++i)
101         {
102                 ReflectData::UniformBlockInfo &block = reflect_data.uniform_blocks[i];
103                 for(const string &n: block_uniform_names[i])
104                 {
105                         // The element is already known to be present
106                         ReflectData::UniformInfo &uni = *lower_bound_member(reflect_data.uniforms, Tag(n), &ReflectData::UniformInfo::tag);
107                         block.uniforms.push_back(&uni);
108                         uni.block = &block;
109                 }
110                 block.sort_uniforms();
111                 block.update_layout_hash();
112         }
113
114         reflect_data.update_layout_hash();
115 }
116
117 void Program::collect_block_uniforms(const SpirVModule::Structure &strct, const string &prefix, unsigned base_offset, const map<unsigned, int> &spec_values, vector<string> &uniform_names)
118 {
119         for(const SpirVModule::StructMember &m: strct.members)
120         {
121                 unsigned offset = base_offset+m.offset;
122                 if(m.struct_type)
123                 {
124                         unsigned array_size = m.array_size;
125                         if(m.array_size_spec)
126                         {
127                                 array_size = m.array_size_spec->i_value;
128                                 auto j = spec_values.find(m.array_size_spec->constant_id);
129                                 if(j!=spec_values.end())
130                                         array_size = j->second;
131                         }
132
133                         if(array_size)
134                         {
135                                 for(unsigned j=0; j<array_size; ++j, offset+=m.array_stride)
136                                         collect_block_uniforms(*m.struct_type, format("%s%s[%d].", prefix, m.name, j), offset, spec_values, uniform_names);
137                         }
138                         else
139                                 collect_block_uniforms(*m.struct_type, prefix+m.name+".", offset, spec_values, uniform_names);
140                 }
141                 else
142                 {
143                         string name = prefix+m.name;
144                         uniform_names.push_back(name);
145                         reflect_data.uniforms.push_back(ReflectData::UniformInfo());
146                         ReflectData::UniformInfo &info = reflect_data.uniforms.back();
147                         info.name = name;
148                         info.tag = name;
149                         info.offset = offset;
150                         info.array_size = m.array_size;
151                         info.array_stride = m.array_stride;
152                         info.matrix_stride = m.matrix_stride;
153                         info.type = m.type;
154                 }
155         }
156 }
157
158 void Program::collect_attributes(const SpirVModule &mod)
159 {
160         for(const SpirVModule::EntryPoint &e: mod.get_entry_points())
161                 if(e.stage==SpirVModule::VERTEX && e.name=="main")
162                 {
163                         for(const SpirVModule::Variable *v: e.globals)
164                                 if(v->storage==SpirVModule::INPUT)
165                                 {
166                                         reflect_data.attributes.push_back(ReflectData::AttributeInfo());
167                                         ReflectData::AttributeInfo &info = reflect_data.attributes.back();
168                                         info.name = v->name;
169                                         info.location = v->location;
170                                         info.array_size = v->array_size;
171                                         info.type = v->type;
172                                 }
173                 }
174 }
175
176 const ReflectData::UniformBlockInfo &Program::get_uniform_block_info(const string &name) const
177 {
178         auto i = find_member(reflect_data.uniform_blocks, name, &ReflectData::UniformBlockInfo::name);
179         if(i==reflect_data.uniform_blocks.end())
180                 throw key_error(name);
181         return *i;
182 }
183
184 const ReflectData::UniformInfo &Program::get_uniform_info(const string &name) const
185 {
186         auto i = lower_bound_member(reflect_data.uniforms, Tag(name), &ReflectData::UniformInfo::tag);
187         if(i==reflect_data.uniforms.end() || i->name!=name)
188                 throw key_error(name);
189         return *i;
190 }
191
192 const ReflectData::UniformInfo &Program::get_uniform_info(Tag tag) const
193 {
194         auto i = lower_bound_member(reflect_data.uniforms, tag, &ReflectData::UniformInfo::tag);
195         if(i==reflect_data.uniforms.end() || i->tag!=tag)
196                 throw key_error(tag);
197         return *i;
198 }
199
200 int Program::get_uniform_location(const string &name) const
201 {
202         if(name[name.size()-1]==']')
203                 throw invalid_argument("Program::get_uniform_location");
204
205         auto i = lower_bound_member(reflect_data.uniforms, Tag(name), &ReflectData::UniformInfo::tag);
206         return i!=reflect_data.uniforms.end() && i->name==name && i->block->bind_point<0 ? i->location : -1;
207 }
208
209 int Program::get_uniform_location(Tag tag) const
210 {
211         auto i = lower_bound_member(reflect_data.uniforms, tag, &ReflectData::UniformInfo::tag);
212         return i!=reflect_data.uniforms.end() && i->tag==tag && i->block->bind_point<0 ? i->location : -1;
213 }
214
215 int Program::get_uniform_binding(Tag tag) const
216 {
217         auto i = lower_bound_member(reflect_data.uniforms, tag, &ReflectData::UniformInfo::tag);
218         return i!=reflect_data.uniforms.end() && i->tag==tag ? i->binding : -1;
219 }
220
221 const ReflectData::AttributeInfo &Program::get_attribute_info(const string &name) const
222 {
223         auto i = lower_bound_member(reflect_data.attributes, name, &ReflectData::AttributeInfo::name);
224         if(i==reflect_data.attributes.end() || i->name!=name)
225                 throw key_error(name);
226         return *i;
227 }
228
229 int Program::get_attribute_location(const string &name) const
230 {
231         if(name[name.size()-1]==']')
232                 throw invalid_argument("Program::get_attribute_location");
233
234         auto i = lower_bound_member(reflect_data.attributes, name, &ReflectData::AttributeInfo::name);
235         return i!=reflect_data.attributes.end() && i->name==name ? i->location : -1;
236 }
237
238
239 Program::Loader::Loader(Program &p, Collection &c):
240         DataFile::CollectionObjectLoader<Program>(p, &c)
241 {
242         add("module", &Loader::module);
243 }
244
245 void Program::Loader::module(const string &n)
246 {
247         map<string, int> spec_values;
248         SpecializationLoader ldr(spec_values);
249         load_sub_with(ldr);
250         obj.add_stages(get_collection().get<Module>(n), spec_values);
251 }
252
253
254 DataFile::Loader::ActionMap Program::SpecializationLoader::shared_actions;
255
256 Program::SpecializationLoader::SpecializationLoader(map<string, int> &sv):
257         spec_values(sv)
258 {
259         set_actions(shared_actions);
260 }
261
262 void Program::SpecializationLoader::init_actions()
263 {
264         add("specialize", &SpecializationLoader::specialize_bool);
265         add("specialize", &SpecializationLoader::specialize_int);
266 }
267
268 void Program::SpecializationLoader::specialize_bool(const string &name, bool value)
269 {
270         spec_values[name] = value;
271 }
272
273 void Program::SpecializationLoader::specialize_int(const string &name, int value)
274 {
275         spec_values[name] = value;
276 }
277
278 } // namespace GL
279 } // namespace Msp