diff --git a/tests/Engine/Shader/ShaderUtils.cpp b/tests/Engine/Shader/ShaderUtils.cpp index 8d170c8b9..27f20cb46 100644 --- a/tests/Engine/Shader/ShaderUtils.cpp +++ b/tests/Engine/Shader/ShaderUtils.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -126,8 +127,26 @@ void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutpu SECTION("Generating GLSL") { + // Retrieve entry-point to get shader type + std::optional entryShaderStage; + + Nz::ShaderAst::AstReflect::Callbacks callbacks; + callbacks.onEntryPointDeclaration = [&](Nz::ShaderStageType stageType, const std::string& functionName) + { + INFO("multiple entry points found! (" << functionName << ")"); + REQUIRE((!entryShaderStage.has_value() || stageType == entryShaderStage)); + + entryShaderStage = stageType; + }; + + Nz::ShaderAst::AstReflect reflectVisitor; + reflectVisitor.Reflect(shader, callbacks); + + INFO("no entry point found"); + REQUIRE(entryShaderStage.has_value()); + Nz::GlslWriter writer; - std::string output = writer.Generate(shader); + std::string output = writer.Generate(entryShaderStage, shader); WHEN("Validating expected code") { @@ -137,18 +156,30 @@ void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutpu WHEN("Validating full GLSL code (using glslang)") { - glslang::TShader shader(EShLangFragment); - shader.setEnvInput(glslang::EShSourceGlsl, EShLangFragment, glslang::EShClientOpenGL, 300); - shader.setEnvClient(glslang::EShClientOpenGL, glslang::EShTargetOpenGL_450); - shader.setEnvTarget(glslang::EShTargetNone, static_cast(0)); - shader.setEntryPoint("main"); + EShLanguage stage = EShLangVertex; + switch (*entryShaderStage) + { + case Nz::ShaderStageType::Fragment: + stage = EShLangFragment; + break; + + case Nz::ShaderStageType::Vertex: + stage = EShLangVertex; + break; + } + + glslang::TShader glslangShader(stage); + glslangShader.setEnvInput(glslang::EShSourceGlsl, stage, glslang::EShClientOpenGL, 300); + glslangShader.setEnvClient(glslang::EShClientOpenGL, glslang::EShTargetOpenGL_450); + glslangShader.setEnvTarget(glslang::EShTargetNone, static_cast(0)); + glslangShader.setEntryPoint("main"); const char* source = output.c_str(); - shader.setStrings(&source, 1); + glslangShader.setStrings(&source, 1); - if (!shader.parse(&s_minResources, 100, false, static_cast(EShMsgDefault | EShMsgKeepUncalled))) + if (!glslangShader.parse(&s_minResources, 100, false, static_cast(EShMsgDefault | EShMsgKeepUncalled))) { - INFO("full GLSL output:\n" << output << "\nerror:\n" << shader.getInfoLog()); + INFO("full GLSL output:\n" << output << "\nerror:\n" << glslangShader.getInfoLog()); REQUIRE(false); } }