Rework shader module unit tests

This commit is contained in:
Jérôme Leclercq
2021-12-23 17:39:24 +01:00
parent a5cc915948
commit b53d2a0560
6 changed files with 165 additions and 108 deletions

View File

@@ -1,75 +1,36 @@
#include <Engine/Shader/ShaderUtils.hpp>
#include <Nazara/Core/File.hpp>
#include <Nazara/Core/StringExt.hpp>
#include <Nazara/Shader/GlslWriter.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/SpirvPrinter.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <catch2/catch.hpp>
#include <spirv-tools/libspirv.hpp>
#include <cctype>
void ExpectingGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput)
{
Nz::GlslWriter writer;
std::string output = writer.Generate(shader);
std::size_t funcOffset = output.find("void main()");
std::string_view subset = Nz::Trim(output).substr(funcOffset);
expectedOutput = Nz::Trim(expectedOutput);
REQUIRE(subset == expectedOutput);
}
void ExpectingSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput)
{
Nz::SpirvWriter writer;
auto spirv = writer.Generate(shader);
Nz::SpirvPrinter printer;
Nz::SpirvPrinter::Settings settings;
settings.printHeader = false;
settings.printParameters = false;
std::string output = printer.Print(spirv.data(), spirv.size(), settings);
std::size_t funcOffset = output.find("OpFunction");
std::string_view subset = Nz::Trim(output).substr(funcOffset);
expectedOutput = Nz::Trim(expectedOutput);
REQUIRE(subset == expectedOutput);
}
SCENARIO("Shader generation", "[Shader]")
{
SECTION("Nested member loading")
{
std::vector<Nz::ShaderAst::StatementPtr> statements;
std::string_view nzslSource = R"(
struct innerStruct
{
field: vec3<f32>
}
Nz::ShaderAst::StructDescription innerStructDesc;
{
innerStructDesc.name = "innerStruct";
auto& member = innerStructDesc.members.emplace_back();
member.name = "field";
member.type = Nz::ShaderAst::VectorType{ 3, Nz::ShaderAst::PrimitiveType::Float32 };
}
statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(innerStructDesc)));
struct outerStruct
{
s: innerStruct
}
Nz::ShaderAst::StructDescription outerStruct;
{
outerStruct.name = "outerStruct";
auto& member = outerStruct.members.emplace_back();
member.name = "s";
member.type = Nz::ShaderAst::IdentifierType{ "innerStruct" };
}
statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(outerStruct)));
external
{
[set(0), binding(0)] ubo: uniform<outerStruct>
}
)";
auto external = std::make_unique<Nz::ShaderAst::DeclareExternalStatement>();
auto& externalVar = external->externalVars.emplace_back();
externalVar.bindingIndex = 0;
externalVar.name = "ubo";
externalVar.type = Nz::ShaderAst::UniformType{ Nz::ShaderAst::IdentifierType{ "outerStruct" } };
statements.push_back(std::move(external));
Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource);
REQUIRE(shader->GetType() == Nz::ShaderAst::NodeType::MultiStatement);
Nz::ShaderAst::MultiStatement& multiStatement = static_cast<Nz::ShaderAst::MultiStatement&>(*shader);
SECTION("Nested AccessMember")
{
@@ -80,22 +41,24 @@ SCENARIO("Shader generation", "[Shader]")
auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(secondAccess), { 2u });
auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::PrimitiveType::Float32, std::move(swizzle));
statements.push_back(Nz::ShaderBuilder::DeclareFunction("main", std::move(varDecl)));
multiStatement.statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl)));
Nz::ShaderAst::StatementPtr shader = Nz::ShaderBuilder::MultiStatement(std::move(statements));
SECTION("Generating GLSL")
{
ExpectingGLSL(*shader, R"(
ExpectingGLSL(*shader, R"(
void main()
{
float result = ubo.s.field.z;
}
)");
}
SECTION("Generating Spir-V")
{
ExpectingSpirV(*shader, R"(
ExpectingNZSL(*shader, R"(
[entry(vert)]
fn main()
{
let result: f32 = ubo.s.field.z;
}
)");
ExpectingSpirV(*shader, R"(
OpFunction
OpLabel
OpVariable
@@ -105,7 +68,6 @@ OpCompositeExtract
OpStore
OpReturn
OpFunctionEnd)");
}
}
SECTION("AccessMember with multiples fields")
@@ -116,22 +78,24 @@ OpFunctionEnd)");
auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(access), { 2u });
auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::PrimitiveType::Float32, std::move(swizzle));
statements.push_back(Nz::ShaderBuilder::DeclareFunction("main", std::move(varDecl)));
multiStatement.statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl)));
Nz::ShaderAst::StatementPtr shader = Nz::ShaderBuilder::MultiStatement(std::move(statements));
SECTION("Generating GLSL")
{
ExpectingGLSL(*shader, R"(
ExpectingGLSL(*shader, R"(
void main()
{
float result = ubo.s.field.z;
}
)");
}
SECTION("Generating Spir-V")
{
ExpectingSpirV(*shader, R"(
ExpectingNZSL(*shader, R"(
[entry(vert)]
fn main()
{
let result: f32 = ubo.s.field.z;
}
)");
ExpectingSpirV(*shader, R"(
OpFunction
OpLabel
OpVariable
@@ -141,7 +105,6 @@ OpCompositeExtract
OpStore
OpReturn
OpFunctionEnd)");
}
}
}
}

View File

@@ -0,0 +1,76 @@
#include <Engine/Shader/ShaderUtils.hpp>
#include <Nazara/Core/StringExt.hpp>
#include <Nazara/Shader/GlslWriter.hpp>
#include <Nazara/Shader/LangWriter.hpp>
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <Nazara/Shader/SpirvPrinter.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <catch2/catch.hpp>
#include <spirv-tools/libspirv.hpp>
void ExpectingGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput)
{
expectedOutput = Nz::Trim(expectedOutput);
Nz::GlslWriter writer;
SECTION("Generating GLSL")
{
std::string output = writer.Generate(shader);
INFO("full GLSL output:\n" << output << "\nexcepted output:\n" << expectedOutput);
REQUIRE(output.find(expectedOutput) != std::string::npos);
}
}
void ExpectingNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput)
{
expectedOutput = Nz::Trim(expectedOutput);
Nz::LangWriter writer;
SECTION("Generating NZSL")
{
std::string output = writer.Generate(shader);
INFO("full NZSL output:\n" << output << "\nexcepted output:\n" << expectedOutput);
REQUIRE(output.find(expectedOutput) != std::string::npos);
// validate NZSL by recompiling it
REQUIRE_NOTHROW(Nz::ShaderLang::Parse(output));
}
}
void ExpectingSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput)
{
expectedOutput = Nz::Trim(expectedOutput);
Nz::SpirvWriter writer;
Nz::SpirvPrinter printer;
Nz::SpirvPrinter::Settings settings;
settings.printHeader = false;
settings.printParameters = false;
SECTION("Generating SPIRV")
{
auto spirv = writer.Generate(shader);
std::string output = printer.Print(spirv.data(), spirv.size(), settings);
INFO("full SPIRV output:\n" << output << "\nexcepted output:\n" << expectedOutput);
REQUIRE(output.find(expectedOutput) != std::string::npos);
// validate SPIRV with libspirv
spvtools::SpirvTools spirvTools(spv_target_env::SPV_ENV_VULKAN_1_0);
spirvTools.SetMessageConsumer([&](spv_message_level_t /*level*/, const char* /*source*/, const spv_position_t& /*position*/, const char* message)
{
std::string fullSpirv;
if (!spirvTools.Disassemble(spirv, &fullSpirv))
fullSpirv = "<failed to disassemble SPIRV>";
UNSCOPED_INFO(fullSpirv + "\n" + message);
});
REQUIRE(spirvTools.Validate(spirv));
}
}

View File

@@ -0,0 +1,13 @@
#pragma once
#ifndef NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP
#define NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP
#include <Nazara/Shader/Ast/Nodes.hpp>
#include <string>
void ExpectingGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput);
void ExpectingNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput);
void ExpectingSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput);
#endif