Shader: Fix optimization
This commit is contained in:
parent
87ce2edc6e
commit
3a7f5c2630
|
|
@ -126,6 +126,8 @@ namespace Nz::ShaderAst
|
|||
NodeType GetType() const override;
|
||||
void Visit(AstExpressionVisitor& visitor) override;
|
||||
|
||||
ExpressionType GetExpressionType() const;
|
||||
|
||||
ShaderAst::ConstantValue value;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct CompEqBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs == rhs);
|
||||
}
|
||||
|
|
@ -60,7 +60,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct CompGeBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs >= rhs);
|
||||
}
|
||||
|
|
@ -79,7 +79,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct CompGtBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs > rhs);
|
||||
}
|
||||
|
|
@ -98,7 +98,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct CompLeBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs <= rhs);
|
||||
}
|
||||
|
|
@ -117,7 +117,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct CompLtBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs < rhs);
|
||||
}
|
||||
|
|
@ -136,7 +136,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct CompNeBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs != rhs);
|
||||
}
|
||||
|
|
@ -155,7 +155,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct AdditionBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs + rhs);
|
||||
}
|
||||
|
|
@ -174,7 +174,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct DivisionBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs / rhs);
|
||||
}
|
||||
|
|
@ -193,7 +193,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct MultiplicationBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs * rhs);
|
||||
}
|
||||
|
|
@ -212,7 +212,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T1, typename T2>
|
||||
struct SubtractionBase
|
||||
{
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs - rhs);
|
||||
}
|
||||
|
|
@ -382,17 +382,18 @@ namespace Nz::ShaderAst
|
|||
|
||||
StatementPtr AstOptimizer::Optimise(StatementPtr& statement)
|
||||
{
|
||||
m_enabledOptions.reset();
|
||||
return CloneStatement(statement);
|
||||
}
|
||||
|
||||
StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions)
|
||||
{
|
||||
m_enabledConditions = enabledConditions;
|
||||
m_enabledOptions = enabledConditions;
|
||||
|
||||
return CloneStatement(statement);
|
||||
}
|
||||
|
||||
void AstOptimizer::Visit(BinaryExpression& node)
|
||||
ExpressionPtr AstOptimizer::Clone(BinaryExpression& node)
|
||||
{
|
||||
auto lhs = CloneExpression(node.left);
|
||||
auto rhs = CloneExpression(node.right);
|
||||
|
|
@ -402,44 +403,60 @@ namespace Nz::ShaderAst
|
|||
auto lhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(lhs));
|
||||
auto rhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(rhs));
|
||||
|
||||
ExpressionPtr optimized;
|
||||
switch (node.op)
|
||||
{
|
||||
case BinaryType::Add:
|
||||
return PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
break;
|
||||
|
||||
case BinaryType::Subtract:
|
||||
return PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case BinaryType::Multiply:
|
||||
return PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
break;
|
||||
|
||||
case BinaryType::Divide:
|
||||
return PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
break;
|
||||
|
||||
case BinaryType::CompEq:
|
||||
return PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
break;
|
||||
|
||||
case BinaryType::CompGe:
|
||||
return PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
break;
|
||||
|
||||
case BinaryType::CompGt:
|
||||
return PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
break;
|
||||
|
||||
case BinaryType::CompLe:
|
||||
return PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
break;
|
||||
|
||||
case BinaryType::CompLt:
|
||||
return PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
optimized = PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
break;
|
||||
|
||||
case BinaryType::CompNe:
|
||||
return PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
}
|
||||
optimized = PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
break;
|
||||
}
|
||||
|
||||
AstCloner::Visit(node);
|
||||
if (optimized)
|
||||
return optimized;
|
||||
}
|
||||
|
||||
void AstOptimizer::Visit(BranchStatement& node)
|
||||
auto binary = ShaderBuilder::Binary(node.op, std::move(lhs), std::move(rhs));
|
||||
binary->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return binary;
|
||||
}
|
||||
|
||||
StatementPtr AstOptimizer::Clone(BranchStatement& node)
|
||||
{
|
||||
std::vector<BranchStatement::ConditionalStatement> statements;
|
||||
StatementPtr elseStatement;
|
||||
|
|
@ -465,8 +482,7 @@ namespace Nz::ShaderAst
|
|||
if (statements.empty())
|
||||
{
|
||||
// First condition is true, dismiss the branch
|
||||
condStatement.statement->Visit(*this);
|
||||
return;
|
||||
return AstCloner::Clone(condStatement.statement);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
@ -487,54 +503,43 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
// All conditions have been removed, replace by else statement or no-op
|
||||
if (node.elseStatement)
|
||||
{
|
||||
node.elseStatement->Visit(*this);
|
||||
return;
|
||||
}
|
||||
return AstCloner::Clone(node.elseStatement);
|
||||
else
|
||||
return PushStatement(ShaderBuilder::NoOp());
|
||||
return ShaderBuilder::NoOp();
|
||||
}
|
||||
|
||||
if (!elseStatement)
|
||||
elseStatement = CloneStatement(node.elseStatement);
|
||||
|
||||
PushStatement(ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)));
|
||||
return ShaderBuilder::Branch(std::move(statements), std::move(elseStatement));
|
||||
}
|
||||
|
||||
void AstOptimizer::Visit(ConditionalExpression& node)
|
||||
ExpressionPtr AstOptimizer::Clone(ConditionalExpression& node)
|
||||
{
|
||||
return AstCloner::Visit(node);
|
||||
if (!m_enabledOptions)
|
||||
return AstCloner::Clone(node);
|
||||
|
||||
/*if (!m_shaderAst)
|
||||
return ShaderAstCloner::Visit(node);
|
||||
|
||||
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
|
||||
Visit(node.truePath);
|
||||
if (TestBit<UInt64>(*m_enabledOptions, node.optionIndex))
|
||||
return AstCloner::Clone(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);*/
|
||||
return AstCloner::Clone(node.falsePath);
|
||||
}
|
||||
|
||||
void AstOptimizer::Visit(ConditionalStatement& node)
|
||||
StatementPtr AstOptimizer::Clone(ConditionalStatement& node)
|
||||
{
|
||||
return AstCloner::Visit(node);
|
||||
if (!m_enabledOptions)
|
||||
return AstCloner::Clone(node);
|
||||
|
||||
/*if (!m_shaderAst)
|
||||
return ShaderAstCloner::Visit(node);
|
||||
|
||||
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
|
||||
Visit(node.statement);*/
|
||||
if (TestBit<UInt64>(*m_enabledOptions, node.optionIndex))
|
||||
return AstCloner::Clone(node);
|
||||
else
|
||||
return ShaderBuilder::NoOp();
|
||||
}
|
||||
|
||||
template<BinaryType Type>
|
||||
void AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
|
||||
ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
|
||||
{
|
||||
ExpressionPtr optimized;
|
||||
std::unique_ptr<ConstantExpression> optimized;
|
||||
std::visit([&](auto&& arg1)
|
||||
{
|
||||
using T1 = std::decay_t<decltype(arg1)>;
|
||||
|
|
@ -555,8 +560,8 @@ namespace Nz::ShaderAst
|
|||
}, lhs->value);
|
||||
|
||||
if (optimized)
|
||||
PushExpression(std::move(optimized));
|
||||
else
|
||||
PushExpression(ShaderBuilder::Binary(Type, std::move(lhs), std::move(rhs)));
|
||||
optimized->cachedExpressionType = optimized->GetExpressionType();
|
||||
|
||||
return optimized;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -358,33 +358,7 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
|
||||
{
|
||||
auto clone = static_unique_pointer_cast<ConstantExpression>(AstCloner::Clone(node));
|
||||
clone->cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, bool>)
|
||||
return PrimitiveType::Boolean;
|
||||
else if constexpr (std::is_same_v<T, float>)
|
||||
return PrimitiveType::Float32;
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
return PrimitiveType::Int32;
|
||||
else if constexpr (std::is_same_v<T, UInt32>)
|
||||
return PrimitiveType::UInt32;
|
||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||
return VectorType{ 2, PrimitiveType::Float32 };
|
||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||
return VectorType{ 3, PrimitiveType::Float32 };
|
||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||
return VectorType{ 4, PrimitiveType::Float32 };
|
||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||
return VectorType{ 2, PrimitiveType::Int32 };
|
||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||
return VectorType{ 3, PrimitiveType::Int32 };
|
||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||
return VectorType{ 4, PrimitiveType::Int32 };
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, clone->value);
|
||||
clone->cachedExpressionType = clone->GetExpressionType();
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue