Shader: Fix optimization issues

Moving the unique_ptrs but failing to optimize the expression could cause issues
Also the constant query callback defaulted to ConstantValue (without const ref), thanks auto.
This commit is contained in:
Jérôme Leclercq 2021-07-11 11:17:48 +02:00
parent ae364934bb
commit 863fb3ea7e
3 changed files with 37 additions and 39 deletions

View File

@ -48,9 +48,9 @@ namespace Nz::ShaderAst
StatementPtr Clone(BranchStatement& node) override;
StatementPtr Clone(ConditionalStatement& node) override;
template<BinaryType Type> ExpressionPtr PropagateBinaryConstant(std::unique_ptr<ConstantValueExpression>&& lhs, std::unique_ptr<ConstantValueExpression>&& rhs);
template<typename TargetType> ExpressionPtr PropagateSingleValueCast(std::unique_ptr<ConstantValueExpression>&& operand);
template<UnaryType Type> ExpressionPtr PropagateUnaryConstant(std::unique_ptr<ConstantValueExpression>&& operand);
template<BinaryType Type> ExpressionPtr PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs);
template<typename TargetType> ExpressionPtr PropagateSingleValueCast(const ConstantValueExpression& operand);
template<UnaryType Type> ExpressionPtr PropagateUnaryConstant(const ConstantValueExpression& operand);
template<typename TargetType> ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2);
template<typename TargetType> ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3);
template<typename TargetType> ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4);

View File

@ -575,58 +575,58 @@ namespace Nz::ShaderAst
if (lhs->GetType() == NodeType::ConstantValueExpression && rhs->GetType() == NodeType::ConstantValueExpression)
{
auto lhsConstant = static_unique_pointer_cast<ConstantValueExpression>(std::move(lhs));
auto rhsConstant = static_unique_pointer_cast<ConstantValueExpression>(std::move(rhs));
const ConstantValueExpression& lhsConstant = static_cast<const ConstantValueExpression&>(*lhs);
const ConstantValueExpression& rhsConstant = static_cast<const ConstantValueExpression&>(*rhs);
ExpressionPtr optimized;
switch (node.op)
{
case BinaryType::Add:
optimized = PropagateBinaryConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::Add>(lhsConstant, rhsConstant);
break;
case BinaryType::Subtract:
optimized = PropagateBinaryConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::Subtract>(lhsConstant, rhsConstant);
break;
case BinaryType::Multiply:
optimized = PropagateBinaryConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::Multiply>(lhsConstant, rhsConstant);
break;
case BinaryType::Divide:
optimized = PropagateBinaryConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::Divide>(lhsConstant, rhsConstant);
break;
case BinaryType::CompEq:
optimized = PropagateBinaryConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::CompEq>(lhsConstant, rhsConstant);
break;
case BinaryType::CompGe:
optimized = PropagateBinaryConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::CompGe>(lhsConstant, rhsConstant);
break;
case BinaryType::CompGt:
optimized = PropagateBinaryConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::CompGt>(lhsConstant, rhsConstant);
break;
case BinaryType::CompLe:
optimized = PropagateBinaryConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::CompLe>(lhsConstant, rhsConstant);
break;
case BinaryType::CompLt:
optimized = PropagateBinaryConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::CompLt>(lhsConstant, rhsConstant);
break;
case BinaryType::CompNe:
optimized = PropagateBinaryConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::CompNe>(lhsConstant, rhsConstant);
break;
case BinaryType::LogicalAnd:
optimized = PropagateBinaryConstant<BinaryType::LogicalAnd>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::LogicalAnd>(lhsConstant, rhsConstant);
break;
case BinaryType::LogicalOr:
optimized = PropagateBinaryConstant<BinaryType::LogicalOr>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateBinaryConstant<BinaryType::LogicalOr>(lhsConstant, rhsConstant);
break;
}
@ -659,14 +659,14 @@ namespace Nz::ShaderAst
{
if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression)
{
auto constantExpr = static_unique_pointer_cast<ConstantValueExpression>(std::move(expressions.front()));
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expressions.front());
switch (std::get<PrimitiveType>(node.targetType))
{
case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(std::move(constantExpr)); break;
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(std::move(constantExpr)); break;
case PrimitiveType::Int32: optimized = PropagateSingleValueCast<Int32>(std::move(constantExpr)); break;
case PrimitiveType::UInt32: optimized = PropagateSingleValueCast<UInt32>(std::move(constantExpr)); break;
case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(constantExpr); break;
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(constantExpr); break;
case PrimitiveType::Int32: optimized = PropagateSingleValueCast<Int32>(constantExpr); break;
case PrimitiveType::UInt32: optimized = PropagateSingleValueCast<UInt32>(constantExpr); break;
}
}
}
@ -770,11 +770,9 @@ namespace Nz::ShaderAst
{
auto& constant = static_cast<ConstantValueExpression&>(*cond);
assert(constant.cachedExpressionType);
const ExpressionType& constantType = constant.cachedExpressionType.value();
assert(IsPrimitiveType(constantType));
assert(std::get<PrimitiveType>(constantType) == PrimitiveType::Boolean);
const ExpressionType& constantType = GetExpressionType(constant);
if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean)
continue;
bool cValue = std::get<bool>(constant.value);
if (!cValue)
@ -856,21 +854,21 @@ namespace Nz::ShaderAst
if (expr->GetType() == NodeType::ConstantValueExpression)
{
auto constantExpr = static_unique_pointer_cast<ConstantValueExpression>(std::move(expr));
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expr);
ExpressionPtr optimized;
switch (node.op)
{
case UnaryType::LogicalNot:
optimized = PropagateUnaryConstant<UnaryType::LogicalNot>(std::move(constantExpr));
optimized = PropagateUnaryConstant<UnaryType::LogicalNot>(constantExpr);
break;
case UnaryType::Minus:
optimized = PropagateUnaryConstant<UnaryType::Minus>(std::move(constantExpr));
optimized = PropagateUnaryConstant<UnaryType::Minus>(constantExpr);
break;
case UnaryType::Plus:
optimized = PropagateUnaryConstant<UnaryType::Plus>(std::move(constantExpr));
optimized = PropagateUnaryConstant<UnaryType::Plus>(constantExpr);
break;
}
@ -909,7 +907,7 @@ namespace Nz::ShaderAst
}
template<BinaryType Type>
ExpressionPtr AstOptimizer::PropagateBinaryConstant(std::unique_ptr<ConstantValueExpression>&& lhs, std::unique_ptr<ConstantValueExpression>&& rhs)
ExpressionPtr AstOptimizer::PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs)
{
std::unique_ptr<ConstantValueExpression> optimized;
std::visit([&](auto&& arg1)
@ -928,8 +926,8 @@ namespace Nz::ShaderAst
optimized = Op{}(arg1, arg2);
}
}, rhs->value);
}, lhs->value);
}, rhs.value);
}, lhs.value);
if (optimized)
optimized->cachedExpressionType = GetExpressionType(optimized->value);
@ -938,7 +936,7 @@ namespace Nz::ShaderAst
}
template<typename TargetType>
ExpressionPtr AstOptimizer::PropagateSingleValueCast(std::unique_ptr<ConstantValueExpression>&& operand)
ExpressionPtr AstOptimizer::PropagateSingleValueCast(const ConstantValueExpression& operand)
{
std::unique_ptr<ConstantValueExpression> optimized;
@ -953,13 +951,13 @@ namespace Nz::ShaderAst
if constexpr (is_complete_v<Op>)
optimized = Op{}(arg);
}
}, operand->value);
}, operand.value);
return optimized;
}
template<UnaryType Type>
ExpressionPtr AstOptimizer::PropagateUnaryConstant(std::unique_ptr<ConstantValueExpression>&& operand)
ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand)
{
std::unique_ptr<ConstantValueExpression> optimized;
std::visit([&](auto&& arg)
@ -973,7 +971,7 @@ namespace Nz::ShaderAst
if constexpr (is_complete_v<Op>)
optimized = Op{}(arg);
}
}, operand->value);
}, operand.value);
if (optimized)
optimized->cachedExpressionType = GetExpressionType(optimized->value);

View File

@ -1168,7 +1168,7 @@ namespace Nz::ShaderAst
std::unique_ptr<T> SanitizeVisitor::Optimize(T& node)
{
AstOptimizer::Options optimizerOptions;
optimizerOptions.constantQueryCallback = [this](std::size_t constantId)
optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue&
{
assert(constantId < m_context->constantValues.size());
return m_context->constantValues[constantId];