Shader: Add support for for-each statements and improve arrays

This commit is contained in:
Jérôme Leclercq
2022-01-02 22:02:11 +01:00
parent aac6e38da2
commit 4fe44339c5
30 changed files with 712 additions and 93 deletions

View File

@@ -170,6 +170,16 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr AstCloner::Clone(ForEachStatement& node)
{
auto clone = std::make_unique<ForEachStatement>();
clone->isConst = node.isConst;
clone->expression = CloneExpression(node.expression);
clone->statement = CloneStatement(node.statement);
return clone;
}
StatementPtr AstCloner::Clone(MultiStatement& node)
{
auto clone = std::make_unique<MultiStatement>();

View File

@@ -85,7 +85,8 @@ namespace Nz::ShaderAst
void AstRecursiveVisitor::Visit(SwizzleExpression& node)
{
node.expression->Visit(*this);
if (node.expression)
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(VariableExpression& /*node*/)
@@ -95,7 +96,8 @@ namespace Nz::ShaderAst
void AstRecursiveVisitor::Visit(UnaryExpression& node)
{
node.expression->Visit(*this);
if (node.expression)
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(BranchStatement& node)
@@ -159,6 +161,15 @@ namespace Nz::ShaderAst
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(ForEachStatement& node)
{
if (node.expression)
node.expression->Visit(*this);
if (node.statement)
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(MultiStatement& node)
{
for (auto& statement : node.statements)

View File

@@ -301,6 +301,14 @@ namespace Nz::ShaderAst
Node(node.expression);
}
void AstSerializerBase::Serialize(ForEachStatement& node)
{
Value(node.isConst);
Value(node.varName);
Node(node.expression);
Node(node.statement);
}
void AstSerializerBase::Serialize(MultiStatement& node)
{
Container(node.statements);

View File

@@ -13,7 +13,7 @@ namespace Nz::ShaderAst
{
assert(array.containedType);
containedType = std::make_unique<ContainedType>(*array.containedType);
length = Clone(length);
length = Clone(array.length);
}
ArrayType& ArrayType::operator=(const ArrayType& array)
@@ -21,7 +21,7 @@ namespace Nz::ShaderAst
assert(array.containedType);
containedType = std::make_unique<ContainedType>(*array.containedType);
length = Clone(length);
length = Clone(array.length);
return *this;
}

View File

@@ -278,39 +278,7 @@ namespace Nz::ShaderAst
MandatoryExpr(node.right);
auto clone = static_unique_pointer_cast<AssignExpression>(AstCloner::Clone(node));
if (GetExpressionCategory(*clone->left) != ExpressionCategory::LValue)
throw AstError{ "Assignation is only possible with a l-value" };
std::optional<BinaryType> binaryType;
switch (clone->op)
{
case AssignType::Simple:
TypeMustMatch(clone->left, clone->right);
break;
case AssignType::CompoundAdd: binaryType = BinaryType::Add; break;
case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break;
case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break;
case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break;
case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break;
case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break;
}
if (binaryType)
{
ExpressionType expressionType = ValidateBinaryOp(*binaryType, clone->left, clone->right);
TypeMustMatch(GetExpressionType(*clone->left), expressionType);
if (m_context->options.removeCompoundAssignments)
{
clone->op = AssignType::Simple;
clone->right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*clone->left), std::move(clone->right));
clone->right->cachedExpressionType = std::move(expressionType);
}
}
clone->cachedExpressionType = GetExpressionType(*clone->left);
Validate(*clone);
return clone;
}
@@ -318,7 +286,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node)
{
auto clone = static_unique_pointer_cast<BinaryExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = ValidateBinaryOp(clone->op, clone->left, clone->right);
Validate(*clone);
return clone;
}
@@ -861,6 +829,119 @@ namespace Nz::ShaderAst
return AstCloner::Clone(node);
}
StatementPtr SanitizeVisitor::Clone(ForEachStatement& node)
{
auto expr = CloneExpression(node.expression);
const ExpressionType& exprType = GetExpressionType(*expr);
ExpressionType innerType;
if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
innerType = arrayType.containedType->type;
}
else
throw AstError{ "for-each is only supported on arrays and range expressions" };
if (node.isConst)
{
// Repeat code
auto multi = std::make_unique<MultiStatement>();
if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
UInt32 length = arrayType.length.GetResultingValue();
for (UInt32 i = 0; i < length; ++i)
{
auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i));
Validate(*accessIndex);
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
Validate(*elementVariable);
multi->statements.emplace_back(std::move(elementVariable));
multi->statements.emplace_back(CloneStatement(node.statement));
}
}
return multi;
}
if (m_context->options.reduceLoopsToWhile)
{
PushScope();
auto multi = std::make_unique<MultiStatement>();
if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
UInt32 length = arrayType.length.GetResultingValue();
multi->statements.reserve(2);
// Counter variable
auto counterVariable = ShaderBuilder::DeclareVariable("i", ShaderBuilder::Constant(0u));
Validate(*counterVariable);
std::size_t counterVarIndex = counterVariable->varIndex.value();
multi->statements.emplace_back(std::move(counterVariable));
auto whileStatement = std::make_unique<WhileStatement>();
// While condition
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(length));
Validate(*condition);
whileStatement->condition = std::move(condition);
// While body
auto body = std::make_unique<MultiStatement>();
body->statements.reserve(3);
auto accessIndex = ShaderBuilder::AccessIndex(std::move(expr), ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32));
Validate(*accessIndex);
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
Validate(*elementVariable);
body->statements.emplace_back(std::move(elementVariable));
body->statements.emplace_back(CloneStatement(node.statement));
auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(1u));
Validate(*incrCounter);
body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter)));
whileStatement->body = std::move(body);
multi->statements.emplace_back(std::move(whileStatement));
}
PopScope();
return multi;
}
else
{
auto clone = std::make_unique<ForEachStatement>();
clone->expression = std::move(expr);
clone->varName = node.varName;
PushScope();
{
clone->varIndex = RegisterVariable(node.varName, innerType);
clone->statement = CloneStatement(node.statement);
}
PopScope();
SanitizeIdentifier(node.varName);
return clone;
}
}
StatementPtr SanitizeVisitor::Clone(MultiStatement& node)
{
PushScope();
@@ -1206,7 +1287,6 @@ namespace Nz::ShaderAst
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, NoType> ||
std::is_same_v<T, ArrayType> ||
std::is_same_v<T, PrimitiveType> ||
std::is_same_v<T, MatrixType> ||
std::is_same_v<T, SamplerType> ||
@@ -1215,6 +1295,22 @@ namespace Nz::ShaderAst
{
return exprType;
}
else if constexpr (std::is_same_v<T, ArrayType>)
{
ArrayType resolvedArrayType;
if (arg.length.IsExpression())
{
resolvedArrayType.length = CloneExpression(arg.length.GetExpression());
ComputeAttributeValue(resolvedArrayType.length);
}
else if (arg.length.IsResultingValue())
resolvedArrayType.length = arg.length.GetResultingValue();
resolvedArrayType.containedType = std::make_unique<ContainedType>();
resolvedArrayType.containedType->type = ResolveType(arg.containedType->type);
return resolvedArrayType;
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
const Identifier* identifier = FindIdentifier(arg.name);
@@ -1262,8 +1358,12 @@ namespace Nz::ShaderAst
for (auto& index : node.indices)
{
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
if (!IsPrimitiveType(indexType) || std::get<PrimitiveType>(indexType) != PrimitiveType::Int32)
throw AstError{ "AccessIndex expects Int32 indices" };
if (!IsPrimitiveType(indexType))
throw AstError{ "AccessIndex expects integer indices" };
PrimitiveType primitiveIndexType = std::get<PrimitiveType>(indexType);
if (primitiveIndexType != PrimitiveType::Int32 && primitiveIndexType != PrimitiveType::UInt32)
throw AstError{ "AccessIndex expects integer indices" };
}
ExpressionType exprType = GetExpressionType(*node.expr);
@@ -1272,8 +1372,8 @@ namespace Nz::ShaderAst
if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
exprType = arrayType.containedType->type;
ExpressionType containedType = arrayType.containedType->type; //< Don't overwrite exprType directly since it contains arrayType
exprType = std::move(containedType);
}
else if (IsStructType(exprType))
{
@@ -1294,7 +1394,7 @@ namespace Nz::ShaderAst
else if (IsMatrixType(exprType))
{
// Matrix index (ex: mat[2])
const MatrixType& matrixType = std::get<MatrixType>(exprType);
MatrixType matrixType = std::get<MatrixType>(exprType);
//TODO: Handle row-major matrices
exprType = VectorType{ matrixType.rowCount, matrixType.type };
@@ -1302,7 +1402,7 @@ namespace Nz::ShaderAst
else if (IsVectorType(exprType))
{
// Swizzle expression with one component (ex: vec[2])
const VectorType& swizzledVec = std::get<VectorType>(exprType);
VectorType swizzledVec = std::get<VectorType>(exprType);
exprType = swizzledVec.type;
}
@@ -1313,6 +1413,47 @@ namespace Nz::ShaderAst
node.cachedExpressionType = std::move(exprType);
}
void SanitizeVisitor::Validate(AssignExpression& node)
{
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
throw AstError{ "Assignation is only possible with a l-value" };
std::optional<BinaryType> binaryType;
switch (node.op)
{
case AssignType::Simple:
TypeMustMatch(node.left, node.right);
break;
case AssignType::CompoundAdd: binaryType = BinaryType::Add; break;
case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break;
case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break;
case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break;
case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break;
case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break;
}
if (binaryType)
{
ExpressionType expressionType = ValidateBinaryOp(*binaryType, node.left, node.right);
TypeMustMatch(GetExpressionType(*node.left), expressionType);
if (m_context->options.removeCompoundAssignments)
{
node.op = AssignType::Simple;
node.right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*node.left), std::move(node.right));
node.right->cachedExpressionType = std::move(expressionType);
}
}
node.cachedExpressionType = GetExpressionType(*node.left);
}
void SanitizeVisitor::Validate(BinaryExpression& node)
{
node.cachedExpressionType = ValidateBinaryOp(node.op, node.left, node.right);
}
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
{
if (referenceDeclaration->entryStage.HasValue())