Shader/AstOptimizer: Add swizzle optimization

This commit is contained in:
Jérôme Leclercq 2021-12-28 20:09:04 +01:00
parent 22651255df
commit 101a3d70da
5 changed files with 378 additions and 0 deletions

View File

@ -44,12 +44,14 @@ namespace Nz::ShaderAst
ExpressionPtr Clone(CastExpression& node) override;
ExpressionPtr Clone(ConditionalExpression& node) override;
ExpressionPtr Clone(ConstantExpression& node) override;
ExpressionPtr Clone(SwizzleExpression& node) override;
ExpressionPtr Clone(UnaryExpression& node) override;
StatementPtr Clone(BranchStatement& node) override;
StatementPtr Clone(ConditionalStatement& node) override;
template<BinaryType Type> ExpressionPtr PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs);
template<typename TargetType> ExpressionPtr PropagateSingleValueCast(const ConstantValueExpression& operand);
template<std::size_t TargetComponentCount> ExpressionPtr PropagateConstantSwizzle(const std::array<UInt32, 4>& components, const ConstantValueExpression& operand);
template<UnaryType Type> ExpressionPtr PropagateUnaryConstant(const ConstantValueExpression& operand);
template<typename TargetType> ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2);
template<typename TargetType> ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3);

View File

@ -134,6 +134,7 @@ namespace Nz::ShaderBuilder
struct Swizzle
{
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const;
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::vector<UInt32> swizzleComponents) const;
};

View File

@ -289,6 +289,19 @@ namespace Nz::ShaderBuilder
return returnNode;
}
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const
{
assert(componentCount > 0);
assert(componentCount <= 4);
auto swizzleNode = std::make_unique<ShaderAst::SwizzleExpression>();
swizzleNode->expression = std::move(expression);
swizzleNode->componentCount = componentCount;
swizzleNode->components = swizzleComponents;
return swizzleNode;
}
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector<UInt32> swizzleComponents) const
{
auto swizzleNode = std::make_unique<ShaderAst::SwizzleExpression>();

View File

@ -33,6 +33,34 @@ namespace Nz::ShaderAst
template<typename T>
inline constexpr bool is_complete_v = is_complete<T>::value;
template<typename T>
struct VectorInfo
{
static constexpr std::size_t Dimensions = 1;
using Base = T;
};
template<typename T>
struct VectorInfo<Vector2<T>>
{
static constexpr std::size_t Dimensions = 2;
using Base = T;
};
template<typename T>
struct VectorInfo<Vector3<T>>
{
static constexpr std::size_t Dimensions = 3;
using Base = T;
};
template<typename T>
struct VectorInfo<Vector4<T>>
{
static constexpr std::size_t Dimensions = 4;
using Base = T;
};
/*************************************************************************************************/
template<BinaryType Type, typename T1, typename T2>
@ -288,6 +316,92 @@ namespace Nz::ShaderAst
/*************************************************************************************************/
template<typename T, std::size_t TargetComponentCount, std::size_t FromComponentCount>
struct SwizzleBase;
template<typename T, std::size_t TargetComponentCount>
struct SwizzleBase<T, TargetComponentCount, 1>
{
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& /*components*/, T value)
{
if constexpr (TargetComponentCount == 4)
return ShaderBuilder::Constant(Vector4<T>(value, value, value, value));
else if constexpr (TargetComponentCount == 3)
return ShaderBuilder::Constant(Vector3<T>(value, value, value));
else if constexpr (TargetComponentCount == 2)
return ShaderBuilder::Constant(Vector2<T>(value, value));
else if constexpr (TargetComponentCount == 1)
return ShaderBuilder::Constant(value);
else
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
}
};
template<typename T, std::size_t TargetComponentCount>
struct SwizzleBase<T, TargetComponentCount, 2>
{
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector2<T>& value)
{
if constexpr (TargetComponentCount == 4)
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
else if constexpr (TargetComponentCount == 3)
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
else if constexpr (TargetComponentCount == 2)
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
else if constexpr (TargetComponentCount == 1)
return ShaderBuilder::Constant(value[components[0]]);
else
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
}
};
template<typename T, std::size_t TargetComponentCount>
struct SwizzleBase<T, TargetComponentCount, 3>
{
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector3<T>& value)
{
if constexpr (TargetComponentCount == 4)
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
else if constexpr (TargetComponentCount == 3)
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
else if constexpr (TargetComponentCount == 2)
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
else if constexpr (TargetComponentCount == 1)
return ShaderBuilder::Constant(value[components[0]]);
else
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
}
};
template<typename T, std::size_t TargetComponentCount>
struct SwizzleBase<T, TargetComponentCount, 4>
{
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector4<T>& value)
{
if constexpr (TargetComponentCount == 4)
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
else if constexpr (TargetComponentCount == 3)
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
else if constexpr (TargetComponentCount == 2)
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
else if constexpr (TargetComponentCount == 1)
return ShaderBuilder::Constant(value[components[0]]);
else
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
}
};
template<typename T, std::size_t... Args>
struct Swizzle;
template<typename T, std::size_t... Args>
struct SwizzlePropagation
{
using Op = Swizzle<T, Args...>;
};
/*************************************************************************************************/
template<UnaryType Type, typename T>
struct UnaryConstantPropagation;
@ -541,6 +655,67 @@ namespace Nz::ShaderAst
//EnableOptimisation(CastConstant, Vector3ui32, UInt32, UInt32, UInt32);
//EnableOptimisation(CastConstant, Vector4ui32, UInt32, UInt32, UInt32, UInt32);
// Swizzle
EnableOptimisation(Swizzle, double, 1, 1);
EnableOptimisation(Swizzle, double, 1, 2);
EnableOptimisation(Swizzle, double, 1, 3);
EnableOptimisation(Swizzle, double, 1, 4);
EnableOptimisation(Swizzle, double, 2, 1);
EnableOptimisation(Swizzle, double, 2, 2);
EnableOptimisation(Swizzle, double, 2, 3);
EnableOptimisation(Swizzle, double, 2, 4);
EnableOptimisation(Swizzle, double, 3, 1);
EnableOptimisation(Swizzle, double, 3, 2);
EnableOptimisation(Swizzle, double, 3, 3);
EnableOptimisation(Swizzle, double, 3, 4);
EnableOptimisation(Swizzle, double, 4, 1);
EnableOptimisation(Swizzle, double, 4, 2);
EnableOptimisation(Swizzle, double, 4, 3);
EnableOptimisation(Swizzle, double, 4, 4);
EnableOptimisation(Swizzle, float, 1, 1);
EnableOptimisation(Swizzle, float, 1, 2);
EnableOptimisation(Swizzle, float, 1, 3);
EnableOptimisation(Swizzle, float, 1, 4);
EnableOptimisation(Swizzle, float, 2, 1);
EnableOptimisation(Swizzle, float, 2, 2);
EnableOptimisation(Swizzle, float, 2, 3);
EnableOptimisation(Swizzle, float, 2, 4);
EnableOptimisation(Swizzle, float, 3, 1);
EnableOptimisation(Swizzle, float, 3, 2);
EnableOptimisation(Swizzle, float, 3, 3);
EnableOptimisation(Swizzle, float, 3, 4);
EnableOptimisation(Swizzle, float, 4, 1);
EnableOptimisation(Swizzle, float, 4, 2);
EnableOptimisation(Swizzle, float, 4, 3);
EnableOptimisation(Swizzle, float, 4, 4);
EnableOptimisation(Swizzle, Int32, 1, 1);
EnableOptimisation(Swizzle, Int32, 1, 2);
EnableOptimisation(Swizzle, Int32, 1, 3);
EnableOptimisation(Swizzle, Int32, 1, 4);
EnableOptimisation(Swizzle, Int32, 2, 1);
EnableOptimisation(Swizzle, Int32, 2, 2);
EnableOptimisation(Swizzle, Int32, 2, 3);
EnableOptimisation(Swizzle, Int32, 2, 4);
EnableOptimisation(Swizzle, Int32, 3, 1);
EnableOptimisation(Swizzle, Int32, 3, 2);
EnableOptimisation(Swizzle, Int32, 3, 3);
EnableOptimisation(Swizzle, Int32, 3, 4);
EnableOptimisation(Swizzle, Int32, 4, 1);
EnableOptimisation(Swizzle, Int32, 4, 2);
EnableOptimisation(Swizzle, Int32, 4, 3);
EnableOptimisation(Swizzle, Int32, 4, 4);
// Unary
EnableOptimisation(UnaryLogicalNot, bool);
@ -847,6 +1022,57 @@ namespace Nz::ShaderAst
return constant;
}
ExpressionPtr AstOptimizer::Clone(SwizzleExpression& node)
{
auto expr = CloneExpression(node.expression);
if (expr->GetType() == NodeType::ConstantValueExpression)
{
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expr);
ExpressionPtr optimized;
switch (node.componentCount)
{
case 1:
optimized = PropagateConstantSwizzle<1>(node.components, constantExpr);
break;
case 2:
optimized = PropagateConstantSwizzle<2>(node.components, constantExpr);
break;
case 3:
optimized = PropagateConstantSwizzle<3>(node.components, constantExpr);
break;
case 4:
optimized = PropagateConstantSwizzle<4>(node.components, constantExpr);
break;
}
if (optimized)
return optimized;
}
else if (expr->GetType() == NodeType::SwizzleExpression)
{
SwizzleExpression& constantExpr = static_cast<SwizzleExpression&>(*expr);
std::array<UInt32, 4> newComponents;
for (std::size_t i = 0; i < node.componentCount; ++i)
newComponents[i] = constantExpr.components[node.components[i]];
constantExpr.componentCount = node.componentCount;
constantExpr.components = newComponents;
return expr;
}
auto swizzle = ShaderBuilder::Swizzle(std::move(expr), node.components, node.componentCount);
swizzle->cachedExpressionType = node.cachedExpressionType;
return swizzle;
}
ExpressionPtr AstOptimizer::Clone(UnaryExpression& node)
{
auto expr = CloneExpression(node.expression);
@ -952,6 +1178,30 @@ namespace Nz::ShaderAst
return optimized;
}
template<std::size_t TargetComponentCount>
ExpressionPtr AstOptimizer::PropagateConstantSwizzle(const std::array<UInt32, 4>& components, const ConstantValueExpression& operand)
{
std::unique_ptr<ConstantValueExpression> optimized;
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
using BaseType = typename VectorInfo<T>::Base;
constexpr std::size_t FromComponentCount = VectorInfo<T>::Dimensions;
using SPType = SwizzlePropagation<BaseType, TargetComponentCount, FromComponentCount>;
if constexpr (is_complete_v<SPType>)
{
using Op = typename SPType::Op;
if constexpr (is_complete_v<Op>)
optimized = Op{}(components, arg);
}
}, operand.value);
return optimized;
}
template<UnaryType Type>
ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand)
{

View File

@ -140,6 +140,118 @@ fn main()
let output: f32 = 0.000000;
output = 3.000000;
}
)");
}
WHEN("optimizing out scalar swizzle")
{
ExpectOptimization(R"(
[entry(frag)]
fn main()
{
let value = vec3<f32>(3.0, 0.0, 1.0).z;
}
)", R"(
[entry(frag)]
fn main()
{
let value: f32 = 1.000000;
}
)");
}
WHEN("optimizing out scalar swizzle to vector")
{
ExpectOptimization(R"(
[entry(frag)]
fn main()
{
let value = (42.0).xxxx;
}
)", R"(
[entry(frag)]
fn main()
{
let value: vec4<f32> = vec4<f32>(42.000000, 42.000000, 42.000000, 42.000000);
}
)");
}
WHEN("optimizing out vector swizzle")
{
ExpectOptimization(R"(
[entry(frag)]
fn main()
{
let value = vec4<f32>(3.0, 0.0, 1.0, 2.0).yzwx;
}
)", R"(
[entry(frag)]
fn main()
{
let value: vec4<f32> = vec4<f32>(0.000000, 1.000000, 2.000000, 3.000000);
}
)");
}
WHEN("optimizing out vector swizzle with repetition")
{
ExpectOptimization(R"(
[entry(frag)]
fn main()
{
let value = vec4<f32>(3.0, 0.0, 1.0, 2.0).zzxx;
}
)", R"(
[entry(frag)]
fn main()
{
let value: vec4<f32> = vec4<f32>(1.000000, 1.000000, 3.000000, 3.000000);
}
)");
}
WHEN("optimizing out complex swizzle")
{
ExpectOptimization(R"(
[entry(frag)]
fn main()
{
let value = vec4<f32>(0.0, 1.0, 2.0, 3.0).xyz.yz.y.x.xxxx;
}
)", R"(
[entry(frag)]
fn main()
{
let value: vec4<f32> = vec4<f32>(2.000000, 2.000000, 2.000000, 2.000000);
}
)");
}
WHEN("optimizing out complex swizzle on unknown value")
{
ExpectOptimization(R"(
struct inputStruct
{
value: vec4<f32>
}
external
{
[set(0), binding(0)] data: uniform<inputStruct>
}
[entry(frag)]
fn main()
{
let value = data.value.xyz.yz.y.x.xxxx;
}
)", R"(
[entry(frag)]
fn main()
{
let value: vec4<f32> = data.value.zzzz;
}
)");
}
}