Shader/SPIRV: Fix multiple branch handling (by splitting)

This commit is contained in:
Jérôme Leclercq
2021-12-28 11:48:19 +01:00
parent e21b45946f
commit a0f66d9e88
5 changed files with 168 additions and 55 deletions

View File

@@ -531,18 +531,35 @@ namespace Nz::ShaderAst
if (!m_context->currentFunction)
throw AstError{ "non-const branching statements can only exist inside a function" };
for (auto& cond : node.condStatements)
BranchStatement* root = clone.get();
for (std::size_t condIndex = 0; condIndex < node.condStatements.size(); ++condIndex)
{
auto& cond = node.condStatements[condIndex];
PushScope();
auto& condStatement = clone->condStatements.emplace_back();
condStatement.condition = CloneExpression(MandatoryExpr(cond.condition));
auto BuildCondStatement = [&](BranchStatement::ConditionalStatement& condStatement)
{
condStatement.condition = CloneExpression(MandatoryExpr(cond.condition));
const ExpressionType& condType = GetExpressionType(*condStatement.condition);
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
throw AstError{ "branch expressions must resolve to boolean type" };
const ExpressionType& condType = GetExpressionType(*condStatement.condition);
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
throw AstError{ "branch expressions must resolve to boolean type" };
condStatement.statement = CloneStatement(MandatoryStatement(cond.statement));
condStatement.statement = CloneStatement(MandatoryStatement(cond.statement));
};
if (m_context->options.splitMultipleBranches && condIndex > 0)
{
auto currentBranch = std::make_unique<BranchStatement>();
BuildCondStatement(currentBranch->condStatements.emplace_back());
root->elseStatement = std::move(currentBranch);
root = static_cast<BranchStatement*>(root->elseStatement.get());
}
else
BuildCondStatement(clone->condStatements.emplace_back());
PopScope();
}
@@ -550,7 +567,7 @@ namespace Nz::ShaderAst
if (node.elseStatement)
{
PushScope();
clone->elseStatement = CloneStatement(node.elseStatement);
root->elseStatement = CloneStatement(node.elseStatement);
PopScope();
}

View File

@@ -365,66 +365,36 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::BranchStatement& node)
{
assert(!node.condStatements.empty());
auto& firstCond = node.condStatements.front();
UInt32 previousConditionId = EvaluateExpression(firstCond.condition);
SpirvBlock previousContentBlock(m_writer);
m_currentBlock = &previousContentBlock;
firstCond.statement->Visit(*this);
assert(node.condStatements.size() == 1); //< sanitization splits multiple branches
auto& condStatement = node.condStatements.front();
SpirvBlock mergeBlock(m_writer);
SpirvBlock contentBlock(m_writer);
SpirvBlock elseBlock(m_writer);
if (!previousContentBlock.IsTerminated())
previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
UInt32 conditionId = EvaluateExpression(condStatement.condition);
m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
// FIXME: Can we use merge block directly in OpBranchConditional if no else statement?
m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock.GetLabelId(), elseBlock.GetLabelId());
m_functionBlocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
m_functionBlocks.emplace_back(std::move(contentBlock));
m_currentBlock = &m_functionBlocks.back();
std::optional<std::size_t> nextBlock;
for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex)
{
auto& statement = node.condStatements[statementIndex];
condStatement.statement->Visit(*this);
SpirvBlock contentBlock(m_writer);
if (!m_currentBlock->IsTerminated())
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId());
previousConditionId = EvaluateExpression(statement.condition);
m_functionBlocks.emplace_back(std::move(previousContentBlock));
previousContentBlock = std::move(contentBlock);
m_currentBlock = &previousContentBlock;
statement.statement->Visit(*this);
if (!previousContentBlock.IsTerminated())
previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
}
m_functionBlocks.emplace_back(std::move(elseBlock));
m_currentBlock = &m_functionBlocks.back();
if (node.elseStatement)
{
SpirvBlock elseBlock(m_writer);
m_currentBlock = &elseBlock;
node.elseStatement->Visit(*this);
if (!elseBlock.IsTerminated())
elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(previousContentBlock));
m_functionBlocks.emplace_back(std::move(elseBlock));
}
else
{
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(previousContentBlock));
}
if (!m_currentBlock->IsTerminated())
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_functionBlocks.back();
}

View File

@@ -488,6 +488,7 @@ namespace Nz
ShaderAst::SanitizeVisitor::Options options;
options.optionValues = states.optionValues;
options.removeCompoundAssignments = true;
options.splitMultipleBranches = true;
sanitizedAst = ShaderAst::Sanitize(shader, options);
targetAst = sanitizedAst.get();