diff --git a/include/Nazara/Shader/Ast/AstOptimizer.hpp b/include/Nazara/Shader/Ast/AstOptimizer.hpp index 32b9cd821..c6f2d17fe 100644 --- a/include/Nazara/Shader/Ast/AstOptimizer.hpp +++ b/include/Nazara/Shader/Ast/AstOptimizer.hpp @@ -44,12 +44,14 @@ namespace Nz::ShaderAst ExpressionPtr Clone(CastExpression& node) override; ExpressionPtr Clone(ConditionalExpression& node) override; ExpressionPtr Clone(ConstantExpression& node) override; + ExpressionPtr Clone(SwizzleExpression& node) override; ExpressionPtr Clone(UnaryExpression& node) override; StatementPtr Clone(BranchStatement& node) override; StatementPtr Clone(ConditionalStatement& node) override; template ExpressionPtr PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs); template ExpressionPtr PropagateSingleValueCast(const ConstantValueExpression& operand); + template ExpressionPtr PropagateConstantSwizzle(const std::array& components, const ConstantValueExpression& operand); template ExpressionPtr PropagateUnaryConstant(const ConstantValueExpression& operand); template ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2); template ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3); diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 391a32975..bd1a1008b 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -134,6 +134,7 @@ namespace Nz::ShaderBuilder struct Swizzle { + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression, std::array swizzleComponents, std::size_t componentCount) const; inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression, std::vector swizzleComponents) const; }; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 1a9f7f0b1..9c903f35d 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -289,6 +289,19 @@ namespace Nz::ShaderBuilder return returnNode; } + inline std::unique_ptr Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::array swizzleComponents, std::size_t componentCount) const + { + assert(componentCount > 0); + assert(componentCount <= 4); + + auto swizzleNode = std::make_unique(); + swizzleNode->expression = std::move(expression); + swizzleNode->componentCount = componentCount; + swizzleNode->components = swizzleComponents; + + return swizzleNode; + } + inline std::unique_ptr Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector swizzleComponents) const { auto swizzleNode = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstOptimizer.cpp b/src/Nazara/Shader/Ast/AstOptimizer.cpp index 289a7904b..308a2da16 100644 --- a/src/Nazara/Shader/Ast/AstOptimizer.cpp +++ b/src/Nazara/Shader/Ast/AstOptimizer.cpp @@ -33,6 +33,34 @@ namespace Nz::ShaderAst template inline constexpr bool is_complete_v = is_complete::value; + template + struct VectorInfo + { + static constexpr std::size_t Dimensions = 1; + using Base = T; + }; + + template + struct VectorInfo> + { + static constexpr std::size_t Dimensions = 2; + using Base = T; + }; + + template + struct VectorInfo> + { + static constexpr std::size_t Dimensions = 3; + using Base = T; + }; + + template + struct VectorInfo> + { + static constexpr std::size_t Dimensions = 4; + using Base = T; + }; + /*************************************************************************************************/ template @@ -288,6 +316,92 @@ namespace Nz::ShaderAst /*************************************************************************************************/ + template + struct SwizzleBase; + + template + struct SwizzleBase + { + std::unique_ptr operator()(const std::array& /*components*/, T value) + { + if constexpr (TargetComponentCount == 4) + return ShaderBuilder::Constant(Vector4(value, value, value, value)); + else if constexpr (TargetComponentCount == 3) + return ShaderBuilder::Constant(Vector3(value, value, value)); + else if constexpr (TargetComponentCount == 2) + return ShaderBuilder::Constant(Vector2(value, value)); + else if constexpr (TargetComponentCount == 1) + return ShaderBuilder::Constant(value); + else + static_assert(AlwaysFalse, "unexpected TargetComponentCount"); + } + }; + + template + struct SwizzleBase + { + std::unique_ptr operator()(const std::array& components, const Vector2& value) + { + if constexpr (TargetComponentCount == 4) + return ShaderBuilder::Constant(Vector4(value[components[0]], value[components[1]], value[components[2]], value[components[3]])); + else if constexpr (TargetComponentCount == 3) + return ShaderBuilder::Constant(Vector3(value[components[0]], value[components[1]], value[components[2]])); + else if constexpr (TargetComponentCount == 2) + return ShaderBuilder::Constant(Vector2(value[components[0]], value[components[1]])); + else if constexpr (TargetComponentCount == 1) + return ShaderBuilder::Constant(value[components[0]]); + else + static_assert(AlwaysFalse, "unexpected TargetComponentCount"); + } + }; + + template + struct SwizzleBase + { + std::unique_ptr operator()(const std::array& components, const Vector3& value) + { + if constexpr (TargetComponentCount == 4) + return ShaderBuilder::Constant(Vector4(value[components[0]], value[components[1]], value[components[2]], value[components[3]])); + else if constexpr (TargetComponentCount == 3) + return ShaderBuilder::Constant(Vector3(value[components[0]], value[components[1]], value[components[2]])); + else if constexpr (TargetComponentCount == 2) + return ShaderBuilder::Constant(Vector2(value[components[0]], value[components[1]])); + else if constexpr (TargetComponentCount == 1) + return ShaderBuilder::Constant(value[components[0]]); + else + static_assert(AlwaysFalse, "unexpected TargetComponentCount"); + } + }; + + template + struct SwizzleBase + { + std::unique_ptr operator()(const std::array& components, const Vector4& value) + { + if constexpr (TargetComponentCount == 4) + return ShaderBuilder::Constant(Vector4(value[components[0]], value[components[1]], value[components[2]], value[components[3]])); + else if constexpr (TargetComponentCount == 3) + return ShaderBuilder::Constant(Vector3(value[components[0]], value[components[1]], value[components[2]])); + else if constexpr (TargetComponentCount == 2) + return ShaderBuilder::Constant(Vector2(value[components[0]], value[components[1]])); + else if constexpr (TargetComponentCount == 1) + return ShaderBuilder::Constant(value[components[0]]); + else + static_assert(AlwaysFalse, "unexpected TargetComponentCount"); + } + }; + + template + struct Swizzle; + + template + struct SwizzlePropagation + { + using Op = Swizzle; + }; + + /*************************************************************************************************/ + template struct UnaryConstantPropagation; @@ -541,6 +655,67 @@ namespace Nz::ShaderAst //EnableOptimisation(CastConstant, Vector3ui32, UInt32, UInt32, UInt32); //EnableOptimisation(CastConstant, Vector4ui32, UInt32, UInt32, UInt32, UInt32); + // Swizzle + EnableOptimisation(Swizzle, double, 1, 1); + EnableOptimisation(Swizzle, double, 1, 2); + EnableOptimisation(Swizzle, double, 1, 3); + EnableOptimisation(Swizzle, double, 1, 4); + + EnableOptimisation(Swizzle, double, 2, 1); + EnableOptimisation(Swizzle, double, 2, 2); + EnableOptimisation(Swizzle, double, 2, 3); + EnableOptimisation(Swizzle, double, 2, 4); + + EnableOptimisation(Swizzle, double, 3, 1); + EnableOptimisation(Swizzle, double, 3, 2); + EnableOptimisation(Swizzle, double, 3, 3); + EnableOptimisation(Swizzle, double, 3, 4); + + EnableOptimisation(Swizzle, double, 4, 1); + EnableOptimisation(Swizzle, double, 4, 2); + EnableOptimisation(Swizzle, double, 4, 3); + EnableOptimisation(Swizzle, double, 4, 4); + + EnableOptimisation(Swizzle, float, 1, 1); + EnableOptimisation(Swizzle, float, 1, 2); + EnableOptimisation(Swizzle, float, 1, 3); + EnableOptimisation(Swizzle, float, 1, 4); + + EnableOptimisation(Swizzle, float, 2, 1); + EnableOptimisation(Swizzle, float, 2, 2); + EnableOptimisation(Swizzle, float, 2, 3); + EnableOptimisation(Swizzle, float, 2, 4); + + EnableOptimisation(Swizzle, float, 3, 1); + EnableOptimisation(Swizzle, float, 3, 2); + EnableOptimisation(Swizzle, float, 3, 3); + EnableOptimisation(Swizzle, float, 3, 4); + + EnableOptimisation(Swizzle, float, 4, 1); + EnableOptimisation(Swizzle, float, 4, 2); + EnableOptimisation(Swizzle, float, 4, 3); + EnableOptimisation(Swizzle, float, 4, 4); + + EnableOptimisation(Swizzle, Int32, 1, 1); + EnableOptimisation(Swizzle, Int32, 1, 2); + EnableOptimisation(Swizzle, Int32, 1, 3); + EnableOptimisation(Swizzle, Int32, 1, 4); + + EnableOptimisation(Swizzle, Int32, 2, 1); + EnableOptimisation(Swizzle, Int32, 2, 2); + EnableOptimisation(Swizzle, Int32, 2, 3); + EnableOptimisation(Swizzle, Int32, 2, 4); + + EnableOptimisation(Swizzle, Int32, 3, 1); + EnableOptimisation(Swizzle, Int32, 3, 2); + EnableOptimisation(Swizzle, Int32, 3, 3); + EnableOptimisation(Swizzle, Int32, 3, 4); + + EnableOptimisation(Swizzle, Int32, 4, 1); + EnableOptimisation(Swizzle, Int32, 4, 2); + EnableOptimisation(Swizzle, Int32, 4, 3); + EnableOptimisation(Swizzle, Int32, 4, 4); + // Unary EnableOptimisation(UnaryLogicalNot, bool); @@ -847,6 +1022,57 @@ namespace Nz::ShaderAst return constant; } + ExpressionPtr AstOptimizer::Clone(SwizzleExpression& node) + { + auto expr = CloneExpression(node.expression); + + if (expr->GetType() == NodeType::ConstantValueExpression) + { + const ConstantValueExpression& constantExpr = static_cast(*expr); + + ExpressionPtr optimized; + switch (node.componentCount) + { + case 1: + optimized = PropagateConstantSwizzle<1>(node.components, constantExpr); + break; + + case 2: + optimized = PropagateConstantSwizzle<2>(node.components, constantExpr); + break; + + case 3: + optimized = PropagateConstantSwizzle<3>(node.components, constantExpr); + break; + + case 4: + optimized = PropagateConstantSwizzle<4>(node.components, constantExpr); + break; + } + + if (optimized) + return optimized; + } + else if (expr->GetType() == NodeType::SwizzleExpression) + { + SwizzleExpression& constantExpr = static_cast(*expr); + + std::array newComponents; + for (std::size_t i = 0; i < node.componentCount; ++i) + newComponents[i] = constantExpr.components[node.components[i]]; + + constantExpr.componentCount = node.componentCount; + constantExpr.components = newComponents; + + return expr; + } + + auto swizzle = ShaderBuilder::Swizzle(std::move(expr), node.components, node.componentCount); + swizzle->cachedExpressionType = node.cachedExpressionType; + + return swizzle; + } + ExpressionPtr AstOptimizer::Clone(UnaryExpression& node) { auto expr = CloneExpression(node.expression); @@ -952,6 +1178,30 @@ namespace Nz::ShaderAst return optimized; } + template + ExpressionPtr AstOptimizer::PropagateConstantSwizzle(const std::array& components, const ConstantValueExpression& operand) + { + std::unique_ptr optimized; + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + using BaseType = typename VectorInfo::Base; + constexpr std::size_t FromComponentCount = VectorInfo::Dimensions; + + using SPType = SwizzlePropagation; + + if constexpr (is_complete_v) + { + using Op = typename SPType::Op; + if constexpr (is_complete_v) + optimized = Op{}(components, arg); + } + }, operand.value); + + return optimized; + } + template ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand) { diff --git a/tests/Engine/Shader/Optimizations.cpp b/tests/Engine/Shader/Optimizations.cpp index 4e51b8cc7..e23d1484f 100644 --- a/tests/Engine/Shader/Optimizations.cpp +++ b/tests/Engine/Shader/Optimizations.cpp @@ -140,6 +140,118 @@ fn main() let output: f32 = 0.000000; output = 3.000000; } +)"); + } + + WHEN("optimizing out scalar swizzle") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + let value = vec3(3.0, 0.0, 1.0).z; +} +)", R"( +[entry(frag)] +fn main() +{ + let value: f32 = 1.000000; +} +)"); + } + + WHEN("optimizing out scalar swizzle to vector") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + let value = (42.0).xxxx; +} +)", R"( +[entry(frag)] +fn main() +{ + let value: vec4 = vec4(42.000000, 42.000000, 42.000000, 42.000000); +} +)"); + } + + WHEN("optimizing out vector swizzle") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + let value = vec4(3.0, 0.0, 1.0, 2.0).yzwx; +} +)", R"( +[entry(frag)] +fn main() +{ + let value: vec4 = vec4(0.000000, 1.000000, 2.000000, 3.000000); +} +)"); + } + + WHEN("optimizing out vector swizzle with repetition") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + let value = vec4(3.0, 0.0, 1.0, 2.0).zzxx; +} +)", R"( +[entry(frag)] +fn main() +{ + let value: vec4 = vec4(1.000000, 1.000000, 3.000000, 3.000000); +} +)"); + } + + WHEN("optimizing out complex swizzle") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + let value = vec4(0.0, 1.0, 2.0, 3.0).xyz.yz.y.x.xxxx; +} +)", R"( +[entry(frag)] +fn main() +{ + let value: vec4 = vec4(2.000000, 2.000000, 2.000000, 2.000000); +} +)"); + } + + WHEN("optimizing out complex swizzle on unknown value") + { + ExpectOptimization(R"( +struct inputStruct +{ + value: vec4 +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + let value = data.value.xyz.yz.y.x.xxxx; +} +)", R"( +[entry(frag)] +fn main() +{ + let value: vec4 = data.value.zzzz; +} )"); } }