Shader/SPIRV: Fix issues with loops containing branches and cross product

This commit is contained in:
Jérôme Leclercq 2022-01-23 19:59:10 +01:00
parent 64efd81bf8
commit 2463e471cc
2 changed files with 53 additions and 30 deletions

View File

@ -152,7 +152,7 @@ namespace Nz
std::unordered_map<std::size_t, ShaderAst::StructDescription*> m_structs;
std::unordered_map<std::size_t, Variable> m_variables;
std::vector<std::size_t> m_scopeSizes;
std::vector<SpirvBlock> m_functionBlocks;
std::vector<std::unique_ptr<SpirvBlock>> m_functionBlocks;
std::vector<UInt32> m_resultIds;
SpirvBlock* m_currentBlock;
SpirvSection& m_instructions;

View File

@ -368,34 +368,34 @@ namespace Nz
assert(node.condStatements.size() == 1); //< sanitization splits multiple branches
auto& condStatement = node.condStatements.front();
SpirvBlock mergeBlock(m_writer);
SpirvBlock contentBlock(m_writer);
SpirvBlock elseBlock(m_writer);
auto mergeBlock = std::make_unique<SpirvBlock>(m_writer);
auto contentBlock = std::make_unique<SpirvBlock>(m_writer);
auto elseBlock = std::make_unique<SpirvBlock>(m_writer);
UInt32 conditionId = EvaluateExpression(condStatement.condition);
m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock->GetLabelId(), SpirvSelectionControl::None);
// FIXME: Can we use merge block directly in OpBranchConditional if no else statement?
m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock.GetLabelId(), elseBlock.GetLabelId());
m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock->GetLabelId(), elseBlock->GetLabelId());
m_functionBlocks.emplace_back(std::move(contentBlock));
m_currentBlock = &m_functionBlocks.back();
m_currentBlock = m_functionBlocks.back().get();
condStatement.statement->Visit(*this);
if (!m_currentBlock->IsTerminated())
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock->GetLabelId());
m_functionBlocks.emplace_back(std::move(elseBlock));
m_currentBlock = &m_functionBlocks.back();
m_currentBlock = m_functionBlocks.back().get();
if (node.elseStatement)
node.elseStatement->Visit(*this);
if (!m_currentBlock->IsTerminated())
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock->GetLabelId());
m_functionBlocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_functionBlocks.back();
m_currentBlock = m_functionBlocks.back().get();
}
void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node)
@ -609,9 +609,12 @@ namespace Nz
}
}
m_functionBlocks.clear();
auto contentBlock = std::make_unique<SpirvBlock>(m_writer);
m_currentBlock = contentBlock.get();
m_functionBlocks.clear();
m_functionBlocks.emplace_back(std::move(contentBlock));
m_currentBlock = &m_functionBlocks.emplace_back(m_writer);
CallOnExit resetCurrentBlock([&] { m_currentBlock = nullptr; });
for (auto& var : func.variables)
@ -647,11 +650,11 @@ namespace Nz
statementPtr->Visit(*this);
// Add implicit return
if (!m_functionBlocks.back().IsTerminated())
m_functionBlocks.back().Append(SpirvOp::OpReturn);
if (!m_functionBlocks.back()->IsTerminated())
m_functionBlocks.back()->Append(SpirvOp::OpReturn);
for (SpirvBlock& block : m_functionBlocks)
m_instructions.AppendSection(block);
for (std::unique_ptr<SpirvBlock>& block : m_functionBlocks)
m_instructions.AppendSection(*block);
m_instructions.Append(SpirvOp::OpFunctionEnd);
}
@ -702,6 +705,23 @@ namespace Nz
{
switch (node.intrinsic)
{
case ShaderAst::IntrinsicType::CrossProduct:
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(parameterType));
UInt32 typeId = m_writer.GetTypeId(parameterType);
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpExtInst, typeId, resultId, glslInstructionSet, GLSLstd450Cross, firstParam, secondParam);
PushResultId(resultId);
break;
}
case ShaderAst::IntrinsicType::DotProduct:
{
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
@ -867,7 +887,6 @@ namespace Nz
break;
}
case ShaderAst::IntrinsicType::CrossProduct:
default:
throw std::runtime_error("not yet implemented");
}
@ -1057,12 +1076,12 @@ namespace Nz
assert(node.condition);
assert(node.body);
SpirvBlock headerBlock(m_writer);
SpirvBlock bodyBlock(m_writer);
SpirvBlock mergeBlock(m_writer);
auto headerBlock = std::make_unique<SpirvBlock>(m_writer);
auto bodyBlock = std::make_unique<SpirvBlock>(m_writer);
auto mergeBlock = std::make_unique<SpirvBlock>(m_writer);
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId());
m_currentBlock = &headerBlock;
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock->GetLabelId());
m_currentBlock = headerBlock.get();
UInt32 expressionId = EvaluateExpression(node.condition);
@ -1087,18 +1106,22 @@ namespace Nz
else
loopControl = SpirvLoopControl::None;
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), loopControl);
m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId());
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock->GetLabelId(), bodyBlock->GetLabelId(), loopControl);
m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock->GetLabelId(), mergeBlock->GetLabelId());
m_currentBlock = &bodyBlock;
node.body->Visit(*this);
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId());
UInt32 headerLabelId = headerBlock->GetLabelId();
m_currentBlock = bodyBlock.get();
m_functionBlocks.emplace_back(std::move(headerBlock));
m_functionBlocks.emplace_back(std::move(bodyBlock));
node.body->Visit(*this);
// Jump back to header block to test condition
m_currentBlock->Append(SpirvOp::OpBranch, headerLabelId);
m_functionBlocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_functionBlocks.back();
m_currentBlock = m_functionBlocks.back().get();
}
void SpirvAstVisitor::PushResultId(UInt32 value)