Shader: Fix optimization issues

Moving the unique_ptrs but failing to optimize the expression could cause issues
Also the constant query callback defaulted to ConstantValue (without const ref), thanks auto.
This commit is contained in:
Jérôme Leclercq 2021-07-11 11:17:48 +02:00
parent ae364934bb
commit 863fb3ea7e
3 changed files with 37 additions and 39 deletions

View File

@ -48,9 +48,9 @@ namespace Nz::ShaderAst
StatementPtr Clone(BranchStatement& node) override; StatementPtr Clone(BranchStatement& node) override;
StatementPtr Clone(ConditionalStatement& node) override; StatementPtr Clone(ConditionalStatement& node) override;
template<BinaryType Type> ExpressionPtr PropagateBinaryConstant(std::unique_ptr<ConstantValueExpression>&& lhs, std::unique_ptr<ConstantValueExpression>&& rhs); template<BinaryType Type> ExpressionPtr PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs);
template<typename TargetType> ExpressionPtr PropagateSingleValueCast(std::unique_ptr<ConstantValueExpression>&& operand); template<typename TargetType> ExpressionPtr PropagateSingleValueCast(const ConstantValueExpression& operand);
template<UnaryType Type> ExpressionPtr PropagateUnaryConstant(std::unique_ptr<ConstantValueExpression>&& operand); template<UnaryType Type> ExpressionPtr PropagateUnaryConstant(const ConstantValueExpression& operand);
template<typename TargetType> ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2); template<typename TargetType> ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2);
template<typename TargetType> ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3); template<typename TargetType> ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3);
template<typename TargetType> ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4); template<typename TargetType> ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4);

View File

@ -575,58 +575,58 @@ namespace Nz::ShaderAst
if (lhs->GetType() == NodeType::ConstantValueExpression && rhs->GetType() == NodeType::ConstantValueExpression) if (lhs->GetType() == NodeType::ConstantValueExpression && rhs->GetType() == NodeType::ConstantValueExpression)
{ {
auto lhsConstant = static_unique_pointer_cast<ConstantValueExpression>(std::move(lhs)); const ConstantValueExpression& lhsConstant = static_cast<const ConstantValueExpression&>(*lhs);
auto rhsConstant = static_unique_pointer_cast<ConstantValueExpression>(std::move(rhs)); const ConstantValueExpression& rhsConstant = static_cast<const ConstantValueExpression&>(*rhs);
ExpressionPtr optimized; ExpressionPtr optimized;
switch (node.op) switch (node.op)
{ {
case BinaryType::Add: case BinaryType::Add:
optimized = PropagateBinaryConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::Add>(lhsConstant, rhsConstant);
break; break;
case BinaryType::Subtract: case BinaryType::Subtract:
optimized = PropagateBinaryConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::Subtract>(lhsConstant, rhsConstant);
break; break;
case BinaryType::Multiply: case BinaryType::Multiply:
optimized = PropagateBinaryConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::Multiply>(lhsConstant, rhsConstant);
break; break;
case BinaryType::Divide: case BinaryType::Divide:
optimized = PropagateBinaryConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::Divide>(lhsConstant, rhsConstant);
break; break;
case BinaryType::CompEq: case BinaryType::CompEq:
optimized = PropagateBinaryConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::CompEq>(lhsConstant, rhsConstant);
break; break;
case BinaryType::CompGe: case BinaryType::CompGe:
optimized = PropagateBinaryConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::CompGe>(lhsConstant, rhsConstant);
break; break;
case BinaryType::CompGt: case BinaryType::CompGt:
optimized = PropagateBinaryConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::CompGt>(lhsConstant, rhsConstant);
break; break;
case BinaryType::CompLe: case BinaryType::CompLe:
optimized = PropagateBinaryConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::CompLe>(lhsConstant, rhsConstant);
break; break;
case BinaryType::CompLt: case BinaryType::CompLt:
optimized = PropagateBinaryConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::CompLt>(lhsConstant, rhsConstant);
break; break;
case BinaryType::CompNe: case BinaryType::CompNe:
optimized = PropagateBinaryConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::CompNe>(lhsConstant, rhsConstant);
break; break;
case BinaryType::LogicalAnd: case BinaryType::LogicalAnd:
optimized = PropagateBinaryConstant<BinaryType::LogicalAnd>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::LogicalAnd>(lhsConstant, rhsConstant);
break; break;
case BinaryType::LogicalOr: case BinaryType::LogicalOr:
optimized = PropagateBinaryConstant<BinaryType::LogicalOr>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::LogicalOr>(lhsConstant, rhsConstant);
break; break;
} }
@ -659,14 +659,14 @@ namespace Nz::ShaderAst
{ {
if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression) if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression)
{ {
auto constantExpr = static_unique_pointer_cast<ConstantValueExpression>(std::move(expressions.front())); const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expressions.front());
switch (std::get<PrimitiveType>(node.targetType)) switch (std::get<PrimitiveType>(node.targetType))
{ {
case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(std::move(constantExpr)); break; case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(constantExpr); break;
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(std::move(constantExpr)); break; case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(constantExpr); break;
case PrimitiveType::Int32: optimized = PropagateSingleValueCast<Int32>(std::move(constantExpr)); break; case PrimitiveType::Int32: optimized = PropagateSingleValueCast<Int32>(constantExpr); break;
case PrimitiveType::UInt32: optimized = PropagateSingleValueCast<UInt32>(std::move(constantExpr)); break; case PrimitiveType::UInt32: optimized = PropagateSingleValueCast<UInt32>(constantExpr); break;
} }
} }
} }
@ -770,11 +770,9 @@ namespace Nz::ShaderAst
{ {
auto& constant = static_cast<ConstantValueExpression&>(*cond); auto& constant = static_cast<ConstantValueExpression&>(*cond);
assert(constant.cachedExpressionType); const ExpressionType& constantType = GetExpressionType(constant);
const ExpressionType& constantType = constant.cachedExpressionType.value(); if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean)
continue;
assert(IsPrimitiveType(constantType));
assert(std::get<PrimitiveType>(constantType) == PrimitiveType::Boolean);
bool cValue = std::get<bool>(constant.value); bool cValue = std::get<bool>(constant.value);
if (!cValue) if (!cValue)
@ -856,21 +854,21 @@ namespace Nz::ShaderAst
if (expr->GetType() == NodeType::ConstantValueExpression) if (expr->GetType() == NodeType::ConstantValueExpression)
{ {
auto constantExpr = static_unique_pointer_cast<ConstantValueExpression>(std::move(expr)); const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expr);
ExpressionPtr optimized; ExpressionPtr optimized;
switch (node.op) switch (node.op)
{ {
case UnaryType::LogicalNot: case UnaryType::LogicalNot:
optimized = PropagateUnaryConstant<UnaryType::LogicalNot>(std::move(constantExpr)); optimized = PropagateUnaryConstant<UnaryType::LogicalNot>(constantExpr);
break; break;
case UnaryType::Minus: case UnaryType::Minus:
optimized = PropagateUnaryConstant<UnaryType::Minus>(std::move(constantExpr)); optimized = PropagateUnaryConstant<UnaryType::Minus>(constantExpr);
break; break;
case UnaryType::Plus: case UnaryType::Plus:
optimized = PropagateUnaryConstant<UnaryType::Plus>(std::move(constantExpr)); optimized = PropagateUnaryConstant<UnaryType::Plus>(constantExpr);
break; break;
} }
@ -909,7 +907,7 @@ namespace Nz::ShaderAst
} }
template<BinaryType Type> template<BinaryType Type>
ExpressionPtr AstOptimizer::PropagateBinaryConstant(std::unique_ptr<ConstantValueExpression>&& lhs, std::unique_ptr<ConstantValueExpression>&& rhs) ExpressionPtr AstOptimizer::PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs)
{ {
std::unique_ptr<ConstantValueExpression> optimized; std::unique_ptr<ConstantValueExpression> optimized;
std::visit([&](auto&& arg1) std::visit([&](auto&& arg1)
@ -928,8 +926,8 @@ namespace Nz::ShaderAst
optimized = Op{}(arg1, arg2); optimized = Op{}(arg1, arg2);
} }
}, rhs->value); }, rhs.value);
}, lhs->value); }, lhs.value);
if (optimized) if (optimized)
optimized->cachedExpressionType = GetExpressionType(optimized->value); optimized->cachedExpressionType = GetExpressionType(optimized->value);
@ -938,7 +936,7 @@ namespace Nz::ShaderAst
} }
template<typename TargetType> template<typename TargetType>
ExpressionPtr AstOptimizer::PropagateSingleValueCast(std::unique_ptr<ConstantValueExpression>&& operand) ExpressionPtr AstOptimizer::PropagateSingleValueCast(const ConstantValueExpression& operand)
{ {
std::unique_ptr<ConstantValueExpression> optimized; std::unique_ptr<ConstantValueExpression> optimized;
@ -953,13 +951,13 @@ namespace Nz::ShaderAst
if constexpr (is_complete_v<Op>) if constexpr (is_complete_v<Op>)
optimized = Op{}(arg); optimized = Op{}(arg);
} }
}, operand->value); }, operand.value);
return optimized; return optimized;
} }
template<UnaryType Type> template<UnaryType Type>
ExpressionPtr AstOptimizer::PropagateUnaryConstant(std::unique_ptr<ConstantValueExpression>&& operand) ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand)
{ {
std::unique_ptr<ConstantValueExpression> optimized; std::unique_ptr<ConstantValueExpression> optimized;
std::visit([&](auto&& arg) std::visit([&](auto&& arg)
@ -973,7 +971,7 @@ namespace Nz::ShaderAst
if constexpr (is_complete_v<Op>) if constexpr (is_complete_v<Op>)
optimized = Op{}(arg); optimized = Op{}(arg);
} }
}, operand->value); }, operand.value);
if (optimized) if (optimized)
optimized->cachedExpressionType = GetExpressionType(optimized->value); optimized->cachedExpressionType = GetExpressionType(optimized->value);

View File

@ -1168,7 +1168,7 @@ namespace Nz::ShaderAst
std::unique_ptr<T> SanitizeVisitor::Optimize(T& node) std::unique_ptr<T> SanitizeVisitor::Optimize(T& node)
{ {
AstOptimizer::Options optimizerOptions; AstOptimizer::Options optimizerOptions;
optimizerOptions.constantQueryCallback = [this](std::size_t constantId) optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue&
{ {
assert(constantId < m_context->constantValues.size()); assert(constantId < m_context->constantValues.size());
return m_context->constantValues[constantId]; return m_context->constantValues[constantId];