Shader: Add support for compound operators
This commit is contained in:
@@ -253,8 +253,35 @@ namespace Nz::ShaderAst
|
||||
|
||||
auto clone = static_unique_pointer_cast<AssignExpression>(AstCloner::Clone(node));
|
||||
|
||||
TypeMustMatch(clone->left, clone->right);
|
||||
clone->cachedExpressionType = GetExpressionType(*clone->right);
|
||||
std::optional<BinaryType> binaryType;
|
||||
switch (clone->op)
|
||||
{
|
||||
case AssignType::Simple:
|
||||
TypeMustMatch(clone->left, clone->right);
|
||||
break;
|
||||
|
||||
case AssignType::CompoundAdd: binaryType = BinaryType::Add; break;
|
||||
case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break;
|
||||
case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break;
|
||||
case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break;
|
||||
case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break;
|
||||
case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break;
|
||||
}
|
||||
|
||||
if (binaryType)
|
||||
{
|
||||
ExpressionType expressionType = ValidateBinaryOp(*binaryType, clone->left, clone->right);
|
||||
TypeMustMatch(GetExpressionType(*clone->left), expressionType);
|
||||
|
||||
if (m_context->options.removeCompoundAssignments)
|
||||
{
|
||||
clone->op = AssignType::Simple;
|
||||
clone->right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*clone->left), std::move(clone->right));
|
||||
clone->right->cachedExpressionType = std::move(expressionType);
|
||||
}
|
||||
}
|
||||
|
||||
clone->cachedExpressionType = GetExpressionType(*clone->left);
|
||||
|
||||
return clone;
|
||||
}
|
||||
@@ -262,191 +289,7 @@ namespace Nz::ShaderAst
|
||||
ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node)
|
||||
{
|
||||
auto clone = static_unique_pointer_cast<BinaryExpression>(AstCloner::Clone(node));
|
||||
|
||||
const ExpressionType& leftExprType = GetExpressionType(MandatoryExpr(clone->left));
|
||||
if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType))
|
||||
throw AstError{ "left expression type does not support binary operation" };
|
||||
|
||||
const ExpressionType& rightExprType = GetExpressionType(MandatoryExpr(clone->right));
|
||||
if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType))
|
||||
throw AstError{ "right expression type does not support binary operation" };
|
||||
|
||||
if (IsPrimitiveType(leftExprType))
|
||||
{
|
||||
PrimitiveType leftType = std::get<PrimitiveType>(leftExprType);
|
||||
switch (clone->op)
|
||||
{
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
if (leftType == PrimitiveType::Boolean)
|
||||
throw AstError{ "this operation is not supported for booleans" };
|
||||
|
||||
TypeMustMatch(clone->left, clone->right);
|
||||
|
||||
clone->cachedExpressionType = PrimitiveType::Boolean;
|
||||
break;
|
||||
|
||||
case BinaryType::Add:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(clone->left, clone->right);
|
||||
|
||||
clone->cachedExpressionType = leftExprType;
|
||||
break;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
case BinaryType::Divide:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case PrimitiveType::Float32:
|
||||
case PrimitiveType::Int32:
|
||||
case PrimitiveType::UInt32:
|
||||
{
|
||||
if (IsMatrixType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, std::get<MatrixType>(rightExprType).type);
|
||||
clone->cachedExpressionType = rightExprType;
|
||||
}
|
||||
else if (IsPrimitiveType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, rightExprType);
|
||||
clone->cachedExpressionType = leftExprType;
|
||||
}
|
||||
else if (IsVectorType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, std::get<VectorType>(rightExprType).type);
|
||||
clone->cachedExpressionType = rightExprType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "incompatible types" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case PrimitiveType::Boolean:
|
||||
throw AstError{ "this operation is not supported for booleans" };
|
||||
|
||||
default:
|
||||
throw AstError{ "incompatible types" };
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case BinaryType::LogicalAnd:
|
||||
case BinaryType::LogicalOr:
|
||||
if (leftType != PrimitiveType::Boolean)
|
||||
throw AstError{ "logical and/or are only supported on booleans" };
|
||||
|
||||
TypeMustMatch(clone->left, clone->right);
|
||||
|
||||
clone->cachedExpressionType = PrimitiveType::Boolean;
|
||||
break;
|
||||
}
|
||||
}
|
||||
else if (IsMatrixType(leftExprType))
|
||||
{
|
||||
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
|
||||
switch (clone->op)
|
||||
{
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
TypeMustMatch(clone->left, clone->right);
|
||||
clone->cachedExpressionType = PrimitiveType::Boolean;
|
||||
break;
|
||||
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(clone->left, clone->right);
|
||||
clone->cachedExpressionType = leftExprType;
|
||||
break;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
case BinaryType::Divide:
|
||||
{
|
||||
if (IsMatrixType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftExprType, rightExprType);
|
||||
clone->cachedExpressionType = leftExprType; //< FIXME
|
||||
}
|
||||
else if (IsPrimitiveType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType.type, rightExprType);
|
||||
clone->cachedExpressionType = leftExprType;
|
||||
}
|
||||
else if (IsVectorType(rightExprType))
|
||||
{
|
||||
const VectorType& rightType = std::get<VectorType>(rightExprType);
|
||||
TypeMustMatch(leftType.type, rightType.type);
|
||||
|
||||
if (leftType.columnCount != rightType.componentCount)
|
||||
throw AstError{ "incompatible types" };
|
||||
|
||||
clone->cachedExpressionType = rightExprType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "incompatible types" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case BinaryType::LogicalAnd:
|
||||
case BinaryType::LogicalOr:
|
||||
throw AstError{ "logical and/or are only supported on booleans" };
|
||||
}
|
||||
}
|
||||
else if (IsVectorType(leftExprType))
|
||||
{
|
||||
const VectorType& leftType = std::get<VectorType>(leftExprType);
|
||||
switch (clone->op)
|
||||
{
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
TypeMustMatch(clone->left, clone->right);
|
||||
clone->cachedExpressionType = PrimitiveType::Boolean;
|
||||
break;
|
||||
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(clone->left, clone->right);
|
||||
clone->cachedExpressionType = leftExprType;
|
||||
break;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
case BinaryType::Divide:
|
||||
{
|
||||
if (IsPrimitiveType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType.type, rightExprType);
|
||||
clone->cachedExpressionType = leftExprType;
|
||||
}
|
||||
else if (IsVectorType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, rightExprType);
|
||||
clone->cachedExpressionType = rightExprType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "incompatible types" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case BinaryType::LogicalAnd:
|
||||
case BinaryType::LogicalOr:
|
||||
throw AstError{ "logical and/or are only supported on booleans" };
|
||||
}
|
||||
}
|
||||
clone->cachedExpressionType = ValidateBinaryOp(clone->op, clone->left, clone->right);
|
||||
|
||||
return clone;
|
||||
}
|
||||
@@ -1116,7 +959,7 @@ namespace Nz::ShaderAst
|
||||
return &*it;
|
||||
}
|
||||
|
||||
Expression& SanitizeVisitor::MandatoryExpr(ExpressionPtr& node)
|
||||
Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node)
|
||||
{
|
||||
if (!node)
|
||||
throw AstError{ "Invalid expression" };
|
||||
@@ -1124,7 +967,7 @@ namespace Nz::ShaderAst
|
||||
return *node;
|
||||
}
|
||||
|
||||
Statement& SanitizeVisitor::MandatoryStatement(StatementPtr& node)
|
||||
Statement& SanitizeVisitor::MandatoryStatement(const StatementPtr& node)
|
||||
{
|
||||
if (!node)
|
||||
throw AstError{ "Invalid statement" };
|
||||
@@ -1640,7 +1483,189 @@ namespace Nz::ShaderAst
|
||||
}
|
||||
}
|
||||
|
||||
void SanitizeVisitor::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
|
||||
ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr)
|
||||
{
|
||||
const ExpressionType& leftExprType = GetExpressionType(MandatoryExpr(leftExpr));
|
||||
const ExpressionType& rightExprType = GetExpressionType(MandatoryExpr(rightExpr));
|
||||
|
||||
if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType))
|
||||
throw AstError{ "left expression type does not support binary operation" };
|
||||
|
||||
if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType))
|
||||
throw AstError{ "right expression type does not support binary operation" };
|
||||
|
||||
if (IsPrimitiveType(leftExprType))
|
||||
{
|
||||
PrimitiveType leftType = std::get<PrimitiveType>(leftExprType);
|
||||
switch (op)
|
||||
{
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
{
|
||||
if (leftType == PrimitiveType::Boolean)
|
||||
throw AstError{ "this operation is not supported for booleans" };
|
||||
|
||||
TypeMustMatch(leftExpr, rightExpr);
|
||||
return PrimitiveType::Boolean;
|
||||
}
|
||||
|
||||
case BinaryType::Add:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(leftExpr, rightExpr);
|
||||
return leftExprType;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
case BinaryType::Divide:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case PrimitiveType::Float32:
|
||||
case PrimitiveType::Int32:
|
||||
case PrimitiveType::UInt32:
|
||||
{
|
||||
if (IsMatrixType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, std::get<MatrixType>(rightExprType).type);
|
||||
return rightExprType;
|
||||
}
|
||||
else if (IsPrimitiveType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, rightExprType);
|
||||
return leftExprType;
|
||||
}
|
||||
else if (IsVectorType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, std::get<VectorType>(rightExprType).type);
|
||||
return rightExprType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "incompatible types" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case PrimitiveType::Boolean:
|
||||
throw AstError{ "this operation is not supported for booleans" };
|
||||
|
||||
default:
|
||||
throw AstError{ "incompatible types" };
|
||||
}
|
||||
}
|
||||
|
||||
case BinaryType::LogicalAnd:
|
||||
case BinaryType::LogicalOr:
|
||||
{
|
||||
if (leftType != PrimitiveType::Boolean)
|
||||
throw AstError{ "logical and/or are only supported on booleans" };
|
||||
|
||||
TypeMustMatch(leftExpr, rightExpr);
|
||||
return PrimitiveType::Boolean;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (IsMatrixType(leftExprType))
|
||||
{
|
||||
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
|
||||
switch (op)
|
||||
{
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
TypeMustMatch(leftExpr, rightExpr);
|
||||
return PrimitiveType::Boolean;
|
||||
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(leftExpr, rightExpr);
|
||||
return leftExprType;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
case BinaryType::Divide:
|
||||
{
|
||||
if (IsMatrixType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftExprType, rightExprType);
|
||||
return leftExprType; //< FIXME
|
||||
}
|
||||
else if (IsPrimitiveType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType.type, rightExprType);
|
||||
return leftExprType;
|
||||
}
|
||||
else if (IsVectorType(rightExprType))
|
||||
{
|
||||
const VectorType& rightType = std::get<VectorType>(rightExprType);
|
||||
TypeMustMatch(leftType.type, rightType.type);
|
||||
|
||||
if (leftType.columnCount != rightType.componentCount)
|
||||
throw AstError{ "incompatible types" };
|
||||
|
||||
return rightExprType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "incompatible types" };
|
||||
}
|
||||
|
||||
case BinaryType::LogicalAnd:
|
||||
case BinaryType::LogicalOr:
|
||||
throw AstError{ "logical and/or are only supported on booleans" };
|
||||
}
|
||||
}
|
||||
else if (IsVectorType(leftExprType))
|
||||
{
|
||||
const VectorType& leftType = std::get<VectorType>(leftExprType);
|
||||
switch (op)
|
||||
{
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
TypeMustMatch(leftExpr, rightExpr);
|
||||
return PrimitiveType::Boolean;
|
||||
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(leftExpr, rightExpr);
|
||||
return leftExprType;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
case BinaryType::Divide:
|
||||
{
|
||||
if (IsPrimitiveType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType.type, rightExprType);
|
||||
return leftExprType;
|
||||
}
|
||||
else if (IsVectorType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, rightExprType);
|
||||
return rightExprType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "incompatible types" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case BinaryType::LogicalAnd:
|
||||
case BinaryType::LogicalOr:
|
||||
throw AstError{ "logical and/or are only supported on booleans" };
|
||||
}
|
||||
}
|
||||
|
||||
throw AstError{ "internal error: unchecked operation" };
|
||||
}
|
||||
|
||||
void SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right)
|
||||
{
|
||||
return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user