Shader: Fix optimization

This commit is contained in:
Jérôme Leclercq 2021-04-17 14:43:29 +02:00
parent 87ce2edc6e
commit 3a7f5c2630
3 changed files with 66 additions and 85 deletions

View File

@ -126,6 +126,8 @@ namespace Nz::ShaderAst
NodeType GetType() const override; NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override; void Visit(AstExpressionVisitor& visitor) override;
ExpressionType GetExpressionType() const;
ShaderAst::ConstantValue value; ShaderAst::ConstantValue value;
}; };

View File

@ -41,7 +41,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct CompEqBase struct CompEqBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs == rhs); return ShaderBuilder::Constant(lhs == rhs);
} }
@ -60,7 +60,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct CompGeBase struct CompGeBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs >= rhs); return ShaderBuilder::Constant(lhs >= rhs);
} }
@ -79,7 +79,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct CompGtBase struct CompGtBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs > rhs); return ShaderBuilder::Constant(lhs > rhs);
} }
@ -98,7 +98,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct CompLeBase struct CompLeBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs <= rhs); return ShaderBuilder::Constant(lhs <= rhs);
} }
@ -117,7 +117,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct CompLtBase struct CompLtBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs < rhs); return ShaderBuilder::Constant(lhs < rhs);
} }
@ -136,7 +136,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct CompNeBase struct CompNeBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs != rhs); return ShaderBuilder::Constant(lhs != rhs);
} }
@ -155,7 +155,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct AdditionBase struct AdditionBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs + rhs); return ShaderBuilder::Constant(lhs + rhs);
} }
@ -174,7 +174,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct DivisionBase struct DivisionBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs / rhs); return ShaderBuilder::Constant(lhs / rhs);
} }
@ -193,7 +193,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct MultiplicationBase struct MultiplicationBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs * rhs); return ShaderBuilder::Constant(lhs * rhs);
} }
@ -212,7 +212,7 @@ namespace Nz::ShaderAst
template<typename T1, typename T2> template<typename T1, typename T2>
struct SubtractionBase struct SubtractionBase
{ {
ExpressionPtr operator()(const T1& lhs, const T2& rhs) std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{ {
return ShaderBuilder::Constant(lhs - rhs); return ShaderBuilder::Constant(lhs - rhs);
} }
@ -382,17 +382,18 @@ namespace Nz::ShaderAst
StatementPtr AstOptimizer::Optimise(StatementPtr& statement) StatementPtr AstOptimizer::Optimise(StatementPtr& statement)
{ {
m_enabledOptions.reset();
return CloneStatement(statement); return CloneStatement(statement);
} }
StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions) StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions)
{ {
m_enabledConditions = enabledConditions; m_enabledOptions = enabledConditions;
return CloneStatement(statement); return CloneStatement(statement);
} }
void AstOptimizer::Visit(BinaryExpression& node) ExpressionPtr AstOptimizer::Clone(BinaryExpression& node)
{ {
auto lhs = CloneExpression(node.left); auto lhs = CloneExpression(node.left);
auto rhs = CloneExpression(node.right); auto rhs = CloneExpression(node.right);
@ -402,44 +403,60 @@ namespace Nz::ShaderAst
auto lhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(lhs)); auto lhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(lhs));
auto rhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(rhs)); auto rhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(rhs));
ExpressionPtr optimized;
switch (node.op) switch (node.op)
{ {
case BinaryType::Add: case BinaryType::Add:
return PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::Subtract: case BinaryType::Subtract:
return PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
case BinaryType::Multiply: case BinaryType::Multiply:
return PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::Divide: case BinaryType::Divide:
return PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompEq: case BinaryType::CompEq:
return PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompGe: case BinaryType::CompGe:
return PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompGt: case BinaryType::CompGt:
return PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompLe: case BinaryType::CompLe:
return PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompLt: case BinaryType::CompLt:
return PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::CompNe: case BinaryType::CompNe:
return PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
} break;
} }
AstCloner::Visit(node); if (optimized)
return optimized;
} }
void AstOptimizer::Visit(BranchStatement& node) auto binary = ShaderBuilder::Binary(node.op, std::move(lhs), std::move(rhs));
binary->cachedExpressionType = node.cachedExpressionType;
return binary;
}
StatementPtr AstOptimizer::Clone(BranchStatement& node)
{ {
std::vector<BranchStatement::ConditionalStatement> statements; std::vector<BranchStatement::ConditionalStatement> statements;
StatementPtr elseStatement; StatementPtr elseStatement;
@ -465,8 +482,7 @@ namespace Nz::ShaderAst
if (statements.empty()) if (statements.empty())
{ {
// First condition is true, dismiss the branch // First condition is true, dismiss the branch
condStatement.statement->Visit(*this); return AstCloner::Clone(condStatement.statement);
return;
} }
else else
{ {
@ -487,54 +503,43 @@ namespace Nz::ShaderAst
{ {
// All conditions have been removed, replace by else statement or no-op // All conditions have been removed, replace by else statement or no-op
if (node.elseStatement) if (node.elseStatement)
{ return AstCloner::Clone(node.elseStatement);
node.elseStatement->Visit(*this);
return;
}
else else
return PushStatement(ShaderBuilder::NoOp()); return ShaderBuilder::NoOp();
} }
if (!elseStatement) if (!elseStatement)
elseStatement = CloneStatement(node.elseStatement); elseStatement = CloneStatement(node.elseStatement);
PushStatement(ShaderBuilder::Branch(std::move(statements), std::move(elseStatement))); return ShaderBuilder::Branch(std::move(statements), std::move(elseStatement));
} }
void AstOptimizer::Visit(ConditionalExpression& node) ExpressionPtr AstOptimizer::Clone(ConditionalExpression& node)
{ {
return AstCloner::Visit(node); if (!m_enabledOptions)
return AstCloner::Clone(node);
/*if (!m_shaderAst) if (TestBit<UInt64>(*m_enabledOptions, node.optionIndex))
return ShaderAstCloner::Visit(node); return AstCloner::Clone(node.truePath);
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
assert(conditionIndex != InvalidCondition);
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
Visit(node.truePath);
else else
Visit(node.falsePath);*/ return AstCloner::Clone(node.falsePath);
} }
void AstOptimizer::Visit(ConditionalStatement& node) StatementPtr AstOptimizer::Clone(ConditionalStatement& node)
{ {
return AstCloner::Visit(node); if (!m_enabledOptions)
return AstCloner::Clone(node);
/*if (!m_shaderAst) if (TestBit<UInt64>(*m_enabledOptions, node.optionIndex))
return ShaderAstCloner::Visit(node); return AstCloner::Clone(node);
else
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName); return ShaderBuilder::NoOp();
assert(conditionIndex != InvalidCondition);
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
Visit(node.statement);*/
} }
template<BinaryType Type> template<BinaryType Type>
void AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs) ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
{ {
ExpressionPtr optimized; std::unique_ptr<ConstantExpression> optimized;
std::visit([&](auto&& arg1) std::visit([&](auto&& arg1)
{ {
using T1 = std::decay_t<decltype(arg1)>; using T1 = std::decay_t<decltype(arg1)>;
@ -555,8 +560,8 @@ namespace Nz::ShaderAst
}, lhs->value); }, lhs->value);
if (optimized) if (optimized)
PushExpression(std::move(optimized)); optimized->cachedExpressionType = optimized->GetExpressionType();
else
PushExpression(ShaderBuilder::Binary(Type, std::move(lhs), std::move(rhs))); return optimized;
} }
} }

View File

@ -358,33 +358,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
{ {
auto clone = static_unique_pointer_cast<ConstantExpression>(AstCloner::Clone(node)); auto clone = static_unique_pointer_cast<ConstantExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType clone->cachedExpressionType = clone->GetExpressionType();
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return PrimitiveType::Float32;
else if constexpr (std::is_same_v<T, Int32>)
return PrimitiveType::Int32;
else if constexpr (std::is_same_v<T, UInt32>)
return PrimitiveType::UInt32;
else if constexpr (std::is_same_v<T, Vector2f>)
return VectorType{ 2, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector3f>)
return VectorType{ 3, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector4f>)
return VectorType{ 4, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector2i32>)
return VectorType{ 2, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector3i32>)
return VectorType{ 3, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector4i32>)
return VectorType{ 4, PrimitiveType::Int32 };
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, clone->value);
return clone; return clone;
} }