]> git.tdb.fi Git - libs/gl.git/blob - tests/glsl/glslcompiler.cpp
dbac0f516a02d7334feb37fb0dc03d97bd734947
[libs/gl.git] / tests / glsl / glslcompiler.cpp
1 #include <spirv-tools/libspirv.hpp>
2 #include <msp/core/algorithm.h>
3 #include <msp/fs/dir.h>
4 #include <msp/fs/utils.h>
5 #include <msp/gl/glsl/compiler.h>
6 #include <msp/gl/glsl/glsl_error.h>
7 #include <msp/gl/glsl/tokenizer.h>
8 #include <msp/strings/utils.h>
9 #include <msp/test/test.h>
10
11 class GlslCompilerHelper
12 {
13 protected:
14         struct TestCase
15         {
16                 std::string name;
17                 std::string source;
18                 Msp::GL::SL::Compiler::Mode compile_mode;
19                 std::map<std::string, int> spec_values;
20                 std::map<Msp::GL::SL::Stage::Type, std::string> expected_output;
21                 std::string expected_diagnostic;
22                 bool expect_success;
23
24                 TestCase(): expect_success(true) { }
25         };
26
27         std::list<TestCase> test_cases;
28
29         void load_all_test_cases(const Msp::FS::Path &);
30         const TestCase &load_test_case(const std::string &);
31
32         void verify_output(const std::string &, const std::string &);
33         void verify_diagnostic(const std::string &, const std::string &);
34         std::string extract_line(const std::string &, const std::string::const_iterator &);
35         virtual void fail(const std::string &) = 0;
36 };
37
38 class GlslCompilerTest: public Msp::Test::RegisteredTest<GlslCompilerTest>, private GlslCompilerHelper
39 {
40 public:
41         GlslCompilerTest();
42
43         static const char *get_name() { return "GLSL compiler"; }
44
45 private:
46         void run_test_case(const TestCase *);
47         virtual void fail(const std::string &m) { Test::fail(m); }
48 };
49
50 class GlslCompilerIdempotence: public Msp::Test::RegisteredTest<GlslCompilerIdempotence>, private GlslCompilerHelper
51 {
52 public:
53         GlslCompilerIdempotence();
54
55         static const char *get_name() { return "GLSL compiler idempotence"; }
56
57 private:
58         void run_test_case(const TestCase *);
59         virtual void fail(const std::string &m) { Test::fail(m); }
60 };
61
62 class GlslCompilerSpirV: public Msp::Test::RegisteredTest<GlslCompilerSpirV>, private GlslCompilerHelper
63 {
64 private:
65         spvtools::SpirvTools spirv_tools;
66
67 public:
68         GlslCompilerSpirV();
69
70         static const char *get_name() { return "GLSL to SPIR-V compilation"; }
71
72 private:
73         void run_test_case(const TestCase *);
74         void diagnostic(spv_message_level_t, const char *, const spv_position_t &, const char *);
75         virtual void fail(const std::string &m) { Test::fail(m); }
76 };
77
78 using namespace std;
79 using namespace Msp;
80
81 void GlslCompilerHelper::load_all_test_cases(const FS::Path &tests_dir)
82 {
83         vector<string> test_files = FS::list_filtered(tests_dir, "\\.glsl$");
84         sort(test_files);
85         for(const auto &fn: test_files)
86                 load_test_case((tests_dir/fn).str());
87 }
88
89 const GlslCompilerHelper::TestCase &GlslCompilerHelper::load_test_case(const string &fn)
90 {
91         IO::BufferedFile file(fn);
92         TestCase test_case;
93         test_case.name = FS::basename(fn);
94         test_case.compile_mode = GL::SL::Compiler::PROGRAM;
95         string *target = &test_case.source;
96         while(!file.eof())
97         {
98                 string line;
99                 if(!file.getline(line))
100                         break;
101
102                 if(line=="*/")
103                         continue;
104
105                 string::size_type pos = line.find("Expected output:");
106                 if(pos!=string::npos)
107                 {
108                         string stage = strip(line.substr(pos+16));
109                         if(stage=="vertex")
110                                 target = &test_case.expected_output[GL::SL::Stage::VERTEX];
111                         else if(stage=="geometry")
112                                 target = &test_case.expected_output[GL::SL::Stage::GEOMETRY];
113                         else if(stage=="fragment")
114                                 target = &test_case.expected_output[GL::SL::Stage::FRAGMENT];
115                         else
116                                 throw runtime_error("Unknown stage "+stage);
117                         continue;
118                 }
119
120                 pos = line.find("Expected error:");
121                 if(pos==string::npos)
122                         pos = line.find("Expected diagnostic:");
123                 if(pos!=string::npos)
124                 {
125                         target = &test_case.expected_diagnostic;
126                         test_case.expect_success = (line[pos+9]!='e');
127                         continue;
128                 }
129
130                 pos = line.find("Compile mode:");
131                 if(pos!=string::npos)
132                 {
133                         string mode = strip(line.substr(pos+13));
134                         if(mode=="module")
135                                 test_case.compile_mode = GL::SL::Compiler::MODULE;
136                         else if(mode=="program")
137                                 test_case.compile_mode = GL::SL::Compiler::PROGRAM;
138                         else
139                                 throw runtime_error("Unknown compile mode "+mode);
140                         continue;
141                 }
142
143                 pos = line.find("Specialize:");
144                 if(pos!=string::npos)
145                 {
146                         vector<string> parts = split(line.substr(pos+11));
147                         int value = 0;
148                         if(parts[1]=="true")
149                                 value = 1;
150                         else if(parts[1]=="false")
151                                 value = 0;
152                         else
153                                 value = lexical_cast<int>(parts[1]);
154                         test_case.spec_values[parts[0]] = value;
155                         continue;
156                 }
157
158                 *target += line;
159                 *target += '\n';
160         }
161         test_cases.push_back(test_case);
162
163         return test_cases.back();
164 }
165
166 void GlslCompilerHelper::verify_output(const string &output, const string &expected)
167 {
168         GL::SL::Tokenizer tokenizer;
169         tokenizer.begin(output, "<output>");
170
171         GL::SL::Tokenizer expected_tkn;
172         expected_tkn.begin(expected, "<expected>");
173
174         while(1)
175         {
176                 string token = expected_tkn.parse_token();
177
178                 try
179                 {
180                         tokenizer.expect(token);
181                 }
182                 catch(const GL::SL::invalid_shader_source &exc)
183                 {
184                         fail(exc.what());
185                 }
186
187                 if(token.empty())
188                         break;
189         }
190 }
191
192 void GlslCompilerHelper::verify_diagnostic(const string &output, const string &expected)
193 {
194         auto i = output.begin();
195         auto j = expected.begin();
196         bool space = true;
197         while(i!=output.end() && j!=expected.end())
198         {
199                 if(*i==*j)
200                 {
201                         ++i;
202                         ++j;
203                 }
204                 else if(isspace(*i) && isspace(*j))
205                 {
206                         ++i;
207                         ++j;
208                         space = true;
209                 }
210                 else if(space && isspace(*i))
211                         ++i;
212                 else if(space && isspace(*j))
213                         ++j;
214                 else
215                 {
216                         string out_line = extract_line(output, i);
217                         string expect_line = extract_line(expected, j);
218                         fail(format("Incorrect diagnostic line:\n%s\nExpected:\n%s", out_line, expect_line));
219                 }
220         }
221
222         while(i!=output.end() && isspace(*i))
223                 ++i;
224         while(j!=expected.end() && isspace(*j))
225                 ++j;
226
227         if(i!=output.end())
228                 fail(format("Extra diagnostic line: %s", extract_line(output, i)));
229         if(j!=expected.end())
230                 fail(format("Missing diagnostic line: %s", extract_line(expected, j)));
231 }
232
233 string GlslCompilerHelper::extract_line(const string &text, const string::const_iterator &iter)
234 {
235         string::const_iterator begin = iter;
236         for(; (begin!=text.begin() && *begin!='\n'); --begin) ;
237         if(*begin=='\n')
238                 ++begin;
239         string::const_iterator end = iter;
240         for(; (end!=text.end() && *end!='\n'); ++end) ;
241         return string(begin, end);
242 }
243
244
245 GlslCompilerTest::GlslCompilerTest()
246 {
247         load_all_test_cases("glsl");
248         for(const auto &tc: test_cases)
249                 add(&GlslCompilerTest::run_test_case, &tc, tc.name);
250 }
251
252 void GlslCompilerTest::run_test_case(const TestCase *test_case)
253 {
254         GL::SL::Compiler compiler(GL::SL::Features::latest());
255         try
256         {
257                 compiler.set_source(test_case->source, "<test>");
258                 if(test_case->compile_mode==GL::SL::Compiler::PROGRAM)
259                         compiler.specialize(test_case->spec_values);
260                 compiler.compile(test_case->compile_mode);
261         }
262         catch(const GL::SL::invalid_shader_source &exc)
263         {
264                 if(!test_case->expect_success)
265                 {
266                         debug("Errors from compile:");
267                         debug(exc.what());
268                         verify_diagnostic(exc.what(), test_case->expected_diagnostic);
269                         return;
270                 }
271                 throw;
272         }
273
274         if(!test_case->expect_success)
275                 fail("Error expected but none thrown");
276
277         verify_diagnostic(compiler.get_diagnostics(), test_case->expected_diagnostic);
278
279         auto stages = compiler.get_stages();
280         for(auto s: stages)
281         {
282                 auto i = test_case->expected_output.find(s);
283                 if(i==test_case->expected_output.end())
284                         fail(format("Compiler produced extra stage %s", GL::SL::Stage::get_stage_name(s)));
285
286                 string output = compiler.get_stage_glsl(s);
287                 debug(format("Output for stage %s:", GL::SL::Stage::get_stage_name(s)));
288                 auto lines = split_fields(output, '\n');
289                 for(unsigned j=0; j<lines.size(); ++j)
290                         debug(format("%3d: %s", j+1, lines[j]));
291
292                 verify_output(output, i->second);
293         }
294
295         for(const auto &s: test_case->expected_output)
296                 if(find(stages, s.first)==stages.end())
297                         fail(format("Compiler didn't produce stage %s", GL::SL::Stage::get_stage_name(s.first)));
298 }
299
300
301 GlslCompilerIdempotence::GlslCompilerIdempotence()
302 {
303         load_all_test_cases("glsl");
304         for(const auto &tc: test_cases)
305                 if(tc.expect_success)
306                         add(&GlslCompilerIdempotence::run_test_case, &tc, tc.name);
307 }
308
309 void GlslCompilerIdempotence::run_test_case(const TestCase *test_case)
310 {
311         GL::SL::Compiler compiler(GL::SL::Features::latest());
312         compiler.set_source(test_case->source, "<test>");
313         if(test_case->compile_mode==GL::SL::Compiler::PROGRAM)
314                 compiler.specialize(test_case->spec_values);
315         compiler.compile(test_case->compile_mode);
316
317         GL::SL::Compiler compiler2(GL::SL::Features::latest());
318         compiler2.set_source(compiler.get_combined_glsl(), "<loopback>");
319         compiler2.compile(test_case->compile_mode);
320
321         auto stages = compiler.get_stages();
322         auto stages2 = compiler2.get_stages();
323         auto i = stages.begin();
324         auto j = stages2.begin();
325         for(; (i!=stages.end() && j!=stages2.end() && *i==*j); ++i, ++j)
326         {
327                 string output = compiler.get_stage_glsl(*i);
328                 string output2 = compiler2.get_stage_glsl(*j);
329
330                 verify_output(output2, output);
331         }
332
333         if(i!=stages.end())
334                 fail(format("Second pass didn't produce stage %s", GL::SL::Stage::get_stage_name(*i)));
335         if(j!=stages2.end())
336                 fail(format("Second pass produced extra stage %s", GL::SL::Stage::get_stage_name(*j)));
337 }
338
339
340 GlslCompilerSpirV::GlslCompilerSpirV():
341         spirv_tools(SPV_ENV_UNIVERSAL_1_5)
342 {
343         load_all_test_cases("glsl");
344         for(const auto &tc: test_cases)
345                 if(tc.expect_success)
346                         add(&GlslCompilerSpirV::run_test_case, &tc, tc.name);
347
348         using namespace std::placeholders;
349         spirv_tools.SetMessageConsumer(std::bind(std::mem_fn(&GlslCompilerSpirV::diagnostic), this, _1, _2, _3, _4));
350 }
351
352 void GlslCompilerSpirV::run_test_case(const TestCase *test_case)
353 {
354         GL::SL::Compiler compiler(GL::SL::Features::latest());
355         compiler.set_source(test_case->source, "<test>");
356         compiler.compile(GL::SL::Compiler::SPIRV);
357
358         vector<UInt32> code = compiler.get_combined_spirv();
359         if(!spirv_tools.Validate(code))
360                 fail("Invalid SPIR-V generated");
361 }
362
363 void GlslCompilerSpirV::diagnostic(spv_message_level_t level, const char *, const spv_position_t &, const char *message)
364 {
365         const char *prefix;
366         switch(level)
367         {
368         case SPV_MSG_DEBUG: prefix = "debug: "; break;
369         case SPV_MSG_INFO: prefix = "info: "; break;
370         case SPV_MSG_WARNING: prefix = "warning: "; break;
371         case SPV_MSG_ERROR: prefix = "error: "; break;
372         default: prefix = "";
373         }
374         info(format("%s%s", prefix, message));
375 }