Shader: Fix unary plus/minus on vector

This commit is contained in:
Jérôme Leclercq 2021-05-19 20:30:29 +02:00
parent 2d3652bba1
commit 0339ea346f
2 changed files with 19 additions and 9 deletions

View File

@ -597,16 +597,12 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<UnaryExpression>(AstCloner::Clone(node));
const ExpressionType& exprType = GetExpressionType(MandatoryExpr(clone->expression));
if (!IsPrimitiveType(exprType))
throw AstError{ "unary expression operand type does not support unary operation" };
PrimitiveType primitiveType = std::get<PrimitiveType>(exprType);
switch (node.op)
{
case UnaryType::LogicalNot:
{
if (primitiveType != PrimitiveType::Boolean)
if (exprType != ExpressionType(PrimitiveType::Boolean))
throw AstError{ "logical not is only supported on booleans" };
break;
@ -615,14 +611,22 @@ namespace Nz::ShaderAst
case UnaryType::Minus:
case UnaryType::Plus:
{
if (primitiveType != PrimitiveType::Float32 && primitiveType != PrimitiveType::Int32 && primitiveType != PrimitiveType::UInt32)
ShaderAst::PrimitiveType basicType;
if (IsPrimitiveType(exprType))
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
else if (IsVectorType(exprType))
basicType = std::get<ShaderAst::VectorType>(exprType).type;
else
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };
if (basicType != PrimitiveType::Float32 && basicType != PrimitiveType::Int32 && basicType != PrimitiveType::UInt32)
throw AstError{ "plus and minus unary expressions are only supported on floating points and integers types" };
break;
}
}
clone->cachedExpressionType = primitiveType;
clone->cachedExpressionType = exprType;
return clone;
}

View File

@ -839,11 +839,17 @@ namespace Nz
case ShaderAst::UnaryType::Minus:
{
assert(IsPrimitiveType(exprType));
ShaderAst::PrimitiveType basicType;
if (IsPrimitiveType(exprType))
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
else if (IsVectorType(exprType))
basicType = std::get<ShaderAst::VectorType>(exprType).type;
else
throw std::runtime_error("unexpected expression type");
UInt32 resultId = m_writer.AllocateResultId();
switch (std::get<ShaderAst::PrimitiveType>(resultType))
switch (basicType)
{
case ShaderAst::PrimitiveType::Float32:
m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(resultType), resultId, operand);