Shader: Fix optimization issues
Moving the unique_ptrs but failing to optimize the expression could cause issues Also the constant query callback defaulted to ConstantValue (without const ref), thanks auto.
This commit is contained in:
parent
ae364934bb
commit
863fb3ea7e
|
|
@ -48,9 +48,9 @@ namespace Nz::ShaderAst
|
|||
StatementPtr Clone(BranchStatement& node) override;
|
||||
StatementPtr Clone(ConditionalStatement& node) override;
|
||||
|
||||
template<BinaryType Type> ExpressionPtr PropagateBinaryConstant(std::unique_ptr<ConstantValueExpression>&& lhs, std::unique_ptr<ConstantValueExpression>&& rhs);
|
||||
template<typename TargetType> ExpressionPtr PropagateSingleValueCast(std::unique_ptr<ConstantValueExpression>&& operand);
|
||||
template<UnaryType Type> ExpressionPtr PropagateUnaryConstant(std::unique_ptr<ConstantValueExpression>&& operand);
|
||||
template<BinaryType Type> ExpressionPtr PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs);
|
||||
template<typename TargetType> ExpressionPtr PropagateSingleValueCast(const ConstantValueExpression& operand);
|
||||
template<UnaryType Type> ExpressionPtr PropagateUnaryConstant(const ConstantValueExpression& operand);
|
||||
template<typename TargetType> ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2);
|
||||
template<typename TargetType> ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3);
|
||||
template<typename TargetType> ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4);
|
||||
|
|
|
|||
|
|
@ -575,58 +575,58 @@ namespace Nz::ShaderAst
|
|||
|
||||
if (lhs->GetType() == NodeType::ConstantValueExpression && rhs->GetType() == NodeType::ConstantValueExpression)
|
||||
{
|
||||
auto lhsConstant = static_unique_pointer_cast<ConstantValueExpression>(std::move(lhs));
|
||||
auto rhsConstant = static_unique_pointer_cast<ConstantValueExpression>(std::move(rhs));
|
||||
const ConstantValueExpression& lhsConstant = static_cast<const ConstantValueExpression&>(*lhs);
|
||||
const ConstantValueExpression& rhsConstant = static_cast<const ConstantValueExpression&>(*rhs);
|
||||
|
||||
ExpressionPtr optimized;
|
||||
switch (node.op)
|
||||
{
|
||||
case BinaryType::Add:
|
||||
optimized = PropagateBinaryConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::Add>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::Subtract:
|
||||
optimized = PropagateBinaryConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::Subtract>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
optimized = PropagateBinaryConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::Multiply>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::Divide:
|
||||
optimized = PropagateBinaryConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::Divide>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::CompEq:
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompEq>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::CompGe:
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompGe>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::CompGt:
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompGt>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::CompLe:
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompLe>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::CompLt:
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompLt>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::CompNe:
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::CompNe>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::LogicalAnd:
|
||||
optimized = PropagateBinaryConstant<BinaryType::LogicalAnd>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::LogicalAnd>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
|
||||
case BinaryType::LogicalOr:
|
||||
optimized = PropagateBinaryConstant<BinaryType::LogicalOr>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateBinaryConstant<BinaryType::LogicalOr>(lhsConstant, rhsConstant);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -659,14 +659,14 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression)
|
||||
{
|
||||
auto constantExpr = static_unique_pointer_cast<ConstantValueExpression>(std::move(expressions.front()));
|
||||
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expressions.front());
|
||||
|
||||
switch (std::get<PrimitiveType>(node.targetType))
|
||||
{
|
||||
case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(std::move(constantExpr)); break;
|
||||
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(std::move(constantExpr)); break;
|
||||
case PrimitiveType::Int32: optimized = PropagateSingleValueCast<Int32>(std::move(constantExpr)); break;
|
||||
case PrimitiveType::UInt32: optimized = PropagateSingleValueCast<UInt32>(std::move(constantExpr)); break;
|
||||
case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(constantExpr); break;
|
||||
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(constantExpr); break;
|
||||
case PrimitiveType::Int32: optimized = PropagateSingleValueCast<Int32>(constantExpr); break;
|
||||
case PrimitiveType::UInt32: optimized = PropagateSingleValueCast<UInt32>(constantExpr); break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -770,11 +770,9 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
auto& constant = static_cast<ConstantValueExpression&>(*cond);
|
||||
|
||||
assert(constant.cachedExpressionType);
|
||||
const ExpressionType& constantType = constant.cachedExpressionType.value();
|
||||
|
||||
assert(IsPrimitiveType(constantType));
|
||||
assert(std::get<PrimitiveType>(constantType) == PrimitiveType::Boolean);
|
||||
const ExpressionType& constantType = GetExpressionType(constant);
|
||||
if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean)
|
||||
continue;
|
||||
|
||||
bool cValue = std::get<bool>(constant.value);
|
||||
if (!cValue)
|
||||
|
|
@ -856,21 +854,21 @@ namespace Nz::ShaderAst
|
|||
|
||||
if (expr->GetType() == NodeType::ConstantValueExpression)
|
||||
{
|
||||
auto constantExpr = static_unique_pointer_cast<ConstantValueExpression>(std::move(expr));
|
||||
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expr);
|
||||
|
||||
ExpressionPtr optimized;
|
||||
switch (node.op)
|
||||
{
|
||||
case UnaryType::LogicalNot:
|
||||
optimized = PropagateUnaryConstant<UnaryType::LogicalNot>(std::move(constantExpr));
|
||||
optimized = PropagateUnaryConstant<UnaryType::LogicalNot>(constantExpr);
|
||||
break;
|
||||
|
||||
case UnaryType::Minus:
|
||||
optimized = PropagateUnaryConstant<UnaryType::Minus>(std::move(constantExpr));
|
||||
optimized = PropagateUnaryConstant<UnaryType::Minus>(constantExpr);
|
||||
break;
|
||||
|
||||
case UnaryType::Plus:
|
||||
optimized = PropagateUnaryConstant<UnaryType::Plus>(std::move(constantExpr));
|
||||
optimized = PropagateUnaryConstant<UnaryType::Plus>(constantExpr);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -909,7 +907,7 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
|
||||
template<BinaryType Type>
|
||||
ExpressionPtr AstOptimizer::PropagateBinaryConstant(std::unique_ptr<ConstantValueExpression>&& lhs, std::unique_ptr<ConstantValueExpression>&& rhs)
|
||||
ExpressionPtr AstOptimizer::PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs)
|
||||
{
|
||||
std::unique_ptr<ConstantValueExpression> optimized;
|
||||
std::visit([&](auto&& arg1)
|
||||
|
|
@ -928,8 +926,8 @@ namespace Nz::ShaderAst
|
|||
optimized = Op{}(arg1, arg2);
|
||||
}
|
||||
|
||||
}, rhs->value);
|
||||
}, lhs->value);
|
||||
}, rhs.value);
|
||||
}, lhs.value);
|
||||
|
||||
if (optimized)
|
||||
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
||||
|
|
@ -938,7 +936,7 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
|
||||
template<typename TargetType>
|
||||
ExpressionPtr AstOptimizer::PropagateSingleValueCast(std::unique_ptr<ConstantValueExpression>&& operand)
|
||||
ExpressionPtr AstOptimizer::PropagateSingleValueCast(const ConstantValueExpression& operand)
|
||||
{
|
||||
std::unique_ptr<ConstantValueExpression> optimized;
|
||||
|
||||
|
|
@ -953,13 +951,13 @@ namespace Nz::ShaderAst
|
|||
if constexpr (is_complete_v<Op>)
|
||||
optimized = Op{}(arg);
|
||||
}
|
||||
}, operand->value);
|
||||
}, operand.value);
|
||||
|
||||
return optimized;
|
||||
}
|
||||
|
||||
template<UnaryType Type>
|
||||
ExpressionPtr AstOptimizer::PropagateUnaryConstant(std::unique_ptr<ConstantValueExpression>&& operand)
|
||||
ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand)
|
||||
{
|
||||
std::unique_ptr<ConstantValueExpression> optimized;
|
||||
std::visit([&](auto&& arg)
|
||||
|
|
@ -973,7 +971,7 @@ namespace Nz::ShaderAst
|
|||
if constexpr (is_complete_v<Op>)
|
||||
optimized = Op{}(arg);
|
||||
}
|
||||
}, operand->value);
|
||||
}, operand.value);
|
||||
|
||||
if (optimized)
|
||||
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
||||
|
|
|
|||
|
|
@ -1168,7 +1168,7 @@ namespace Nz::ShaderAst
|
|||
std::unique_ptr<T> SanitizeVisitor::Optimize(T& node)
|
||||
{
|
||||
AstOptimizer::Options optimizerOptions;
|
||||
optimizerOptions.constantQueryCallback = [this](std::size_t constantId)
|
||||
optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue&
|
||||
{
|
||||
assert(constantId < m_context->constantValues.size());
|
||||
return m_context->constantValues[constantId];
|
||||
|
|
|
|||
Loading…
Reference in New Issue