From: Mikko Rasa Date: Mon, 11 Apr 2022 11:44:16 +0000 (+0300) Subject: Support compute shaders in the shader compiler X-Git-Url: http://git.tdb.fi/?p=libs%2Fgl.git;a=commitdiff_plain;h=30c7ba8f7fd08c13562c86bf651bdc3ec8d30ab5 Support compute shaders in the shader compiler --- diff --git a/builtin_data/_builtin.glsl b/builtin_data/_builtin.glsl index 157e05a7..51035d32 100644 --- a/builtin_data/_builtin.glsl +++ b/builtin_data/_builtin.glsl @@ -664,3 +664,10 @@ in int gl_SampleID; in vec2 gl_SamplePosition; in int gl_Layer; out float gl_FragDepth; + +#pragma MSP stage(compute) +in uvec3 gl_NumWorkGroups; +in uvec3 gl_WorkGroupID; +in uvec3 gl_LocalInvocationID; +in uvec3 gl_GlobalInvocationID; +in uint gl_LocalInvocationIndex; diff --git a/source/glsl/finalize.cpp b/source/glsl/finalize.cpp index 85a6b07a..94052432 100644 --- a/source/glsl/finalize.cpp +++ b/source/glsl/finalize.cpp @@ -473,6 +473,13 @@ bool StructuralFeatureConverter::supports_stage(Stage::Type st) const else return check_version(Version(1, 50)); } + else if(st==Stage::COMPUTE) + { + if(features.target_api==OPENGL_ES) + return check_version(Version(3, 10)); + else + return check_version(Version(4, 30)); + } else return true; } diff --git a/source/glsl/preprocessor.cpp b/source/glsl/preprocessor.cpp index ea6dd6de..f2becd91 100644 --- a/source/glsl/preprocessor.cpp +++ b/source/glsl/preprocessor.cpp @@ -93,6 +93,8 @@ void Preprocessor::preprocess_stage() stage = Stage::GEOMETRY; else if(token=="fragment") stage = Stage::FRAGMENT; + else if(token=="compute") + stage = Stage::COMPUTE; else throw parse_error(tokenizer.get_location(), token, "stage identifier"); tokenizer.expect(")"); diff --git a/source/glsl/spirv.cpp b/source/glsl/spirv.cpp index 9cb1b7f2..8cb485c9 100644 --- a/source/glsl/spirv.cpp +++ b/source/glsl/spirv.cpp @@ -199,6 +199,18 @@ SpirVGenerator::BuiltinSemantic SpirVGenerator::get_builtin_semantic(const strin return BUILTIN_SAMPLE_POSITION; else if(name=="gl_FragDepth") return BUILTIN_FRAG_DEPTH; + else if(name=="gl_NumWorkGroups") + return BUILTIN_NUM_WORKGROUPS; + else if(name=="gl_WorkGroupSize") + return BUILTIN_WORKGROUP_SIZE; + else if(name=="gl_WorkGroupID") + return BUILTIN_WORKGROUP_ID; + else if(name=="gl_LocalInvocationID") + return BUILTIN_LOCAL_INVOCATION_ID; + else if(name=="gl_GlobalInvocationID") + return BUILTIN_GLOBAL_INVOCATION_ID; + else if(name=="gl_LocalInvocationIndex") + return BUILTIN_LOCAL_INVOCATION_INDEX; else throw invalid_argument("SpirVGenerator::get_builtin_semantic"); } @@ -1779,6 +1791,7 @@ void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id) case Stage::VERTEX: writer.write(0); break; case Stage::GEOMETRY: writer.write(3); break; case Stage::FRAGMENT: writer.write(4); break; + case Stage::COMPUTE: writer.write(5); break; default: throw internal_error("unknown stage"); } writer.write(func_id); @@ -1803,6 +1816,8 @@ void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id) writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_INVOCATIONS, 1); } + unsigned local_size[3] = { 0, 1, 1 }; + for(const InterfaceLayout *i: interface_layouts) { for(const Layout::Qualifier &q: i->layout.qualifiers) @@ -1824,8 +1839,24 @@ void SpirVGenerator::visit_entry_point(FunctionDeclaration &func, Id func_id) writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_TRIANGLE_STRIP); else if(q.name=="max_vertices") writer.write_op(content.exec_modes, OP_EXECUTION_MODE, func_id, EXEC_OUTPUT_VERTICES, q.value); + else if(q.name=="local_size_x") + local_size[0] = q.value; + else if(q.name=="local_size_y") + local_size[1] = q.value; + else if(q.name=="local_size_z") + local_size[2] = q.value; } } + + if(stage->type==Stage::COMPUTE && local_size[0]) + { + writer.begin_op(content.exec_modes, OP_EXECUTION_MODE); + writer.write(func_id); + writer.write(EXEC_LOCAL_SIZE); + for(unsigned j=0; j<3; ++j) + writer.write(local_size[j]); + writer.end_op(OP_EXECUTION_MODE); + } } void SpirVGenerator::visit(FunctionDeclaration &func) diff --git a/source/glsl/spirvconstants.h b/source/glsl/spirvconstants.h index ac934eae..d60b809c 100644 --- a/source/glsl/spirvconstants.h +++ b/source/glsl/spirvconstants.h @@ -177,6 +177,7 @@ enum SpirVExecutionMode EXEC_INVOCATIONS = 0, EXEC_ORIGIN_UPPER_LEFT = 7, EXEC_ORIGIN_LOWER_LEFT = 8, + EXEC_LOCAL_SIZE = 17, EXEC_INPUT_POINTS = 19, EXEC_INPUT_LINES = 20, EXEC_INPUT_LINES_ADJACENCY = 21, @@ -229,7 +230,13 @@ enum SpirVBuiltin BUILTIN_FRONT_FACING = 17, BUILTIN_SAMPLE_ID = 18, BUILTIN_SAMPLE_POSITION = 19, - BUILTIN_FRAG_DEPTH = 22 + BUILTIN_FRAG_DEPTH = 22, + BUILTIN_NUM_WORKGROUPS = 24, + BUILTIN_WORKGROUP_SIZE = 25, + BUILTIN_WORKGROUP_ID = 26, + BUILTIN_LOCAL_INVOCATION_ID = 27, + BUILTIN_GLOBAL_INVOCATION_ID = 28, + BUILTIN_LOCAL_INVOCATION_INDEX = 29 }; enum SpirVFormat diff --git a/source/glsl/syntax.cpp b/source/glsl/syntax.cpp index 38a7f003..b805a95a 100644 --- a/source/glsl/syntax.cpp +++ b/source/glsl/syntax.cpp @@ -332,7 +332,7 @@ Stage::Stage(Stage::Type t): const char *Stage::get_stage_name(Type type) { - static const char *const names[] = { "shared", "vertex", "geometry", "fragment" }; + static const char *const names[] = { "shared", "vertex", "geometry", "fragment", "compute" }; return names[type]; } diff --git a/source/glsl/syntax.h b/source/glsl/syntax.h index 63e41921..1f43eac3 100644 --- a/source/glsl/syntax.h +++ b/source/glsl/syntax.h @@ -490,7 +490,8 @@ struct Stage SHARED, VERTEX, GEOMETRY, - FRAGMENT + FRAGMENT, + COMPUTE }; Type type; diff --git a/source/glsl/validate.cpp b/source/glsl/validate.cpp index 6151483c..d669b237 100644 --- a/source/glsl/validate.cpp +++ b/source/glsl/validate.cpp @@ -156,6 +156,8 @@ void DeclarationValidator::visit(Layout &layout) allowed = (iface_block && !variable && iface_block->interface=="uniform"); value = false; } + else if(q.name=="local_size_x" || q.name=="local_size_y" || q.name=="local_size_z") + allowed = (stage->type==Stage::COMPUTE && iface_layout && iface_layout->interface=="in"); else if(q.name=="rgba32f" || q.name=="rgba16f" || q.name=="rg32f" || q.name=="rg16f" || q.name=="r32f" || q.name=="r16f" || q.name=="rgba16" || q.name=="rgba8" || q.name=="rg16" || q.name=="rg8" || q.name=="r16" || q.name=="r8" || q.name=="rgba16_snorm" || q.name=="rgba8_snorm" || q.name=="rg16_snorm" || q.name=="rg8_snorm" || q.name=="r16_snorm" || q.name=="r8_snorm")