Fix Shader unit tests

This commit is contained in:
Jérôme Leclercq 2021-04-14 20:11:41 +02:00
parent 3e704b9ea6
commit 54c34869a4
3 changed files with 68 additions and 45 deletions

View File

@ -29,7 +29,8 @@ namespace Nz
GlslWriter(GlslWriter&&) = delete; GlslWriter(GlslWriter&&) = delete;
~GlslWriter() = default; ~GlslWriter() = default;
std::string Generate(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader, const States& conditions = {}); std::string Generate(ShaderAst::StatementPtr& shader, const States& conditions = {});
std::string Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::StatementPtr& shader, const States& conditions = {});
void SetEnv(Environment environment); void SetEnv(Environment environment);

View File

@ -43,19 +43,27 @@ namespace Nz
ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get()); ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get());
// Remove function if it's an entry point of another type than the one selected // Remove function if it's an entry point of another type than the one selected
if (node.entryStage) if (selectedStage)
{ {
ShaderStageType stage = *node.entryStage; if (node.entryStage)
if (stage != selectedStage) {
return ShaderBuilder::NoOp(); ShaderStageType stage = *node.entryStage;
if (stage != *selectedStage)
return ShaderBuilder::NoOp();
entryPoint = func;
}
}
else
{
assert(!entryPoint);
entryPoint = func; entryPoint = func;
} }
return clone; return clone;
} }
ShaderStageType selectedStage; std::optional<ShaderStageType> selectedStage;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr; ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
}; };
@ -79,7 +87,7 @@ namespace Nz
std::string targetName; std::string targetName;
}; };
ShaderStageType stage; std::optional<ShaderStageType> stage;
const States* states = nullptr; const States* states = nullptr;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr; ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream; std::stringstream stream;
@ -97,7 +105,12 @@ namespace Nz
{ {
} }
std::string GlslWriter::Generate(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader, const States& conditions) std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
{
return Generate(std::nullopt, shader, conditions);
}
std::string GlslWriter::Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::StatementPtr& shader, const States& conditions)
{ {
State state; State state;
state.stage = shaderStage; state.stage = shaderStage;
@ -394,7 +407,7 @@ namespace Nz
assert(it != s_builtinMapping.end()); assert(it != s_builtinMapping.end());
const Builtin& builtin = it->second; const Builtin& builtin = it->second;
if (!builtin.stageFlags.Test(m_currentState->stage)) if (m_currentState->stage && !builtin.stageFlags.Test(*m_currentState->stage))
continue; //< This builtin is not active in this stage, skip it continue; //< This builtin is not active in this stage, skip it
fields.push_back({ fields.push_back({
@ -852,8 +865,6 @@ namespace Nz
{ {
bool isOutputPosition = (m_currentState->stage == ShaderStageType::Vertex && m_environment.flipYPosition && targetName == "gl_Position"); bool isOutputPosition = (m_currentState->stage == ShaderStageType::Vertex && m_environment.flipYPosition && targetName == "gl_Position");
AppendLine();
Append(targetName, " = ", outputStructVarName, ".", name); Append(targetName, " = ", outputStructVarName, ".", name);
if (isOutputPosition) if (isOutputPosition)
Append(" * vec4(1.0, ", s_flipYUniformName, ", 1.0, 1.0)"); Append(" * vec4(1.0, ", s_flipYUniformName, ", 1.0, 1.0)");
@ -861,7 +872,6 @@ namespace Nz
AppendLine(";"); AppendLine(";");
} }
AppendLine();
Append("return;"); //< TODO: Don't return if it's the last statement of the function Append("return;"); //< TODO: Don't return if it's the last statement of the function
} }
else else

View File

@ -1,6 +1,5 @@
#include <Nazara/Core/File.hpp> #include <Nazara/Core/File.hpp>
#include <Nazara/Shader/GlslWriter.hpp> #include <Nazara/Shader/GlslWriter.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp> #include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/SpirvPrinter.hpp> #include <Nazara/Shader/SpirvPrinter.hpp>
#include <Nazara/Shader/SpirvWriter.hpp> #include <Nazara/Shader/SpirvWriter.hpp>
@ -18,7 +17,7 @@ std::string_view Trim(std::string_view str)
return str; return str;
} }
void ExpectingGLSL(const Nz::ShaderAst& shader, std::string_view expectedOutput) void ExpectingGLSL(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput)
{ {
Nz::GlslWriter writer; Nz::GlslWriter writer;
@ -30,7 +29,7 @@ void ExpectingGLSL(const Nz::ShaderAst& shader, std::string_view expectedOutput)
REQUIRE(subset == expectedOutput); REQUIRE(subset == expectedOutput);
} }
void ExpectingSpirV(const Nz::ShaderAst& shader, std::string_view expectedOutput) void ExpectingSpirV(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput)
{ {
Nz::SpirvWriter writer; Nz::SpirvWriter writer;
auto spirv = writer.Generate(shader); auto spirv = writer.Generate(shader);
@ -53,42 +52,53 @@ SCENARIO("Shader generation", "[Shader]")
{ {
SECTION("Nested member loading") SECTION("Nested member loading")
{ {
Nz::ShaderAst baseShader(Nz::ShaderStageType::Vertex); std::vector<Nz::ShaderAst::StatementPtr> statements;
baseShader.AddStruct("innerStruct", {
{
"field",
Nz::ShaderNodes::BasicType::Float3
}
});
baseShader.AddStruct("outerStruct", { Nz::ShaderAst::StructDescription innerStructDesc;
{ {
"s", innerStructDesc.name = "innerStruct";
"innerStruct" auto& member = innerStructDesc.members.emplace_back();
} member.name = "field";
}); member.type = Nz::ShaderAst::VectorType{ 3, Nz::ShaderAst::PrimitiveType::Float32 };
}
statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(innerStructDesc)));
baseShader.AddUniform("ubo", "outerStruct"); Nz::ShaderAst::StructDescription outerStruct;
baseShader.AddOutput("result", Nz::ShaderNodes::BasicType::Float1); {
outerStruct.name = "outerStruct";
auto& member = outerStruct.members.emplace_back();
member.name = "s";
member.type = Nz::ShaderAst::IdentifierType{ "innerStruct" };
}
statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(outerStruct)));
auto external = std::make_unique<Nz::ShaderAst::DeclareExternalStatement>();
external->externalVars.push_back({
std::nullopt,
"ubo",
Nz::ShaderAst::IdentifierType{ "outerStruct" }
});
statements.push_back(std::move(external));
SECTION("Nested AccessMember") SECTION("Nested AccessMember")
{ {
Nz::ShaderAst shader = baseShader; auto ubo = Nz::ShaderBuilder::Identifier("ubo");
auto firstAccess = Nz::ShaderBuilder::AccessMember(std::move(ubo), { "s" });
auto secondAccess = Nz::ShaderBuilder::AccessMember(std::move(firstAccess), { "field" });
auto uniform = Nz::ShaderBuilder::Uniform("ubo", "outerStruct"); auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(secondAccess), { Nz::ShaderAst::SwizzleComponent::Third });
auto output = Nz::ShaderBuilder::Output("result", Nz::ShaderNodes::BasicType::Float1); auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::PrimitiveType::Float32, std::move(swizzle));
auto access = Nz::ShaderBuilder::Swizzle(Nz::ShaderBuilder::AccessMember(Nz::ShaderBuilder::AccessMember(Nz::ShaderBuilder::Identifier(uniform), 0, "innerStruct"), 0, Nz::ShaderNodes::BasicType::Float3), Nz::ShaderNodes::SwizzleComponent::Third); statements.push_back(Nz::ShaderBuilder::DeclareFunction("main", std::move(varDecl)));
auto assign = Nz::ShaderBuilder::Assign(Nz::ShaderBuilder::Identifier(output), access);
shader.AddFunction("main", Nz::ShaderBuilder::ExprStatement(assign)); Nz::ShaderAst::StatementPtr shader = Nz::ShaderBuilder::MultiStatement(std::move(statements));
SECTION("Generating GLSL") SECTION("Generating GLSL")
{ {
ExpectingGLSL(shader, R"( ExpectingGLSL(shader, R"(
void main() void main()
{ {
result = ubo.s.field.z; float result = ubo.s.field.z;
} }
)"); )");
} }
@ -97,6 +107,7 @@ void main()
ExpectingSpirV(shader, R"( ExpectingSpirV(shader, R"(
OpFunction OpFunction
OpLabel OpLabel
OpVariable
OpAccessChain OpAccessChain
OpAccessChain OpAccessChain
OpLoad OpLoad
@ -109,22 +120,22 @@ OpFunctionEnd)");
SECTION("AccessMember with multiples fields") SECTION("AccessMember with multiples fields")
{ {
Nz::ShaderAst shader = baseShader; auto ubo = Nz::ShaderBuilder::Identifier("ubo");
auto access = Nz::ShaderBuilder::AccessMember(std::move(ubo), { "s", "field" });
auto uniform = Nz::ShaderBuilder::Uniform("ubo", "outerStruct"); auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(access), { Nz::ShaderAst::SwizzleComponent::Third });
auto output = Nz::ShaderBuilder::Output("result", Nz::ShaderNodes::BasicType::Float1); auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::PrimitiveType::Float32, std::move(swizzle));
auto access = Nz::ShaderBuilder::Swizzle(Nz::ShaderBuilder::AccessMember(Nz::ShaderBuilder::Identifier(uniform), std::vector<std::size_t>{ 0, 0 }, Nz::ShaderNodes::BasicType::Float3), Nz::ShaderNodes::SwizzleComponent::Third); statements.push_back(Nz::ShaderBuilder::DeclareFunction("main", std::move(varDecl)));
auto assign = Nz::ShaderBuilder::Assign(Nz::ShaderBuilder::Identifier(output), access);
shader.AddFunction("main", Nz::ShaderBuilder::ExprStatement(assign)); Nz::ShaderAst::StatementPtr shader = Nz::ShaderBuilder::MultiStatement(std::move(statements));
SECTION("Generating GLSL") SECTION("Generating GLSL")
{ {
ExpectingGLSL(shader, R"( ExpectingGLSL(shader, R"(
void main() void main()
{ {
result = ubo.s.field.z; float result = ubo.s.field.z;
} }
)"); )");
} }
@ -133,6 +144,7 @@ void main()
ExpectingSpirV(shader, R"( ExpectingSpirV(shader, R"(
OpFunction OpFunction
OpLabel OpLabel
OpVariable
OpAccessChain OpAccessChain
OpLoad OpLoad
OpCompositeExtract OpCompositeExtract