Shader: Fix optimization

This commit is contained in:
Jérôme Leclercq 2021-04-17 14:43:29 +02:00
parent 87ce2edc6e
commit 3a7f5c2630
3 changed files with 66 additions and 85 deletions

View File

@ -126,6 +126,8 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
ExpressionType GetExpressionType() const;
ShaderAst::ConstantValue value;
};

View File

@ -41,7 +41,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct CompEqBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs == rhs);
}
@ -60,7 +60,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct CompGeBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs >= rhs);
}
@ -79,7 +79,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct CompGtBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs > rhs);
}
@ -98,7 +98,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct CompLeBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs <= rhs);
}
@ -117,7 +117,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct CompLtBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs < rhs);
}
@ -136,7 +136,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct CompNeBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs != rhs);
}
@ -155,7 +155,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct AdditionBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs + rhs);
}
@ -174,7 +174,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct DivisionBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs / rhs);
}
@ -193,7 +193,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct MultiplicationBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs * rhs);
}
@ -212,7 +212,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2>
struct SubtractionBase
{
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs - rhs);
}
@ -382,17 +382,18 @@ namespace Nz::ShaderAst
StatementPtr AstOptimizer::Optimise(StatementPtr& statement)
{
m_enabledOptions.reset();
return CloneStatement(statement);
}
StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions)
{
m_enabledConditions = enabledConditions;
m_enabledOptions = enabledConditions;
return CloneStatement(statement);
}
void AstOptimizer::Visit(BinaryExpression& node)
ExpressionPtr AstOptimizer::Clone(BinaryExpression& node)
{
auto lhs = CloneExpression(node.left);
auto rhs = CloneExpression(node.right);
@ -402,44 +403,60 @@ namespace Nz::ShaderAst
auto lhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(lhs));
auto rhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(rhs));
ExpressionPtr optimized;
switch (node.op)
{
case BinaryType::Add:
return PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::Subtract:
return PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
case BinaryType::Multiply:
return PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::Divide:
return PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompEq:
return PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompGe:
return PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompGt:
return PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompLe:
return PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompLt:
return PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
optimized = PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompNe:
return PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
}
optimized = PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
break;
}
AstCloner::Visit(node);
if (optimized)
return optimized;
}
void AstOptimizer::Visit(BranchStatement& node)
auto binary = ShaderBuilder::Binary(node.op, std::move(lhs), std::move(rhs));
binary->cachedExpressionType = node.cachedExpressionType;
return binary;
}
StatementPtr AstOptimizer::Clone(BranchStatement& node)
{
std::vector<BranchStatement::ConditionalStatement> statements;
StatementPtr elseStatement;
@ -465,8 +482,7 @@ namespace Nz::ShaderAst
if (statements.empty())
{
// First condition is true, dismiss the branch
condStatement.statement->Visit(*this);
return;
return AstCloner::Clone(condStatement.statement);
}
else
{
@ -487,54 +503,43 @@ namespace Nz::ShaderAst
{
// All conditions have been removed, replace by else statement or no-op
if (node.elseStatement)
{
node.elseStatement->Visit(*this);
return;
}
return AstCloner::Clone(node.elseStatement);
else
return PushStatement(ShaderBuilder::NoOp());
return ShaderBuilder::NoOp();
}
if (!elseStatement)
elseStatement = CloneStatement(node.elseStatement);
PushStatement(ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)));
return ShaderBuilder::Branch(std::move(statements), std::move(elseStatement));
}
void AstOptimizer::Visit(ConditionalExpression& node)
ExpressionPtr AstOptimizer::Clone(ConditionalExpression& node)
{
return AstCloner::Visit(node);
if (!m_enabledOptions)
return AstCloner::Clone(node);
/*if (!m_shaderAst)
return ShaderAstCloner::Visit(node);
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
assert(conditionIndex != InvalidCondition);
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
Visit(node.truePath);
if (TestBit<UInt64>(*m_enabledOptions, node.optionIndex))
return AstCloner::Clone(node.truePath);
else
Visit(node.falsePath);*/
return AstCloner::Clone(node.falsePath);
}
void AstOptimizer::Visit(ConditionalStatement& node)
StatementPtr AstOptimizer::Clone(ConditionalStatement& node)
{
return AstCloner::Visit(node);
if (!m_enabledOptions)
return AstCloner::Clone(node);
/*if (!m_shaderAst)
return ShaderAstCloner::Visit(node);
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
assert(conditionIndex != InvalidCondition);
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
Visit(node.statement);*/
if (TestBit<UInt64>(*m_enabledOptions, node.optionIndex))
return AstCloner::Clone(node);
else
return ShaderBuilder::NoOp();
}
template<BinaryType Type>
void AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
{
ExpressionPtr optimized;
std::unique_ptr<ConstantExpression> optimized;
std::visit([&](auto&& arg1)
{
using T1 = std::decay_t<decltype(arg1)>;
@ -555,8 +560,8 @@ namespace Nz::ShaderAst
}, lhs->value);
if (optimized)
PushExpression(std::move(optimized));
else
PushExpression(ShaderBuilder::Binary(Type, std::move(lhs), std::move(rhs)));
optimized->cachedExpressionType = optimized->GetExpressionType();
return optimized;
}
}

View File

@ -358,33 +358,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
{
auto clone = static_unique_pointer_cast<ConstantExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return PrimitiveType::Float32;
else if constexpr (std::is_same_v<T, Int32>)
return PrimitiveType::Int32;
else if constexpr (std::is_same_v<T, UInt32>)
return PrimitiveType::UInt32;
else if constexpr (std::is_same_v<T, Vector2f>)
return VectorType{ 2, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector3f>)
return VectorType{ 3, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector4f>)
return VectorType{ 4, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector2i32>)
return VectorType{ 2, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector3i32>)
return VectorType{ 3, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector4i32>)
return VectorType{ 4, PrimitiveType::Int32 };
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, clone->value);
clone->cachedExpressionType = clone->GetExpressionType();
return clone;
}