|
|
|
|
@@ -278,39 +278,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
MandatoryExpr(node.right);
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<AssignExpression>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
if (GetExpressionCategory(*clone->left) != ExpressionCategory::LValue)
|
|
|
|
|
throw AstError{ "Assignation is only possible with a l-value" };
|
|
|
|
|
|
|
|
|
|
std::optional<BinaryType> binaryType;
|
|
|
|
|
switch (clone->op)
|
|
|
|
|
{
|
|
|
|
|
case AssignType::Simple:
|
|
|
|
|
TypeMustMatch(clone->left, clone->right);
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case AssignType::CompoundAdd: binaryType = BinaryType::Add; break;
|
|
|
|
|
case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break;
|
|
|
|
|
case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break;
|
|
|
|
|
case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break;
|
|
|
|
|
case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break;
|
|
|
|
|
case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (binaryType)
|
|
|
|
|
{
|
|
|
|
|
ExpressionType expressionType = ValidateBinaryOp(*binaryType, clone->left, clone->right);
|
|
|
|
|
TypeMustMatch(GetExpressionType(*clone->left), expressionType);
|
|
|
|
|
|
|
|
|
|
if (m_context->options.removeCompoundAssignments)
|
|
|
|
|
{
|
|
|
|
|
clone->op = AssignType::Simple;
|
|
|
|
|
clone->right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*clone->left), std::move(clone->right));
|
|
|
|
|
clone->right->cachedExpressionType = std::move(expressionType);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
clone->cachedExpressionType = GetExpressionType(*clone->left);
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
@@ -318,7 +286,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node)
|
|
|
|
|
{
|
|
|
|
|
auto clone = static_unique_pointer_cast<BinaryExpression>(AstCloner::Clone(node));
|
|
|
|
|
clone->cachedExpressionType = ValidateBinaryOp(clone->op, clone->left, clone->right);
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
@@ -861,6 +829,119 @@ namespace Nz::ShaderAst
|
|
|
|
|
return AstCloner::Clone(node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(ForEachStatement& node)
|
|
|
|
|
{
|
|
|
|
|
auto expr = CloneExpression(node.expression);
|
|
|
|
|
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*expr);
|
|
|
|
|
ExpressionType innerType;
|
|
|
|
|
if (IsArrayType(exprType))
|
|
|
|
|
{
|
|
|
|
|
const ArrayType& arrayType = std::get<ArrayType>(exprType);
|
|
|
|
|
innerType = arrayType.containedType->type;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "for-each is only supported on arrays and range expressions" };
|
|
|
|
|
|
|
|
|
|
if (node.isConst)
|
|
|
|
|
{
|
|
|
|
|
// Repeat code
|
|
|
|
|
auto multi = std::make_unique<MultiStatement>();
|
|
|
|
|
if (IsArrayType(exprType))
|
|
|
|
|
{
|
|
|
|
|
const ArrayType& arrayType = std::get<ArrayType>(exprType);
|
|
|
|
|
UInt32 length = arrayType.length.GetResultingValue();
|
|
|
|
|
|
|
|
|
|
for (UInt32 i = 0; i < length; ++i)
|
|
|
|
|
{
|
|
|
|
|
auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i));
|
|
|
|
|
Validate(*accessIndex);
|
|
|
|
|
|
|
|
|
|
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
|
|
|
|
|
Validate(*elementVariable);
|
|
|
|
|
|
|
|
|
|
multi->statements.emplace_back(std::move(elementVariable));
|
|
|
|
|
multi->statements.emplace_back(CloneStatement(node.statement));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return multi;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (m_context->options.reduceLoopsToWhile)
|
|
|
|
|
{
|
|
|
|
|
PushScope();
|
|
|
|
|
|
|
|
|
|
auto multi = std::make_unique<MultiStatement>();
|
|
|
|
|
|
|
|
|
|
if (IsArrayType(exprType))
|
|
|
|
|
{
|
|
|
|
|
const ArrayType& arrayType = std::get<ArrayType>(exprType);
|
|
|
|
|
UInt32 length = arrayType.length.GetResultingValue();
|
|
|
|
|
|
|
|
|
|
multi->statements.reserve(2);
|
|
|
|
|
|
|
|
|
|
// Counter variable
|
|
|
|
|
auto counterVariable = ShaderBuilder::DeclareVariable("i", ShaderBuilder::Constant(0u));
|
|
|
|
|
Validate(*counterVariable);
|
|
|
|
|
|
|
|
|
|
std::size_t counterVarIndex = counterVariable->varIndex.value();
|
|
|
|
|
|
|
|
|
|
multi->statements.emplace_back(std::move(counterVariable));
|
|
|
|
|
|
|
|
|
|
auto whileStatement = std::make_unique<WhileStatement>();
|
|
|
|
|
|
|
|
|
|
// While condition
|
|
|
|
|
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(length));
|
|
|
|
|
Validate(*condition);
|
|
|
|
|
whileStatement->condition = std::move(condition);
|
|
|
|
|
|
|
|
|
|
// While body
|
|
|
|
|
auto body = std::make_unique<MultiStatement>();
|
|
|
|
|
body->statements.reserve(3);
|
|
|
|
|
|
|
|
|
|
auto accessIndex = ShaderBuilder::AccessIndex(std::move(expr), ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32));
|
|
|
|
|
Validate(*accessIndex);
|
|
|
|
|
|
|
|
|
|
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
|
|
|
|
|
Validate(*elementVariable);
|
|
|
|
|
body->statements.emplace_back(std::move(elementVariable));
|
|
|
|
|
|
|
|
|
|
body->statements.emplace_back(CloneStatement(node.statement));
|
|
|
|
|
|
|
|
|
|
auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(1u));
|
|
|
|
|
Validate(*incrCounter);
|
|
|
|
|
|
|
|
|
|
body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter)));
|
|
|
|
|
|
|
|
|
|
whileStatement->body = std::move(body);
|
|
|
|
|
|
|
|
|
|
multi->statements.emplace_back(std::move(whileStatement));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PopScope();
|
|
|
|
|
|
|
|
|
|
return multi;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
auto clone = std::make_unique<ForEachStatement>();
|
|
|
|
|
clone->expression = std::move(expr);
|
|
|
|
|
clone->varName = node.varName;
|
|
|
|
|
|
|
|
|
|
PushScope();
|
|
|
|
|
{
|
|
|
|
|
clone->varIndex = RegisterVariable(node.varName, innerType);
|
|
|
|
|
clone->statement = CloneStatement(node.statement);
|
|
|
|
|
}
|
|
|
|
|
PopScope();
|
|
|
|
|
|
|
|
|
|
SanitizeIdentifier(node.varName);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(MultiStatement& node)
|
|
|
|
|
{
|
|
|
|
|
PushScope();
|
|
|
|
|
@@ -1206,7 +1287,6 @@ namespace Nz::ShaderAst
|
|
|
|
|
using T = std::decay_t<decltype(arg)>;
|
|
|
|
|
|
|
|
|
|
if constexpr (std::is_same_v<T, NoType> ||
|
|
|
|
|
std::is_same_v<T, ArrayType> ||
|
|
|
|
|
std::is_same_v<T, PrimitiveType> ||
|
|
|
|
|
std::is_same_v<T, MatrixType> ||
|
|
|
|
|
std::is_same_v<T, SamplerType> ||
|
|
|
|
|
@@ -1215,6 +1295,22 @@ namespace Nz::ShaderAst
|
|
|
|
|
{
|
|
|
|
|
return exprType;
|
|
|
|
|
}
|
|
|
|
|
else if constexpr (std::is_same_v<T, ArrayType>)
|
|
|
|
|
{
|
|
|
|
|
ArrayType resolvedArrayType;
|
|
|
|
|
if (arg.length.IsExpression())
|
|
|
|
|
{
|
|
|
|
|
resolvedArrayType.length = CloneExpression(arg.length.GetExpression());
|
|
|
|
|
ComputeAttributeValue(resolvedArrayType.length);
|
|
|
|
|
}
|
|
|
|
|
else if (arg.length.IsResultingValue())
|
|
|
|
|
resolvedArrayType.length = arg.length.GetResultingValue();
|
|
|
|
|
|
|
|
|
|
resolvedArrayType.containedType = std::make_unique<ContainedType>();
|
|
|
|
|
resolvedArrayType.containedType->type = ResolveType(arg.containedType->type);
|
|
|
|
|
|
|
|
|
|
return resolvedArrayType;
|
|
|
|
|
}
|
|
|
|
|
else if constexpr (std::is_same_v<T, IdentifierType>)
|
|
|
|
|
{
|
|
|
|
|
const Identifier* identifier = FindIdentifier(arg.name);
|
|
|
|
|
@@ -1262,8 +1358,12 @@ namespace Nz::ShaderAst
|
|
|
|
|
for (auto& index : node.indices)
|
|
|
|
|
{
|
|
|
|
|
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
|
|
|
|
|
if (!IsPrimitiveType(indexType) || std::get<PrimitiveType>(indexType) != PrimitiveType::Int32)
|
|
|
|
|
throw AstError{ "AccessIndex expects Int32 indices" };
|
|
|
|
|
if (!IsPrimitiveType(indexType))
|
|
|
|
|
throw AstError{ "AccessIndex expects integer indices" };
|
|
|
|
|
|
|
|
|
|
PrimitiveType primitiveIndexType = std::get<PrimitiveType>(indexType);
|
|
|
|
|
if (primitiveIndexType != PrimitiveType::Int32 && primitiveIndexType != PrimitiveType::UInt32)
|
|
|
|
|
throw AstError{ "AccessIndex expects integer indices" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionType exprType = GetExpressionType(*node.expr);
|
|
|
|
|
@@ -1272,8 +1372,8 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (IsArrayType(exprType))
|
|
|
|
|
{
|
|
|
|
|
const ArrayType& arrayType = std::get<ArrayType>(exprType);
|
|
|
|
|
|
|
|
|
|
exprType = arrayType.containedType->type;
|
|
|
|
|
ExpressionType containedType = arrayType.containedType->type; //< Don't overwrite exprType directly since it contains arrayType
|
|
|
|
|
exprType = std::move(containedType);
|
|
|
|
|
}
|
|
|
|
|
else if (IsStructType(exprType))
|
|
|
|
|
{
|
|
|
|
|
@@ -1294,7 +1394,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
else if (IsMatrixType(exprType))
|
|
|
|
|
{
|
|
|
|
|
// Matrix index (ex: mat[2])
|
|
|
|
|
const MatrixType& matrixType = std::get<MatrixType>(exprType);
|
|
|
|
|
MatrixType matrixType = std::get<MatrixType>(exprType);
|
|
|
|
|
|
|
|
|
|
//TODO: Handle row-major matrices
|
|
|
|
|
exprType = VectorType{ matrixType.rowCount, matrixType.type };
|
|
|
|
|
@@ -1302,7 +1402,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
else if (IsVectorType(exprType))
|
|
|
|
|
{
|
|
|
|
|
// Swizzle expression with one component (ex: vec[2])
|
|
|
|
|
const VectorType& swizzledVec = std::get<VectorType>(exprType);
|
|
|
|
|
VectorType swizzledVec = std::get<VectorType>(exprType);
|
|
|
|
|
|
|
|
|
|
exprType = swizzledVec.type;
|
|
|
|
|
}
|
|
|
|
|
@@ -1313,6 +1413,47 @@ namespace Nz::ShaderAst
|
|
|
|
|
node.cachedExpressionType = std::move(exprType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(AssignExpression& node)
|
|
|
|
|
{
|
|
|
|
|
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
|
|
|
|
|
throw AstError{ "Assignation is only possible with a l-value" };
|
|
|
|
|
|
|
|
|
|
std::optional<BinaryType> binaryType;
|
|
|
|
|
switch (node.op)
|
|
|
|
|
{
|
|
|
|
|
case AssignType::Simple:
|
|
|
|
|
TypeMustMatch(node.left, node.right);
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case AssignType::CompoundAdd: binaryType = BinaryType::Add; break;
|
|
|
|
|
case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break;
|
|
|
|
|
case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break;
|
|
|
|
|
case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break;
|
|
|
|
|
case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break;
|
|
|
|
|
case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (binaryType)
|
|
|
|
|
{
|
|
|
|
|
ExpressionType expressionType = ValidateBinaryOp(*binaryType, node.left, node.right);
|
|
|
|
|
TypeMustMatch(GetExpressionType(*node.left), expressionType);
|
|
|
|
|
|
|
|
|
|
if (m_context->options.removeCompoundAssignments)
|
|
|
|
|
{
|
|
|
|
|
node.op = AssignType::Simple;
|
|
|
|
|
node.right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*node.left), std::move(node.right));
|
|
|
|
|
node.right->cachedExpressionType = std::move(expressionType);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = GetExpressionType(*node.left);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(BinaryExpression& node)
|
|
|
|
|
{
|
|
|
|
|
node.cachedExpressionType = ValidateBinaryOp(node.op, node.left, node.right);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
|
|
|
|
|
{
|
|
|
|
|
if (referenceDeclaration->entryStage.HasValue())
|
|
|
|
|
|