From 2463e471cc6235b8eeb447353c2f936b5700d283 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Sun, 23 Jan 2022 19:59:10 +0100 Subject: [PATCH] Shader/SPIRV: Fix issues with loops containing branches and cross product --- include/Nazara/Shader/SpirvAstVisitor.hpp | 2 +- src/Nazara/Shader/SpirvAstVisitor.cpp | 81 +++++++++++++++-------- 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index 297160404..28e98b157 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -152,7 +152,7 @@ namespace Nz std::unordered_map m_structs; std::unordered_map m_variables; std::vector m_scopeSizes; - std::vector m_functionBlocks; + std::vector> m_functionBlocks; std::vector m_resultIds; SpirvBlock* m_currentBlock; SpirvSection& m_instructions; diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 3a187f1fb..224b223d7 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -368,34 +368,34 @@ namespace Nz assert(node.condStatements.size() == 1); //< sanitization splits multiple branches auto& condStatement = node.condStatements.front(); - SpirvBlock mergeBlock(m_writer); - SpirvBlock contentBlock(m_writer); - SpirvBlock elseBlock(m_writer); + auto mergeBlock = std::make_unique(m_writer); + auto contentBlock = std::make_unique(m_writer); + auto elseBlock = std::make_unique(m_writer); UInt32 conditionId = EvaluateExpression(condStatement.condition); - m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None); + m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock->GetLabelId(), SpirvSelectionControl::None); // FIXME: Can we use merge block directly in OpBranchConditional if no else statement? - m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock.GetLabelId(), elseBlock.GetLabelId()); + m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock->GetLabelId(), elseBlock->GetLabelId()); m_functionBlocks.emplace_back(std::move(contentBlock)); - m_currentBlock = &m_functionBlocks.back(); + m_currentBlock = m_functionBlocks.back().get(); condStatement.statement->Visit(*this); if (!m_currentBlock->IsTerminated()) - m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); + m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock->GetLabelId()); m_functionBlocks.emplace_back(std::move(elseBlock)); - m_currentBlock = &m_functionBlocks.back(); + m_currentBlock = m_functionBlocks.back().get(); if (node.elseStatement) node.elseStatement->Visit(*this); if (!m_currentBlock->IsTerminated()) - m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); + m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock->GetLabelId()); m_functionBlocks.emplace_back(std::move(mergeBlock)); - m_currentBlock = &m_functionBlocks.back(); + m_currentBlock = m_functionBlocks.back().get(); } void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node) @@ -609,9 +609,12 @@ namespace Nz } } - m_functionBlocks.clear(); + auto contentBlock = std::make_unique(m_writer); + m_currentBlock = contentBlock.get(); + + m_functionBlocks.clear(); + m_functionBlocks.emplace_back(std::move(contentBlock)); - m_currentBlock = &m_functionBlocks.emplace_back(m_writer); CallOnExit resetCurrentBlock([&] { m_currentBlock = nullptr; }); for (auto& var : func.variables) @@ -647,11 +650,11 @@ namespace Nz statementPtr->Visit(*this); // Add implicit return - if (!m_functionBlocks.back().IsTerminated()) - m_functionBlocks.back().Append(SpirvOp::OpReturn); + if (!m_functionBlocks.back()->IsTerminated()) + m_functionBlocks.back()->Append(SpirvOp::OpReturn); - for (SpirvBlock& block : m_functionBlocks) - m_instructions.AppendSection(block); + for (std::unique_ptr& block : m_functionBlocks) + m_instructions.AppendSection(*block); m_instructions.Append(SpirvOp::OpFunctionEnd); } @@ -702,6 +705,23 @@ namespace Nz { switch (node.intrinsic) { + case ShaderAst::IntrinsicType::CrossProduct: + { + UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); + + const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); + assert(IsVectorType(parameterType)); + UInt32 typeId = m_writer.GetTypeId(parameterType); + + UInt32 firstParam = EvaluateExpression(node.parameters[0]); + UInt32 secondParam = EvaluateExpression(node.parameters[1]); + UInt32 resultId = m_writer.AllocateResultId(); + + m_currentBlock->Append(SpirvOp::OpExtInst, typeId, resultId, glslInstructionSet, GLSLstd450Cross, firstParam, secondParam); + PushResultId(resultId); + break; + } + case ShaderAst::IntrinsicType::DotProduct: { const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); @@ -867,7 +887,6 @@ namespace Nz break; } - case ShaderAst::IntrinsicType::CrossProduct: default: throw std::runtime_error("not yet implemented"); } @@ -1057,12 +1076,12 @@ namespace Nz assert(node.condition); assert(node.body); - SpirvBlock headerBlock(m_writer); - SpirvBlock bodyBlock(m_writer); - SpirvBlock mergeBlock(m_writer); + auto headerBlock = std::make_unique(m_writer); + auto bodyBlock = std::make_unique(m_writer); + auto mergeBlock = std::make_unique(m_writer); - m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId()); - m_currentBlock = &headerBlock; + m_currentBlock->Append(SpirvOp::OpBranch, headerBlock->GetLabelId()); + m_currentBlock = headerBlock.get(); UInt32 expressionId = EvaluateExpression(node.condition); @@ -1087,18 +1106,22 @@ namespace Nz else loopControl = SpirvLoopControl::None; - m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), loopControl); - m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId()); + m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock->GetLabelId(), bodyBlock->GetLabelId(), loopControl); + m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock->GetLabelId(), mergeBlock->GetLabelId()); - m_currentBlock = &bodyBlock; - node.body->Visit(*this); - - m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId()); + UInt32 headerLabelId = headerBlock->GetLabelId(); + m_currentBlock = bodyBlock.get(); m_functionBlocks.emplace_back(std::move(headerBlock)); m_functionBlocks.emplace_back(std::move(bodyBlock)); + + node.body->Visit(*this); + + // Jump back to header block to test condition + m_currentBlock->Append(SpirvOp::OpBranch, headerLabelId); + m_functionBlocks.emplace_back(std::move(mergeBlock)); - m_currentBlock = &m_functionBlocks.back(); + m_currentBlock = m_functionBlocks.back().get(); } void SpirvAstVisitor::PushResultId(UInt32 value)