Shader/SPIRV: Fix issues with loops containing branches and cross product

This commit is contained in:
Jérôme Leclercq 2022-01-23 19:59:10 +01:00
parent 64efd81bf8
commit 2463e471cc
2 changed files with 53 additions and 30 deletions

View File

@ -152,7 +152,7 @@ namespace Nz
std::unordered_map<std::size_t, ShaderAst::StructDescription*> m_structs; std::unordered_map<std::size_t, ShaderAst::StructDescription*> m_structs;
std::unordered_map<std::size_t, Variable> m_variables; std::unordered_map<std::size_t, Variable> m_variables;
std::vector<std::size_t> m_scopeSizes; std::vector<std::size_t> m_scopeSizes;
std::vector<SpirvBlock> m_functionBlocks; std::vector<std::unique_ptr<SpirvBlock>> m_functionBlocks;
std::vector<UInt32> m_resultIds; std::vector<UInt32> m_resultIds;
SpirvBlock* m_currentBlock; SpirvBlock* m_currentBlock;
SpirvSection& m_instructions; SpirvSection& m_instructions;

View File

@ -368,34 +368,34 @@ namespace Nz
assert(node.condStatements.size() == 1); //< sanitization splits multiple branches assert(node.condStatements.size() == 1); //< sanitization splits multiple branches
auto& condStatement = node.condStatements.front(); auto& condStatement = node.condStatements.front();
SpirvBlock mergeBlock(m_writer); auto mergeBlock = std::make_unique<SpirvBlock>(m_writer);
SpirvBlock contentBlock(m_writer); auto contentBlock = std::make_unique<SpirvBlock>(m_writer);
SpirvBlock elseBlock(m_writer); auto elseBlock = std::make_unique<SpirvBlock>(m_writer);
UInt32 conditionId = EvaluateExpression(condStatement.condition); UInt32 conditionId = EvaluateExpression(condStatement.condition);
m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None); m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock->GetLabelId(), SpirvSelectionControl::None);
// FIXME: Can we use merge block directly in OpBranchConditional if no else statement? // FIXME: Can we use merge block directly in OpBranchConditional if no else statement?
m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock.GetLabelId(), elseBlock.GetLabelId()); m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock->GetLabelId(), elseBlock->GetLabelId());
m_functionBlocks.emplace_back(std::move(contentBlock)); m_functionBlocks.emplace_back(std::move(contentBlock));
m_currentBlock = &m_functionBlocks.back(); m_currentBlock = m_functionBlocks.back().get();
condStatement.statement->Visit(*this); condStatement.statement->Visit(*this);
if (!m_currentBlock->IsTerminated()) if (!m_currentBlock->IsTerminated())
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock->GetLabelId());
m_functionBlocks.emplace_back(std::move(elseBlock)); m_functionBlocks.emplace_back(std::move(elseBlock));
m_currentBlock = &m_functionBlocks.back(); m_currentBlock = m_functionBlocks.back().get();
if (node.elseStatement) if (node.elseStatement)
node.elseStatement->Visit(*this); node.elseStatement->Visit(*this);
if (!m_currentBlock->IsTerminated()) if (!m_currentBlock->IsTerminated())
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock->GetLabelId());
m_functionBlocks.emplace_back(std::move(mergeBlock)); m_functionBlocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_functionBlocks.back(); m_currentBlock = m_functionBlocks.back().get();
} }
void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node) void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node)
@ -609,9 +609,12 @@ namespace Nz
} }
} }
m_functionBlocks.clear(); auto contentBlock = std::make_unique<SpirvBlock>(m_writer);
m_currentBlock = contentBlock.get();
m_functionBlocks.clear();
m_functionBlocks.emplace_back(std::move(contentBlock));
m_currentBlock = &m_functionBlocks.emplace_back(m_writer);
CallOnExit resetCurrentBlock([&] { m_currentBlock = nullptr; }); CallOnExit resetCurrentBlock([&] { m_currentBlock = nullptr; });
for (auto& var : func.variables) for (auto& var : func.variables)
@ -647,11 +650,11 @@ namespace Nz
statementPtr->Visit(*this); statementPtr->Visit(*this);
// Add implicit return // Add implicit return
if (!m_functionBlocks.back().IsTerminated()) if (!m_functionBlocks.back()->IsTerminated())
m_functionBlocks.back().Append(SpirvOp::OpReturn); m_functionBlocks.back()->Append(SpirvOp::OpReturn);
for (SpirvBlock& block : m_functionBlocks) for (std::unique_ptr<SpirvBlock>& block : m_functionBlocks)
m_instructions.AppendSection(block); m_instructions.AppendSection(*block);
m_instructions.Append(SpirvOp::OpFunctionEnd); m_instructions.Append(SpirvOp::OpFunctionEnd);
} }
@ -702,6 +705,23 @@ namespace Nz
{ {
switch (node.intrinsic) switch (node.intrinsic)
{ {
case ShaderAst::IntrinsicType::CrossProduct:
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(parameterType));
UInt32 typeId = m_writer.GetTypeId(parameterType);
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpExtInst, typeId, resultId, glslInstructionSet, GLSLstd450Cross, firstParam, secondParam);
PushResultId(resultId);
break;
}
case ShaderAst::IntrinsicType::DotProduct: case ShaderAst::IntrinsicType::DotProduct:
{ {
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
@ -867,7 +887,6 @@ namespace Nz
break; break;
} }
case ShaderAst::IntrinsicType::CrossProduct:
default: default:
throw std::runtime_error("not yet implemented"); throw std::runtime_error("not yet implemented");
} }
@ -1057,12 +1076,12 @@ namespace Nz
assert(node.condition); assert(node.condition);
assert(node.body); assert(node.body);
SpirvBlock headerBlock(m_writer); auto headerBlock = std::make_unique<SpirvBlock>(m_writer);
SpirvBlock bodyBlock(m_writer); auto bodyBlock = std::make_unique<SpirvBlock>(m_writer);
SpirvBlock mergeBlock(m_writer); auto mergeBlock = std::make_unique<SpirvBlock>(m_writer);
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId()); m_currentBlock->Append(SpirvOp::OpBranch, headerBlock->GetLabelId());
m_currentBlock = &headerBlock; m_currentBlock = headerBlock.get();
UInt32 expressionId = EvaluateExpression(node.condition); UInt32 expressionId = EvaluateExpression(node.condition);
@ -1087,18 +1106,22 @@ namespace Nz
else else
loopControl = SpirvLoopControl::None; loopControl = SpirvLoopControl::None;
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), loopControl); m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock->GetLabelId(), bodyBlock->GetLabelId(), loopControl);
m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId()); m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock->GetLabelId(), mergeBlock->GetLabelId());
m_currentBlock = &bodyBlock; UInt32 headerLabelId = headerBlock->GetLabelId();
node.body->Visit(*this);
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId());
m_currentBlock = bodyBlock.get();
m_functionBlocks.emplace_back(std::move(headerBlock)); m_functionBlocks.emplace_back(std::move(headerBlock));
m_functionBlocks.emplace_back(std::move(bodyBlock)); m_functionBlocks.emplace_back(std::move(bodyBlock));
node.body->Visit(*this);
// Jump back to header block to test condition
m_currentBlock->Append(SpirvOp::OpBranch, headerLabelId);
m_functionBlocks.emplace_back(std::move(mergeBlock)); m_functionBlocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_functionBlocks.back(); m_currentBlock = m_functionBlocks.back().get();
} }
void SpirvAstVisitor::PushResultId(UInt32 value) void SpirvAstVisitor::PushResultId(UInt32 value)