1279 lines
41 KiB
C++
1279 lines
41 KiB
C++
// Copyright (C) 2021 Jérôme "Lynix" Leclercq (lynix680@gmail.com)
|
|
// This file is part of the "Nazara Engine - Shader module"
|
|
// For conditions of distribution and use, see copyright notice in Config.hpp
|
|
|
|
#include <Nazara/Shader/Ast/AstOptimizer.hpp>
|
|
#include <Nazara/Shader/ShaderBuilder.hpp>
|
|
#include <cassert>
|
|
#include <stdexcept>
|
|
#include <Nazara/Shader/Debug.hpp>
|
|
|
|
namespace Nz::ShaderAst
|
|
{
|
|
namespace
|
|
{
|
|
template<typename T, typename U>
|
|
std::unique_ptr<T> static_unique_pointer_cast(std::unique_ptr<U>&& ptr)
|
|
{
|
|
return std::unique_ptr<T>(static_cast<T*>(ptr.release()));
|
|
}
|
|
|
|
template <typename T>
|
|
struct is_complete_helper
|
|
{
|
|
template <typename U> static auto test(U*)->std::integral_constant<bool, sizeof(U) == sizeof(U)>;
|
|
static auto test(...) -> std::false_type;
|
|
|
|
using type = decltype(test((T*)0));
|
|
};
|
|
|
|
template <typename T>
|
|
struct is_complete : is_complete_helper<T>::type {};
|
|
|
|
template<typename T>
|
|
inline constexpr bool is_complete_v = is_complete<T>::value;
|
|
|
|
template<typename T>
|
|
struct VectorInfo
|
|
{
|
|
static constexpr std::size_t Dimensions = 1;
|
|
using Base = T;
|
|
};
|
|
|
|
template<typename T>
|
|
struct VectorInfo<Vector2<T>>
|
|
{
|
|
static constexpr std::size_t Dimensions = 2;
|
|
using Base = T;
|
|
};
|
|
|
|
template<typename T>
|
|
struct VectorInfo<Vector3<T>>
|
|
{
|
|
static constexpr std::size_t Dimensions = 3;
|
|
using Base = T;
|
|
};
|
|
|
|
template<typename T>
|
|
struct VectorInfo<Vector4<T>>
|
|
{
|
|
static constexpr std::size_t Dimensions = 4;
|
|
using Base = T;
|
|
};
|
|
|
|
/*************************************************************************************************/
|
|
|
|
template<BinaryType Type, typename T1, typename T2>
|
|
struct BinaryConstantPropagation;
|
|
|
|
// CompEq
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompEqBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs == rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompEq;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::CompEq, T1, T2>
|
|
{
|
|
using Op = BinaryCompEq<T1, T2>;
|
|
};
|
|
|
|
// CompGe
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompGeBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs >= rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompGe;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::CompGe, T1, T2>
|
|
{
|
|
using Op = BinaryCompGe<T1, T2>;
|
|
};
|
|
|
|
// CompGt
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompGtBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs > rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompGt;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::CompGt, T1, T2>
|
|
{
|
|
using Op = BinaryCompGt<T1, T2>;
|
|
};
|
|
|
|
// CompLe
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompLeBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs <= rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompLe;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::CompLe, T1, T2>
|
|
{
|
|
using Op = BinaryCompLe<T1, T2>;
|
|
};
|
|
|
|
// CompLt
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompLtBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs < rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompLt;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::CompLt, T1, T2>
|
|
{
|
|
using Op = BinaryCompLe<T1, T2>;
|
|
};
|
|
|
|
// CompNe
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompNeBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs != rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryCompNe;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::CompNe, T1, T2>
|
|
{
|
|
using Op = BinaryCompNe<T1, T2>;
|
|
};
|
|
|
|
// LogicalAnd
|
|
template<typename T1, typename T2>
|
|
struct BinaryLogicalAndBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs && rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryLogicalAnd;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::LogicalAnd, T1, T2>
|
|
{
|
|
using Op = BinaryLogicalAnd<T1, T2>;
|
|
};
|
|
|
|
// LogicalOr
|
|
template<typename T1, typename T2>
|
|
struct BinaryLogicalOrBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs || rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryLogicalOr;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::LogicalOr, T1, T2>
|
|
{
|
|
using Op = BinaryLogicalOr<T1, T2>;
|
|
};
|
|
|
|
// Addition
|
|
template<typename T1, typename T2>
|
|
struct BinaryAdditionBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs + rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryAddition;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::Add, T1, T2>
|
|
{
|
|
using Op = BinaryAddition<T1, T2>;
|
|
};
|
|
|
|
// Division
|
|
template<typename T1, typename T2>
|
|
struct BinaryDivisionBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs / rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryDivision;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::Divide, T1, T2>
|
|
{
|
|
using Op = BinaryDivision<T1, T2>;
|
|
};
|
|
|
|
// Multiplication
|
|
template<typename T1, typename T2>
|
|
struct BinaryMultiplicationBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs * rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryMultiplication;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::Multiply, T1, T2>
|
|
{
|
|
using Op = BinaryMultiplication<T1, T2>;
|
|
};
|
|
|
|
// Subtraction
|
|
template<typename T1, typename T2>
|
|
struct BinarySubtractionBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T1& lhs, const T2& rhs)
|
|
{
|
|
return ShaderBuilder::Constant(lhs - rhs);
|
|
}
|
|
};
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinarySubtraction;
|
|
|
|
template<typename T1, typename T2>
|
|
struct BinaryConstantPropagation<BinaryType::Subtract, T1, T2>
|
|
{
|
|
using Op = BinarySubtraction<T1, T2>;
|
|
};
|
|
|
|
/*************************************************************************************************/
|
|
|
|
template<typename T, typename... Args>
|
|
struct CastConstantBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const Args&... args)
|
|
{
|
|
return ShaderBuilder::Constant(T(args...));
|
|
}
|
|
};
|
|
|
|
template<typename T, typename... Args>
|
|
struct CastConstant;
|
|
|
|
template<typename T, typename... Args>
|
|
struct CastConstantPropagation
|
|
{
|
|
using Op = CastConstant<T, Args...>;
|
|
};
|
|
|
|
/*************************************************************************************************/
|
|
|
|
template<typename T, std::size_t TargetComponentCount, std::size_t FromComponentCount>
|
|
struct SwizzleBase;
|
|
|
|
template<typename T, std::size_t TargetComponentCount>
|
|
struct SwizzleBase<T, TargetComponentCount, 1>
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& /*components*/, T value)
|
|
{
|
|
if constexpr (TargetComponentCount == 4)
|
|
return ShaderBuilder::Constant(Vector4<T>(value, value, value, value));
|
|
else if constexpr (TargetComponentCount == 3)
|
|
return ShaderBuilder::Constant(Vector3<T>(value, value, value));
|
|
else if constexpr (TargetComponentCount == 2)
|
|
return ShaderBuilder::Constant(Vector2<T>(value, value));
|
|
else if constexpr (TargetComponentCount == 1)
|
|
return ShaderBuilder::Constant(value);
|
|
else
|
|
static_assert(AlwaysFalse<T>(), "unexpected TargetComponentCount");
|
|
}
|
|
};
|
|
|
|
template<typename T, std::size_t TargetComponentCount>
|
|
struct SwizzleBase<T, TargetComponentCount, 2>
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector2<T>& value)
|
|
{
|
|
if constexpr (TargetComponentCount == 4)
|
|
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
|
|
else if constexpr (TargetComponentCount == 3)
|
|
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
|
|
else if constexpr (TargetComponentCount == 2)
|
|
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
|
|
else if constexpr (TargetComponentCount == 1)
|
|
return ShaderBuilder::Constant(value[components[0]]);
|
|
else
|
|
static_assert(AlwaysFalse<T>(), "unexpected TargetComponentCount");
|
|
}
|
|
};
|
|
|
|
template<typename T, std::size_t TargetComponentCount>
|
|
struct SwizzleBase<T, TargetComponentCount, 3>
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector3<T>& value)
|
|
{
|
|
if constexpr (TargetComponentCount == 4)
|
|
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
|
|
else if constexpr (TargetComponentCount == 3)
|
|
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
|
|
else if constexpr (TargetComponentCount == 2)
|
|
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
|
|
else if constexpr (TargetComponentCount == 1)
|
|
return ShaderBuilder::Constant(value[components[0]]);
|
|
else
|
|
static_assert(AlwaysFalse<T>(), "unexpected TargetComponentCount");
|
|
}
|
|
};
|
|
|
|
template<typename T, std::size_t TargetComponentCount>
|
|
struct SwizzleBase<T, TargetComponentCount, 4>
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const std::array<UInt32, 4>& components, const Vector4<T>& value)
|
|
{
|
|
if constexpr (TargetComponentCount == 4)
|
|
return ShaderBuilder::Constant(Vector4<T>(value[components[0]], value[components[1]], value[components[2]], value[components[3]]));
|
|
else if constexpr (TargetComponentCount == 3)
|
|
return ShaderBuilder::Constant(Vector3<T>(value[components[0]], value[components[1]], value[components[2]]));
|
|
else if constexpr (TargetComponentCount == 2)
|
|
return ShaderBuilder::Constant(Vector2<T>(value[components[0]], value[components[1]]));
|
|
else if constexpr (TargetComponentCount == 1)
|
|
return ShaderBuilder::Constant(value[components[0]]);
|
|
else
|
|
static_assert(AlwaysFalse<T>(), "unexpected TargetComponentCount");
|
|
}
|
|
};
|
|
|
|
template<typename T, std::size_t... Args>
|
|
struct Swizzle;
|
|
|
|
template<typename T, std::size_t... Args>
|
|
struct SwizzlePropagation
|
|
{
|
|
using Op = Swizzle<T, Args...>;
|
|
};
|
|
|
|
/*************************************************************************************************/
|
|
|
|
template<UnaryType Type, typename T>
|
|
struct UnaryConstantPropagation;
|
|
|
|
// LogicalNot
|
|
template<typename T>
|
|
struct UnaryLogicalNotBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T& arg)
|
|
{
|
|
return ShaderBuilder::Constant(!arg);
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
struct UnaryLogicalNot;
|
|
|
|
template<typename T>
|
|
struct UnaryConstantPropagation<UnaryType::LogicalNot, T>
|
|
{
|
|
using Op = UnaryLogicalNot<T>;
|
|
};
|
|
|
|
// Minus
|
|
template<typename T>
|
|
struct UnaryMinusBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T& arg)
|
|
{
|
|
return ShaderBuilder::Constant(-arg);
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
struct UnaryMinus;
|
|
|
|
template<typename T>
|
|
struct UnaryConstantPropagation<UnaryType::Minus, T>
|
|
{
|
|
using Op = UnaryMinus<T>;
|
|
};
|
|
|
|
// Plus
|
|
template<typename T>
|
|
struct UnaryPlusBase
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> operator()(const T& arg)
|
|
{
|
|
return ShaderBuilder::Constant(arg);
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
struct UnaryPlus;
|
|
|
|
template<typename T>
|
|
struct UnaryConstantPropagation<UnaryType::Plus, T>
|
|
{
|
|
using Op = UnaryPlus<T>;
|
|
};
|
|
|
|
#define EnableOptimisation(Op, ...) template<> struct Op<__VA_ARGS__> : Op##Base<__VA_ARGS__> {}
|
|
|
|
// Binary
|
|
|
|
EnableOptimisation(BinaryCompEq, bool, bool);
|
|
EnableOptimisation(BinaryCompEq, double, double);
|
|
EnableOptimisation(BinaryCompEq, float, float);
|
|
EnableOptimisation(BinaryCompEq, Int32, Int32);
|
|
EnableOptimisation(BinaryCompEq, Vector2f, Vector2f);
|
|
EnableOptimisation(BinaryCompEq, Vector3f, Vector3f);
|
|
EnableOptimisation(BinaryCompEq, Vector4f, Vector4f);
|
|
EnableOptimisation(BinaryCompEq, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinaryCompEq, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinaryCompEq, Vector4i32, Vector4i32);
|
|
|
|
EnableOptimisation(BinaryCompGe, double, double);
|
|
EnableOptimisation(BinaryCompGe, float, float);
|
|
EnableOptimisation(BinaryCompGe, Int32, Int32);
|
|
EnableOptimisation(BinaryCompGe, Vector2f, Vector2f);
|
|
EnableOptimisation(BinaryCompGe, Vector3f, Vector3f);
|
|
EnableOptimisation(BinaryCompGe, Vector4f, Vector4f);
|
|
EnableOptimisation(BinaryCompGe, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinaryCompGe, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinaryCompGe, Vector4i32, Vector4i32);
|
|
|
|
EnableOptimisation(BinaryCompGt, double, double);
|
|
EnableOptimisation(BinaryCompGt, float, float);
|
|
EnableOptimisation(BinaryCompGt, Int32, Int32);
|
|
EnableOptimisation(BinaryCompGt, Vector2f, Vector2f);
|
|
EnableOptimisation(BinaryCompGt, Vector3f, Vector3f);
|
|
EnableOptimisation(BinaryCompGt, Vector4f, Vector4f);
|
|
EnableOptimisation(BinaryCompGt, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinaryCompGt, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinaryCompGt, Vector4i32, Vector4i32);
|
|
|
|
EnableOptimisation(BinaryCompLe, double, double);
|
|
EnableOptimisation(BinaryCompLe, float, float);
|
|
EnableOptimisation(BinaryCompLe, Int32, Int32);
|
|
EnableOptimisation(BinaryCompLe, Vector2f, Vector2f);
|
|
EnableOptimisation(BinaryCompLe, Vector3f, Vector3f);
|
|
EnableOptimisation(BinaryCompLe, Vector4f, Vector4f);
|
|
EnableOptimisation(BinaryCompLe, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinaryCompLe, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinaryCompLe, Vector4i32, Vector4i32);
|
|
|
|
EnableOptimisation(BinaryCompLt, double, double);
|
|
EnableOptimisation(BinaryCompLt, float, float);
|
|
EnableOptimisation(BinaryCompLt, Int32, Int32);
|
|
EnableOptimisation(BinaryCompLt, Vector2f, Vector2f);
|
|
EnableOptimisation(BinaryCompLt, Vector3f, Vector3f);
|
|
EnableOptimisation(BinaryCompLt, Vector4f, Vector4f);
|
|
EnableOptimisation(BinaryCompLt, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinaryCompLt, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinaryCompLt, Vector4i32, Vector4i32);
|
|
|
|
EnableOptimisation(BinaryCompNe, bool, bool);
|
|
EnableOptimisation(BinaryCompNe, double, double);
|
|
EnableOptimisation(BinaryCompNe, float, float);
|
|
EnableOptimisation(BinaryCompNe, Int32, Int32);
|
|
EnableOptimisation(BinaryCompNe, Vector2f, Vector2f);
|
|
EnableOptimisation(BinaryCompNe, Vector3f, Vector3f);
|
|
EnableOptimisation(BinaryCompNe, Vector4f, Vector4f);
|
|
EnableOptimisation(BinaryCompNe, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinaryCompNe, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinaryCompNe, Vector4i32, Vector4i32);
|
|
|
|
EnableOptimisation(BinaryLogicalAnd, bool, bool);
|
|
EnableOptimisation(BinaryLogicalOr, bool, bool);
|
|
|
|
EnableOptimisation(BinaryAddition, double, double);
|
|
EnableOptimisation(BinaryAddition, float, float);
|
|
EnableOptimisation(BinaryAddition, Int32, Int32);
|
|
EnableOptimisation(BinaryAddition, Vector2f, Vector2f);
|
|
EnableOptimisation(BinaryAddition, Vector3f, Vector3f);
|
|
EnableOptimisation(BinaryAddition, Vector4f, Vector4f);
|
|
EnableOptimisation(BinaryAddition, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinaryAddition, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinaryAddition, Vector4i32, Vector4i32);
|
|
|
|
EnableOptimisation(BinaryDivision, double, double);
|
|
EnableOptimisation(BinaryDivision, double, Vector2d);
|
|
EnableOptimisation(BinaryDivision, double, Vector3d);
|
|
EnableOptimisation(BinaryDivision, double, Vector4d);
|
|
EnableOptimisation(BinaryDivision, float, float);
|
|
EnableOptimisation(BinaryDivision, float, Vector2f);
|
|
EnableOptimisation(BinaryDivision, float, Vector3f);
|
|
EnableOptimisation(BinaryDivision, float, Vector4f);
|
|
EnableOptimisation(BinaryDivision, Int32, Int32);
|
|
EnableOptimisation(BinaryDivision, Int32, Vector2i32);
|
|
EnableOptimisation(BinaryDivision, Int32, Vector3i32);
|
|
EnableOptimisation(BinaryDivision, Int32, Vector4i32);
|
|
EnableOptimisation(BinaryDivision, Vector2f, float);
|
|
EnableOptimisation(BinaryDivision, Vector2f, Vector2f);
|
|
EnableOptimisation(BinaryDivision, Vector3f, float);
|
|
EnableOptimisation(BinaryDivision, Vector3f, Vector3f);
|
|
EnableOptimisation(BinaryDivision, Vector4f, float);
|
|
EnableOptimisation(BinaryDivision, Vector4f, Vector4f);
|
|
EnableOptimisation(BinaryDivision, Vector2d, double);
|
|
EnableOptimisation(BinaryDivision, Vector2d, Vector2d);
|
|
EnableOptimisation(BinaryDivision, Vector3d, double);
|
|
EnableOptimisation(BinaryDivision, Vector3d, Vector3d);
|
|
EnableOptimisation(BinaryDivision, Vector4d, double);
|
|
EnableOptimisation(BinaryDivision, Vector4d, Vector4d);
|
|
EnableOptimisation(BinaryDivision, Vector2i32, Int32);
|
|
EnableOptimisation(BinaryDivision, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinaryDivision, Vector3i32, Int32);
|
|
EnableOptimisation(BinaryDivision, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinaryDivision, Vector4i32, Int32);
|
|
EnableOptimisation(BinaryDivision, Vector4i32, Vector4i32);
|
|
|
|
EnableOptimisation(BinaryMultiplication, double, double);
|
|
EnableOptimisation(BinaryMultiplication, double, Vector2d);
|
|
EnableOptimisation(BinaryMultiplication, double, Vector3d);
|
|
EnableOptimisation(BinaryMultiplication, double, Vector4d);
|
|
EnableOptimisation(BinaryMultiplication, float, float);
|
|
EnableOptimisation(BinaryMultiplication, float, Vector2f);
|
|
EnableOptimisation(BinaryMultiplication, float, Vector3f);
|
|
EnableOptimisation(BinaryMultiplication, float, Vector4f);
|
|
EnableOptimisation(BinaryMultiplication, Int32, Int32);
|
|
EnableOptimisation(BinaryMultiplication, Int32, Vector2i32);
|
|
EnableOptimisation(BinaryMultiplication, Int32, Vector3i32);
|
|
EnableOptimisation(BinaryMultiplication, Int32, Vector4i32);
|
|
EnableOptimisation(BinaryMultiplication, Vector2f, float);
|
|
EnableOptimisation(BinaryMultiplication, Vector2f, Vector2f);
|
|
EnableOptimisation(BinaryMultiplication, Vector3f, float);
|
|
EnableOptimisation(BinaryMultiplication, Vector3f, Vector3f);
|
|
EnableOptimisation(BinaryMultiplication, Vector4f, float);
|
|
EnableOptimisation(BinaryMultiplication, Vector4f, Vector4f);
|
|
EnableOptimisation(BinaryMultiplication, Vector2d, double);
|
|
EnableOptimisation(BinaryMultiplication, Vector2d, Vector2d);
|
|
EnableOptimisation(BinaryMultiplication, Vector3d, double);
|
|
EnableOptimisation(BinaryMultiplication, Vector3d, Vector3d);
|
|
EnableOptimisation(BinaryMultiplication, Vector4d, double);
|
|
EnableOptimisation(BinaryMultiplication, Vector4d, Vector4d);
|
|
EnableOptimisation(BinaryMultiplication, Vector2i32, Int32);
|
|
EnableOptimisation(BinaryMultiplication, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinaryMultiplication, Vector3i32, Int32);
|
|
EnableOptimisation(BinaryMultiplication, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinaryMultiplication, Vector4i32, Int32);
|
|
EnableOptimisation(BinaryMultiplication, Vector4i32, Vector4i32);
|
|
|
|
EnableOptimisation(BinarySubtraction, double, double);
|
|
EnableOptimisation(BinarySubtraction, float, float);
|
|
EnableOptimisation(BinarySubtraction, Int32, Int32);
|
|
EnableOptimisation(BinarySubtraction, Vector2f, Vector2f);
|
|
EnableOptimisation(BinarySubtraction, Vector3f, Vector3f);
|
|
EnableOptimisation(BinarySubtraction, Vector4f, Vector4f);
|
|
EnableOptimisation(BinarySubtraction, Vector2i32, Vector2i32);
|
|
EnableOptimisation(BinarySubtraction, Vector3i32, Vector3i32);
|
|
EnableOptimisation(BinarySubtraction, Vector4i32, Vector4i32);
|
|
|
|
// Cast
|
|
|
|
EnableOptimisation(CastConstant, bool, bool);
|
|
EnableOptimisation(CastConstant, bool, Int32);
|
|
EnableOptimisation(CastConstant, bool, UInt32);
|
|
|
|
EnableOptimisation(CastConstant, double, double);
|
|
EnableOptimisation(CastConstant, double, float);
|
|
EnableOptimisation(CastConstant, double, Int32);
|
|
EnableOptimisation(CastConstant, double, UInt32);
|
|
|
|
EnableOptimisation(CastConstant, float, double);
|
|
EnableOptimisation(CastConstant, float, float);
|
|
EnableOptimisation(CastConstant, float, Int32);
|
|
EnableOptimisation(CastConstant, float, UInt32);
|
|
|
|
EnableOptimisation(CastConstant, Int32, double);
|
|
EnableOptimisation(CastConstant, Int32, float);
|
|
EnableOptimisation(CastConstant, Int32, Int32);
|
|
EnableOptimisation(CastConstant, Int32, UInt32);
|
|
|
|
EnableOptimisation(CastConstant, UInt32, double);
|
|
EnableOptimisation(CastConstant, UInt32, float);
|
|
EnableOptimisation(CastConstant, UInt32, Int32);
|
|
EnableOptimisation(CastConstant, UInt32, UInt32);
|
|
|
|
//EnableOptimisation(CastConstant, Vector2d, double, double);
|
|
//EnableOptimisation(CastConstant, Vector3d, double, double, double);
|
|
//EnableOptimisation(CastConstant, Vector4d, double, double, double, double);
|
|
|
|
EnableOptimisation(CastConstant, Vector2f, float, float);
|
|
EnableOptimisation(CastConstant, Vector3f, float, float, float);
|
|
EnableOptimisation(CastConstant, Vector4f, float, float, float, float);
|
|
|
|
EnableOptimisation(CastConstant, Vector2i32, Int32, Int32);
|
|
EnableOptimisation(CastConstant, Vector3i32, Int32, Int32, Int32);
|
|
EnableOptimisation(CastConstant, Vector4i32, Int32, Int32, Int32, Int32);
|
|
|
|
//EnableOptimisation(CastConstant, Vector2ui32, UInt32, UInt32);
|
|
//EnableOptimisation(CastConstant, Vector3ui32, UInt32, UInt32, UInt32);
|
|
//EnableOptimisation(CastConstant, Vector4ui32, UInt32, UInt32, UInt32, UInt32);
|
|
|
|
// Swizzle
|
|
EnableOptimisation(Swizzle, double, 1, 1);
|
|
EnableOptimisation(Swizzle, double, 1, 2);
|
|
EnableOptimisation(Swizzle, double, 1, 3);
|
|
EnableOptimisation(Swizzle, double, 1, 4);
|
|
|
|
EnableOptimisation(Swizzle, double, 2, 1);
|
|
EnableOptimisation(Swizzle, double, 2, 2);
|
|
EnableOptimisation(Swizzle, double, 2, 3);
|
|
EnableOptimisation(Swizzle, double, 2, 4);
|
|
|
|
EnableOptimisation(Swizzle, double, 3, 1);
|
|
EnableOptimisation(Swizzle, double, 3, 2);
|
|
EnableOptimisation(Swizzle, double, 3, 3);
|
|
EnableOptimisation(Swizzle, double, 3, 4);
|
|
|
|
EnableOptimisation(Swizzle, double, 4, 1);
|
|
EnableOptimisation(Swizzle, double, 4, 2);
|
|
EnableOptimisation(Swizzle, double, 4, 3);
|
|
EnableOptimisation(Swizzle, double, 4, 4);
|
|
|
|
EnableOptimisation(Swizzle, float, 1, 1);
|
|
EnableOptimisation(Swizzle, float, 1, 2);
|
|
EnableOptimisation(Swizzle, float, 1, 3);
|
|
EnableOptimisation(Swizzle, float, 1, 4);
|
|
|
|
EnableOptimisation(Swizzle, float, 2, 1);
|
|
EnableOptimisation(Swizzle, float, 2, 2);
|
|
EnableOptimisation(Swizzle, float, 2, 3);
|
|
EnableOptimisation(Swizzle, float, 2, 4);
|
|
|
|
EnableOptimisation(Swizzle, float, 3, 1);
|
|
EnableOptimisation(Swizzle, float, 3, 2);
|
|
EnableOptimisation(Swizzle, float, 3, 3);
|
|
EnableOptimisation(Swizzle, float, 3, 4);
|
|
|
|
EnableOptimisation(Swizzle, float, 4, 1);
|
|
EnableOptimisation(Swizzle, float, 4, 2);
|
|
EnableOptimisation(Swizzle, float, 4, 3);
|
|
EnableOptimisation(Swizzle, float, 4, 4);
|
|
|
|
EnableOptimisation(Swizzle, Int32, 1, 1);
|
|
EnableOptimisation(Swizzle, Int32, 1, 2);
|
|
EnableOptimisation(Swizzle, Int32, 1, 3);
|
|
EnableOptimisation(Swizzle, Int32, 1, 4);
|
|
|
|
EnableOptimisation(Swizzle, Int32, 2, 1);
|
|
EnableOptimisation(Swizzle, Int32, 2, 2);
|
|
EnableOptimisation(Swizzle, Int32, 2, 3);
|
|
EnableOptimisation(Swizzle, Int32, 2, 4);
|
|
|
|
EnableOptimisation(Swizzle, Int32, 3, 1);
|
|
EnableOptimisation(Swizzle, Int32, 3, 2);
|
|
EnableOptimisation(Swizzle, Int32, 3, 3);
|
|
EnableOptimisation(Swizzle, Int32, 3, 4);
|
|
|
|
EnableOptimisation(Swizzle, Int32, 4, 1);
|
|
EnableOptimisation(Swizzle, Int32, 4, 2);
|
|
EnableOptimisation(Swizzle, Int32, 4, 3);
|
|
EnableOptimisation(Swizzle, Int32, 4, 4);
|
|
|
|
// Unary
|
|
|
|
EnableOptimisation(UnaryLogicalNot, bool);
|
|
|
|
EnableOptimisation(UnaryMinus, double);
|
|
EnableOptimisation(UnaryMinus, float);
|
|
EnableOptimisation(UnaryMinus, Int32);
|
|
EnableOptimisation(UnaryMinus, Vector2f);
|
|
EnableOptimisation(UnaryMinus, Vector3f);
|
|
EnableOptimisation(UnaryMinus, Vector4f);
|
|
EnableOptimisation(UnaryMinus, Vector2i32);
|
|
EnableOptimisation(UnaryMinus, Vector3i32);
|
|
EnableOptimisation(UnaryMinus, Vector4i32);
|
|
|
|
EnableOptimisation(UnaryPlus, double);
|
|
EnableOptimisation(UnaryPlus, float);
|
|
EnableOptimisation(UnaryPlus, Int32);
|
|
EnableOptimisation(UnaryPlus, Vector2f);
|
|
EnableOptimisation(UnaryPlus, Vector3f);
|
|
EnableOptimisation(UnaryPlus, Vector4f);
|
|
EnableOptimisation(UnaryPlus, Vector2i32);
|
|
EnableOptimisation(UnaryPlus, Vector3i32);
|
|
EnableOptimisation(UnaryPlus, Vector4i32);
|
|
|
|
#undef EnableOptimisation
|
|
}
|
|
|
|
ExpressionPtr AstOptimizer::Clone(BinaryExpression& node)
|
|
{
|
|
auto lhs = CloneExpression(node.left);
|
|
auto rhs = CloneExpression(node.right);
|
|
|
|
if (lhs->GetType() == NodeType::ConstantValueExpression && rhs->GetType() == NodeType::ConstantValueExpression)
|
|
{
|
|
const ConstantValueExpression& lhsConstant = static_cast<const ConstantValueExpression&>(*lhs);
|
|
const ConstantValueExpression& rhsConstant = static_cast<const ConstantValueExpression&>(*rhs);
|
|
|
|
ExpressionPtr optimized;
|
|
switch (node.op)
|
|
{
|
|
case BinaryType::Add:
|
|
optimized = PropagateBinaryConstant<BinaryType::Add>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::Subtract:
|
|
optimized = PropagateBinaryConstant<BinaryType::Subtract>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::Multiply:
|
|
optimized = PropagateBinaryConstant<BinaryType::Multiply>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::Divide:
|
|
optimized = PropagateBinaryConstant<BinaryType::Divide>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::CompEq:
|
|
optimized = PropagateBinaryConstant<BinaryType::CompEq>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::CompGe:
|
|
optimized = PropagateBinaryConstant<BinaryType::CompGe>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::CompGt:
|
|
optimized = PropagateBinaryConstant<BinaryType::CompGt>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::CompLe:
|
|
optimized = PropagateBinaryConstant<BinaryType::CompLe>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::CompLt:
|
|
optimized = PropagateBinaryConstant<BinaryType::CompLt>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::CompNe:
|
|
optimized = PropagateBinaryConstant<BinaryType::CompNe>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::LogicalAnd:
|
|
optimized = PropagateBinaryConstant<BinaryType::LogicalAnd>(lhsConstant, rhsConstant);
|
|
break;
|
|
|
|
case BinaryType::LogicalOr:
|
|
optimized = PropagateBinaryConstant<BinaryType::LogicalOr>(lhsConstant, rhsConstant);
|
|
break;
|
|
}
|
|
|
|
if (optimized)
|
|
return optimized;
|
|
}
|
|
|
|
auto binary = ShaderBuilder::Binary(node.op, std::move(lhs), std::move(rhs));
|
|
binary->cachedExpressionType = node.cachedExpressionType;
|
|
|
|
return binary;
|
|
}
|
|
|
|
ExpressionPtr AstOptimizer::Clone(CastExpression& node)
|
|
{
|
|
std::array<ExpressionPtr, 4> expressions;
|
|
|
|
std::size_t expressionCount = 0;
|
|
for (const auto& expression : node.expressions)
|
|
{
|
|
if (!expression)
|
|
break;
|
|
|
|
expressions[expressionCount] = CloneExpression(expression);
|
|
expressionCount++;
|
|
}
|
|
|
|
ExpressionPtr optimized;
|
|
if (IsPrimitiveType(node.targetType))
|
|
{
|
|
if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression)
|
|
{
|
|
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expressions.front());
|
|
|
|
switch (std::get<PrimitiveType>(node.targetType))
|
|
{
|
|
case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(constantExpr); break;
|
|
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(constantExpr); break;
|
|
case PrimitiveType::Int32: optimized = PropagateSingleValueCast<Int32>(constantExpr); break;
|
|
case PrimitiveType::UInt32: optimized = PropagateSingleValueCast<UInt32>(constantExpr); break;
|
|
}
|
|
}
|
|
}
|
|
else if (IsVectorType(node.targetType))
|
|
{
|
|
const auto& vecType = std::get<VectorType>(node.targetType);
|
|
|
|
// Decompose vector into values (cast(vec3, float) => cast(float, float, float, float))
|
|
std::vector<ConstantValue> constantValues;
|
|
for (std::size_t i = 0; i < expressionCount; ++i)
|
|
{
|
|
if (expressions[i]->GetType() != NodeType::ConstantValueExpression)
|
|
{
|
|
constantValues.clear();
|
|
break;
|
|
}
|
|
|
|
const auto& constantExpr = static_cast<ConstantValueExpression&>(*expressions[i]);
|
|
|
|
if (!constantValues.empty() && GetExpressionType(constantValues.front()) != GetExpressionType(constantExpr.value))
|
|
{
|
|
// Unhandled case, all cast parameters are expected to be of the same type
|
|
constantValues.clear();
|
|
break;
|
|
}
|
|
|
|
std::visit([&](auto&& arg)
|
|
{
|
|
using T = std::decay_t<decltype(arg)>;
|
|
|
|
if constexpr (std::is_same_v<T, NoValue>)
|
|
throw std::runtime_error("invalid type (value expected)");
|
|
else if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
|
|
constantValues.push_back(arg);
|
|
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i32>)
|
|
{
|
|
constantValues.push_back(arg.x);
|
|
constantValues.push_back(arg.y);
|
|
}
|
|
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i32>)
|
|
{
|
|
constantValues.push_back(arg.x);
|
|
constantValues.push_back(arg.y);
|
|
constantValues.push_back(arg.z);
|
|
}
|
|
else if constexpr (std::is_same_v<T, Vector4f> || std::is_same_v<T, Vector4i32>)
|
|
{
|
|
constantValues.push_back(arg.x);
|
|
constantValues.push_back(arg.y);
|
|
constantValues.push_back(arg.z);
|
|
constantValues.push_back(arg.w);
|
|
}
|
|
else
|
|
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
|
}, constantExpr.value);
|
|
}
|
|
|
|
if (!constantValues.empty())
|
|
{
|
|
assert(constantValues.size() == vecType.componentCount);
|
|
|
|
std::visit([&](auto&& arg)
|
|
{
|
|
using T = std::decay_t<decltype(arg)>;
|
|
|
|
switch (vecType.componentCount)
|
|
{
|
|
case 2:
|
|
optimized = PropagateVec2Cast(std::get<T>(constantValues[0]), std::get<T>(constantValues[1]));
|
|
break;
|
|
|
|
case 3:
|
|
optimized = PropagateVec3Cast(std::get<T>(constantValues[0]), std::get<T>(constantValues[1]), std::get<T>(constantValues[2]));
|
|
break;
|
|
|
|
case 4:
|
|
optimized = PropagateVec4Cast(std::get<T>(constantValues[0]), std::get<T>(constantValues[1]), std::get<T>(constantValues[2]), std::get<T>(constantValues[3]));
|
|
break;
|
|
}
|
|
}, constantValues.front());
|
|
}
|
|
}
|
|
|
|
if (optimized)
|
|
return optimized;
|
|
|
|
auto cast = ShaderBuilder::Cast(node.targetType, std::move(expressions));
|
|
cast->cachedExpressionType = node.cachedExpressionType;
|
|
|
|
return cast;
|
|
}
|
|
|
|
StatementPtr AstOptimizer::Clone(BranchStatement& node)
|
|
{
|
|
std::vector<BranchStatement::ConditionalStatement> statements;
|
|
StatementPtr elseStatement;
|
|
|
|
for (auto& condStatement : node.condStatements)
|
|
{
|
|
auto cond = CloneExpression(condStatement.condition);
|
|
|
|
if (cond->GetType() == NodeType::ConstantValueExpression)
|
|
{
|
|
auto& constant = static_cast<ConstantValueExpression&>(*cond);
|
|
|
|
const ExpressionType& constantType = GetExpressionType(constant);
|
|
if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean)
|
|
continue;
|
|
|
|
bool cValue = std::get<bool>(constant.value);
|
|
if (!cValue)
|
|
continue;
|
|
|
|
if (statements.empty())
|
|
{
|
|
// First condition is true, dismiss the branch
|
|
return AstCloner::Clone(*condStatement.statement);
|
|
}
|
|
else
|
|
{
|
|
// Some condition after the first one is true, make it the else statement and stop there
|
|
elseStatement = CloneStatement(condStatement.statement);
|
|
break;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
auto& c = statements.emplace_back();
|
|
c.condition = std::move(cond);
|
|
c.statement = CloneStatement(condStatement.statement);
|
|
}
|
|
}
|
|
|
|
if (statements.empty())
|
|
{
|
|
// All conditions have been removed, replace by else statement or no-op
|
|
if (node.elseStatement)
|
|
return AstCloner::Clone(*node.elseStatement);
|
|
else
|
|
return ShaderBuilder::NoOp();
|
|
}
|
|
|
|
if (!elseStatement)
|
|
elseStatement = CloneStatement(node.elseStatement);
|
|
|
|
return ShaderBuilder::Branch(std::move(statements), std::move(elseStatement));
|
|
}
|
|
|
|
ExpressionPtr AstOptimizer::Clone(ConditionalExpression& node)
|
|
{
|
|
auto cond = CloneExpression(node.condition);
|
|
if (cond->GetType() != NodeType::ConstantValueExpression)
|
|
throw std::runtime_error("conditional expression condition must be a constant expression");
|
|
|
|
auto& constant = static_cast<ConstantValueExpression&>(*cond);
|
|
|
|
assert(constant.cachedExpressionType);
|
|
const ExpressionType& constantType = constant.cachedExpressionType.value();
|
|
|
|
if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean)
|
|
throw std::runtime_error("conditional expression condition must resolve to a boolean");
|
|
|
|
bool cValue = std::get<bool>(constant.value);
|
|
if (cValue)
|
|
return AstCloner::Clone(*node.truePath);
|
|
else
|
|
return AstCloner::Clone(*node.falsePath);
|
|
}
|
|
|
|
ExpressionPtr AstOptimizer::Clone(ConstantExpression& node)
|
|
{
|
|
if (!m_options.constantQueryCallback)
|
|
return AstCloner::Clone(node);
|
|
|
|
auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
|
|
constant->cachedExpressionType = GetExpressionType(constant->value);
|
|
|
|
return constant;
|
|
}
|
|
|
|
ExpressionPtr AstOptimizer::Clone(SwizzleExpression& node)
|
|
{
|
|
auto expr = CloneExpression(node.expression);
|
|
|
|
if (expr->GetType() == NodeType::ConstantValueExpression)
|
|
{
|
|
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expr);
|
|
|
|
ExpressionPtr optimized;
|
|
switch (node.componentCount)
|
|
{
|
|
case 1:
|
|
optimized = PropagateConstantSwizzle<1>(node.components, constantExpr);
|
|
break;
|
|
|
|
case 2:
|
|
optimized = PropagateConstantSwizzle<2>(node.components, constantExpr);
|
|
break;
|
|
|
|
case 3:
|
|
optimized = PropagateConstantSwizzle<3>(node.components, constantExpr);
|
|
break;
|
|
|
|
case 4:
|
|
optimized = PropagateConstantSwizzle<4>(node.components, constantExpr);
|
|
break;
|
|
}
|
|
|
|
if (optimized)
|
|
return optimized;
|
|
}
|
|
else if (expr->GetType() == NodeType::SwizzleExpression)
|
|
{
|
|
SwizzleExpression& constantExpr = static_cast<SwizzleExpression&>(*expr);
|
|
|
|
std::array<UInt32, 4> newComponents = {};
|
|
for (std::size_t i = 0; i < node.componentCount; ++i)
|
|
newComponents[i] = constantExpr.components[node.components[i]];
|
|
|
|
constantExpr.componentCount = node.componentCount;
|
|
constantExpr.components = newComponents;
|
|
|
|
return expr;
|
|
}
|
|
|
|
auto swizzle = ShaderBuilder::Swizzle(std::move(expr), node.components, node.componentCount);
|
|
swizzle->cachedExpressionType = node.cachedExpressionType;
|
|
|
|
return swizzle;
|
|
}
|
|
|
|
ExpressionPtr AstOptimizer::Clone(UnaryExpression& node)
|
|
{
|
|
auto expr = CloneExpression(node.expression);
|
|
|
|
if (expr->GetType() == NodeType::ConstantValueExpression)
|
|
{
|
|
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expr);
|
|
|
|
ExpressionPtr optimized;
|
|
switch (node.op)
|
|
{
|
|
case UnaryType::LogicalNot:
|
|
optimized = PropagateUnaryConstant<UnaryType::LogicalNot>(constantExpr);
|
|
break;
|
|
|
|
case UnaryType::Minus:
|
|
optimized = PropagateUnaryConstant<UnaryType::Minus>(constantExpr);
|
|
break;
|
|
|
|
case UnaryType::Plus:
|
|
optimized = PropagateUnaryConstant<UnaryType::Plus>(constantExpr);
|
|
break;
|
|
}
|
|
|
|
if (optimized)
|
|
return optimized;
|
|
}
|
|
|
|
auto unary = ShaderBuilder::Unary(node.op, std::move(expr));
|
|
unary->cachedExpressionType = node.cachedExpressionType;
|
|
|
|
return unary;
|
|
}
|
|
|
|
StatementPtr AstOptimizer::Clone(ConditionalStatement& node)
|
|
{
|
|
auto cond = CloneExpression(node.condition);
|
|
if (cond->GetType() != NodeType::ConstantValueExpression)
|
|
throw std::runtime_error("conditional expression condition must be a constant expression");
|
|
|
|
auto& constant = static_cast<ConstantValueExpression&>(*cond);
|
|
|
|
assert(constant.cachedExpressionType);
|
|
const ExpressionType& constantType = constant.cachedExpressionType.value();
|
|
|
|
if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean)
|
|
throw std::runtime_error("conditional expression condition must resolve to a boolean");
|
|
|
|
bool cValue = std::get<bool>(constant.value);
|
|
if (cValue)
|
|
return AstCloner::Clone(node);
|
|
else
|
|
return ShaderBuilder::NoOp();
|
|
}
|
|
|
|
template<BinaryType Type>
|
|
ExpressionPtr AstOptimizer::PropagateBinaryConstant(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs)
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> optimized;
|
|
std::visit([&](auto&& arg1)
|
|
{
|
|
using T1 = std::decay_t<decltype(arg1)>;
|
|
|
|
std::visit([&](auto&& arg2)
|
|
{
|
|
using T2 = std::decay_t<decltype(arg2)>;
|
|
using PCType = BinaryConstantPropagation<Type, T1, T2>;
|
|
|
|
if constexpr (is_complete_v<PCType>)
|
|
{
|
|
using Op = typename PCType::Op;
|
|
if constexpr (is_complete_v<Op>)
|
|
optimized = Op{}(arg1, arg2);
|
|
}
|
|
|
|
}, rhs.value);
|
|
}, lhs.value);
|
|
|
|
if (optimized)
|
|
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
|
|
|
return optimized;
|
|
}
|
|
|
|
template<typename TargetType>
|
|
ExpressionPtr AstOptimizer::PropagateSingleValueCast(const ConstantValueExpression& operand)
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> optimized;
|
|
|
|
std::visit([&](auto&& arg)
|
|
{
|
|
using T = std::decay_t<decltype(arg)>;
|
|
using CCType = CastConstantPropagation<TargetType, T>;
|
|
|
|
if constexpr (is_complete_v<CCType>)
|
|
{
|
|
using Op = typename CCType::Op;
|
|
if constexpr (is_complete_v<Op>)
|
|
optimized = Op{}(arg);
|
|
}
|
|
}, operand.value);
|
|
|
|
return optimized;
|
|
}
|
|
|
|
template<std::size_t TargetComponentCount>
|
|
ExpressionPtr AstOptimizer::PropagateConstantSwizzle(const std::array<UInt32, 4>& components, const ConstantValueExpression& operand)
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> optimized;
|
|
std::visit([&](auto&& arg)
|
|
{
|
|
using T = std::decay_t<decltype(arg)>;
|
|
|
|
using BaseType = typename VectorInfo<T>::Base;
|
|
constexpr std::size_t FromComponentCount = VectorInfo<T>::Dimensions;
|
|
|
|
using SPType = SwizzlePropagation<BaseType, TargetComponentCount, FromComponentCount>;
|
|
|
|
if constexpr (is_complete_v<SPType>)
|
|
{
|
|
using Op = typename SPType::Op;
|
|
if constexpr (is_complete_v<Op>)
|
|
optimized = Op{}(components, arg);
|
|
}
|
|
}, operand.value);
|
|
|
|
return optimized;
|
|
}
|
|
|
|
template<UnaryType Type>
|
|
ExpressionPtr AstOptimizer::PropagateUnaryConstant(const ConstantValueExpression& operand)
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> optimized;
|
|
std::visit([&](auto&& arg)
|
|
{
|
|
using T = std::decay_t<decltype(arg)>;
|
|
using PCType = UnaryConstantPropagation<Type, T>;
|
|
|
|
if constexpr (is_complete_v<PCType>)
|
|
{
|
|
using Op = typename PCType::Op;
|
|
if constexpr (is_complete_v<Op>)
|
|
optimized = Op{}(arg);
|
|
}
|
|
}, operand.value);
|
|
|
|
if (optimized)
|
|
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
|
|
|
return optimized;
|
|
}
|
|
|
|
template<typename TargetType>
|
|
ExpressionPtr AstOptimizer::PropagateVec2Cast(TargetType v1, TargetType v2)
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> optimized;
|
|
|
|
using CCType = CastConstantPropagation<Vector2<TargetType>, TargetType, TargetType>;
|
|
|
|
if constexpr (is_complete_v<CCType>)
|
|
{
|
|
using Op = typename CCType::Op;
|
|
if constexpr (is_complete_v<Op>)
|
|
optimized = Op{}(v1, v2);
|
|
}
|
|
|
|
return optimized;
|
|
}
|
|
|
|
template<typename TargetType>
|
|
ExpressionPtr AstOptimizer::PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3)
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> optimized;
|
|
|
|
using CCType = CastConstantPropagation<Vector3<TargetType>, TargetType, TargetType, TargetType>;
|
|
|
|
if constexpr (is_complete_v<CCType>)
|
|
{
|
|
using Op = typename CCType::Op;
|
|
if constexpr (is_complete_v<Op>)
|
|
optimized = Op{}(v1, v2, v3);
|
|
}
|
|
|
|
return optimized;
|
|
}
|
|
|
|
template<typename TargetType>
|
|
ExpressionPtr AstOptimizer::PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4)
|
|
{
|
|
std::unique_ptr<ConstantValueExpression> optimized;
|
|
|
|
using CCType = CastConstantPropagation<Vector4<TargetType>, TargetType, TargetType, TargetType, TargetType>;
|
|
|
|
if constexpr (is_complete_v<CCType>)
|
|
{
|
|
using Op = typename CCType::Op;
|
|
if constexpr (is_complete_v<Op>)
|
|
optimized = Op{}(v1, v2, v3, v4);
|
|
}
|
|
|
|
return optimized;
|
|
}
|
|
}
|