Shader/SPIRV: Fix multiple branch handling (by splitting)
This commit is contained in:
parent
e21b45946f
commit
a0f66d9e88
|
|
@ -43,6 +43,7 @@ namespace Nz::ShaderAst
|
|||
bool removeCompoundAssignments = false;
|
||||
bool removeOptionDeclaration = true;
|
||||
bool removeScalarSwizzling = false;
|
||||
bool splitMultipleBranches = false;
|
||||
};
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -531,18 +531,35 @@ namespace Nz::ShaderAst
|
|||
if (!m_context->currentFunction)
|
||||
throw AstError{ "non-const branching statements can only exist inside a function" };
|
||||
|
||||
for (auto& cond : node.condStatements)
|
||||
BranchStatement* root = clone.get();
|
||||
for (std::size_t condIndex = 0; condIndex < node.condStatements.size(); ++condIndex)
|
||||
{
|
||||
auto& cond = node.condStatements[condIndex];
|
||||
|
||||
PushScope();
|
||||
|
||||
auto& condStatement = clone->condStatements.emplace_back();
|
||||
condStatement.condition = CloneExpression(MandatoryExpr(cond.condition));
|
||||
auto BuildCondStatement = [&](BranchStatement::ConditionalStatement& condStatement)
|
||||
{
|
||||
condStatement.condition = CloneExpression(MandatoryExpr(cond.condition));
|
||||
|
||||
const ExpressionType& condType = GetExpressionType(*condStatement.condition);
|
||||
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
|
||||
throw AstError{ "branch expressions must resolve to boolean type" };
|
||||
const ExpressionType& condType = GetExpressionType(*condStatement.condition);
|
||||
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
|
||||
throw AstError{ "branch expressions must resolve to boolean type" };
|
||||
|
||||
condStatement.statement = CloneStatement(MandatoryStatement(cond.statement));
|
||||
condStatement.statement = CloneStatement(MandatoryStatement(cond.statement));
|
||||
};
|
||||
|
||||
if (m_context->options.splitMultipleBranches && condIndex > 0)
|
||||
{
|
||||
auto currentBranch = std::make_unique<BranchStatement>();
|
||||
|
||||
BuildCondStatement(currentBranch->condStatements.emplace_back());
|
||||
|
||||
root->elseStatement = std::move(currentBranch);
|
||||
root = static_cast<BranchStatement*>(root->elseStatement.get());
|
||||
}
|
||||
else
|
||||
BuildCondStatement(clone->condStatements.emplace_back());
|
||||
|
||||
PopScope();
|
||||
}
|
||||
|
|
@ -550,7 +567,7 @@ namespace Nz::ShaderAst
|
|||
if (node.elseStatement)
|
||||
{
|
||||
PushScope();
|
||||
clone->elseStatement = CloneStatement(node.elseStatement);
|
||||
root->elseStatement = CloneStatement(node.elseStatement);
|
||||
PopScope();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -365,66 +365,36 @@ namespace Nz
|
|||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::BranchStatement& node)
|
||||
{
|
||||
assert(!node.condStatements.empty());
|
||||
auto& firstCond = node.condStatements.front();
|
||||
|
||||
UInt32 previousConditionId = EvaluateExpression(firstCond.condition);
|
||||
SpirvBlock previousContentBlock(m_writer);
|
||||
m_currentBlock = &previousContentBlock;
|
||||
|
||||
firstCond.statement->Visit(*this);
|
||||
assert(node.condStatements.size() == 1); //< sanitization splits multiple branches
|
||||
auto& condStatement = node.condStatements.front();
|
||||
|
||||
SpirvBlock mergeBlock(m_writer);
|
||||
SpirvBlock contentBlock(m_writer);
|
||||
SpirvBlock elseBlock(m_writer);
|
||||
|
||||
if (!previousContentBlock.IsTerminated())
|
||||
previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
|
||||
UInt32 conditionId = EvaluateExpression(condStatement.condition);
|
||||
m_currentBlock->Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
|
||||
// FIXME: Can we use merge block directly in OpBranchConditional if no else statement?
|
||||
m_currentBlock->Append(SpirvOp::OpBranchConditional, conditionId, contentBlock.GetLabelId(), elseBlock.GetLabelId());
|
||||
|
||||
m_functionBlocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
|
||||
m_functionBlocks.emplace_back(std::move(contentBlock));
|
||||
m_currentBlock = &m_functionBlocks.back();
|
||||
|
||||
std::optional<std::size_t> nextBlock;
|
||||
for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex)
|
||||
{
|
||||
auto& statement = node.condStatements[statementIndex];
|
||||
condStatement.statement->Visit(*this);
|
||||
|
||||
SpirvBlock contentBlock(m_writer);
|
||||
if (!m_currentBlock->IsTerminated())
|
||||
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
|
||||
|
||||
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId());
|
||||
|
||||
previousConditionId = EvaluateExpression(statement.condition);
|
||||
m_functionBlocks.emplace_back(std::move(previousContentBlock));
|
||||
previousContentBlock = std::move(contentBlock);
|
||||
|
||||
m_currentBlock = &previousContentBlock;
|
||||
|
||||
statement.statement->Visit(*this);
|
||||
|
||||
if (!previousContentBlock.IsTerminated())
|
||||
previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
|
||||
}
|
||||
m_functionBlocks.emplace_back(std::move(elseBlock));
|
||||
m_currentBlock = &m_functionBlocks.back();
|
||||
|
||||
if (node.elseStatement)
|
||||
{
|
||||
SpirvBlock elseBlock(m_writer);
|
||||
|
||||
m_currentBlock = &elseBlock;
|
||||
|
||||
node.elseStatement->Visit(*this);
|
||||
|
||||
if (!elseBlock.IsTerminated())
|
||||
elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
|
||||
|
||||
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId());
|
||||
m_functionBlocks.emplace_back(std::move(previousContentBlock));
|
||||
m_functionBlocks.emplace_back(std::move(elseBlock));
|
||||
}
|
||||
else
|
||||
{
|
||||
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId());
|
||||
m_functionBlocks.emplace_back(std::move(previousContentBlock));
|
||||
}
|
||||
if (!m_currentBlock->IsTerminated())
|
||||
m_currentBlock->Append(SpirvOp::OpBranch, mergeBlock.GetLabelId());
|
||||
|
||||
m_functionBlocks.emplace_back(std::move(mergeBlock));
|
||||
|
||||
m_currentBlock = &m_functionBlocks.back();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -488,6 +488,7 @@ namespace Nz
|
|||
ShaderAst::SanitizeVisitor::Options options;
|
||||
options.optionValues = states.optionValues;
|
||||
options.removeCompoundAssignments = true;
|
||||
options.splitMultipleBranches = true;
|
||||
|
||||
sanitizedAst = ShaderAst::Sanitize(shader, options);
|
||||
targetAst = sanitizedAst.get();
|
||||
|
|
|
|||
|
|
@ -144,6 +144,130 @@ OpBranchConditional
|
|||
OpLabel
|
||||
OpKill
|
||||
OpLabel
|
||||
OpBranch
|
||||
OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd)");
|
||||
}
|
||||
|
||||
|
||||
WHEN("using a complex branch")
|
||||
{
|
||||
std::string_view nzslSource = R"(
|
||||
struct inputStruct
|
||||
{
|
||||
value: f32
|
||||
}
|
||||
|
||||
external
|
||||
{
|
||||
[set(0), binding(0)] data: uniform<inputStruct>
|
||||
}
|
||||
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: f32;
|
||||
if (data.value > 3.0)
|
||||
value = 3.0;
|
||||
else if (data.value > 2.0)
|
||||
value = 2.0;
|
||||
else if (data.value > 1.0)
|
||||
value = 1.0;
|
||||
else
|
||||
value = 0.0;
|
||||
}
|
||||
)";
|
||||
|
||||
Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource);
|
||||
|
||||
ExpectGLSL(*shader, R"(
|
||||
void main()
|
||||
{
|
||||
float value;
|
||||
if (data.value > (3.000000))
|
||||
{
|
||||
value = 3.000000;
|
||||
}
|
||||
else if (data.value > (2.000000))
|
||||
{
|
||||
value = 2.000000;
|
||||
}
|
||||
else if (data.value > (1.000000))
|
||||
{
|
||||
value = 1.000000;
|
||||
}
|
||||
else
|
||||
{
|
||||
value = 0.000000;
|
||||
}
|
||||
|
||||
}
|
||||
)");
|
||||
|
||||
ExpectNZSL(*shader, R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: f32;
|
||||
if (data.value > (3.000000))
|
||||
{
|
||||
value = 3.000000;
|
||||
}
|
||||
else if (data.value > (2.000000))
|
||||
{
|
||||
value = 2.000000;
|
||||
}
|
||||
else if (data.value > (1.000000))
|
||||
{
|
||||
value = 1.000000;
|
||||
}
|
||||
else
|
||||
{
|
||||
value = 0.000000;
|
||||
}
|
||||
|
||||
}
|
||||
)");
|
||||
|
||||
ExpectSpirV(*shader, R"(
|
||||
OpFunction
|
||||
OpLabel
|
||||
OpVariable
|
||||
OpAccessChain
|
||||
OpLoad
|
||||
OpFOrdGreaterThanEqual
|
||||
OpSelectionMerge
|
||||
OpBranchConditional
|
||||
OpLabel
|
||||
OpStore
|
||||
OpBranch
|
||||
OpLabel
|
||||
OpAccessChain
|
||||
OpLoad
|
||||
OpFOrdGreaterThanEqual
|
||||
OpSelectionMerge
|
||||
OpBranchConditional
|
||||
OpLabel
|
||||
OpStore
|
||||
OpBranch
|
||||
OpLabel
|
||||
OpAccessChain
|
||||
OpLoad
|
||||
OpFOrdGreaterThanEqual
|
||||
OpSelectionMerge
|
||||
OpBranchConditional
|
||||
OpLabel
|
||||
OpStore
|
||||
OpBranch
|
||||
OpLabel
|
||||
OpStore
|
||||
OpBranch
|
||||
OpLabel
|
||||
OpBranch
|
||||
OpLabel
|
||||
OpBranch
|
||||
OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd)");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue