Shader/AstOptimizer: Add swizzle optimization
This commit is contained in:
parent
22651255df
commit
101a3d70da
|
|
@ -44,12 +44,14 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr Clone(CastExpression& node) override;
|
||||
ExpressionPtr Clone(ConditionalExpression& node) override;
|
||||
ExpressionPtr Clone(ConstantExpression& node) override;
|
||||
ExpressionPtr Clone(SwizzleExpression& node) override;
|
||||
ExpressionPtr Clone(UnaryExpression& node) override;
|
||||
StatementPtr Clone(BranchStatement& node) override;
|
||||
StatementPtr Clone(ConditionalStatement& node) override;
|
||||
|
||||
template<BinaryType Type> ExpressionPtr PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs);
|
||||
template<typename TargetType> ExpressionPtr PropagateSingleValueCast(const ConstantValueExpression& operand);
|
||||
template<std::size_t TargetComponentCount> ExpressionPtr PropagateConstantSwizzle(const std::array<UInt32, 4>& components, const ConstantValueExpression& operand);
|
||||
template<UnaryType Type> ExpressionPtr PropagateUnaryConstant(const ConstantValueExpression& operand);
|
||||
template<typename TargetType> ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2);
|
||||
template<typename TargetType> ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3);
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ namespace Nz::ShaderBuilder
|
|||
|
||||
struct Swizzle
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const;
|
||||
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::vector<UInt32> swizzleComponents) const;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -289,6 +289,19 @@ namespace Nz::ShaderBuilder
|
|||
return returnNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const
|
||||
{
|
||||
assert(componentCount > 0);
|
||||
assert(componentCount <= 4);
|
||||
|
||||
auto swizzleNode = std::make_unique<ShaderAst::SwizzleExpression>();
|
||||
swizzleNode->expression = std::move(expression);
|
||||
swizzleNode->componentCount = componentCount;
|
||||
swizzleNode->components = swizzleComponents;
|
||||
|
||||
return swizzleNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector<UInt32> swizzleComponents) const
|
||||
{
|
||||
auto swizzleNode = std::make_unique<ShaderAst::SwizzleExpression>();
|
||||
|
|
|
|||
|
|
@ -33,6 +33,34 @@ namespace Nz::ShaderAst
|
|||
template<typename T>
|
||||
inline constexpr bool is_complete_v = is_complete<T>::value;
|
||||
|
||||
template<typename T>
|
||||
struct VectorInfo
|
||||
{
|
||||
static constexpr std::size_t Dimensions = 1;
|
||||
using Base = T;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct VectorInfo<Vector2<T>>
|
||||
{
|
||||
static constexpr std::size_t Dimensions = 2;
|
||||
using Base = T;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct VectorInfo<Vector3<T>>
|
||||
{
|
||||
static constexpr std::size_t Dimensions = 3;
|
||||
using Base = T;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct VectorInfo<Vector4<T>>
|
||||
{
|
||||
static constexpr std::size_t Dimensions = 4;
|
||||
using Base = T;
|
||||
};
|
||||
|
||||
/*************************************************************************************************/
|
||||
|
||||
template<BinaryType Type, typename T1, typename T2>
|
||||
|
|
@ -288,6 +316,92 @@ namespace Nz::ShaderAst
|
|||
|
||||
/*************************************************************************************************/
|
||||
|
||||
template<typename T, std::size_t TargetComponentCount, std::size_t FromComponentCount>
|
||||
struct SwizzleBase;
|
||||
|
||||
template<typename T, std::size_t TargetComponentCount>
|
||||
struct SwizzleBase<T, TargetComponentCount, 1>
|
||||
{
|
||||
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& /*components*/, T value)
|
||||
{
|
||||
if constexpr (TargetComponentCount == 4)
|
||||
return ShaderBuilder::Constant(Vector4<T>(value, value, value, value));
|
||||
else if constexpr (TargetComponentCount == 3)
|
||||
return ShaderBuilder::Constant(Vector3<T>(value, value, value));
|
||||
else if constexpr (TargetComponentCount == 2)
|
||||
return ShaderBuilder::Constant(Vector2<T>(value, value));
|
||||
else if constexpr (TargetComponentCount == 1)
|
||||
return ShaderBuilder::Constant(value);
|
||||
else
|
||||
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, std::size_t TargetComponentCount>
|
||||
struct SwizzleBase<T, TargetComponentCount, 2>
|
||||
{
|
||||
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector2<T>& value)
|
||||
{
|
||||
if constexpr (TargetComponentCount == 4)
|
||||
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
|
||||
else if constexpr (TargetComponentCount == 3)
|
||||
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
|
||||
else if constexpr (TargetComponentCount == 2)
|
||||
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
|
||||
else if constexpr (TargetComponentCount == 1)
|
||||
return ShaderBuilder::Constant(value[components[0]]);
|
||||
else
|
||||
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, std::size_t TargetComponentCount>
|
||||
struct SwizzleBase<T, TargetComponentCount, 3>
|
||||
{
|
||||
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector3<T>& value)
|
||||
{
|
||||
if constexpr (TargetComponentCount == 4)
|
||||
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
|
||||
else if constexpr (TargetComponentCount == 3)
|
||||
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
|
||||
else if constexpr (TargetComponentCount == 2)
|
||||
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
|
||||
else if constexpr (TargetComponentCount == 1)
|
||||
return ShaderBuilder::Constant(value[components[0]]);
|
||||
else
|
||||
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, std::size_t TargetComponentCount>
|
||||
struct SwizzleBase<T, TargetComponentCount, 4>
|
||||
{
|
||||
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector4<T>& value)
|
||||
{
|
||||
if constexpr (TargetComponentCount == 4)
|
||||
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
|
||||
else if constexpr (TargetComponentCount == 3)
|
||||
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
|
||||
else if constexpr (TargetComponentCount == 2)
|
||||
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
|
||||
else if constexpr (TargetComponentCount == 1)
|
||||
return ShaderBuilder::Constant(value[components[0]]);
|
||||
else
|
||||
static_assert(AlwaysFalse<T>, "unexpected TargetComponentCount");
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, std::size_t... Args>
|
||||
struct Swizzle;
|
||||
|
||||
template<typename T, std::size_t... Args>
|
||||
struct SwizzlePropagation
|
||||
{
|
||||
using Op = Swizzle<T, Args...>;
|
||||
};
|
||||
|
||||
/*************************************************************************************************/
|
||||
|
||||
template<UnaryType Type, typename T>
|
||||
struct UnaryConstantPropagation;
|
||||
|
||||
|
|
@ -541,6 +655,67 @@ namespace Nz::ShaderAst
|
|||
//EnableOptimisation(CastConstant, Vector3ui32, UInt32, UInt32, UInt32);
|
||||
//EnableOptimisation(CastConstant, Vector4ui32, UInt32, UInt32, UInt32, UInt32);
|
||||
|
||||
// Swizzle
|
||||
EnableOptimisation(Swizzle, double, 1, 1);
|
||||
EnableOptimisation(Swizzle, double, 1, 2);
|
||||
EnableOptimisation(Swizzle, double, 1, 3);
|
||||
EnableOptimisation(Swizzle, double, 1, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, double, 2, 1);
|
||||
EnableOptimisation(Swizzle, double, 2, 2);
|
||||
EnableOptimisation(Swizzle, double, 2, 3);
|
||||
EnableOptimisation(Swizzle, double, 2, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, double, 3, 1);
|
||||
EnableOptimisation(Swizzle, double, 3, 2);
|
||||
EnableOptimisation(Swizzle, double, 3, 3);
|
||||
EnableOptimisation(Swizzle, double, 3, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, double, 4, 1);
|
||||
EnableOptimisation(Swizzle, double, 4, 2);
|
||||
EnableOptimisation(Swizzle, double, 4, 3);
|
||||
EnableOptimisation(Swizzle, double, 4, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, float, 1, 1);
|
||||
EnableOptimisation(Swizzle, float, 1, 2);
|
||||
EnableOptimisation(Swizzle, float, 1, 3);
|
||||
EnableOptimisation(Swizzle, float, 1, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, float, 2, 1);
|
||||
EnableOptimisation(Swizzle, float, 2, 2);
|
||||
EnableOptimisation(Swizzle, float, 2, 3);
|
||||
EnableOptimisation(Swizzle, float, 2, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, float, 3, 1);
|
||||
EnableOptimisation(Swizzle, float, 3, 2);
|
||||
EnableOptimisation(Swizzle, float, 3, 3);
|
||||
EnableOptimisation(Swizzle, float, 3, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, float, 4, 1);
|
||||
EnableOptimisation(Swizzle, float, 4, 2);
|
||||
EnableOptimisation(Swizzle, float, 4, 3);
|
||||
EnableOptimisation(Swizzle, float, 4, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, Int32, 1, 1);
|
||||
EnableOptimisation(Swizzle, Int32, 1, 2);
|
||||
EnableOptimisation(Swizzle, Int32, 1, 3);
|
||||
EnableOptimisation(Swizzle, Int32, 1, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, Int32, 2, 1);
|
||||
EnableOptimisation(Swizzle, Int32, 2, 2);
|
||||
EnableOptimisation(Swizzle, Int32, 2, 3);
|
||||
EnableOptimisation(Swizzle, Int32, 2, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, Int32, 3, 1);
|
||||
EnableOptimisation(Swizzle, Int32, 3, 2);
|
||||
EnableOptimisation(Swizzle, Int32, 3, 3);
|
||||
EnableOptimisation(Swizzle, Int32, 3, 4);
|
||||
|
||||
EnableOptimisation(Swizzle, Int32, 4, 1);
|
||||
EnableOptimisation(Swizzle, Int32, 4, 2);
|
||||
EnableOptimisation(Swizzle, Int32, 4, 3);
|
||||
EnableOptimisation(Swizzle, Int32, 4, 4);
|
||||
|
||||
// Unary
|
||||
|
||||
EnableOptimisation(UnaryLogicalNot, bool);
|
||||
|
|
@ -847,6 +1022,57 @@ namespace Nz::ShaderAst
|
|||
return constant;
|
||||
}
|
||||
|
||||
ExpressionPtr AstOptimizer::Clone(SwizzleExpression& node)
|
||||
{
|
||||
auto expr = CloneExpression(node.expression);
|
||||
|
||||
if (expr->GetType() == NodeType::ConstantValueExpression)
|
||||
{
|
||||
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expr);
|
||||
|
||||
ExpressionPtr optimized;
|
||||
switch (node.componentCount)
|
||||
{
|
||||
case 1:
|
||||
optimized = PropagateConstantSwizzle<1>(node.components, constantExpr);
|
||||
break;
|
||||
|
||||
case 2:
|
||||
optimized = PropagateConstantSwizzle<2>(node.components, constantExpr);
|
||||
break;
|
||||
|
||||
case 3:
|
||||
optimized = PropagateConstantSwizzle<3>(node.components, constantExpr);
|
||||
break;
|
||||
|
||||
case 4:
|
||||
optimized = PropagateConstantSwizzle<4>(node.components, constantExpr);
|
||||
break;
|
||||
}
|
||||
|
||||
if (optimized)
|
||||
return optimized;
|
||||
}
|
||||
else if (expr->GetType() == NodeType::SwizzleExpression)
|
||||
{
|
||||
SwizzleExpression& constantExpr = static_cast<SwizzleExpression&>(*expr);
|
||||
|
||||
std::array<UInt32, 4> newComponents;
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
newComponents[i] = constantExpr.components[node.components[i]];
|
||||
|
||||
constantExpr.componentCount = node.componentCount;
|
||||
constantExpr.components = newComponents;
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
auto swizzle = ShaderBuilder::Swizzle(std::move(expr), node.components, node.componentCount);
|
||||
swizzle->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return swizzle;
|
||||
}
|
||||
|
||||
ExpressionPtr AstOptimizer::Clone(UnaryExpression& node)
|
||||
{
|
||||
auto expr = CloneExpression(node.expression);
|
||||
|
|
@ -952,6 +1178,30 @@ namespace Nz::ShaderAst
|
|||
return optimized;
|
||||
}
|
||||
|
||||
template<std::size_t TargetComponentCount>
|
||||
ExpressionPtr AstOptimizer::PropagateConstantSwizzle(const std::array<UInt32, 4>& components, const ConstantValueExpression& operand)
|
||||
{
|
||||
std::unique_ptr<ConstantValueExpression> optimized;
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
using BaseType = typename VectorInfo<T>::Base;
|
||||
constexpr std::size_t FromComponentCount = VectorInfo<T>::Dimensions;
|
||||
|
||||
using SPType = SwizzlePropagation<BaseType, TargetComponentCount, FromComponentCount>;
|
||||
|
||||
if constexpr (is_complete_v<SPType>)
|
||||
{
|
||||
using Op = typename SPType::Op;
|
||||
if constexpr (is_complete_v<Op>)
|
||||
optimized = Op{}(components, arg);
|
||||
}
|
||||
}, operand.value);
|
||||
|
||||
return optimized;
|
||||
}
|
||||
|
||||
template<UnaryType Type>
|
||||
ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -140,6 +140,118 @@ fn main()
|
|||
let output: f32 = 0.000000;
|
||||
output = 3.000000;
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
WHEN("optimizing out scalar swizzle")
|
||||
{
|
||||
ExpectOptimization(R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value = vec3<f32>(3.0, 0.0, 1.0).z;
|
||||
}
|
||||
)", R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: f32 = 1.000000;
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
WHEN("optimizing out scalar swizzle to vector")
|
||||
{
|
||||
ExpectOptimization(R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value = (42.0).xxxx;
|
||||
}
|
||||
)", R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: vec4<f32> = vec4<f32>(42.000000, 42.000000, 42.000000, 42.000000);
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
WHEN("optimizing out vector swizzle")
|
||||
{
|
||||
ExpectOptimization(R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value = vec4<f32>(3.0, 0.0, 1.0, 2.0).yzwx;
|
||||
}
|
||||
)", R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: vec4<f32> = vec4<f32>(0.000000, 1.000000, 2.000000, 3.000000);
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
WHEN("optimizing out vector swizzle with repetition")
|
||||
{
|
||||
ExpectOptimization(R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value = vec4<f32>(3.0, 0.0, 1.0, 2.0).zzxx;
|
||||
}
|
||||
)", R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: vec4<f32> = vec4<f32>(1.000000, 1.000000, 3.000000, 3.000000);
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
WHEN("optimizing out complex swizzle")
|
||||
{
|
||||
ExpectOptimization(R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value = vec4<f32>(0.0, 1.0, 2.0, 3.0).xyz.yz.y.x.xxxx;
|
||||
}
|
||||
)", R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: vec4<f32> = vec4<f32>(2.000000, 2.000000, 2.000000, 2.000000);
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
WHEN("optimizing out complex swizzle on unknown value")
|
||||
{
|
||||
ExpectOptimization(R"(
|
||||
struct inputStruct
|
||||
{
|
||||
value: vec4<f32>
|
||||
}
|
||||
|
||||
external
|
||||
{
|
||||
[set(0), binding(0)] data: uniform<inputStruct>
|
||||
}
|
||||
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value = data.value.xyz.yz.y.x.xxxx;
|
||||
}
|
||||
)", R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: vec4<f32> = data.value.zzzz;
|
||||
}
|
||||
)");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue