Shader: Add support for compound operators

This commit is contained in:
Jérôme Leclercq
2021-09-24 15:39:03 +02:00
parent 601ed047ba
commit 0442db1c53
15 changed files with 350 additions and 234 deletions

View File

@@ -253,8 +253,35 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<AssignExpression>(AstCloner::Clone(node));
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = GetExpressionType(*clone->right);
std::optional<BinaryType> binaryType;
switch (clone->op)
{
case AssignType::Simple:
TypeMustMatch(clone->left, clone->right);
break;
case AssignType::CompoundAdd: binaryType = BinaryType::Add; break;
case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break;
case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break;
case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break;
case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break;
case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break;
}
if (binaryType)
{
ExpressionType expressionType = ValidateBinaryOp(*binaryType, clone->left, clone->right);
TypeMustMatch(GetExpressionType(*clone->left), expressionType);
if (m_context->options.removeCompoundAssignments)
{
clone->op = AssignType::Simple;
clone->right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*clone->left), std::move(clone->right));
clone->right->cachedExpressionType = std::move(expressionType);
}
}
clone->cachedExpressionType = GetExpressionType(*clone->left);
return clone;
}
@@ -262,191 +289,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node)
{
auto clone = static_unique_pointer_cast<BinaryExpression>(AstCloner::Clone(node));
const ExpressionType& leftExprType = GetExpressionType(MandatoryExpr(clone->left));
if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
const ExpressionType& rightExprType = GetExpressionType(MandatoryExpr(clone->right));
if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
if (IsPrimitiveType(leftExprType))
{
PrimitiveType leftType = std::get<PrimitiveType>(leftExprType);
switch (clone->op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
if (leftType == PrimitiveType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = PrimitiveType::Boolean;
break;
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = leftExprType;
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
switch (leftType)
{
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
{
if (IsMatrixType(rightExprType))
{
TypeMustMatch(leftType, std::get<MatrixType>(rightExprType).type);
clone->cachedExpressionType = rightExprType;
}
else if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType, rightExprType);
clone->cachedExpressionType = leftExprType;
}
else if (IsVectorType(rightExprType))
{
TypeMustMatch(leftType, std::get<VectorType>(rightExprType).type);
clone->cachedExpressionType = rightExprType;
}
else
throw AstError{ "incompatible types" };
break;
}
case PrimitiveType::Boolean:
throw AstError{ "this operation is not supported for booleans" };
default:
throw AstError{ "incompatible types" };
}
break;
}
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
if (leftType != PrimitiveType::Boolean)
throw AstError{ "logical and/or are only supported on booleans" };
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = PrimitiveType::Boolean;
break;
}
}
else if (IsMatrixType(leftExprType))
{
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
switch (clone->op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = PrimitiveType::Boolean;
break;
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = leftExprType;
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsMatrixType(rightExprType))
{
TypeMustMatch(leftExprType, rightExprType);
clone->cachedExpressionType = leftExprType; //< FIXME
}
else if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType.type, rightExprType);
clone->cachedExpressionType = leftExprType;
}
else if (IsVectorType(rightExprType))
{
const VectorType& rightType = std::get<VectorType>(rightExprType);
TypeMustMatch(leftType.type, rightType.type);
if (leftType.columnCount != rightType.componentCount)
throw AstError{ "incompatible types" };
clone->cachedExpressionType = rightExprType;
}
else
throw AstError{ "incompatible types" };
break;
}
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
throw AstError{ "logical and/or are only supported on booleans" };
}
}
else if (IsVectorType(leftExprType))
{
const VectorType& leftType = std::get<VectorType>(leftExprType);
switch (clone->op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = PrimitiveType::Boolean;
break;
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = leftExprType;
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType.type, rightExprType);
clone->cachedExpressionType = leftExprType;
}
else if (IsVectorType(rightExprType))
{
TypeMustMatch(leftType, rightExprType);
clone->cachedExpressionType = rightExprType;
}
else
throw AstError{ "incompatible types" };
break;
}
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
throw AstError{ "logical and/or are only supported on booleans" };
}
}
clone->cachedExpressionType = ValidateBinaryOp(clone->op, clone->left, clone->right);
return clone;
}
@@ -1116,7 +959,7 @@ namespace Nz::ShaderAst
return &*it;
}
Expression& SanitizeVisitor::MandatoryExpr(ExpressionPtr& node)
Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node)
{
if (!node)
throw AstError{ "Invalid expression" };
@@ -1124,7 +967,7 @@ namespace Nz::ShaderAst
return *node;
}
Statement& SanitizeVisitor::MandatoryStatement(StatementPtr& node)
Statement& SanitizeVisitor::MandatoryStatement(const StatementPtr& node)
{
if (!node)
throw AstError{ "Invalid statement" };
@@ -1640,7 +1483,189 @@ namespace Nz::ShaderAst
}
}
void SanitizeVisitor::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr)
{
const ExpressionType& leftExprType = GetExpressionType(MandatoryExpr(leftExpr));
const ExpressionType& rightExprType = GetExpressionType(MandatoryExpr(rightExpr));
if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
if (IsPrimitiveType(leftExprType))
{
PrimitiveType leftType = std::get<PrimitiveType>(leftExprType);
switch (op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
{
if (leftType == PrimitiveType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
TypeMustMatch(leftExpr, rightExpr);
return PrimitiveType::Boolean;
}
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(leftExpr, rightExpr);
return leftExprType;
case BinaryType::Multiply:
case BinaryType::Divide:
{
switch (leftType)
{
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
{
if (IsMatrixType(rightExprType))
{
TypeMustMatch(leftType, std::get<MatrixType>(rightExprType).type);
return rightExprType;
}
else if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType, rightExprType);
return leftExprType;
}
else if (IsVectorType(rightExprType))
{
TypeMustMatch(leftType, std::get<VectorType>(rightExprType).type);
return rightExprType;
}
else
throw AstError{ "incompatible types" };
break;
}
case PrimitiveType::Boolean:
throw AstError{ "this operation is not supported for booleans" };
default:
throw AstError{ "incompatible types" };
}
}
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
{
if (leftType != PrimitiveType::Boolean)
throw AstError{ "logical and/or are only supported on booleans" };
TypeMustMatch(leftExpr, rightExpr);
return PrimitiveType::Boolean;
}
}
}
else if (IsMatrixType(leftExprType))
{
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
switch (op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
TypeMustMatch(leftExpr, rightExpr);
return PrimitiveType::Boolean;
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(leftExpr, rightExpr);
return leftExprType;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsMatrixType(rightExprType))
{
TypeMustMatch(leftExprType, rightExprType);
return leftExprType; //< FIXME
}
else if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType.type, rightExprType);
return leftExprType;
}
else if (IsVectorType(rightExprType))
{
const VectorType& rightType = std::get<VectorType>(rightExprType);
TypeMustMatch(leftType.type, rightType.type);
if (leftType.columnCount != rightType.componentCount)
throw AstError{ "incompatible types" };
return rightExprType;
}
else
throw AstError{ "incompatible types" };
}
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
throw AstError{ "logical and/or are only supported on booleans" };
}
}
else if (IsVectorType(leftExprType))
{
const VectorType& leftType = std::get<VectorType>(leftExprType);
switch (op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
TypeMustMatch(leftExpr, rightExpr);
return PrimitiveType::Boolean;
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(leftExpr, rightExpr);
return leftExprType;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType.type, rightExprType);
return leftExprType;
}
else if (IsVectorType(rightExprType))
{
TypeMustMatch(leftType, rightExprType);
return rightExprType;
}
else
throw AstError{ "incompatible types" };
break;
}
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
throw AstError{ "logical and/or are only supported on booleans" };
}
}
throw AstError{ "internal error: unchecked operation" };
}
void SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right)
{
return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right));
}