1 #include <spirv-tools/libspirv.hpp>
2 #include <msp/core/algorithm.h>
3 #include <msp/fs/dir.h>
4 #include <msp/fs/utils.h>
5 #include <msp/gl/glsl/compiler.h>
6 #include <msp/gl/glsl/glsl_error.h>
7 #include <msp/gl/glsl/tokenizer.h>
8 #include <msp/strings/utils.h>
9 #include <msp/test/test.h>
11 class GlslCompilerHelper
18 Msp::GL::GraphicsApi target_api;
19 Msp::GL::SL::Compiler::Mode compile_mode;
20 std::map<std::string, int> spec_values;
21 std::map<Msp::GL::SL::Stage::Type, std::string> expected_output;
22 std::string expected_diagnostic;
25 TestCase(): expect_success(true) { }
28 std::list<TestCase> test_cases;
30 void load_all_test_cases(const Msp::FS::Path &);
31 const TestCase &load_test_case(const std::string &);
33 void verify_output(const std::string &, const std::string &);
34 void verify_diagnostic(const std::string &, const std::string &);
35 std::string extract_line(const std::string &, const std::string::const_iterator &);
36 virtual void fail(const std::string &) = 0;
39 class GlslCompilerTest: public Msp::Test::RegisteredTest<GlslCompilerTest>, private GlslCompilerHelper
44 static const char *get_name() { return "GLSL compiler"; }
47 void run_test_case(const TestCase *);
48 virtual void fail(const std::string &m) { Test::fail(m); }
51 class GlslCompilerIdempotence: public Msp::Test::RegisteredTest<GlslCompilerIdempotence>, private GlslCompilerHelper
54 GlslCompilerIdempotence();
56 static const char *get_name() { return "GLSL compiler idempotence"; }
59 void run_test_case(const TestCase *);
60 virtual void fail(const std::string &m) { Test::fail(m); }
63 class GlslCompilerSpirV: public Msp::Test::RegisteredTest<GlslCompilerSpirV>, private GlslCompilerHelper
66 spvtools::SpirvTools spirv_tools;
71 static const char *get_name() { return "GLSL to SPIR-V compilation"; }
74 void run_test_case(const TestCase *);
75 void diagnostic(spv_message_level_t, const char *, const spv_position_t &, const char *);
76 virtual void fail(const std::string &m) { Test::fail(m); }
82 void GlslCompilerHelper::load_all_test_cases(const FS::Path &tests_dir)
84 vector<string> test_files = FS::list_filtered(tests_dir, "\\.glsl$");
86 for(const auto &fn: test_files)
87 load_test_case((tests_dir/fn).str());
90 const GlslCompilerHelper::TestCase &GlslCompilerHelper::load_test_case(const string &fn)
92 IO::BufferedFile file(fn);
94 test_case.name = FS::basename(fn);
95 test_case.target_api = GL::OPENGL;
96 test_case.compile_mode = GL::SL::Compiler::PROGRAM;
97 string *target = &test_case.source;
101 if(!file.getline(line))
107 string::size_type pos = line.find("Expected output:");
108 if(pos!=string::npos)
110 string stage = strip(line.substr(pos+16));
112 target = &test_case.expected_output[GL::SL::Stage::VERTEX];
113 else if(stage=="geometry")
114 target = &test_case.expected_output[GL::SL::Stage::GEOMETRY];
115 else if(stage=="fragment")
116 target = &test_case.expected_output[GL::SL::Stage::FRAGMENT];
118 throw runtime_error("Unknown stage "+stage);
122 pos = line.find("Expected error:");
123 if(pos==string::npos)
124 pos = line.find("Expected diagnostic:");
125 if(pos!=string::npos)
127 target = &test_case.expected_diagnostic;
128 test_case.expect_success = (line[pos+9]!='e');
132 pos = line.find("Target API:");
133 if(pos!=string::npos)
135 string api = strip(line.substr(pos+11));
137 test_case.target_api = GL::OPENGL;
138 else if(api=="OpenGL ES")
139 test_case.target_api = GL::OPENGL_ES;
140 else if(api=="Vulkan")
141 test_case.target_api = GL::VULKAN;
143 throw runtime_error("Unknown API "+api);
147 pos = line.find("Compile mode:");
148 if(pos!=string::npos)
150 string mode = strip(line.substr(pos+13));
152 test_case.compile_mode = GL::SL::Compiler::MODULE;
153 else if(mode=="program")
154 test_case.compile_mode = GL::SL::Compiler::PROGRAM;
156 throw runtime_error("Unknown compile mode "+mode);
160 pos = line.find("Specialize:");
161 if(pos!=string::npos)
163 vector<string> parts = split(line.substr(pos+11));
167 else if(parts[1]=="false")
170 value = lexical_cast<int>(parts[1]);
171 test_case.spec_values[parts[0]] = value;
178 test_cases.push_back(test_case);
180 return test_cases.back();
183 void GlslCompilerHelper::verify_output(const string &output, const string &expected)
185 GL::SL::Tokenizer tokenizer;
186 tokenizer.begin(output, "<output>");
188 GL::SL::Tokenizer expected_tkn;
189 expected_tkn.begin(expected, "<expected>");
193 string token = expected_tkn.parse_token();
197 tokenizer.expect(token);
199 catch(const GL::SL::invalid_shader_source &exc)
209 void GlslCompilerHelper::verify_diagnostic(const string &output, const string &expected)
211 auto i = output.begin();
212 auto j = expected.begin();
214 while(i!=output.end() && j!=expected.end())
221 else if(isspace(*i) && isspace(*j))
227 else if(space && isspace(*i))
229 else if(space && isspace(*j))
233 string out_line = extract_line(output, i);
234 string expect_line = extract_line(expected, j);
235 fail(format("Incorrect diagnostic line:\n%s\nExpected:\n%s", out_line, expect_line));
239 while(i!=output.end() && isspace(*i))
241 while(j!=expected.end() && isspace(*j))
245 fail(format("Extra diagnostic line: %s", extract_line(output, i)));
246 if(j!=expected.end())
247 fail(format("Missing diagnostic line: %s", extract_line(expected, j)));
250 string GlslCompilerHelper::extract_line(const string &text, const string::const_iterator &iter)
252 string::const_iterator begin = iter;
253 for(; (begin!=text.begin() && *begin!='\n'); --begin) ;
256 string::const_iterator end = iter;
257 for(; (end!=text.end() && *end!='\n'); ++end) ;
258 return string(begin, end);
262 GlslCompilerTest::GlslCompilerTest()
264 load_all_test_cases("glsl");
265 for(const auto &tc: test_cases)
266 add(&GlslCompilerTest::run_test_case, &tc, tc.name);
269 void GlslCompilerTest::run_test_case(const TestCase *test_case)
271 GL::SL::Compiler compiler(GL::SL::Features::latest(test_case->target_api));
274 compiler.set_source(test_case->source, "<test>");
275 if(test_case->compile_mode==GL::SL::Compiler::PROGRAM)
276 compiler.specialize(test_case->spec_values);
277 compiler.compile(test_case->compile_mode);
279 catch(const GL::SL::invalid_shader_source &exc)
281 if(!test_case->expect_success)
283 debug("Errors from compile:");
285 verify_diagnostic(exc.what(), test_case->expected_diagnostic);
291 if(!test_case->expect_success)
292 fail("Error expected but none thrown");
294 verify_diagnostic(compiler.get_diagnostics(), test_case->expected_diagnostic);
296 auto stages = compiler.get_stages();
299 auto i = test_case->expected_output.find(s);
300 if(i==test_case->expected_output.end())
301 fail(format("Compiler produced extra stage %s", GL::SL::Stage::get_stage_name(s)));
303 string output = compiler.get_stage_glsl(s);
304 debug(format("Output for stage %s:", GL::SL::Stage::get_stage_name(s)));
305 auto lines = split_fields(output, '\n');
306 for(unsigned j=0; j<lines.size(); ++j)
307 debug(format("%3d: %s", j+1, lines[j]));
309 verify_output(output, i->second);
312 for(const auto &s: test_case->expected_output)
313 if(find(stages, s.first)==stages.end())
314 fail(format("Compiler didn't produce stage %s", GL::SL::Stage::get_stage_name(s.first)));
318 GlslCompilerIdempotence::GlslCompilerIdempotence()
320 load_all_test_cases("glsl");
321 for(const auto &tc: test_cases)
322 if(tc.expect_success)
323 add(&GlslCompilerIdempotence::run_test_case, &tc, tc.name);
326 void GlslCompilerIdempotence::run_test_case(const TestCase *test_case)
328 GL::SL::Compiler compiler(GL::SL::Features::latest(test_case->target_api));
329 compiler.set_source(test_case->source, "<test>");
330 if(test_case->compile_mode==GL::SL::Compiler::PROGRAM)
331 compiler.specialize(test_case->spec_values);
332 compiler.compile(test_case->compile_mode);
334 GL::SL::Compiler compiler2(GL::SL::Features::latest(test_case->target_api));
335 compiler2.set_source(compiler.get_combined_glsl(), "<loopback>");
336 compiler2.compile(test_case->compile_mode);
338 auto stages = compiler.get_stages();
339 auto stages2 = compiler2.get_stages();
340 auto i = stages.begin();
341 auto j = stages2.begin();
342 for(; (i!=stages.end() && j!=stages2.end() && *i==*j); ++i, ++j)
344 string output = compiler.get_stage_glsl(*i);
345 string output2 = compiler2.get_stage_glsl(*j);
347 verify_output(output2, output);
351 fail(format("Second pass didn't produce stage %s", GL::SL::Stage::get_stage_name(*i)));
353 fail(format("Second pass produced extra stage %s", GL::SL::Stage::get_stage_name(*j)));
357 GlslCompilerSpirV::GlslCompilerSpirV():
358 spirv_tools(SPV_ENV_UNIVERSAL_1_5)
360 load_all_test_cases("glsl");
361 for(const auto &tc: test_cases)
362 if(tc.expect_success)
363 add(&GlslCompilerSpirV::run_test_case, &tc, tc.name);
365 using namespace std::placeholders;
366 spirv_tools.SetMessageConsumer(std::bind(std::mem_fn(&GlslCompilerSpirV::diagnostic), this, _1, _2, _3, _4));
369 void GlslCompilerSpirV::run_test_case(const TestCase *test_case)
371 GL::SL::Compiler compiler(GL::SL::Features::latest(test_case->target_api));
372 compiler.set_source(test_case->source, "<test>");
373 compiler.compile(GL::SL::Compiler::SPIRV);
375 vector<uint32_t> code = compiler.get_combined_spirv();
376 if(!spirv_tools.Validate(code))
377 fail("Invalid SPIR-V generated");
380 void GlslCompilerSpirV::diagnostic(spv_message_level_t level, const char *, const spv_position_t &, const char *message)
385 case SPV_MSG_DEBUG: prefix = "debug: "; break;
386 case SPV_MSG_INFO: prefix = "info: "; break;
387 case SPV_MSG_WARNING: prefix = "warning: "; break;
388 case SPV_MSG_ERROR: prefix = "error: "; break;
389 default: prefix = "";
391 info(format("%s%s", prefix, message));