Shader/AstOptimizer: Add swizzle optimization

This commit is contained in:
Jérôme Leclercq
2021-12-28 20:09:04 +01:00
parent 22651255df
commit 101a3d70da
5 changed files with 378 additions and 0 deletions

View File

@@ -33,6 +33,34 @@ namespace Nz::ShaderAst
template<typename T>
inline constexpr bool is_complete_v = is_complete<T>::value;
template<typename T>
struct VectorInfo
{
static constexpr std::size_t Dimensions = 1;
using Base = T;
};
template<typename T>
struct VectorInfo<Vector2<T>>
{
static constexpr std::size_t Dimensions = 2;
using Base = T;
};
template<typename T>
struct VectorInfo<Vector3<T>>
{
static constexpr std::size_t Dimensions = 3;
using Base = T;
};
template<typename T>
struct VectorInfo<Vector4<T>>
{
static constexpr std::size_t Dimensions = 4;
using Base = T;
};
/*************************************************************************************************/
template<BinaryType Type, typename T1, typename T2>
@@ -288,6 +316,92 @@ namespace Nz::ShaderAst
/*************************************************************************************************/
template<typename T, std::size_t TargetComponentCount, std::size_t FromComponentCount>
struct SwizzleBase;
template<typename T, std::size_t TargetComponentCount>
struct SwizzleBase<T, TargetComponentCount, 1>
{
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& /*components*/, T value)
{
if constexpr (TargetComponentCount == 4)
return ShaderBuilder::Constant(Vector4<T>(value, value, value, value));
else if constexpr (TargetComponentCount == 3)
return ShaderBuilder::Constant(Vector3<T>(value, value, value));
else if constexpr (TargetComponentCount == 2)
return ShaderBuilder::Constant(Vector2<T>(value, value));
else if constexpr (TargetComponentCount == 1)
return ShaderBuilder::Constant(value);
else
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
}
};
template<typename T, std::size_t TargetComponentCount>
struct SwizzleBase<T, TargetComponentCount, 2>
{
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector2<T>& value)
{
if constexpr (TargetComponentCount == 4)
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
else if constexpr (TargetComponentCount == 3)
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
else if constexpr (TargetComponentCount == 2)
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
else if constexpr (TargetComponentCount == 1)
return ShaderBuilder::Constant(value[components[0]]);
else
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
}
};
template<typename T, std::size_t TargetComponentCount>
struct SwizzleBase<T, TargetComponentCount, 3>
{
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector3<T>& value)
{
if constexpr (TargetComponentCount == 4)
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
else if constexpr (TargetComponentCount == 3)
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
else if constexpr (TargetComponentCount == 2)
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
else if constexpr (TargetComponentCount == 1)
return ShaderBuilder::Constant(value[components[0]]);
else
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
}
};
template<typename T, std::size_t TargetComponentCount>
struct SwizzleBase<T, TargetComponentCount, 4>
{
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector4<T>& value)
{
if constexpr (TargetComponentCount == 4)
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
else if constexpr (TargetComponentCount == 3)
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
else if constexpr (TargetComponentCount == 2)
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
else if constexpr (TargetComponentCount == 1)
return ShaderBuilder::Constant(value[components[0]]);
else
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
}
};
template<typename T, std::size_t... Args>
struct Swizzle;
template<typename T, std::size_t... Args>
struct SwizzlePropagation
{
using Op = Swizzle<T, Args...>;
};
/*************************************************************************************************/
template<UnaryType Type, typename T>
struct UnaryConstantPropagation;
@@ -541,6 +655,67 @@ namespace Nz::ShaderAst
//EnableOptimisation(CastConstant, Vector3ui32, UInt32, UInt32, UInt32);
//EnableOptimisation(CastConstant, Vector4ui32, UInt32, UInt32, UInt32, UInt32);
// Swizzle
EnableOptimisation(Swizzle, double, 1, 1);
EnableOptimisation(Swizzle, double, 1, 2);
EnableOptimisation(Swizzle, double, 1, 3);
EnableOptimisation(Swizzle, double, 1, 4);
EnableOptimisation(Swizzle, double, 2, 1);
EnableOptimisation(Swizzle, double, 2, 2);
EnableOptimisation(Swizzle, double, 2, 3);
EnableOptimisation(Swizzle, double, 2, 4);
EnableOptimisation(Swizzle, double, 3, 1);
EnableOptimisation(Swizzle, double, 3, 2);
EnableOptimisation(Swizzle, double, 3, 3);
EnableOptimisation(Swizzle, double, 3, 4);
EnableOptimisation(Swizzle, double, 4, 1);
EnableOptimisation(Swizzle, double, 4, 2);
EnableOptimisation(Swizzle, double, 4, 3);
EnableOptimisation(Swizzle, double, 4, 4);
EnableOptimisation(Swizzle, float, 1, 1);
EnableOptimisation(Swizzle, float, 1, 2);
EnableOptimisation(Swizzle, float, 1, 3);
EnableOptimisation(Swizzle, float, 1, 4);
EnableOptimisation(Swizzle, float, 2, 1);
EnableOptimisation(Swizzle, float, 2, 2);
EnableOptimisation(Swizzle, float, 2, 3);
EnableOptimisation(Swizzle, float, 2, 4);
EnableOptimisation(Swizzle, float, 3, 1);
EnableOptimisation(Swizzle, float, 3, 2);
EnableOptimisation(Swizzle, float, 3, 3);
EnableOptimisation(Swizzle, float, 3, 4);
EnableOptimisation(Swizzle, float, 4, 1);
EnableOptimisation(Swizzle, float, 4, 2);
EnableOptimisation(Swizzle, float, 4, 3);
EnableOptimisation(Swizzle, float, 4, 4);
EnableOptimisation(Swizzle, Int32, 1, 1);
EnableOptimisation(Swizzle, Int32, 1, 2);
EnableOptimisation(Swizzle, Int32, 1, 3);
EnableOptimisation(Swizzle, Int32, 1, 4);
EnableOptimisation(Swizzle, Int32, 2, 1);
EnableOptimisation(Swizzle, Int32, 2, 2);
EnableOptimisation(Swizzle, Int32, 2, 3);
EnableOptimisation(Swizzle, Int32, 2, 4);
EnableOptimisation(Swizzle, Int32, 3, 1);
EnableOptimisation(Swizzle, Int32, 3, 2);
EnableOptimisation(Swizzle, Int32, 3, 3);
EnableOptimisation(Swizzle, Int32, 3, 4);
EnableOptimisation(Swizzle, Int32, 4, 1);
EnableOptimisation(Swizzle, Int32, 4, 2);
EnableOptimisation(Swizzle, Int32, 4, 3);
EnableOptimisation(Swizzle, Int32, 4, 4);
// Unary
EnableOptimisation(UnaryLogicalNot, bool);
@@ -847,6 +1022,57 @@ namespace Nz::ShaderAst
return constant;
}
ExpressionPtr AstOptimizer::Clone(SwizzleExpression& node)
{
auto expr = CloneExpression(node.expression);
if (expr->GetType() == NodeType::ConstantValueExpression)
{
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expr);
ExpressionPtr optimized;
switch (node.componentCount)
{
case 1:
optimized = PropagateConstantSwizzle<1>(node.components, constantExpr);
break;
case 2:
optimized = PropagateConstantSwizzle<2>(node.components, constantExpr);
break;
case 3:
optimized = PropagateConstantSwizzle<3>(node.components, constantExpr);
break;
case 4:
optimized = PropagateConstantSwizzle<4>(node.components, constantExpr);
break;
}
if (optimized)
return optimized;
}
else if (expr->GetType() == NodeType::SwizzleExpression)
{
SwizzleExpression& constantExpr = static_cast<SwizzleExpression&>(*expr);
std::array<UInt32, 4> newComponents;
for (std::size_t i = 0; i < node.componentCount; ++i)
newComponents[i] = constantExpr.components[node.components[i]];
constantExpr.componentCount = node.componentCount;
constantExpr.components = newComponents;
return expr;
}
auto swizzle = ShaderBuilder::Swizzle(std::move(expr), node.components, node.componentCount);
swizzle->cachedExpressionType = node.cachedExpressionType;
return swizzle;
}
ExpressionPtr AstOptimizer::Clone(UnaryExpression& node)
{
auto expr = CloneExpression(node.expression);
@@ -952,6 +1178,30 @@ namespace Nz::ShaderAst
return optimized;
}
template<std::size_t TargetComponentCount>
ExpressionPtr AstOptimizer::PropagateConstantSwizzle(const std::array<UInt32, 4>& components, const ConstantValueExpression& operand)
{
std::unique_ptr<ConstantValueExpression> optimized;
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
using BaseType = typename VectorInfo<T>::Base;
constexpr std::size_t FromComponentCount = VectorInfo<T>::Dimensions;
using SPType = SwizzlePropagation<BaseType, TargetComponentCount, FromComponentCount>;
if constexpr (is_complete_v<SPType>)
{
using Op = typename SPType::Op;
if constexpr (is_complete_v<Op>)
optimized = Op{}(components, arg);
}
}, operand.value);
return optimized;
}
template<UnaryType Type>
ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand)
{