Shader: Replace const for with [unroll] attribute
This commit is contained in:
@@ -173,9 +173,9 @@ namespace Nz::ShaderAst
|
||||
StatementPtr AstCloner::Clone(ForEachStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<ForEachStatement>();
|
||||
clone->isConst = node.isConst;
|
||||
clone->expression = CloneExpression(node.expression);
|
||||
clone->statement = CloneStatement(node.statement);
|
||||
clone->unroll = Clone(node.unroll);
|
||||
|
||||
return clone;
|
||||
}
|
||||
@@ -208,6 +208,7 @@ namespace Nz::ShaderAst
|
||||
auto clone = std::make_unique<WhileStatement>();
|
||||
clone->condition = CloneExpression(node.condition);
|
||||
clone->body = CloneStatement(node.body);
|
||||
clone->unroll = Clone(node.unroll);
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
@@ -303,7 +303,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
void AstSerializerBase::Serialize(ForEachStatement& node)
|
||||
{
|
||||
Value(node.isConst);
|
||||
Attribute(node.unroll);
|
||||
Value(node.varName);
|
||||
Node(node.expression);
|
||||
Node(node.statement);
|
||||
@@ -328,6 +328,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
void AstSerializerBase::Serialize(WhileStatement& node)
|
||||
{
|
||||
Attribute(node.unroll);
|
||||
Node(node.condition);
|
||||
Node(node.body);
|
||||
}
|
||||
|
||||
@@ -843,29 +843,34 @@ namespace Nz::ShaderAst
|
||||
else
|
||||
throw AstError{ "for-each is only supported on arrays and range expressions" };
|
||||
|
||||
if (node.isConst)
|
||||
AttributeValue<LoopUnroll> unrollValue;
|
||||
if (node.unroll.HasValue())
|
||||
{
|
||||
// Repeat code
|
||||
auto multi = std::make_unique<MultiStatement>();
|
||||
if (IsArrayType(exprType))
|
||||
unrollValue = ComputeAttributeValue(node.unroll);
|
||||
if (unrollValue.GetResultingValue() == LoopUnroll::Always)
|
||||
{
|
||||
const ArrayType& arrayType = std::get<ArrayType>(exprType);
|
||||
UInt32 length = arrayType.length.GetResultingValue();
|
||||
|
||||
for (UInt32 i = 0; i < length; ++i)
|
||||
// Repeat code
|
||||
auto multi = std::make_unique<MultiStatement>();
|
||||
if (IsArrayType(exprType))
|
||||
{
|
||||
auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i));
|
||||
Validate(*accessIndex);
|
||||
const ArrayType& arrayType = std::get<ArrayType>(exprType);
|
||||
UInt32 length = arrayType.length.GetResultingValue();
|
||||
|
||||
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
|
||||
Validate(*elementVariable);
|
||||
for (UInt32 i = 0; i < length; ++i)
|
||||
{
|
||||
auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i));
|
||||
Validate(*accessIndex);
|
||||
|
||||
multi->statements.emplace_back(std::move(elementVariable));
|
||||
multi->statements.emplace_back(CloneStatement(node.statement));
|
||||
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
|
||||
Validate(*elementVariable);
|
||||
|
||||
multi->statements.emplace_back(std::move(elementVariable));
|
||||
multi->statements.emplace_back(CloneStatement(node.statement));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return multi;
|
||||
return multi;
|
||||
}
|
||||
}
|
||||
|
||||
if (m_context->options.reduceLoopsToWhile)
|
||||
@@ -890,6 +895,7 @@ namespace Nz::ShaderAst
|
||||
multi->statements.emplace_back(std::move(counterVariable));
|
||||
|
||||
auto whileStatement = std::make_unique<WhileStatement>();
|
||||
whileStatement->unroll = std::move(unrollValue);
|
||||
|
||||
// While condition
|
||||
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(length));
|
||||
@@ -928,6 +934,7 @@ namespace Nz::ShaderAst
|
||||
auto clone = std::make_unique<ForEachStatement>();
|
||||
clone->expression = std::move(expr);
|
||||
clone->varName = node.varName;
|
||||
clone->unroll = std::move(unrollValue);
|
||||
|
||||
PushScope();
|
||||
{
|
||||
@@ -968,9 +975,15 @@ namespace Nz::ShaderAst
|
||||
MandatoryStatement(node.body);
|
||||
|
||||
auto clone = static_unique_pointer_cast<WhileStatement>(AstCloner::Clone(node));
|
||||
Validate(*clone);
|
||||
|
||||
if (GetExpressionType(*clone->condition) != ExpressionType{ PrimitiveType::Boolean })
|
||||
throw AstError{ "expected a boolean value" };
|
||||
AttributeValue<LoopUnroll> unrollValue;
|
||||
if (node.unroll.HasValue())
|
||||
{
|
||||
clone->unroll = ComputeAttributeValue(node.unroll);
|
||||
if (clone->unroll.GetResultingValue() == LoopUnroll::Always)
|
||||
throw AstError{ "unroll(always) is not yet supported on while" };
|
||||
}
|
||||
|
||||
return clone;
|
||||
}
|
||||
@@ -1350,6 +1363,12 @@ namespace Nz::ShaderAst
|
||||
}
|
||||
}
|
||||
|
||||
void SanitizeVisitor::Validate(WhileStatement& node)
|
||||
{
|
||||
if (GetExpressionType(*node.condition) != ExpressionType{ PrimitiveType::Boolean })
|
||||
throw AstError{ "expected a boolean value" };
|
||||
}
|
||||
|
||||
void SanitizeVisitor::Validate(AccessIndexExpression& node)
|
||||
{
|
||||
if (node.indices.empty())
|
||||
|
||||
Reference in New Issue
Block a user