// Copyright (C) 2021 Jérôme "Lynix" Leclercq (lynix680@gmail.com) // This file is part of the "Nazara Engine - Shader module" // For conditions of distribution and use, see copyright notice in Config.hpp #include #include #include #include #include namespace Nz::ShaderAst { namespace { template std::unique_ptr static_unique_pointer_cast(std::unique_ptr&& ptr) { return std::unique_ptr(static_cast(ptr.release())); } template struct is_complete_helper { template static auto test(U*)->std::integral_constant; static auto test(...) -> std::false_type; using type = decltype(test((T*)0)); }; template struct is_complete : is_complete_helper::type {}; template inline constexpr bool is_complete_v = is_complete::value; template struct VectorInfo { static constexpr std::size_t Dimensions = 1; using Base = T; }; template struct VectorInfo> { static constexpr std::size_t Dimensions = 2; using Base = T; }; template struct VectorInfo> { static constexpr std::size_t Dimensions = 3; using Base = T; }; template struct VectorInfo> { static constexpr std::size_t Dimensions = 4; using Base = T; }; /*************************************************************************************************/ template struct BinaryConstantPropagation; // CompEq template struct BinaryCompEqBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs == rhs); } }; template struct BinaryCompEq; template struct BinaryConstantPropagation { using Op = BinaryCompEq; }; // CompGe template struct BinaryCompGeBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs >= rhs); } }; template struct BinaryCompGe; template struct BinaryConstantPropagation { using Op = BinaryCompGe; }; // CompGt template struct BinaryCompGtBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs > rhs); } }; template struct BinaryCompGt; template struct BinaryConstantPropagation { using Op = BinaryCompGt; }; // CompLe template struct BinaryCompLeBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs <= rhs); } }; template struct BinaryCompLe; template struct BinaryConstantPropagation { using Op = BinaryCompLe; }; // CompLt template struct BinaryCompLtBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs < rhs); } }; template struct BinaryCompLt; template struct BinaryConstantPropagation { using Op = BinaryCompLe; }; // CompNe template struct BinaryCompNeBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs != rhs); } }; template struct BinaryCompNe; template struct BinaryConstantPropagation { using Op = BinaryCompNe; }; // LogicalAnd template struct BinaryLogicalAndBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs && rhs); } }; template struct BinaryLogicalAnd; template struct BinaryConstantPropagation { using Op = BinaryLogicalAnd; }; // LogicalOr template struct BinaryLogicalOrBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs || rhs); } }; template struct BinaryLogicalOr; template struct BinaryConstantPropagation { using Op = BinaryLogicalOr; }; // Addition template struct BinaryAdditionBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs + rhs); } }; template struct BinaryAddition; template struct BinaryConstantPropagation { using Op = BinaryAddition; }; // Division template struct BinaryDivisionBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs / rhs); } }; template struct BinaryDivision; template struct BinaryConstantPropagation { using Op = BinaryDivision; }; // Multiplication template struct BinaryMultiplicationBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs * rhs); } }; template struct BinaryMultiplication; template struct BinaryConstantPropagation { using Op = BinaryMultiplication; }; // Subtraction template struct BinarySubtractionBase { std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs - rhs); } }; template struct BinarySubtraction; template struct BinaryConstantPropagation { using Op = BinarySubtraction; }; /*************************************************************************************************/ template struct CastConstantBase { std::unique_ptr operator()(const Args&... args) { return ShaderBuilder::Constant(T(args...)); } }; template struct CastConstant; template struct CastConstantPropagation { using Op = CastConstant; }; /*************************************************************************************************/ template struct SwizzleBase; template struct SwizzleBase { std::unique_ptr operator()(const std::array& /*components*/, T value) { if constexpr (TargetComponentCount == 4) return ShaderBuilder::Constant(Vector4(value, value, value, value)); else if constexpr (TargetComponentCount == 3) return ShaderBuilder::Constant(Vector3(value, value, value)); else if constexpr (TargetComponentCount == 2) return ShaderBuilder::Constant(Vector2(value, value)); else if constexpr (TargetComponentCount == 1) return ShaderBuilder::Constant(value); else static_assert(AlwaysFalse(), "unexpected TargetComponentCount"); } }; template struct SwizzleBase { std::unique_ptr operator()(const std::array& components, const Vector2& value) { if constexpr (TargetComponentCount == 4) return ShaderBuilder::Constant(Vector4(value[components[0]], value[components[1]], value[components[2]], value[components[3]])); else if constexpr (TargetComponentCount == 3) return ShaderBuilder::Constant(Vector3(value[components[0]], value[components[1]], value[components[2]])); else if constexpr (TargetComponentCount == 2) return ShaderBuilder::Constant(Vector2(value[components[0]], value[components[1]])); else if constexpr (TargetComponentCount == 1) return ShaderBuilder::Constant(value[components[0]]); else static_assert(AlwaysFalse(), "unexpected TargetComponentCount"); } }; template struct SwizzleBase { std::unique_ptr operator()(const std::array& components, const Vector3& value) { if constexpr (TargetComponentCount == 4) return ShaderBuilder::Constant(Vector4(value[components[0]], value[components[1]], value[components[2]], value[components[3]])); else if constexpr (TargetComponentCount == 3) return ShaderBuilder::Constant(Vector3(value[components[0]], value[components[1]], value[components[2]])); else if constexpr (TargetComponentCount == 2) return ShaderBuilder::Constant(Vector2(value[components[0]], value[components[1]])); else if constexpr (TargetComponentCount == 1) return ShaderBuilder::Constant(value[components[0]]); else static_assert(AlwaysFalse(), "unexpected TargetComponentCount"); } }; template struct SwizzleBase { std::unique_ptr operator()(const std::array& components, const Vector4& value) { if constexpr (TargetComponentCount == 4) return ShaderBuilder::Constant(Vector4(value[components[0]], value[components[1]], value[components[2]], value[components[3]])); else if constexpr (TargetComponentCount == 3) return ShaderBuilder::Constant(Vector3(value[components[0]], value[components[1]], value[components[2]])); else if constexpr (TargetComponentCount == 2) return ShaderBuilder::Constant(Vector2(value[components[0]], value[components[1]])); else if constexpr (TargetComponentCount == 1) return ShaderBuilder::Constant(value[components[0]]); else static_assert(AlwaysFalse(), "unexpected TargetComponentCount"); } }; template struct Swizzle; template struct SwizzlePropagation { using Op = Swizzle; }; /*************************************************************************************************/ template struct UnaryConstantPropagation; // LogicalNot template struct UnaryLogicalNotBase { std::unique_ptr operator()(const T& arg) { return ShaderBuilder::Constant(!arg); } }; template struct UnaryLogicalNot; template struct UnaryConstantPropagation { using Op = UnaryLogicalNot; }; // Minus template struct UnaryMinusBase { std::unique_ptr operator()(const T& arg) { return ShaderBuilder::Constant(-arg); } }; template struct UnaryMinus; template struct UnaryConstantPropagation { using Op = UnaryMinus; }; // Plus template struct UnaryPlusBase { std::unique_ptr operator()(const T& arg) { return ShaderBuilder::Constant(arg); } }; template struct UnaryPlus; template struct UnaryConstantPropagation { using Op = UnaryPlus; }; #define EnableOptimisation(Op, ...) template<> struct Op<__VA_ARGS__> : Op##Base<__VA_ARGS__> {} // Binary EnableOptimisation(BinaryCompEq, bool, bool); EnableOptimisation(BinaryCompEq, double, double); EnableOptimisation(BinaryCompEq, float, float); EnableOptimisation(BinaryCompEq, Int32, Int32); EnableOptimisation(BinaryCompEq, Vector2f, Vector2f); EnableOptimisation(BinaryCompEq, Vector3f, Vector3f); EnableOptimisation(BinaryCompEq, Vector4f, Vector4f); EnableOptimisation(BinaryCompEq, Vector2i32, Vector2i32); EnableOptimisation(BinaryCompEq, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompEq, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompGe, double, double); EnableOptimisation(BinaryCompGe, float, float); EnableOptimisation(BinaryCompGe, Int32, Int32); EnableOptimisation(BinaryCompGe, Vector2f, Vector2f); EnableOptimisation(BinaryCompGe, Vector3f, Vector3f); EnableOptimisation(BinaryCompGe, Vector4f, Vector4f); EnableOptimisation(BinaryCompGe, Vector2i32, Vector2i32); EnableOptimisation(BinaryCompGe, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompGe, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompGt, double, double); EnableOptimisation(BinaryCompGt, float, float); EnableOptimisation(BinaryCompGt, Int32, Int32); EnableOptimisation(BinaryCompGt, Vector2f, Vector2f); EnableOptimisation(BinaryCompGt, Vector3f, Vector3f); EnableOptimisation(BinaryCompGt, Vector4f, Vector4f); EnableOptimisation(BinaryCompGt, Vector2i32, Vector2i32); EnableOptimisation(BinaryCompGt, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompGt, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompLe, double, double); EnableOptimisation(BinaryCompLe, float, float); EnableOptimisation(BinaryCompLe, Int32, Int32); EnableOptimisation(BinaryCompLe, Vector2f, Vector2f); EnableOptimisation(BinaryCompLe, Vector3f, Vector3f); EnableOptimisation(BinaryCompLe, Vector4f, Vector4f); EnableOptimisation(BinaryCompLe, Vector2i32, Vector2i32); EnableOptimisation(BinaryCompLe, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompLe, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompLt, double, double); EnableOptimisation(BinaryCompLt, float, float); EnableOptimisation(BinaryCompLt, Int32, Int32); EnableOptimisation(BinaryCompLt, Vector2f, Vector2f); EnableOptimisation(BinaryCompLt, Vector3f, Vector3f); EnableOptimisation(BinaryCompLt, Vector4f, Vector4f); EnableOptimisation(BinaryCompLt, Vector2i32, Vector2i32); EnableOptimisation(BinaryCompLt, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompLt, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompNe, bool, bool); EnableOptimisation(BinaryCompNe, double, double); EnableOptimisation(BinaryCompNe, float, float); EnableOptimisation(BinaryCompNe, Int32, Int32); EnableOptimisation(BinaryCompNe, Vector2f, Vector2f); EnableOptimisation(BinaryCompNe, Vector3f, Vector3f); EnableOptimisation(BinaryCompNe, Vector4f, Vector4f); EnableOptimisation(BinaryCompNe, Vector2i32, Vector2i32); EnableOptimisation(BinaryCompNe, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompNe, Vector4i32, Vector4i32); EnableOptimisation(BinaryLogicalAnd, bool, bool); EnableOptimisation(BinaryLogicalOr, bool, bool); EnableOptimisation(BinaryAddition, double, double); EnableOptimisation(BinaryAddition, float, float); EnableOptimisation(BinaryAddition, Int32, Int32); EnableOptimisation(BinaryAddition, Vector2f, Vector2f); EnableOptimisation(BinaryAddition, Vector3f, Vector3f); EnableOptimisation(BinaryAddition, Vector4f, Vector4f); EnableOptimisation(BinaryAddition, Vector2i32, Vector2i32); EnableOptimisation(BinaryAddition, Vector3i32, Vector3i32); EnableOptimisation(BinaryAddition, Vector4i32, Vector4i32); EnableOptimisation(BinaryDivision, double, double); EnableOptimisation(BinaryDivision, double, Vector2d); EnableOptimisation(BinaryDivision, double, Vector3d); EnableOptimisation(BinaryDivision, double, Vector4d); EnableOptimisation(BinaryDivision, float, float); EnableOptimisation(BinaryDivision, float, Vector2f); EnableOptimisation(BinaryDivision, float, Vector3f); EnableOptimisation(BinaryDivision, float, Vector4f); EnableOptimisation(BinaryDivision, Int32, Int32); EnableOptimisation(BinaryDivision, Int32, Vector2i32); EnableOptimisation(BinaryDivision, Int32, Vector3i32); EnableOptimisation(BinaryDivision, Int32, Vector4i32); EnableOptimisation(BinaryDivision, Vector2f, float); EnableOptimisation(BinaryDivision, Vector2f, Vector2f); EnableOptimisation(BinaryDivision, Vector3f, float); EnableOptimisation(BinaryDivision, Vector3f, Vector3f); EnableOptimisation(BinaryDivision, Vector4f, float); EnableOptimisation(BinaryDivision, Vector4f, Vector4f); EnableOptimisation(BinaryDivision, Vector2d, double); EnableOptimisation(BinaryDivision, Vector2d, Vector2d); EnableOptimisation(BinaryDivision, Vector3d, double); EnableOptimisation(BinaryDivision, Vector3d, Vector3d); EnableOptimisation(BinaryDivision, Vector4d, double); EnableOptimisation(BinaryDivision, Vector4d, Vector4d); EnableOptimisation(BinaryDivision, Vector2i32, Int32); EnableOptimisation(BinaryDivision, Vector2i32, Vector2i32); EnableOptimisation(BinaryDivision, Vector3i32, Int32); EnableOptimisation(BinaryDivision, Vector3i32, Vector3i32); EnableOptimisation(BinaryDivision, Vector4i32, Int32); EnableOptimisation(BinaryDivision, Vector4i32, Vector4i32); EnableOptimisation(BinaryMultiplication, double, double); EnableOptimisation(BinaryMultiplication, double, Vector2d); EnableOptimisation(BinaryMultiplication, double, Vector3d); EnableOptimisation(BinaryMultiplication, double, Vector4d); EnableOptimisation(BinaryMultiplication, float, float); EnableOptimisation(BinaryMultiplication, float, Vector2f); EnableOptimisation(BinaryMultiplication, float, Vector3f); EnableOptimisation(BinaryMultiplication, float, Vector4f); EnableOptimisation(BinaryMultiplication, Int32, Int32); EnableOptimisation(BinaryMultiplication, Int32, Vector2i32); EnableOptimisation(BinaryMultiplication, Int32, Vector3i32); EnableOptimisation(BinaryMultiplication, Int32, Vector4i32); EnableOptimisation(BinaryMultiplication, Vector2f, float); EnableOptimisation(BinaryMultiplication, Vector2f, Vector2f); EnableOptimisation(BinaryMultiplication, Vector3f, float); EnableOptimisation(BinaryMultiplication, Vector3f, Vector3f); EnableOptimisation(BinaryMultiplication, Vector4f, float); EnableOptimisation(BinaryMultiplication, Vector4f, Vector4f); EnableOptimisation(BinaryMultiplication, Vector2d, double); EnableOptimisation(BinaryMultiplication, Vector2d, Vector2d); EnableOptimisation(BinaryMultiplication, Vector3d, double); EnableOptimisation(BinaryMultiplication, Vector3d, Vector3d); EnableOptimisation(BinaryMultiplication, Vector4d, double); EnableOptimisation(BinaryMultiplication, Vector4d, Vector4d); EnableOptimisation(BinaryMultiplication, Vector2i32, Int32); EnableOptimisation(BinaryMultiplication, Vector2i32, Vector2i32); EnableOptimisation(BinaryMultiplication, Vector3i32, Int32); EnableOptimisation(BinaryMultiplication, Vector3i32, Vector3i32); EnableOptimisation(BinaryMultiplication, Vector4i32, Int32); EnableOptimisation(BinaryMultiplication, Vector4i32, Vector4i32); EnableOptimisation(BinarySubtraction, double, double); EnableOptimisation(BinarySubtraction, float, float); EnableOptimisation(BinarySubtraction, Int32, Int32); EnableOptimisation(BinarySubtraction, Vector2f, Vector2f); EnableOptimisation(BinarySubtraction, Vector3f, Vector3f); EnableOptimisation(BinarySubtraction, Vector4f, Vector4f); EnableOptimisation(BinarySubtraction, Vector2i32, Vector2i32); EnableOptimisation(BinarySubtraction, Vector3i32, Vector3i32); EnableOptimisation(BinarySubtraction, Vector4i32, Vector4i32); // Cast EnableOptimisation(CastConstant, bool, bool); EnableOptimisation(CastConstant, bool, Int32); EnableOptimisation(CastConstant, bool, UInt32); EnableOptimisation(CastConstant, double, double); EnableOptimisation(CastConstant, double, float); EnableOptimisation(CastConstant, double, Int32); EnableOptimisation(CastConstant, double, UInt32); EnableOptimisation(CastConstant, float, double); EnableOptimisation(CastConstant, float, float); EnableOptimisation(CastConstant, float, Int32); EnableOptimisation(CastConstant, float, UInt32); EnableOptimisation(CastConstant, Int32, double); EnableOptimisation(CastConstant, Int32, float); EnableOptimisation(CastConstant, Int32, Int32); EnableOptimisation(CastConstant, Int32, UInt32); EnableOptimisation(CastConstant, UInt32, double); EnableOptimisation(CastConstant, UInt32, float); EnableOptimisation(CastConstant, UInt32, Int32); EnableOptimisation(CastConstant, UInt32, UInt32); //EnableOptimisation(CastConstant, Vector2d, double, double); //EnableOptimisation(CastConstant, Vector3d, double, double, double); //EnableOptimisation(CastConstant, Vector4d, double, double, double, double); EnableOptimisation(CastConstant, Vector2f, float, float); EnableOptimisation(CastConstant, Vector3f, float, float, float); EnableOptimisation(CastConstant, Vector4f, float, float, float, float); EnableOptimisation(CastConstant, Vector2i32, Int32, Int32); EnableOptimisation(CastConstant, Vector3i32, Int32, Int32, Int32); EnableOptimisation(CastConstant, Vector4i32, Int32, Int32, Int32, Int32); //EnableOptimisation(CastConstant, Vector2ui32, UInt32, UInt32); //EnableOptimisation(CastConstant, Vector3ui32, UInt32, UInt32, UInt32); //EnableOptimisation(CastConstant, Vector4ui32, UInt32, UInt32, UInt32, UInt32); // Swizzle EnableOptimisation(Swizzle, double, 1, 1); EnableOptimisation(Swizzle, double, 1, 2); EnableOptimisation(Swizzle, double, 1, 3); EnableOptimisation(Swizzle, double, 1, 4); EnableOptimisation(Swizzle, double, 2, 1); EnableOptimisation(Swizzle, double, 2, 2); EnableOptimisation(Swizzle, double, 2, 3); EnableOptimisation(Swizzle, double, 2, 4); EnableOptimisation(Swizzle, double, 3, 1); EnableOptimisation(Swizzle, double, 3, 2); EnableOptimisation(Swizzle, double, 3, 3); EnableOptimisation(Swizzle, double, 3, 4); EnableOptimisation(Swizzle, double, 4, 1); EnableOptimisation(Swizzle, double, 4, 2); EnableOptimisation(Swizzle, double, 4, 3); EnableOptimisation(Swizzle, double, 4, 4); EnableOptimisation(Swizzle, float, 1, 1); EnableOptimisation(Swizzle, float, 1, 2); EnableOptimisation(Swizzle, float, 1, 3); EnableOptimisation(Swizzle, float, 1, 4); EnableOptimisation(Swizzle, float, 2, 1); EnableOptimisation(Swizzle, float, 2, 2); EnableOptimisation(Swizzle, float, 2, 3); EnableOptimisation(Swizzle, float, 2, 4); EnableOptimisation(Swizzle, float, 3, 1); EnableOptimisation(Swizzle, float, 3, 2); EnableOptimisation(Swizzle, float, 3, 3); EnableOptimisation(Swizzle, float, 3, 4); EnableOptimisation(Swizzle, float, 4, 1); EnableOptimisation(Swizzle, float, 4, 2); EnableOptimisation(Swizzle, float, 4, 3); EnableOptimisation(Swizzle, float, 4, 4); EnableOptimisation(Swizzle, Int32, 1, 1); EnableOptimisation(Swizzle, Int32, 1, 2); EnableOptimisation(Swizzle, Int32, 1, 3); EnableOptimisation(Swizzle, Int32, 1, 4); EnableOptimisation(Swizzle, Int32, 2, 1); EnableOptimisation(Swizzle, Int32, 2, 2); EnableOptimisation(Swizzle, Int32, 2, 3); EnableOptimisation(Swizzle, Int32, 2, 4); EnableOptimisation(Swizzle, Int32, 3, 1); EnableOptimisation(Swizzle, Int32, 3, 2); EnableOptimisation(Swizzle, Int32, 3, 3); EnableOptimisation(Swizzle, Int32, 3, 4); EnableOptimisation(Swizzle, Int32, 4, 1); EnableOptimisation(Swizzle, Int32, 4, 2); EnableOptimisation(Swizzle, Int32, 4, 3); EnableOptimisation(Swizzle, Int32, 4, 4); // Unary EnableOptimisation(UnaryLogicalNot, bool); EnableOptimisation(UnaryMinus, double); EnableOptimisation(UnaryMinus, float); EnableOptimisation(UnaryMinus, Int32); EnableOptimisation(UnaryMinus, Vector2f); EnableOptimisation(UnaryMinus, Vector3f); EnableOptimisation(UnaryMinus, Vector4f); EnableOptimisation(UnaryMinus, Vector2i32); EnableOptimisation(UnaryMinus, Vector3i32); EnableOptimisation(UnaryMinus, Vector4i32); EnableOptimisation(UnaryPlus, double); EnableOptimisation(UnaryPlus, float); EnableOptimisation(UnaryPlus, Int32); EnableOptimisation(UnaryPlus, Vector2f); EnableOptimisation(UnaryPlus, Vector3f); EnableOptimisation(UnaryPlus, Vector4f); EnableOptimisation(UnaryPlus, Vector2i32); EnableOptimisation(UnaryPlus, Vector3i32); EnableOptimisation(UnaryPlus, Vector4i32); #undef EnableOptimisation } ExpressionPtr AstOptimizer::Clone(BinaryExpression& node) { auto lhs = CloneExpression(node.left); auto rhs = CloneExpression(node.right); if (lhs->GetType() == NodeType::ConstantValueExpression && rhs->GetType() == NodeType::ConstantValueExpression) { const ConstantValueExpression& lhsConstant = static_cast(*lhs); const ConstantValueExpression& rhsConstant = static_cast(*rhs); ExpressionPtr optimized; switch (node.op) { case BinaryType::Add: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::Subtract: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::Multiply: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::Divide: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompEq: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompGe: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompGt: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompLe: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompLt: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::CompNe: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::LogicalAnd: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; case BinaryType::LogicalOr: optimized = PropagateBinaryConstant(lhsConstant, rhsConstant); break; } if (optimized) return optimized; } auto binary = ShaderBuilder::Binary(node.op, std::move(lhs), std::move(rhs)); binary->cachedExpressionType = node.cachedExpressionType; return binary; } ExpressionPtr AstOptimizer::Clone(CastExpression& node) { std::array expressions; std::size_t expressionCount = 0; for (const auto& expression : node.expressions) { if (!expression) break; expressions[expressionCount] = CloneExpression(expression); expressionCount++; } ExpressionPtr optimized; if (IsPrimitiveType(node.targetType)) { if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression) { const ConstantValueExpression& constantExpr = static_cast(*expressions.front()); switch (std::get(node.targetType)) { case PrimitiveType::Boolean: optimized = PropagateSingleValueCast(constantExpr); break; case PrimitiveType::Float32: optimized = PropagateSingleValueCast(constantExpr); break; case PrimitiveType::Int32: optimized = PropagateSingleValueCast(constantExpr); break; case PrimitiveType::UInt32: optimized = PropagateSingleValueCast(constantExpr); break; } } } else if (IsVectorType(node.targetType)) { const auto& vecType = std::get(node.targetType); // Decompose vector into values (cast(vec3, float) => cast(float, float, float, float)) std::vector constantValues; for (std::size_t i = 0; i < expressionCount; ++i) { if (expressions[i]->GetType() != NodeType::ConstantValueExpression) { constantValues.clear(); break; } const auto& constantExpr = static_cast(*expressions[i]); if (!constantValues.empty() && GetExpressionType(constantValues.front()) != GetExpressionType(constantExpr.value)) { // Unhandled case, all cast parameters are expected to be of the same type constantValues.clear(); break; } std::visit([&](auto&& arg) { using T = std::decay_t; if constexpr (std::is_same_v) throw std::runtime_error("invalid type (value expected)"); else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) constantValues.push_back(arg); else if constexpr (std::is_same_v || std::is_same_v) { constantValues.push_back(arg.x); constantValues.push_back(arg.y); } else if constexpr (std::is_same_v || std::is_same_v) { constantValues.push_back(arg.x); constantValues.push_back(arg.y); constantValues.push_back(arg.z); } else if constexpr (std::is_same_v || std::is_same_v) { constantValues.push_back(arg.x); constantValues.push_back(arg.y); constantValues.push_back(arg.z); constantValues.push_back(arg.w); } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, constantExpr.value); } if (!constantValues.empty()) { assert(constantValues.size() == vecType.componentCount); std::visit([&](auto&& arg) { using T = std::decay_t; switch (vecType.componentCount) { case 2: optimized = PropagateVec2Cast(std::get(constantValues[0]), std::get(constantValues[1])); break; case 3: optimized = PropagateVec3Cast(std::get(constantValues[0]), std::get(constantValues[1]), std::get(constantValues[2])); break; case 4: optimized = PropagateVec4Cast(std::get(constantValues[0]), std::get(constantValues[1]), std::get(constantValues[2]), std::get(constantValues[3])); break; } }, constantValues.front()); } } if (optimized) return optimized; auto cast = ShaderBuilder::Cast(node.targetType, std::move(expressions)); cast->cachedExpressionType = node.cachedExpressionType; return cast; } StatementPtr AstOptimizer::Clone(BranchStatement& node) { std::vector statements; StatementPtr elseStatement; for (auto& condStatement : node.condStatements) { auto cond = CloneExpression(condStatement.condition); if (cond->GetType() == NodeType::ConstantValueExpression) { auto& constant = static_cast(*cond); const ExpressionType& constantType = GetExpressionType(constant); if (!IsPrimitiveType(constantType) || std::get(constantType) != PrimitiveType::Boolean) continue; bool cValue = std::get(constant.value); if (!cValue) continue; if (statements.empty()) { // First condition is true, dismiss the branch return AstCloner::Clone(*condStatement.statement); } else { // Some condition after the first one is true, make it the else statement and stop there elseStatement = CloneStatement(condStatement.statement); break; } } else { auto& c = statements.emplace_back(); c.condition = std::move(cond); c.statement = CloneStatement(condStatement.statement); } } if (statements.empty()) { // All conditions have been removed, replace by else statement or no-op if (node.elseStatement) return AstCloner::Clone(*node.elseStatement); else return ShaderBuilder::NoOp(); } if (!elseStatement) elseStatement = CloneStatement(node.elseStatement); return ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)); } ExpressionPtr AstOptimizer::Clone(ConditionalExpression& node) { auto cond = CloneExpression(node.condition); if (cond->GetType() != NodeType::ConstantValueExpression) throw std::runtime_error("conditional expression condition must be a constant expression"); auto& constant = static_cast(*cond); assert(constant.cachedExpressionType); const ExpressionType& constantType = constant.cachedExpressionType.value(); if (!IsPrimitiveType(constantType) || std::get(constantType) != PrimitiveType::Boolean) throw std::runtime_error("conditional expression condition must resolve to a boolean"); bool cValue = std::get(constant.value); if (cValue) return AstCloner::Clone(*node.truePath); else return AstCloner::Clone(*node.falsePath); } ExpressionPtr AstOptimizer::Clone(ConstantExpression& node) { if (!m_options.constantQueryCallback) return AstCloner::Clone(node); auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId)); constant->cachedExpressionType = GetExpressionType(constant->value); return constant; } ExpressionPtr AstOptimizer::Clone(SwizzleExpression& node) { auto expr = CloneExpression(node.expression); if (expr->GetType() == NodeType::ConstantValueExpression) { const ConstantValueExpression& constantExpr = static_cast(*expr); ExpressionPtr optimized; switch (node.componentCount) { case 1: optimized = PropagateConstantSwizzle<1>(node.components, constantExpr); break; case 2: optimized = PropagateConstantSwizzle<2>(node.components, constantExpr); break; case 3: optimized = PropagateConstantSwizzle<3>(node.components, constantExpr); break; case 4: optimized = PropagateConstantSwizzle<4>(node.components, constantExpr); break; } if (optimized) return optimized; } else if (expr->GetType() == NodeType::SwizzleExpression) { SwizzleExpression& constantExpr = static_cast(*expr); std::array newComponents = {}; for (std::size_t i = 0; i < node.componentCount; ++i) newComponents[i] = constantExpr.components[node.components[i]]; constantExpr.componentCount = node.componentCount; constantExpr.components = newComponents; return expr; } auto swizzle = ShaderBuilder::Swizzle(std::move(expr), node.components, node.componentCount); swizzle->cachedExpressionType = node.cachedExpressionType; return swizzle; } ExpressionPtr AstOptimizer::Clone(UnaryExpression& node) { auto expr = CloneExpression(node.expression); if (expr->GetType() == NodeType::ConstantValueExpression) { const ConstantValueExpression& constantExpr = static_cast(*expr); ExpressionPtr optimized; switch (node.op) { case UnaryType::LogicalNot: optimized = PropagateUnaryConstant(constantExpr); break; case UnaryType::Minus: optimized = PropagateUnaryConstant(constantExpr); break; case UnaryType::Plus: optimized = PropagateUnaryConstant(constantExpr); break; } if (optimized) return optimized; } auto unary = ShaderBuilder::Unary(node.op, std::move(expr)); unary->cachedExpressionType = node.cachedExpressionType; return unary; } StatementPtr AstOptimizer::Clone(ConditionalStatement& node) { auto cond = CloneExpression(node.condition); if (cond->GetType() != NodeType::ConstantValueExpression) throw std::runtime_error("conditional expression condition must be a constant expression"); auto& constant = static_cast(*cond); assert(constant.cachedExpressionType); const ExpressionType& constantType = constant.cachedExpressionType.value(); if (!IsPrimitiveType(constantType) || std::get(constantType) != PrimitiveType::Boolean) throw std::runtime_error("conditional expression condition must resolve to a boolean"); bool cValue = std::get(constant.value); if (cValue) return AstCloner::Clone(node); else return ShaderBuilder::NoOp(); } template ExpressionPtr AstOptimizer::PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs) { std::unique_ptr optimized; std::visit([&](auto&& arg1) { using T1 = std::decay_t; std::visit([&](auto&& arg2) { using T2 = std::decay_t; using PCType = BinaryConstantPropagation; if constexpr (is_complete_v) { using Op = typename PCType::Op; if constexpr (is_complete_v) optimized = Op{}(arg1, arg2); } }, rhs.value); }, lhs.value); if (optimized) optimized->cachedExpressionType = GetExpressionType(optimized->value); return optimized; } template ExpressionPtr AstOptimizer::PropagateSingleValueCast(const ConstantValueExpression& operand) { std::unique_ptr optimized; std::visit([&](auto&& arg) { using T = std::decay_t; using CCType = CastConstantPropagation; if constexpr (is_complete_v) { using Op = typename CCType::Op; if constexpr (is_complete_v) optimized = Op{}(arg); } }, operand.value); return optimized; } template ExpressionPtr AstOptimizer::PropagateConstantSwizzle(const std::array& components, const ConstantValueExpression& operand) { std::unique_ptr optimized; std::visit([&](auto&& arg) { using T = std::decay_t; using BaseType = typename VectorInfo::Base; constexpr std::size_t FromComponentCount = VectorInfo::Dimensions; using SPType = SwizzlePropagation; if constexpr (is_complete_v) { using Op = typename SPType::Op; if constexpr (is_complete_v) optimized = Op{}(components, arg); } }, operand.value); return optimized; } template ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand) { std::unique_ptr optimized; std::visit([&](auto&& arg) { using T = std::decay_t; using PCType = UnaryConstantPropagation; if constexpr (is_complete_v) { using Op = typename PCType::Op; if constexpr (is_complete_v) optimized = Op{}(arg); } }, operand.value); if (optimized) optimized->cachedExpressionType = GetExpressionType(optimized->value); return optimized; } template ExpressionPtr AstOptimizer::PropagateVec2Cast(TargetType v1, TargetType v2) { std::unique_ptr optimized; using CCType = CastConstantPropagation, TargetType, TargetType>; if constexpr (is_complete_v) { using Op = typename CCType::Op; if constexpr (is_complete_v) optimized = Op{}(v1, v2); } return optimized; } template ExpressionPtr AstOptimizer::PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3) { std::unique_ptr optimized; using CCType = CastConstantPropagation, TargetType, TargetType, TargetType>; if constexpr (is_complete_v) { using Op = typename CCType::Op; if constexpr (is_complete_v) optimized = Op{}(v1, v2, v3); } return optimized; } template ExpressionPtr AstOptimizer::PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4) { std::unique_ptr optimized; using CCType = CastConstantPropagation, TargetType, TargetType, TargetType, TargetType>; if constexpr (is_complete_v) { using Op = typename CCType::Op; if constexpr (is_complete_v) optimized = Op{}(v1, v2, v3, v4); } return optimized; } }