]> git.tdb.fi Git - libs/gl.git/blobdiff - tests/glsl/glslcompiler.cpp
Fix a compile error in the GLSL compiler test runner
[libs/gl.git] / tests / glsl / glslcompiler.cpp
index 371df6cc48f0fd1d738681627b4b7399e996c2ba..e4b36a398e6b88ca965098b3d3a8a3ed37902e64 100644 (file)
@@ -1,3 +1,4 @@
+#include <spirv-tools/libspirv.hpp>
 #include <msp/core/algorithm.h>
 #include <msp/fs/dir.h>
 #include <msp/fs/utils.h>
@@ -17,7 +18,10 @@ protected:
                Msp::GL::SL::Compiler::Mode compile_mode;
                std::map<std::string, int> spec_values;
                std::map<Msp::GL::SL::Stage::Type, std::string> expected_output;
-               std::string expected_error;
+               std::string expected_diagnostic;
+               bool expect_success;
+
+               TestCase(): expect_success(true) { }
        };
 
        std::list<TestCase> test_cases;
@@ -26,7 +30,7 @@ protected:
        const TestCase &load_test_case(const std::string &);
 
        void verify_output(const std::string &, const std::string &);
-       void verify_error(const std::string &, const std::string &);
+       void verify_diagnostic(const std::string &, const std::string &);
        std::string extract_line(const std::string &, const std::string::const_iterator &);
        virtual void fail(const std::string &) = 0;
 };
@@ -55,13 +59,29 @@ private:
        virtual void fail(const std::string &m) { Test::fail(m); }
 };
 
+class GlslCompilerSpirV: public Msp::Test::RegisteredTest<GlslCompilerSpirV>, private GlslCompilerHelper
+{
+private:
+       spvtools::SpirvTools spirv_tools;
+
+public:
+       GlslCompilerSpirV();
+
+       static const char *get_name() { return "GLSL to SPIR-V compilation"; }
+
+private:
+       void run_test_case(const TestCase *);
+       void diagnostic(spv_message_level_t, const char *, const spv_position_t &, const char *);
+       virtual void fail(const std::string &m) { Test::fail(m); }
+};
+
 using namespace std;
 using namespace Msp;
 
 void GlslCompilerHelper::load_all_test_cases(const FS::Path &tests_dir)
 {
-       list<string> test_files = FS::list_filtered(tests_dir, "\\.glsl$");
-       test_files.sort();
+       vector<string> test_files = FS::list_filtered(tests_dir, "\\.glsl$");
+       sort(test_files);
        for(const auto &fn: test_files)
                load_test_case((tests_dir/fn).str());
 }
@@ -98,9 +118,12 @@ const GlslCompilerHelper::TestCase &GlslCompilerHelper::load_test_case(const str
                }
 
                pos = line.find("Expected error:");
+               if(pos==string::npos)
+                       pos = line.find("Expected diagnostic:");
                if(pos!=string::npos)
                {
-                       target = &test_case.expected_error;
+                       target = &test_case.expected_diagnostic;
+                       test_case.expect_success = (line[pos+9]!='e');
                        continue;
                }
 
@@ -166,7 +189,7 @@ void GlslCompilerHelper::verify_output(const string &output, const string &expec
        }
 }
 
-void GlslCompilerHelper::verify_error(const string &output, const string &expected)
+void GlslCompilerHelper::verify_diagnostic(const string &output, const string &expected)
 {
        auto i = output.begin();
        auto j = expected.begin();
@@ -192,7 +215,7 @@ void GlslCompilerHelper::verify_error(const string &output, const string &expect
                {
                        string out_line = extract_line(output, i);
                        string expect_line = extract_line(expected, j);
-                       fail(format("Incorrect error line:\n%s\nExpected:\n%s", out_line, expect_line));
+                       fail(format("Incorrect diagnostic line:\n%s\nExpected:\n%s", out_line, expect_line));
                }
        }
 
@@ -202,9 +225,9 @@ void GlslCompilerHelper::verify_error(const string &output, const string &expect
                ++j;
 
        if(i!=output.end())
-               fail(format("Extra error line: %s", extract_line(output, i)));
+               fail(format("Extra diagnostic line: %s", extract_line(output, i)));
        if(j!=expected.end())
-               fail(format("Missing error line: %s", extract_line(expected, j)));
+               fail(format("Missing diagnostic line: %s", extract_line(expected, j)));
 }
 
 string GlslCompilerHelper::extract_line(const string &text, const string::const_iterator &iter)
@@ -228,7 +251,7 @@ GlslCompilerTest::GlslCompilerTest()
 
 void GlslCompilerTest::run_test_case(const TestCase *test_case)
 {
-       GL::SL::Compiler compiler(GL::SL::Features::all());
+       GL::SL::Compiler compiler(GL::SL::Features::latest());
        try
        {
                compiler.set_source(test_case->source, "<test>");
@@ -238,19 +261,21 @@ void GlslCompilerTest::run_test_case(const TestCase *test_case)
        }
        catch(const GL::SL::invalid_shader_source &exc)
        {
-               if(!test_case->expected_error.empty())
+               if(!test_case->expect_success)
                {
                        debug("Errors from compile:");
                        debug(exc.what());
-                       verify_error(exc.what(), test_case->expected_error);
+                       verify_diagnostic(exc.what(), test_case->expected_diagnostic);
                        return;
                }
                throw;
        }
 
-       if(!test_case->expected_error.empty())
+       if(!test_case->expect_success)
                fail("Error expected but none thrown");
 
+       verify_diagnostic(compiler.get_diagnostics(), test_case->expected_diagnostic);
+
        auto stages = compiler.get_stages();
        for(auto s: stages)
        {
@@ -277,19 +302,19 @@ GlslCompilerIdempotence::GlslCompilerIdempotence()
 {
        load_all_test_cases("glsl");
        for(const auto &tc: test_cases)
-               if(tc.expected_error.empty())
+               if(tc.expect_success)
                        add(&GlslCompilerIdempotence::run_test_case, &tc, tc.name);
 }
 
 void GlslCompilerIdempotence::run_test_case(const TestCase *test_case)
 {
-       GL::SL::Compiler compiler(GL::SL::Features::all());
+       GL::SL::Compiler compiler(GL::SL::Features::latest());
        compiler.set_source(test_case->source, "<test>");
        if(test_case->compile_mode==GL::SL::Compiler::PROGRAM)
                compiler.specialize(test_case->spec_values);
        compiler.compile(test_case->compile_mode);
 
-       GL::SL::Compiler compiler2(GL::SL::Features::all());
+       GL::SL::Compiler compiler2(GL::SL::Features::latest());
        compiler2.set_source(compiler.get_combined_glsl(), "<loopback>");
        compiler2.compile(test_case->compile_mode);
 
@@ -310,3 +335,41 @@ void GlslCompilerIdempotence::run_test_case(const TestCase *test_case)
        if(j!=stages2.end())
                fail(format("Second pass produced extra stage %s", GL::SL::Stage::get_stage_name(*j)));
 }
+
+
+GlslCompilerSpirV::GlslCompilerSpirV():
+       spirv_tools(SPV_ENV_UNIVERSAL_1_5)
+{
+       load_all_test_cases("glsl");
+       for(const auto &tc: test_cases)
+               if(tc.expect_success)
+                       add(&GlslCompilerSpirV::run_test_case, &tc, tc.name);
+
+       using namespace std::placeholders;
+       spirv_tools.SetMessageConsumer(std::bind(std::mem_fn(&GlslCompilerSpirV::diagnostic), this, _1, _2, _3, _4));
+}
+
+void GlslCompilerSpirV::run_test_case(const TestCase *test_case)
+{
+       GL::SL::Compiler compiler(GL::SL::Features::latest());
+       compiler.set_source(test_case->source, "<test>");
+       compiler.compile(GL::SL::Compiler::SPIRV);
+
+       vector<uint32_t> code = compiler.get_combined_spirv();
+       if(!spirv_tools.Validate(code))
+               fail("Invalid SPIR-V generated");
+}
+
+void GlslCompilerSpirV::diagnostic(spv_message_level_t level, const char *, const spv_position_t &, const char *message)
+{
+       const char *prefix;
+       switch(level)
+       {
+       case SPV_MSG_DEBUG: prefix = "debug: "; break;
+       case SPV_MSG_INFO: prefix = "info: "; break;
+       case SPV_MSG_WARNING: prefix = "warning: "; break;
+       case SPV_MSG_ERROR: prefix = "error: "; break;
+       default: prefix = "";
+       }
+       info(format("%s%s", prefix, message));
+}