From dfa46ebaa5946a30feb80042965b846e28437501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Wed, 16 Jun 2021 15:46:14 +0200 Subject: [PATCH] Fix shader generation unit tests --- include/Nazara/Shader/GlslWriter.hpp | 4 +- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 5 +-- src/Nazara/Shader/GlslWriter.cpp | 48 ++++++++++------------- src/Nazara/Shader/ShaderLangParser.cpp | 4 +- src/Nazara/Shader/SpirvWriter.cpp | 17 ++++---- tests/Engine/Shader/AccessMemberTest.cpp | 13 +++--- 6 files changed, 41 insertions(+), 50 deletions(-) diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index c5480be34..988489bac 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -31,8 +31,8 @@ namespace Nz GlslWriter(GlslWriter&&) = delete; ~GlslWriter() = default; - inline std::string Generate(ShaderAst::Statement& shader, const BindingMapping& bindingMapping, const States& states = {}); - std::string Generate(std::optional shaderStage, ShaderAst::Statement& shader, const BindingMapping& bindingMapping, const States& states = {}); + inline std::string Generate(ShaderAst::Statement& shader, const BindingMapping& bindingMapping = {}, const States& states = {}); + std::string Generate(std::optional shaderStage, ShaderAst::Statement& shader, const BindingMapping& bindingMapping = {}, const States& states = {}); void SetEnv(Environment environment); diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index c52117a7c..00be4ace0 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -707,11 +707,8 @@ namespace Nz::ShaderAst if (!extVar.bindingIndex) throw AstError{ "external variable " + extVar.name + " requires a binding index" }; - if (!extVar.bindingSet) - throw AstError{ "external variable " + extVar.name + " requires a binding set" }; - UInt64 bindingIndex = *extVar.bindingIndex; - UInt64 bindingSet = *extVar.bindingSet; + UInt64 bindingSet = extVar.bindingSet.value_or(0); UInt64 bindingKey = bindingSet << 32 | bindingIndex; if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end()) throw AstError{ "Binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" }; diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 668963604..723feaf92 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -52,22 +52,6 @@ namespace Nz node.statement->Visit(*this); } - void Visit(ShaderAst::DeclareExternalStatement& node) override - { - AstRecursiveVisitor::Visit(node); - - for (auto& extVar : node.externalVars) - { - assert(extVar.bindingIndex); - assert(extVar.bindingSet); - - UInt64 set = *extVar.bindingSet; - UInt64 binding = *extVar.bindingIndex; - - bindings.insert(set << 32 | binding); - } - } - void Visit(ShaderAst::DeclareFunctionStatement& node) override { // Dismiss function if it's an entry point of another type than the one selected @@ -113,7 +97,6 @@ namespace Nz FunctionData* currentFunction = nullptr; - std::set bindings; std::optional selectedStage; std::unordered_map functions; ShaderAst::DeclareFunctionStatement* entryPoint = nullptr; @@ -867,21 +850,32 @@ namespace Nz isStd140 = structInfo.layout == StructLayout::Std140; } - assert(externalVar.bindingIndex); - assert(externalVar.bindingSet); + if (!m_currentState->bindingMapping.empty() || isStd140) + Append("layout("); - UInt64 bindingIndex = *externalVar.bindingIndex; - UInt64 bindingSet = *externalVar.bindingSet; + if (!m_currentState->bindingMapping.empty()) + { + assert(externalVar.bindingIndex); - auto bindingIt = m_currentState->bindingMapping.find(bindingSet << 32 | bindingIndex); - if (bindingIt == m_currentState->bindingMapping.end()) - throw std::runtime_error("no binding found for (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ")"); + UInt64 bindingIndex = *externalVar.bindingIndex; + UInt64 bindingSet = externalVar.bindingSet.value_or(0); + + auto bindingIt = m_currentState->bindingMapping.find(bindingSet << 32 | bindingIndex); + if (bindingIt == m_currentState->bindingMapping.end()) + throw std::runtime_error("no binding found for (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ")"); + + Append("binding = ", bindingIt->second); + if (isStd140) + Append(", "); + } - Append("layout(binding = ", bindingIt->second); if (isStd140) - Append(", std140"); + Append("std140"); - Append(") uniform "); + if (!m_currentState->bindingMapping.empty() || isStd140) + Append(") "); + + Append("uniform "); if (IsUniformType(externalVar.type)) { diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 4af04eef1..7e495a646 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -482,8 +482,8 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::Colon); extVar.type = ParseType(); - if (!extVar.bindingSet) - extVar.bindingSet = blockSetIndex.value_or(0); + if (!extVar.bindingSet && blockSetIndex) + extVar.bindingSet = *blockSetIndex; RegisterVariable(extVar.name); } diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index f4532e05e..0dbb15dfc 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -46,8 +46,8 @@ namespace Nz public: struct UniformVar { - std::optional bindingIndex; - std::optional descriptorSet; + UInt32 bindingIndex; + UInt32 descriptorSet; UInt32 pointerId; }; @@ -123,10 +123,12 @@ namespace Nz variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform; variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass); + assert(extVar.bindingIndex); + UniformVar& uniformVar = extVars[varIndex++]; uniformVar.pointerId = m_constantCache.Register(variable); - uniformVar.bindingIndex = extVar.bindingIndex; - uniformVar.descriptorSet = extVar.bindingSet; + uniformVar.bindingIndex = *extVar.bindingIndex; + uniformVar.descriptorSet = extVar.bindingSet.value_or(0); } } @@ -491,11 +493,8 @@ namespace Nz for (auto&& [varIndex, extVar] : preVisitor.extVars) { - if (extVar.bindingIndex) - { - state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, *extVar.bindingIndex); - state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, *extVar.descriptorSet); - } + state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex); + state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet); } for (auto&& [varId, builtin] : preVisitor.builtinDecorations) diff --git a/tests/Engine/Shader/AccessMemberTest.cpp b/tests/Engine/Shader/AccessMemberTest.cpp index 96b7373d0..bdb40fff9 100644 --- a/tests/Engine/Shader/AccessMemberTest.cpp +++ b/tests/Engine/Shader/AccessMemberTest.cpp @@ -7,7 +7,7 @@ #include #include -void ExpectingGLSL(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput) +void ExpectingGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput) { Nz::GlslWriter writer; @@ -19,7 +19,7 @@ void ExpectingGLSL(Nz::ShaderAst::StatementPtr& shader, std::string_view expecte REQUIRE(subset == expectedOutput); } -void ExpectingSpirV(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput) +void ExpectingSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput) { Nz::SpirvWriter writer; auto spirv = writer.Generate(shader); @@ -64,6 +64,7 @@ SCENARIO("Shader generation", "[Shader]") auto external = std::make_unique(); external->externalVars.push_back({ + 0, std::nullopt, "ubo", Nz::ShaderAst::UniformType{ Nz::ShaderAst::IdentifierType{ "outerStruct" } } @@ -85,7 +86,7 @@ SCENARIO("Shader generation", "[Shader]") SECTION("Generating GLSL") { - ExpectingGLSL(shader, R"( + ExpectingGLSL(*shader, R"( void main() { float result = ubo.s.field.z; @@ -94,7 +95,7 @@ void main() } SECTION("Generating Spir-V") { - ExpectingSpirV(shader, R"( + ExpectingSpirV(*shader, R"( OpFunction OpLabel OpVariable @@ -121,7 +122,7 @@ OpFunctionEnd)"); SECTION("Generating GLSL") { - ExpectingGLSL(shader, R"( + ExpectingGLSL(*shader, R"( void main() { float result = ubo.s.field.z; @@ -130,7 +131,7 @@ void main() } SECTION("Generating Spir-V") { - ExpectingSpirV(shader, R"( + ExpectingSpirV(*shader, R"( OpFunction OpLabel OpVariable