Shader: Fix unary plus/minus on vector

This commit is contained in:
Jérôme Leclercq 2021-05-19 20:30:29 +02:00
parent 2d3652bba1
commit 0339ea346f
2 changed files with 19 additions and 9 deletions

View File

@ -597,16 +597,12 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<UnaryExpression>(AstCloner::Clone(node)); auto clone = static_unique_pointer_cast<UnaryExpression>(AstCloner::Clone(node));
const ExpressionType& exprType = GetExpressionType(MandatoryExpr(clone->expression)); const ExpressionType& exprType = GetExpressionType(MandatoryExpr(clone->expression));
if (!IsPrimitiveType(exprType))
throw AstError{ "unary expression operand type does not support unary operation" };
PrimitiveType primitiveType = std::get<PrimitiveType>(exprType);
switch (node.op) switch (node.op)
{ {
case UnaryType::LogicalNot: case UnaryType::LogicalNot:
{ {
if (primitiveType != PrimitiveType::Boolean) if (exprType != ExpressionType(PrimitiveType::Boolean))
throw AstError{ "logical not is only supported on booleans" }; throw AstError{ "logical not is only supported on booleans" };
break; break;
@ -615,14 +611,22 @@ namespace Nz::ShaderAst
case UnaryType::Minus: case UnaryType::Minus:
case UnaryType::Plus: case UnaryType::Plus:
{ {
if (primitiveType != PrimitiveType::Float32 && primitiveType != PrimitiveType::Int32 && primitiveType != PrimitiveType::UInt32) ShaderAst::PrimitiveType basicType;
if (IsPrimitiveType(exprType))
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
else if (IsVectorType(exprType))
basicType = std::get<ShaderAst::VectorType>(exprType).type;
else
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };
if (basicType != PrimitiveType::Float32 && basicType != PrimitiveType::Int32 && basicType != PrimitiveType::UInt32)
throw AstError{ "plus and minus unary expressions are only supported on floating points and integers types" }; throw AstError{ "plus and minus unary expressions are only supported on floating points and integers types" };
break; break;
} }
} }
clone->cachedExpressionType = primitiveType; clone->cachedExpressionType = exprType;
return clone; return clone;
} }

View File

@ -839,11 +839,17 @@ namespace Nz
case ShaderAst::UnaryType::Minus: case ShaderAst::UnaryType::Minus:
{ {
assert(IsPrimitiveType(exprType)); ShaderAst::PrimitiveType basicType;
if (IsPrimitiveType(exprType))
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
else if (IsVectorType(exprType))
basicType = std::get<ShaderAst::VectorType>(exprType).type;
else
throw std::runtime_error("unexpected expression type");
UInt32 resultId = m_writer.AllocateResultId(); UInt32 resultId = m_writer.AllocateResultId();
switch (std::get<ShaderAst::PrimitiveType>(resultType)) switch (basicType)
{ {
case ShaderAst::PrimitiveType::Float32: case ShaderAst::PrimitiveType::Float32:
m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(resultType), resultId, operand); m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(resultType), resultId, operand);