Shader/SPIRV: Fix multiple branch handling (by splitting)

This commit is contained in:
Jérôme Leclercq 2021-12-28 11:48:19 +01:00
parent e21b45946f
commit a0f66d9e88
5 changed files with 168 additions and 55 deletions

View File

@ -43,6 +43,7 @@ namespace Nz::ShaderAst
bool removeCompoundAssignments = false;
bool removeOptionDeclaration = true;
bool removeScalarSwizzling = false;
bool splitMultipleBranches = false;
};
private:

View File

@ -531,18 +531,35 @@ namespace Nz::ShaderAst
if (!m_context->currentFunction)
throw AstError{ "non-const branching statements can only exist inside a function" };
for (auto& cond : node.condStatements)
BranchStatement* root = clone.get();
for (std::size_t condIndex = 0; condIndex < node.condStatements.size(); ++condIndex)
{
auto& cond = node.condStatements[condIndex];
PushScope();
auto& condStatement = clone->condStatements.emplace_back();
condStatement.condition = CloneExpression(MandatoryExpr(cond.condition));
auto BuildCondStatement = [&](BranchStatement::ConditionalStatement& condStatement)
{
condStatement.condition = CloneExpression(MandatoryExpr(cond.condition));
const ExpressionType& condType = GetExpressionType(*condStatement.condition);
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
throw AstError{ "branch expressions must resolve to boolean type" };
const ExpressionType& condType = GetExpressionType(*condStatement.condition);
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
throw AstError{ "branch expressions must resolve to boolean type" };
condStatement.statement = CloneStatement(MandatoryStatement(cond.statement));
condStatement.statement = CloneStatement(MandatoryStatement(cond.statement));
};
if (m_context->options.splitMultipleBranches && condIndex > 0)
{
auto currentBranch = std::make_unique<BranchStatement>();
BuildCondStatement(currentBranch->condStatements.emplace_back());
root->elseStatement = std::move(currentBranch);
root = static_cast<BranchStatement*>(root->elseStatement.get());
}
else
BuildCondStatement(clone->condStatements.emplace_back());
PopScope();
}
@ -550,7 +567,7 @@ namespace Nz::ShaderAst
if (node.elseStatement)
{
PushScope();
clone->elseStatement = CloneStatement(node.elseStatement);
root->elseStatement = CloneStatement(node.elseStatement);
PopScope();
}

View File

@ -365,66 +365,36 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::BranchStatement& node)
{
assert(!node.condStatements.empty());
auto& firstCond = node.condStatements.front();
UInt32 previousConditionId = EvaluateExpression(firstCond.condition);
SpirvBlock previousContentBlock(m_writer);
m_currentBlock = &previousContentBlock;
firstCond.statement->Visit(*this);
assert(node.condStatements.size() == 1); //< sanitization splits multiple branches
auto& condStatement = node.condStatements.front();
SpirvBlock mergeBlock(m_writer);
SpirvBlock contentBlock(m_writer);
SpirvBlock elseBlock(m_writer);
if (!previousContentBlock.IsTerminated())
previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
UInt32 conditionId = EvaluateExpression(condStatement.condition);
m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
// FIXME: Can we use merge block directly in OpBranchConditional if no else statement?
m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock.GetLabelId(), elseBlock.GetLabelId());
m_functionBlocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
m_functionBlocks.emplace_back(std::move(contentBlock));
m_currentBlock = &m_functionBlocks.back();
std::optional<std::size_t> nextBlock;
for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex)
{
auto& statement = node.condStatements[statementIndex];
condStatement.statement->Visit(*this);
SpirvBlock contentBlock(m_writer);
if (!m_currentBlock->IsTerminated())
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId());
previousConditionId = EvaluateExpression(statement.condition);
m_functionBlocks.emplace_back(std::move(previousContentBlock));
previousContentBlock = std::move(contentBlock);
m_currentBlock = &previousContentBlock;
statement.statement->Visit(*this);
if (!previousContentBlock.IsTerminated())
previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
}
m_functionBlocks.emplace_back(std::move(elseBlock));
m_currentBlock = &m_functionBlocks.back();
if (node.elseStatement)
{
SpirvBlock elseBlock(m_writer);
m_currentBlock = &elseBlock;
node.elseStatement->Visit(*this);
if (!elseBlock.IsTerminated())
elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(previousContentBlock));
m_functionBlocks.emplace_back(std::move(elseBlock));
}
else
{
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(previousContentBlock));
}
if (!m_currentBlock->IsTerminated())
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_functionBlocks.back();
}

View File

@ -488,6 +488,7 @@ namespace Nz
ShaderAst::SanitizeVisitor::Options options;
options.optionValues = states.optionValues;
options.removeCompoundAssignments = true;
options.splitMultipleBranches = true;
sanitizedAst = ShaderAst::Sanitize(shader, options);
targetAst = sanitizedAst.get();

View File

@ -144,6 +144,130 @@ OpBranchConditional
OpLabel
OpKill
OpLabel
OpBranch
OpLabel
OpReturn
OpFunctionEnd)");
}
WHEN("using a complex branch")
{
std::string_view nzslSource = R"(
struct inputStruct
{
value: f32
}
external
{
[set(0), binding(0)] data: uniform<inputStruct>
}
[entry(frag)]
fn main()
{
let value: f32;
if (data.value > 3.0)
value = 3.0;
else if (data.value > 2.0)
value = 2.0;
else if (data.value > 1.0)
value = 1.0;
else
value = 0.0;
}
)";
Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource);
ExpectGLSL(*shader, R"(
void main()
{
float value;
if (data.value > (3.000000))
{
value = 3.000000;
}
else if (data.value > (2.000000))
{
value = 2.000000;
}
else if (data.value > (1.000000))
{
value = 1.000000;
}
else
{
value = 0.000000;
}
}
)");
ExpectNZSL(*shader, R"(
[entry(frag)]
fn main()
{
let value: f32;
if (data.value > (3.000000))
{
value = 3.000000;
}
else if (data.value > (2.000000))
{
value = 2.000000;
}
else if (data.value > (1.000000))
{
value = 1.000000;
}
else
{
value = 0.000000;
}
}
)");
ExpectSpirV(*shader, R"(
OpFunction
OpLabel
OpVariable
OpAccessChain
OpLoad
OpFOrdGreaterThanEqual
OpSelectionMerge
OpBranchConditional
OpLabel
OpStore
OpBranch
OpLabel
OpAccessChain
OpLoad
OpFOrdGreaterThanEqual
OpSelectionMerge
OpBranchConditional
OpLabel
OpStore
OpBranch
OpLabel
OpAccessChain
OpLoad
OpFOrdGreaterThanEqual
OpSelectionMerge
OpBranchConditional
OpLabel
OpStore
OpBranch
OpLabel
OpStore
OpBranch
OpLabel
OpBranch
OpLabel
OpBranch
OpLabel
OpReturn
OpFunctionEnd)");
}