From a0f66d9e88eac8b8cc97649ac863f7cfae13681f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Tue, 28 Dec 2021 11:48:19 +0100 Subject: [PATCH] Shader/SPIRV: Fix multiple branch handling (by splitting) --- include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 1 + src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 33 +++-- src/Nazara/Shader/SpirvAstVisitor.cpp | 64 +++------ src/Nazara/Shader/SpirvWriter.cpp | 1 + tests/Engine/Shader/Branch.cpp | 124 ++++++++++++++++++ 5 files changed, 168 insertions(+), 55 deletions(-) diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 9c36388b6..bedf14ca6 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -43,6 +43,7 @@ namespace Nz::ShaderAst bool removeCompoundAssignments = false; bool removeOptionDeclaration = true; bool removeScalarSwizzling = false; + bool splitMultipleBranches = false; }; private: diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 99c1e11c3..75016981c 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -531,18 +531,35 @@ namespace Nz::ShaderAst if (!m_context->currentFunction) throw AstError{ "non-const branching statements can only exist inside a function" }; - for (auto& cond : node.condStatements) + BranchStatement* root = clone.get(); + for (std::size_t condIndex = 0; condIndex < node.condStatements.size(); ++condIndex) { + auto& cond = node.condStatements[condIndex]; + PushScope(); - auto& condStatement = clone->condStatements.emplace_back(); - condStatement.condition = CloneExpression(MandatoryExpr(cond.condition)); + auto BuildCondStatement = [&](BranchStatement::ConditionalStatement& condStatement) + { + condStatement.condition = CloneExpression(MandatoryExpr(cond.condition)); - const ExpressionType& condType = GetExpressionType(*condStatement.condition); - if (!IsPrimitiveType(condType) || std::get(condType) != PrimitiveType::Boolean) - throw AstError{ "branch expressions must resolve to boolean type" }; + const ExpressionType& condType = GetExpressionType(*condStatement.condition); + if (!IsPrimitiveType(condType) || std::get(condType) != PrimitiveType::Boolean) + throw AstError{ "branch expressions must resolve to boolean type" }; - condStatement.statement = CloneStatement(MandatoryStatement(cond.statement)); + condStatement.statement = CloneStatement(MandatoryStatement(cond.statement)); + }; + + if (m_context->options.splitMultipleBranches && condIndex > 0) + { + auto currentBranch = std::make_unique(); + + BuildCondStatement(currentBranch->condStatements.emplace_back()); + + root->elseStatement = std::move(currentBranch); + root = static_cast(root->elseStatement.get()); + } + else + BuildCondStatement(clone->condStatements.emplace_back()); PopScope(); } @@ -550,7 +567,7 @@ namespace Nz::ShaderAst if (node.elseStatement) { PushScope(); - clone->elseStatement = CloneStatement(node.elseStatement); + root->elseStatement = CloneStatement(node.elseStatement); PopScope(); } diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 1346e15c0..06c35bf05 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -365,66 +365,36 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::BranchStatement& node) { - assert(!node.condStatements.empty()); - auto& firstCond = node.condStatements.front(); - - UInt32 previousConditionId = EvaluateExpression(firstCond.condition); - SpirvBlock previousContentBlock(m_writer); - m_currentBlock = &previousContentBlock; - - firstCond.statement->Visit(*this); + assert(node.condStatements.size() == 1); //< sanitization splits multiple branches + auto& condStatement = node.condStatements.front(); SpirvBlock mergeBlock(m_writer); + SpirvBlock contentBlock(m_writer); + SpirvBlock elseBlock(m_writer); - if (!previousContentBlock.IsTerminated()) - previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); + UInt32 conditionId = EvaluateExpression(condStatement.condition); + m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None); + // FIXME: Can we use merge block directly in OpBranchConditional if no else statement? + m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock.GetLabelId(), elseBlock.GetLabelId()); - m_functionBlocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None); + m_functionBlocks.emplace_back(std::move(contentBlock)); + m_currentBlock = &m_functionBlocks.back(); - std::optional nextBlock; - for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex) - { - auto& statement = node.condStatements[statementIndex]; + condStatement.statement->Visit(*this); - SpirvBlock contentBlock(m_writer); + if (!m_currentBlock->IsTerminated()) + m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); - m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId()); - - previousConditionId = EvaluateExpression(statement.condition); - m_functionBlocks.emplace_back(std::move(previousContentBlock)); - previousContentBlock = std::move(contentBlock); - - m_currentBlock = &previousContentBlock; - - statement.statement->Visit(*this); - - if (!previousContentBlock.IsTerminated()) - previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); - } + m_functionBlocks.emplace_back(std::move(elseBlock)); + m_currentBlock = &m_functionBlocks.back(); if (node.elseStatement) - { - SpirvBlock elseBlock(m_writer); - - m_currentBlock = &elseBlock; - node.elseStatement->Visit(*this); - if (!elseBlock.IsTerminated()) - elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); - - m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId()); - m_functionBlocks.emplace_back(std::move(previousContentBlock)); - m_functionBlocks.emplace_back(std::move(elseBlock)); - } - else - { - m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId()); - m_functionBlocks.emplace_back(std::move(previousContentBlock)); - } + if (!m_currentBlock->IsTerminated()) + m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); m_functionBlocks.emplace_back(std::move(mergeBlock)); - m_currentBlock = &m_functionBlocks.back(); } diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index fff1cbfb8..33b4a42bd 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -488,6 +488,7 @@ namespace Nz ShaderAst::SanitizeVisitor::Options options; options.optionValues = states.optionValues; options.removeCompoundAssignments = true; + options.splitMultipleBranches = true; sanitizedAst = ShaderAst::Sanitize(shader, options); targetAst = sanitizedAst.get(); diff --git a/tests/Engine/Shader/Branch.cpp b/tests/Engine/Shader/Branch.cpp index 269fb3ba9..85044c8c3 100644 --- a/tests/Engine/Shader/Branch.cpp +++ b/tests/Engine/Shader/Branch.cpp @@ -144,6 +144,130 @@ OpBranchConditional OpLabel OpKill OpLabel +OpBranch +OpLabel +OpReturn +OpFunctionEnd)"); + } + + + WHEN("using a complex branch") + { + std::string_view nzslSource = R"( +struct inputStruct +{ + value: f32 +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + let value: f32; + if (data.value > 3.0) + value = 3.0; + else if (data.value > 2.0) + value = 2.0; + else if (data.value > 1.0) + value = 1.0; + else + value = 0.0; +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + float value; + if (data.value > (3.000000)) + { + value = 3.000000; + } + else if (data.value > (2.000000)) + { + value = 2.000000; + } + else if (data.value > (1.000000)) + { + value = 1.000000; + } + else + { + value = 0.000000; + } + +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let value: f32; + if (data.value > (3.000000)) + { + value = 3.000000; + } + else if (data.value > (2.000000)) + { + value = 2.000000; + } + else if (data.value > (1.000000)) + { + value = 1.000000; + } + else + { + value = 0.000000; + } + +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpAccessChain +OpLoad +OpFOrdGreaterThanEqual +OpSelectionMerge +OpBranchConditional +OpLabel +OpStore +OpBranch +OpLabel +OpAccessChain +OpLoad +OpFOrdGreaterThanEqual +OpSelectionMerge +OpBranchConditional +OpLabel +OpStore +OpBranch +OpLabel +OpAccessChain +OpLoad +OpFOrdGreaterThanEqual +OpSelectionMerge +OpBranchConditional +OpLabel +OpStore +OpBranch +OpLabel +OpStore +OpBranch +OpLabel +OpBranch +OpLabel +OpBranch +OpLabel OpReturn OpFunctionEnd)"); }