From 0339ea346f3c447a20cb3c3e0db585c9ff813adc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Wed, 19 May 2021 20:30:29 +0200 Subject: [PATCH] Shader: Fix unary plus/minus on vector --- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 18 +++++++++++------- src/Nazara/Shader/SpirvAstVisitor.cpp | 10 ++++++++-- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 5e96ddb0b..2eeba481d 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -597,16 +597,12 @@ namespace Nz::ShaderAst auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); const ExpressionType& exprType = GetExpressionType(MandatoryExpr(clone->expression)); - if (!IsPrimitiveType(exprType)) - throw AstError{ "unary expression operand type does not support unary operation" }; - - PrimitiveType primitiveType = std::get(exprType); switch (node.op) { case UnaryType::LogicalNot: { - if (primitiveType != PrimitiveType::Boolean) + if (exprType != ExpressionType(PrimitiveType::Boolean)) throw AstError{ "logical not is only supported on booleans" }; break; @@ -615,14 +611,22 @@ namespace Nz::ShaderAst case UnaryType::Minus: case UnaryType::Plus: { - if (primitiveType != PrimitiveType::Float32 && primitiveType != PrimitiveType::Int32 && primitiveType != PrimitiveType::UInt32) + ShaderAst::PrimitiveType basicType; + if (IsPrimitiveType(exprType)) + basicType = std::get(exprType); + else if (IsVectorType(exprType)) + basicType = std::get(exprType).type; + else + throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" }; + + if (basicType != PrimitiveType::Float32 && basicType != PrimitiveType::Int32 && basicType != PrimitiveType::UInt32) throw AstError{ "plus and minus unary expressions are only supported on floating points and integers types" }; break; } } - clone->cachedExpressionType = primitiveType; + clone->cachedExpressionType = exprType; return clone; } diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index f731cb7c2..4b6f774af 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -839,11 +839,17 @@ namespace Nz case ShaderAst::UnaryType::Minus: { - assert(IsPrimitiveType(exprType)); + ShaderAst::PrimitiveType basicType; + if (IsPrimitiveType(exprType)) + basicType = std::get(exprType); + else if (IsVectorType(exprType)) + basicType = std::get(exprType).type; + else + throw std::runtime_error("unexpected expression type"); UInt32 resultId = m_writer.AllocateResultId(); - switch (std::get(resultType)) + switch (basicType) { case ShaderAst::PrimitiveType::Float32: m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(resultType), resultId, operand);