Shader: Fix optimization
This commit is contained in:
parent
87ce2edc6e
commit
3a7f5c2630
|
|
@ -126,6 +126,8 @@ namespace Nz::ShaderAst
|
||||||
NodeType GetType() const override;
|
NodeType GetType() const override;
|
||||||
void Visit(AstExpressionVisitor& visitor) override;
|
void Visit(AstExpressionVisitor& visitor) override;
|
||||||
|
|
||||||
|
ExpressionType GetExpressionType() const;
|
||||||
|
|
||||||
ShaderAst::ConstantValue value;
|
ShaderAst::ConstantValue value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct CompEqBase
|
struct CompEqBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs == rhs);
|
return ShaderBuilder::Constant(lhs == rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -60,7 +60,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct CompGeBase
|
struct CompGeBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs >= rhs);
|
return ShaderBuilder::Constant(lhs >= rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -79,7 +79,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct CompGtBase
|
struct CompGtBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs > rhs);
|
return ShaderBuilder::Constant(lhs > rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -98,7 +98,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct CompLeBase
|
struct CompLeBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs <= rhs);
|
return ShaderBuilder::Constant(lhs <= rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -117,7 +117,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct CompLtBase
|
struct CompLtBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs < rhs);
|
return ShaderBuilder::Constant(lhs < rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -136,7 +136,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct CompNeBase
|
struct CompNeBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs != rhs);
|
return ShaderBuilder::Constant(lhs != rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -155,7 +155,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct AdditionBase
|
struct AdditionBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs + rhs);
|
return ShaderBuilder::Constant(lhs + rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -174,7 +174,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct DivisionBase
|
struct DivisionBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs / rhs);
|
return ShaderBuilder::Constant(lhs / rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -193,7 +193,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct MultiplicationBase
|
struct MultiplicationBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs * rhs);
|
return ShaderBuilder::Constant(lhs * rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -212,7 +212,7 @@ namespace Nz::ShaderAst
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
struct SubtractionBase
|
struct SubtractionBase
|
||||||
{
|
{
|
||||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||||
{
|
{
|
||||||
return ShaderBuilder::Constant(lhs - rhs);
|
return ShaderBuilder::Constant(lhs - rhs);
|
||||||
}
|
}
|
||||||
|
|
@ -382,17 +382,18 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
StatementPtr AstOptimizer::Optimise(StatementPtr& statement)
|
StatementPtr AstOptimizer::Optimise(StatementPtr& statement)
|
||||||
{
|
{
|
||||||
|
m_enabledOptions.reset();
|
||||||
return CloneStatement(statement);
|
return CloneStatement(statement);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions)
|
StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions)
|
||||||
{
|
{
|
||||||
m_enabledConditions = enabledConditions;
|
m_enabledOptions = enabledConditions;
|
||||||
|
|
||||||
return CloneStatement(statement);
|
return CloneStatement(statement);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AstOptimizer::Visit(BinaryExpression& node)
|
ExpressionPtr AstOptimizer::Clone(BinaryExpression& node)
|
||||||
{
|
{
|
||||||
auto lhs = CloneExpression(node.left);
|
auto lhs = CloneExpression(node.left);
|
||||||
auto rhs = CloneExpression(node.right);
|
auto rhs = CloneExpression(node.right);
|
||||||
|
|
@ -402,44 +403,60 @@ namespace Nz::ShaderAst
|
||||||
auto lhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(lhs));
|
auto lhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(lhs));
|
||||||
auto rhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(rhs));
|
auto rhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(rhs));
|
||||||
|
|
||||||
|
ExpressionPtr optimized;
|
||||||
switch (node.op)
|
switch (node.op)
|
||||||
{
|
{
|
||||||
case BinaryType::Add:
|
case BinaryType::Add:
|
||||||
return PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
|
|
||||||
case BinaryType::Subtract:
|
case BinaryType::Subtract:
|
||||||
return PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
|
||||||
case BinaryType::Multiply:
|
case BinaryType::Multiply:
|
||||||
return PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
|
|
||||||
case BinaryType::Divide:
|
case BinaryType::Divide:
|
||||||
return PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
|
|
||||||
case BinaryType::CompEq:
|
case BinaryType::CompEq:
|
||||||
return PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
|
|
||||||
case BinaryType::CompGe:
|
case BinaryType::CompGe:
|
||||||
return PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
|
|
||||||
case BinaryType::CompGt:
|
case BinaryType::CompGt:
|
||||||
return PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
|
|
||||||
case BinaryType::CompLe:
|
case BinaryType::CompLe:
|
||||||
return PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
|
|
||||||
case BinaryType::CompLt:
|
case BinaryType::CompLt:
|
||||||
return PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
|
|
||||||
case BinaryType::CompNe:
|
case BinaryType::CompNe:
|
||||||
return PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (optimized)
|
||||||
|
return optimized;
|
||||||
}
|
}
|
||||||
|
|
||||||
AstCloner::Visit(node);
|
auto binary = ShaderBuilder::Binary(node.op, std::move(lhs), std::move(rhs));
|
||||||
|
binary->cachedExpressionType = node.cachedExpressionType;
|
||||||
|
|
||||||
|
return binary;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AstOptimizer::Visit(BranchStatement& node)
|
StatementPtr AstOptimizer::Clone(BranchStatement& node)
|
||||||
{
|
{
|
||||||
std::vector<BranchStatement::ConditionalStatement> statements;
|
std::vector<BranchStatement::ConditionalStatement> statements;
|
||||||
StatementPtr elseStatement;
|
StatementPtr elseStatement;
|
||||||
|
|
@ -465,8 +482,7 @@ namespace Nz::ShaderAst
|
||||||
if (statements.empty())
|
if (statements.empty())
|
||||||
{
|
{
|
||||||
// First condition is true, dismiss the branch
|
// First condition is true, dismiss the branch
|
||||||
condStatement.statement->Visit(*this);
|
return AstCloner::Clone(condStatement.statement);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|
@ -487,54 +503,43 @@ namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
// All conditions have been removed, replace by else statement or no-op
|
// All conditions have been removed, replace by else statement or no-op
|
||||||
if (node.elseStatement)
|
if (node.elseStatement)
|
||||||
{
|
return AstCloner::Clone(node.elseStatement);
|
||||||
node.elseStatement->Visit(*this);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
return PushStatement(ShaderBuilder::NoOp());
|
return ShaderBuilder::NoOp();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!elseStatement)
|
if (!elseStatement)
|
||||||
elseStatement = CloneStatement(node.elseStatement);
|
elseStatement = CloneStatement(node.elseStatement);
|
||||||
|
|
||||||
PushStatement(ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)));
|
return ShaderBuilder::Branch(std::move(statements), std::move(elseStatement));
|
||||||
}
|
}
|
||||||
|
|
||||||
void AstOptimizer::Visit(ConditionalExpression& node)
|
ExpressionPtr AstOptimizer::Clone(ConditionalExpression& node)
|
||||||
{
|
{
|
||||||
return AstCloner::Visit(node);
|
if (!m_enabledOptions)
|
||||||
|
return AstCloner::Clone(node);
|
||||||
|
|
||||||
/*if (!m_shaderAst)
|
if (TestBit<UInt64>(*m_enabledOptions, node.optionIndex))
|
||||||
return ShaderAstCloner::Visit(node);
|
return AstCloner::Clone(node.truePath);
|
||||||
|
|
||||||
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
|
|
||||||
assert(conditionIndex != InvalidCondition);
|
|
||||||
|
|
||||||
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
|
|
||||||
Visit(node.truePath);
|
|
||||||
else
|
else
|
||||||
Visit(node.falsePath);*/
|
return AstCloner::Clone(node.falsePath);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AstOptimizer::Visit(ConditionalStatement& node)
|
StatementPtr AstOptimizer::Clone(ConditionalStatement& node)
|
||||||
{
|
{
|
||||||
return AstCloner::Visit(node);
|
if (!m_enabledOptions)
|
||||||
|
return AstCloner::Clone(node);
|
||||||
|
|
||||||
/*if (!m_shaderAst)
|
if (TestBit<UInt64>(*m_enabledOptions, node.optionIndex))
|
||||||
return ShaderAstCloner::Visit(node);
|
return AstCloner::Clone(node);
|
||||||
|
else
|
||||||
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
|
return ShaderBuilder::NoOp();
|
||||||
assert(conditionIndex != InvalidCondition);
|
|
||||||
|
|
||||||
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
|
|
||||||
Visit(node.statement);*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<BinaryType Type>
|
template<BinaryType Type>
|
||||||
void AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
|
ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
|
||||||
{
|
{
|
||||||
ExpressionPtr optimized;
|
std::unique_ptr<ConstantExpression> optimized;
|
||||||
std::visit([&](auto&& arg1)
|
std::visit([&](auto&& arg1)
|
||||||
{
|
{
|
||||||
using T1 = std::decay_t<decltype(arg1)>;
|
using T1 = std::decay_t<decltype(arg1)>;
|
||||||
|
|
@ -555,8 +560,8 @@ namespace Nz::ShaderAst
|
||||||
}, lhs->value);
|
}, lhs->value);
|
||||||
|
|
||||||
if (optimized)
|
if (optimized)
|
||||||
PushExpression(std::move(optimized));
|
optimized->cachedExpressionType = optimized->GetExpressionType();
|
||||||
else
|
|
||||||
PushExpression(ShaderBuilder::Binary(Type, std::move(lhs), std::move(rhs)));
|
return optimized;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -358,33 +358,7 @@ namespace Nz::ShaderAst
|
||||||
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
|
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
|
||||||
{
|
{
|
||||||
auto clone = static_unique_pointer_cast<ConstantExpression>(AstCloner::Clone(node));
|
auto clone = static_unique_pointer_cast<ConstantExpression>(AstCloner::Clone(node));
|
||||||
clone->cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
clone->cachedExpressionType = clone->GetExpressionType();
|
||||||
{
|
|
||||||
using T = std::decay_t<decltype(arg)>;
|
|
||||||
|
|
||||||
if constexpr (std::is_same_v<T, bool>)
|
|
||||||
return PrimitiveType::Boolean;
|
|
||||||
else if constexpr (std::is_same_v<T, float>)
|
|
||||||
return PrimitiveType::Float32;
|
|
||||||
else if constexpr (std::is_same_v<T, Int32>)
|
|
||||||
return PrimitiveType::Int32;
|
|
||||||
else if constexpr (std::is_same_v<T, UInt32>)
|
|
||||||
return PrimitiveType::UInt32;
|
|
||||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
|
||||||
return VectorType{ 2, PrimitiveType::Float32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
|
||||||
return VectorType{ 3, PrimitiveType::Float32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
|
||||||
return VectorType{ 4, PrimitiveType::Float32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
|
||||||
return VectorType{ 2, PrimitiveType::Int32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
|
||||||
return VectorType{ 3, PrimitiveType::Int32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
|
||||||
return VectorType{ 4, PrimitiveType::Int32 };
|
|
||||||
else
|
|
||||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
|
||||||
}, clone->value);
|
|
||||||
|
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue