#include <cstring>
#include <msp/core/algorithm.h>
+#include <msp/gl/extensions/arb_compute_shader.h>
#include <msp/gl/extensions/arb_es2_compatibility.h>
#include <msp/gl/extensions/arb_fragment_shader.h>
#include <msp/gl/extensions/arb_gl_spirv.h>
case VERTEX: { static Require _req(ARB_vertex_shader); gl_type = GL_VERTEX_SHADER; } break;
case GEOMETRY: { static Require _req(ARB_geometry_shader4); gl_type = GL_GEOMETRY_SHADER; } break;
case FRAGMENT: { static Require _req(ARB_fragment_shader); gl_type = GL_FRAGMENT_SHADER; } break;
+ case COMPUTE: { static Require _req(ARB_compute_shader); gl_type = GL_COMPUTE_SHADER; } break;
default: throw invalid_argument("OpenGLProgram::add_stage");
}
case SL::Stage::VERTEX: stage_id = add_stage(VERTEX); break;
case SL::Stage::GEOMETRY: stage_id = add_stage(GEOMETRY); break;
case SL::Stage::FRAGMENT: stage_id = add_stage(FRAGMENT); break;
+ case SL::Stage::COMPUTE: stage_id = add_stage(COMPUTE); break;
default: throw invalid_operation("OpenGLProgram::add_glsl_stages");
}
link(mod);
query_uniforms();
query_attributes();
+ if(is_compute())
+ {
+ int wg_size[3];
+ glGetProgramiv(id, GL_COMPUTE_WORK_GROUP_SIZE, wg_size);
+ rd.compute_wg_size = LinAl::Vector<unsigned, 3>(wg_size[0], wg_size[1], wg_size[2]);
+ }
const map<string, unsigned> &block_bindings = compiler.get_uniform_block_bindings();
if(!block_bindings.empty())
case SpirVModule::VERTEX: stage_id = add_stage(VERTEX); break;
case SpirVModule::GEOMETRY: stage_id = add_stage(GEOMETRY); break;
case SpirVModule::FRAGMENT: stage_id = add_stage(FRAGMENT); break;
+ case SpirVModule::COMPUTE: stage_id = add_stage(COMPUTE); break;
default: throw invalid_operation("OpenGLProgram::add_spirv_stages");
}
void OpenGLProgram::set_stage_debug_name(unsigned stage_id, Stage type)
{
#ifdef DEBUG
- static const char *const suffixes[] = { " [VS]", " [GS]", " [FS]" };
+ static const char *const suffixes[] = { " [VS]", " [GS]", " [FS]", " [CS]" };
string name = debug_name+suffixes[type];
glObjectLabel(GL_SHADER, stage_id, name.size(), name.c_str());
#else