diff --git a/include/Nazara/Shader/Ast/AstOptimizer.hpp b/include/Nazara/Shader/Ast/AstOptimizer.hpp index 070dbee48..30466c8d9 100644 --- a/include/Nazara/Shader/Ast/AstOptimizer.hpp +++ b/include/Nazara/Shader/Ast/AstOptimizer.hpp @@ -48,9 +48,9 @@ namespace Nz::ShaderAst StatementPtr Clone(BranchStatement& node) override; StatementPtr Clone(ConditionalStatement& node) override; - template ExpressionPtr PropagateBinaryConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs); - template ExpressionPtr PropagateSingleValueCast(std::unique_ptr&& operand); - template ExpressionPtr PropagateUnaryConstant(std::unique_ptr&& operand); + template ExpressionPtr PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs); + template ExpressionPtr PropagateSingleValueCast(const ConstantValueExpression& operand); + template ExpressionPtr PropagateUnaryConstant(const ConstantValueExpression& operand); template ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2); template ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3); template ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4); diff --git a/src/Nazara/Shader/Ast/AstOptimizer.cpp b/src/Nazara/Shader/Ast/AstOptimizer.cpp index 7c7bed29f..3ae59e30d 100644 --- a/src/Nazara/Shader/Ast/AstOptimizer.cpp +++ b/src/Nazara/Shader/Ast/AstOptimizer.cpp @@ -575,58 +575,58 @@ namespace Nz::ShaderAst if (lhs->GetType() == NodeType::ConstantValueExpression && rhs->GetType() == NodeType::ConstantValueExpression) { - auto lhsConstant = static_unique_pointer_cast(std::move(lhs)); - auto rhsConstant = static_unique_pointer_cast(std::move(rhs)); + const ConstantValueExpression& lhsConstant = static_cast(*lhs); + const ConstantValueExpression& rhsConstant = static_cast(*rhs); ExpressionPtr optimized; switch (node.op) { case BinaryType::Add: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::Subtract: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::Multiply: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::Divide: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompEq: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompGe: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompGt: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompLe: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompLt: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompNe: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::LogicalAnd: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::LogicalOr: - optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; } @@ -659,14 +659,14 @@ namespace Nz::ShaderAst { if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression) { - auto constantExpr = static_unique_pointer_cast(std::move(expressions.front())); + const ConstantValueExpression& constantExpr = static_cast(*expressions.front()); switch (std::get(node.targetType)) { - case PrimitiveType::Boolean: optimized = PropagateSingleValueCast(std::move(constantExpr)); break; - case PrimitiveType::Float32: optimized = PropagateSingleValueCast(std::move(constantExpr)); break; - case PrimitiveType::Int32: optimized = PropagateSingleValueCast(std::move(constantExpr)); break; - case PrimitiveType::UInt32: optimized = PropagateSingleValueCast(std::move(constantExpr)); break; + case PrimitiveType::Boolean: optimized = PropagateSingleValueCast(constantExpr); break; + case PrimitiveType::Float32: optimized = PropagateSingleValueCast(constantExpr); break; + case PrimitiveType::Int32: optimized = PropagateSingleValueCast(constantExpr); break; + case PrimitiveType::UInt32: optimized = PropagateSingleValueCast(constantExpr); break; } } } @@ -770,11 +770,9 @@ namespace Nz::ShaderAst { auto& constant = static_cast(*cond); - assert(constant.cachedExpressionType); - const ExpressionType& constantType = constant.cachedExpressionType.value(); - - assert(IsPrimitiveType(constantType)); - assert(std::get(constantType) == PrimitiveType::Boolean); + const ExpressionType& constantType = GetExpressionType(constant); + if (!IsPrimitiveType(constantType) || std::get(constantType) != PrimitiveType::Boolean) + continue; bool cValue = std::get(constant.value); if (!cValue) @@ -856,21 +854,21 @@ namespace Nz::ShaderAst if (expr->GetType() == NodeType::ConstantValueExpression) { - auto constantExpr = static_unique_pointer_cast(std::move(expr)); + const ConstantValueExpression& constantExpr = static_cast(*expr); ExpressionPtr optimized; switch (node.op) { case UnaryType::LogicalNot: - optimized = PropagateUnaryConstant(std::move(constantExpr)); + optimized = PropagateUnaryConstant(constantExpr); break; case UnaryType::Minus: - optimized = PropagateUnaryConstant(std::move(constantExpr)); + optimized = PropagateUnaryConstant(constantExpr); break; case UnaryType::Plus: - optimized = PropagateUnaryConstant(std::move(constantExpr)); + optimized = PropagateUnaryConstant(constantExpr); break; } @@ -909,7 +907,7 @@ namespace Nz::ShaderAst } template - ExpressionPtr AstOptimizer::PropagateBinaryConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs) + ExpressionPtr AstOptimizer::PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs) { std::unique_ptr optimized; std::visit([&](auto&& arg1) @@ -928,8 +926,8 @@ namespace Nz::ShaderAst optimized = Op{}(arg1, arg2); } - }, rhs->value); - }, lhs->value); + }, rhs.value); + }, lhs.value); if (optimized) optimized->cachedExpressionType = GetExpressionType(optimized->value); @@ -938,7 +936,7 @@ namespace Nz::ShaderAst } template - ExpressionPtr AstOptimizer::PropagateSingleValueCast(std::unique_ptr&& operand) + ExpressionPtr AstOptimizer::PropagateSingleValueCast(const ConstantValueExpression& operand) { std::unique_ptr optimized; @@ -953,13 +951,13 @@ namespace Nz::ShaderAst if constexpr (is_complete_v) optimized = Op{}(arg); } - }, operand->value); + }, operand.value); return optimized; } template - ExpressionPtr AstOptimizer::PropagateUnaryConstant(std::unique_ptr&& operand) + ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand) { std::unique_ptr optimized; std::visit([&](auto&& arg) @@ -973,7 +971,7 @@ namespace Nz::ShaderAst if constexpr (is_complete_v) optimized = Op{}(arg); } - }, operand->value); + }, operand.value); if (optimized) optimized->cachedExpressionType = GetExpressionType(optimized->value); diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 575e7a3f4..8e976679e 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -1168,7 +1168,7 @@ namespace Nz::ShaderAst std::unique_ptr SanitizeVisitor::Optimize(T& node) { AstOptimizer::Options optimizerOptions; - optimizerOptions.constantQueryCallback = [this](std::size_t constantId) + optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue& { assert(constantId < m_context->constantValues.size()); return m_context->constantValues[constantId];