UnitTests/Shader: Improve GLSL handling (detect shader point type instead of assuming fragment)

This commit is contained in:
Jérôme Leclercq 2021-12-28 11:49:20 +01:00
parent a0f66d9e88
commit 47e2ec35e3
1 changed files with 40 additions and 9 deletions

View File

@ -5,6 +5,7 @@
#include <Nazara/Shader/ShaderLangParser.hpp> #include <Nazara/Shader/ShaderLangParser.hpp>
#include <Nazara/Shader/SpirvPrinter.hpp> #include <Nazara/Shader/SpirvPrinter.hpp>
#include <Nazara/Shader/SpirvWriter.hpp> #include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Ast/AstReflect.hpp>
#include <catch2/catch.hpp> #include <catch2/catch.hpp>
#include <glslang/Public/ShaderLang.h> #include <glslang/Public/ShaderLang.h>
#include <spirv-tools/libspirv.hpp> #include <spirv-tools/libspirv.hpp>
@ -126,8 +127,26 @@ void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutpu
SECTION("Generating GLSL") SECTION("Generating GLSL")
{ {
// Retrieve entry-point to get shader type
std::optional<Nz::ShaderStageType> entryShaderStage;
Nz::ShaderAst::AstReflect::Callbacks callbacks;
callbacks.onEntryPointDeclaration = [&](Nz::ShaderStageType stageType, const std::string& functionName)
{
INFO("multiple entry points found! (" << functionName << ")");
REQUIRE((!entryShaderStage.has_value() || stageType == entryShaderStage));
entryShaderStage = stageType;
};
Nz::ShaderAst::AstReflect reflectVisitor;
reflectVisitor.Reflect(shader, callbacks);
INFO("no entry point found");
REQUIRE(entryShaderStage.has_value());
Nz::GlslWriter writer; Nz::GlslWriter writer;
std::string output = writer.Generate(shader); std::string output = writer.Generate(entryShaderStage, shader);
WHEN("Validating expected code") WHEN("Validating expected code")
{ {
@ -137,18 +156,30 @@ void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutpu
WHEN("Validating full GLSL code (using glslang)") WHEN("Validating full GLSL code (using glslang)")
{ {
glslang::TShader shader(EShLangFragment); EShLanguage stage = EShLangVertex;
shader.setEnvInput(glslang::EShSourceGlsl, EShLangFragment, glslang::EShClientOpenGL, 300); switch (*entryShaderStage)
shader.setEnvClient(glslang::EShClientOpenGL, glslang::EShTargetOpenGL_450); {
shader.setEnvTarget(glslang::EShTargetNone, static_cast<glslang::EShTargetLanguageVersion>(0)); case Nz::ShaderStageType::Fragment:
shader.setEntryPoint("main"); stage = EShLangFragment;
break;
case Nz::ShaderStageType::Vertex:
stage = EShLangVertex;
break;
}
glslang::TShader glslangShader(stage);
glslangShader.setEnvInput(glslang::EShSourceGlsl, stage, glslang::EShClientOpenGL, 300);
glslangShader.setEnvClient(glslang::EShClientOpenGL, glslang::EShTargetOpenGL_450);
glslangShader.setEnvTarget(glslang::EShTargetNone, static_cast<glslang::EShTargetLanguageVersion>(0));
glslangShader.setEntryPoint("main");
const char* source = output.c_str(); const char* source = output.c_str();
shader.setStrings(&source, 1); glslangShader.setStrings(&source, 1);
if (!shader.parse(&s_minResources, 100, false, static_cast<EShMessages>(EShMsgDefault | EShMsgKeepUncalled))) if (!glslangShader.parse(&s_minResources, 100, false, static_cast<EShMessages>(EShMsgDefault | EShMsgKeepUncalled)))
{ {
INFO("full GLSL output:\n" << output << "\nerror:\n" << shader.getInfoLog()); INFO("full GLSL output:\n" << output << "\nerror:\n" << glslangShader.getInfoLog());
REQUIRE(false); REQUIRE(false);
} }
} }