Shader/ShaderLang: Add support for Unary operators

This commit is contained in:
Lynix
2021-05-16 23:07:25 +02:00
parent 1f05e950e8
commit 525f24af2e
30 changed files with 566 additions and 208 deletions

View File

@@ -301,6 +301,17 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(UnaryExpression& node)
{
auto clone = std::make_unique<UnaryExpression>();
clone->expression = CloneExpression(node.expression);
clone->op = node.op;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
#define NAZARA_SHADERAST_EXPRESSION(NodeType) void AstCloner::Visit(NodeType& node) \
{ \
PushExpression(Clone(node)); \

View File

@@ -33,13 +33,14 @@ namespace Nz::ShaderAst
template<typename T>
inline constexpr bool is_complete_v = is_complete<T>::value;
/*************************************************************************************************/
template<BinaryType Type, typename T1, typename T2>
struct PropagateConstantType;
struct BinaryConstantPropagation;
// CompEq
template<typename T1, typename T2>
struct CompEqBase
struct BinaryCompEqBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -48,17 +49,17 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct CompEq;
struct BinaryCompEq;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::CompEq, T1, T2>
struct BinaryConstantPropagation<BinaryType::CompEq, T1, T2>
{
using Op = CompEq<T1, T2>;
using Op = BinaryCompEq<T1, T2>;
};
// CompGe
template<typename T1, typename T2>
struct CompGeBase
struct BinaryCompGeBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -67,17 +68,17 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct CompGe;
struct BinaryCompGe;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::CompGe, T1, T2>
struct BinaryConstantPropagation<BinaryType::CompGe, T1, T2>
{
using Op = CompGe<T1, T2>;
using Op = BinaryCompGe<T1, T2>;
};
// CompGt
template<typename T1, typename T2>
struct CompGtBase
struct BinaryCompGtBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -86,17 +87,17 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct CompGt;
struct BinaryCompGt;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::CompGt, T1, T2>
struct BinaryConstantPropagation<BinaryType::CompGt, T1, T2>
{
using Op = CompGt<T1, T2>;
using Op = BinaryCompGt<T1, T2>;
};
// CompLe
template<typename T1, typename T2>
struct CompLeBase
struct BinaryCompLeBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -105,17 +106,17 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct CompLe;
struct BinaryCompLe;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::CompLe, T1, T2>
struct BinaryConstantPropagation<BinaryType::CompLe, T1, T2>
{
using Op = CompLe<T1, T2>;
using Op = BinaryCompLe<T1, T2>;
};
// CompLt
template<typename T1, typename T2>
struct CompLtBase
struct BinaryCompLtBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -124,17 +125,17 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct CompLt;
struct BinaryCompLt;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::CompLt, T1, T2>
struct BinaryConstantPropagation<BinaryType::CompLt, T1, T2>
{
using Op = CompLe<T1, T2>;
using Op = BinaryCompLe<T1, T2>;
};
// CompNe
template<typename T1, typename T2>
struct CompNeBase
struct BinaryCompNeBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -143,17 +144,17 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct CompNe;
struct BinaryCompNe;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::CompNe, T1, T2>
struct BinaryConstantPropagation<BinaryType::CompNe, T1, T2>
{
using Op = CompNe<T1, T2>;
using Op = BinaryCompNe<T1, T2>;
};
// Addition
template<typename T1, typename T2>
struct AdditionBase
struct BinaryAdditionBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -162,17 +163,17 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct Addition;
struct BinaryAddition;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::Add, T1, T2>
struct BinaryConstantPropagation<BinaryType::Add, T1, T2>
{
using Op = Addition<T1, T2>;
using Op = BinaryAddition<T1, T2>;
};
// Division
template<typename T1, typename T2>
struct DivisionBase
struct BinaryDivisionBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -181,17 +182,17 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct Division;
struct BinaryDivision;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::Divide, T1, T2>
struct BinaryConstantPropagation<BinaryType::Divide, T1, T2>
{
using Op = Division<T1, T2>;
using Op = BinaryDivision<T1, T2>;
};
// Multiplication
template<typename T1, typename T2>
struct MultiplicationBase
struct BinaryMultiplicationBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -200,17 +201,17 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct Multiplication;
struct BinaryMultiplication;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::Multiply, T1, T2>
struct BinaryConstantPropagation<BinaryType::Multiply, T1, T2>
{
using Op = Multiplication<T1, T2>;
using Op = BinaryMultiplication<T1, T2>;
};
// Subtraction
template<typename T1, typename T2>
struct SubtractionBase
struct BinarySubtractionBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
@@ -219,163 +220,251 @@ namespace Nz::ShaderAst
};
template<typename T1, typename T2>
struct Subtraction;
struct BinarySubtraction;
template<typename T1, typename T2>
struct PropagateConstantType<BinaryType::Subtract, T1, T2>
struct BinaryConstantPropagation<BinaryType::Subtract, T1, T2>
{
using Op = Subtraction<T1, T2>;
using Op = BinarySubtraction<T1, T2>;
};
/*************************************************************************************************/
template<UnaryType Type, typename T>
struct UnaryConstantPropagation;
// LogicalNot
template<typename T>
struct UnaryLogicalNotBase
{
std::unique_ptr<ConstantExpression> operator()(const T& arg)
{
return ShaderBuilder::Constant(!arg);
}
};
#define EnableOptimisation(Op, T1, T2) template<> struct Op<T1, T2> : Op##Base<T1, T2> {}
template<typename T>
struct UnaryLogicalNot;
EnableOptimisation(CompEq, bool, bool);
EnableOptimisation(CompEq, double, double);
EnableOptimisation(CompEq, float, float);
EnableOptimisation(CompEq, Nz::Int32, Nz::Int32);
EnableOptimisation(CompEq, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(CompEq, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(CompEq, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(CompEq, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(CompEq, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(CompEq, Nz::Vector4i32, Nz::Vector4i32);
template<typename T>
struct UnaryConstantPropagation<UnaryType::LogicalNot, T>
{
using Op = UnaryLogicalNot<T>;
};
EnableOptimisation(CompGe, bool, bool);
EnableOptimisation(CompGe, double, double);
EnableOptimisation(CompGe, float, float);
EnableOptimisation(CompGe, Nz::Int32, Nz::Int32);
EnableOptimisation(CompGe, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(CompGe, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(CompGe, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(CompGe, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(CompGe, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(CompGe, Nz::Vector4i32, Nz::Vector4i32);
// Minus
template<typename T>
struct UnaryMinusBase
{
std::unique_ptr<ConstantExpression> operator()(const T& arg)
{
return ShaderBuilder::Constant(-arg);
}
};
EnableOptimisation(CompGt, bool, bool);
EnableOptimisation(CompGt, double, double);
EnableOptimisation(CompGt, float, float);
EnableOptimisation(CompGt, Nz::Int32, Nz::Int32);
EnableOptimisation(CompGt, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(CompGt, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(CompGt, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(CompGt, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(CompGt, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(CompGt, Nz::Vector4i32, Nz::Vector4i32);
template<typename T>
struct UnaryMinus;
EnableOptimisation(CompLe, bool, bool);
EnableOptimisation(CompLe, double, double);
EnableOptimisation(CompLe, float, float);
EnableOptimisation(CompLe, Nz::Int32, Nz::Int32);
EnableOptimisation(CompLe, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(CompLe, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(CompLe, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(CompLe, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(CompLe, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(CompLe, Nz::Vector4i32, Nz::Vector4i32);
template<typename T>
struct UnaryConstantPropagation<UnaryType::Minus, T>
{
using Op = UnaryMinus<T>;
};
EnableOptimisation(CompLt, bool, bool);
EnableOptimisation(CompLt, double, double);
EnableOptimisation(CompLt, float, float);
EnableOptimisation(CompLt, Nz::Int32, Nz::Int32);
EnableOptimisation(CompLt, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(CompLt, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(CompLt, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(CompLt, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(CompLt, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(CompLt, Nz::Vector4i32, Nz::Vector4i32);
// Plus
template<typename T>
struct UnaryPlusBase
{
std::unique_ptr<ConstantExpression> operator()(const T& arg)
{
return ShaderBuilder::Constant(arg);
}
};
EnableOptimisation(CompNe, bool, bool);
EnableOptimisation(CompNe, double, double);
EnableOptimisation(CompNe, float, float);
EnableOptimisation(CompNe, Nz::Int32, Nz::Int32);
EnableOptimisation(CompNe, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(CompNe, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(CompNe, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(CompNe, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(CompNe, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(CompNe, Nz::Vector4i32, Nz::Vector4i32);
template<typename T>
struct UnaryPlus;
EnableOptimisation(Addition, double, double);
EnableOptimisation(Addition, float, float);
EnableOptimisation(Addition, Nz::Int32, Nz::Int32);
EnableOptimisation(Addition, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(Addition, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(Addition, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(Addition, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(Addition, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(Addition, Nz::Vector4i32, Nz::Vector4i32);
template<typename T>
struct UnaryConstantPropagation<UnaryType::Plus, T>
{
using Op = UnaryPlus<T>;
};
EnableOptimisation(Division, double, double);
EnableOptimisation(Division, double, Nz::Vector2d);
EnableOptimisation(Division, double, Nz::Vector3d);
EnableOptimisation(Division, double, Nz::Vector4d);
EnableOptimisation(Division, float, float);
EnableOptimisation(Division, float, Nz::Vector2f);
EnableOptimisation(Division, float, Nz::Vector3f);
EnableOptimisation(Division, float, Nz::Vector4f);
EnableOptimisation(Division, Nz::Int32, Nz::Int32);
EnableOptimisation(Division, Nz::Int32, Nz::Vector2i32);
EnableOptimisation(Division, Nz::Int32, Nz::Vector3i32);
EnableOptimisation(Division, Nz::Int32, Nz::Vector4i32);
EnableOptimisation(Division, Nz::Vector2f, float);
EnableOptimisation(Division, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(Division, Nz::Vector3f, float);
EnableOptimisation(Division, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(Division, Nz::Vector4f, float);
EnableOptimisation(Division, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(Division, Nz::Vector2d, double);
EnableOptimisation(Division, Nz::Vector2d, Nz::Vector2d);
EnableOptimisation(Division, Nz::Vector3d, double);
EnableOptimisation(Division, Nz::Vector3d, Nz::Vector3d);
EnableOptimisation(Division, Nz::Vector4d, double);
EnableOptimisation(Division, Nz::Vector4d, Nz::Vector4d);
EnableOptimisation(Division, Nz::Vector2i32, Nz::Int32);
EnableOptimisation(Division, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(Division, Nz::Vector3i32, Nz::Int32);
EnableOptimisation(Division, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(Division, Nz::Vector4i32, Nz::Int32);
EnableOptimisation(Division, Nz::Vector4i32, Nz::Vector4i32);
#define EnableOptimisation(Op, ...) template<> struct Op<__VA_ARGS__> : Op##Base<__VA_ARGS__> {}
EnableOptimisation(Multiplication, double, double);
EnableOptimisation(Multiplication, double, Nz::Vector2d);
EnableOptimisation(Multiplication, double, Nz::Vector3d);
EnableOptimisation(Multiplication, double, Nz::Vector4d);
EnableOptimisation(Multiplication, float, float);
EnableOptimisation(Multiplication, float, Nz::Vector2f);
EnableOptimisation(Multiplication, float, Nz::Vector3f);
EnableOptimisation(Multiplication, float, Nz::Vector4f);
EnableOptimisation(Multiplication, Nz::Int32, Nz::Int32);
EnableOptimisation(Multiplication, Nz::Int32, Nz::Vector2i32);
EnableOptimisation(Multiplication, Nz::Int32, Nz::Vector3i32);
EnableOptimisation(Multiplication, Nz::Int32, Nz::Vector4i32);
EnableOptimisation(Multiplication, Nz::Vector2f, float);
EnableOptimisation(Multiplication, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(Multiplication, Nz::Vector3f, float);
EnableOptimisation(Multiplication, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(Multiplication, Nz::Vector4f, float);
EnableOptimisation(Multiplication, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(Multiplication, Nz::Vector2d, double);
EnableOptimisation(Multiplication, Nz::Vector2d, Nz::Vector2d);
EnableOptimisation(Multiplication, Nz::Vector3d, double);
EnableOptimisation(Multiplication, Nz::Vector3d, Nz::Vector3d);
EnableOptimisation(Multiplication, Nz::Vector4d, double);
EnableOptimisation(Multiplication, Nz::Vector4d, Nz::Vector4d);
EnableOptimisation(Multiplication, Nz::Vector2i32, Nz::Int32);
EnableOptimisation(Multiplication, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(Multiplication, Nz::Vector3i32, Nz::Int32);
EnableOptimisation(Multiplication, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(Multiplication, Nz::Vector4i32, Nz::Int32);
EnableOptimisation(Multiplication, Nz::Vector4i32, Nz::Vector4i32);
// Binary
EnableOptimisation(Subtraction, double, double);
EnableOptimisation(Subtraction, float, float);
EnableOptimisation(Subtraction, Nz::Int32, Nz::Int32);
EnableOptimisation(Subtraction, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(Subtraction, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(Subtraction, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(Subtraction, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(Subtraction, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(Subtraction, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinaryCompEq, bool, bool);
EnableOptimisation(BinaryCompEq, double, double);
EnableOptimisation(BinaryCompEq, float, float);
EnableOptimisation(BinaryCompEq, Nz::Int32, Nz::Int32);
EnableOptimisation(BinaryCompEq, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinaryCompEq, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinaryCompEq, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinaryCompEq, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinaryCompEq, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinaryCompEq, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinaryCompGe, bool, bool);
EnableOptimisation(BinaryCompGe, double, double);
EnableOptimisation(BinaryCompGe, float, float);
EnableOptimisation(BinaryCompGe, Nz::Int32, Nz::Int32);
EnableOptimisation(BinaryCompGe, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinaryCompGe, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinaryCompGe, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinaryCompGe, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinaryCompGe, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinaryCompGe, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinaryCompGt, bool, bool);
EnableOptimisation(BinaryCompGt, double, double);
EnableOptimisation(BinaryCompGt, float, float);
EnableOptimisation(BinaryCompGt, Nz::Int32, Nz::Int32);
EnableOptimisation(BinaryCompGt, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinaryCompGt, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinaryCompGt, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinaryCompGt, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinaryCompGt, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinaryCompGt, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinaryCompLe, bool, bool);
EnableOptimisation(BinaryCompLe, double, double);
EnableOptimisation(BinaryCompLe, float, float);
EnableOptimisation(BinaryCompLe, Nz::Int32, Nz::Int32);
EnableOptimisation(BinaryCompLe, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinaryCompLe, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinaryCompLe, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinaryCompLe, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinaryCompLe, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinaryCompLe, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinaryCompLt, bool, bool);
EnableOptimisation(BinaryCompLt, double, double);
EnableOptimisation(BinaryCompLt, float, float);
EnableOptimisation(BinaryCompLt, Nz::Int32, Nz::Int32);
EnableOptimisation(BinaryCompLt, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinaryCompLt, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinaryCompLt, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinaryCompLt, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinaryCompLt, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinaryCompLt, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinaryCompNe, bool, bool);
EnableOptimisation(BinaryCompNe, double, double);
EnableOptimisation(BinaryCompNe, float, float);
EnableOptimisation(BinaryCompNe, Nz::Int32, Nz::Int32);
EnableOptimisation(BinaryCompNe, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinaryCompNe, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinaryCompNe, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinaryCompNe, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinaryCompNe, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinaryCompNe, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinaryAddition, double, double);
EnableOptimisation(BinaryAddition, float, float);
EnableOptimisation(BinaryAddition, Nz::Int32, Nz::Int32);
EnableOptimisation(BinaryAddition, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinaryAddition, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinaryAddition, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinaryAddition, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinaryAddition, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinaryAddition, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinaryDivision, double, double);
EnableOptimisation(BinaryDivision, double, Nz::Vector2d);
EnableOptimisation(BinaryDivision, double, Nz::Vector3d);
EnableOptimisation(BinaryDivision, double, Nz::Vector4d);
EnableOptimisation(BinaryDivision, float, float);
EnableOptimisation(BinaryDivision, float, Nz::Vector2f);
EnableOptimisation(BinaryDivision, float, Nz::Vector3f);
EnableOptimisation(BinaryDivision, float, Nz::Vector4f);
EnableOptimisation(BinaryDivision, Nz::Int32, Nz::Int32);
EnableOptimisation(BinaryDivision, Nz::Int32, Nz::Vector2i32);
EnableOptimisation(BinaryDivision, Nz::Int32, Nz::Vector3i32);
EnableOptimisation(BinaryDivision, Nz::Int32, Nz::Vector4i32);
EnableOptimisation(BinaryDivision, Nz::Vector2f, float);
EnableOptimisation(BinaryDivision, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinaryDivision, Nz::Vector3f, float);
EnableOptimisation(BinaryDivision, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinaryDivision, Nz::Vector4f, float);
EnableOptimisation(BinaryDivision, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinaryDivision, Nz::Vector2d, double);
EnableOptimisation(BinaryDivision, Nz::Vector2d, Nz::Vector2d);
EnableOptimisation(BinaryDivision, Nz::Vector3d, double);
EnableOptimisation(BinaryDivision, Nz::Vector3d, Nz::Vector3d);
EnableOptimisation(BinaryDivision, Nz::Vector4d, double);
EnableOptimisation(BinaryDivision, Nz::Vector4d, Nz::Vector4d);
EnableOptimisation(BinaryDivision, Nz::Vector2i32, Nz::Int32);
EnableOptimisation(BinaryDivision, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinaryDivision, Nz::Vector3i32, Nz::Int32);
EnableOptimisation(BinaryDivision, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinaryDivision, Nz::Vector4i32, Nz::Int32);
EnableOptimisation(BinaryDivision, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinaryMultiplication, double, double);
EnableOptimisation(BinaryMultiplication, double, Nz::Vector2d);
EnableOptimisation(BinaryMultiplication, double, Nz::Vector3d);
EnableOptimisation(BinaryMultiplication, double, Nz::Vector4d);
EnableOptimisation(BinaryMultiplication, float, float);
EnableOptimisation(BinaryMultiplication, float, Nz::Vector2f);
EnableOptimisation(BinaryMultiplication, float, Nz::Vector3f);
EnableOptimisation(BinaryMultiplication, float, Nz::Vector4f);
EnableOptimisation(BinaryMultiplication, Nz::Int32, Nz::Int32);
EnableOptimisation(BinaryMultiplication, Nz::Int32, Nz::Vector2i32);
EnableOptimisation(BinaryMultiplication, Nz::Int32, Nz::Vector3i32);
EnableOptimisation(BinaryMultiplication, Nz::Int32, Nz::Vector4i32);
EnableOptimisation(BinaryMultiplication, Nz::Vector2f, float);
EnableOptimisation(BinaryMultiplication, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinaryMultiplication, Nz::Vector3f, float);
EnableOptimisation(BinaryMultiplication, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinaryMultiplication, Nz::Vector4f, float);
EnableOptimisation(BinaryMultiplication, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinaryMultiplication, Nz::Vector2d, double);
EnableOptimisation(BinaryMultiplication, Nz::Vector2d, Nz::Vector2d);
EnableOptimisation(BinaryMultiplication, Nz::Vector3d, double);
EnableOptimisation(BinaryMultiplication, Nz::Vector3d, Nz::Vector3d);
EnableOptimisation(BinaryMultiplication, Nz::Vector4d, double);
EnableOptimisation(BinaryMultiplication, Nz::Vector4d, Nz::Vector4d);
EnableOptimisation(BinaryMultiplication, Nz::Vector2i32, Nz::Int32);
EnableOptimisation(BinaryMultiplication, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinaryMultiplication, Nz::Vector3i32, Nz::Int32);
EnableOptimisation(BinaryMultiplication, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinaryMultiplication, Nz::Vector4i32, Nz::Int32);
EnableOptimisation(BinaryMultiplication, Nz::Vector4i32, Nz::Vector4i32);
EnableOptimisation(BinarySubtraction, double, double);
EnableOptimisation(BinarySubtraction, float, float);
EnableOptimisation(BinarySubtraction, Nz::Int32, Nz::Int32);
EnableOptimisation(BinarySubtraction, Nz::Vector2f, Nz::Vector2f);
EnableOptimisation(BinarySubtraction, Nz::Vector3f, Nz::Vector3f);
EnableOptimisation(BinarySubtraction, Nz::Vector4f, Nz::Vector4f);
EnableOptimisation(BinarySubtraction, Nz::Vector2i32, Nz::Vector2i32);
EnableOptimisation(BinarySubtraction, Nz::Vector3i32, Nz::Vector3i32);
EnableOptimisation(BinarySubtraction, Nz::Vector4i32, Nz::Vector4i32);
// Unary
EnableOptimisation(UnaryLogicalNot, bool);
EnableOptimisation(UnaryMinus, double);
EnableOptimisation(UnaryMinus, float);
EnableOptimisation(UnaryMinus, Nz::Int32);
EnableOptimisation(UnaryMinus, Nz::Vector2f);
EnableOptimisation(UnaryMinus, Nz::Vector3f);
EnableOptimisation(UnaryMinus, Nz::Vector4f);
EnableOptimisation(UnaryMinus, Nz::Vector2i32);
EnableOptimisation(UnaryMinus, Nz::Vector3i32);
EnableOptimisation(UnaryMinus, Nz::Vector4i32);
EnableOptimisation(UnaryPlus, double);
EnableOptimisation(UnaryPlus, float);
EnableOptimisation(UnaryPlus, Nz::Int32);
EnableOptimisation(UnaryPlus, Nz::Vector2f);
EnableOptimisation(UnaryPlus, Nz::Vector3f);
EnableOptimisation(UnaryPlus, Nz::Vector4f);
EnableOptimisation(UnaryPlus, Nz::Vector2i32);
EnableOptimisation(UnaryPlus, Nz::Vector3i32);
EnableOptimisation(UnaryPlus, Nz::Vector4i32);
#undef EnableOptimisation
}
@@ -525,6 +614,40 @@ namespace Nz::ShaderAst
return AstCloner::Clone(node.falsePath);
}
ExpressionPtr AstOptimizer::Clone(UnaryExpression& node)
{
auto expr = CloneExpression(node.expression);
if (expr->GetType() == NodeType::ConstantExpression)
{
auto constantExpr = static_unique_pointer_cast<ConstantExpression>(std::move(expr));
ExpressionPtr optimized;
switch (node.op)
{
case UnaryType::LogicalNot:
optimized = PropagateConstant<UnaryType::LogicalNot>(std::move(constantExpr));
break;
case UnaryType::Minus:
optimized = PropagateConstant<UnaryType::Minus>(std::move(constantExpr));
break;
case UnaryType::Plus:
optimized = PropagateConstant<UnaryType::Plus>(std::move(constantExpr));
break;
}
if (optimized)
return optimized;
}
auto unary = ShaderBuilder::Unary(node.op, std::move(expr));
unary->cachedExpressionType = node.cachedExpressionType;
return unary;
}
StatementPtr AstOptimizer::Clone(ConditionalStatement& node)
{
if (!m_enabledOptions)
@@ -547,7 +670,7 @@ namespace Nz::ShaderAst
std::visit([&](auto&& arg2)
{
using T2 = std::decay_t<decltype(arg2)>;
using PCType = PropagateConstantType<Type, T1, T2>;
using PCType = BinaryConstantPropagation<Type, T1, T2>;
if constexpr (is_complete_v<PCType>)
{
@@ -564,4 +687,27 @@ namespace Nz::ShaderAst
return optimized;
}
template<UnaryType Type>
ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& operand)
{
std::unique_ptr<ConstantExpression> optimized;
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
using PCType = UnaryConstantPropagation<Type, T>;
if constexpr (is_complete_v<PCType>)
{
using Op = typename PCType::Op;
if constexpr (is_complete_v<Op>)
optimized = Op{}(arg);
}
}, operand->value);
if (optimized)
optimized->cachedExpressionType = optimized->GetExpressionType();
return optimized;
}
}

View File

@@ -73,11 +73,16 @@ namespace Nz::ShaderAst
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(VariableExpression& node)
void AstRecursiveVisitor::Visit(VariableExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(UnaryExpression& node)
{
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(BranchStatement& node)
{
for (auto& cond : node.condStatements)
@@ -95,7 +100,7 @@ namespace Nz::ShaderAst
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareExternalStatement& node)
void AstRecursiveVisitor::Visit(DeclareExternalStatement& /*node*/)
{
/* Nothing to do */
}

View File

@@ -147,6 +147,12 @@ namespace Nz::ShaderAst
SizeT(node.variableId);
}
void AstSerializerBase::Serialize(UnaryExpression& node)
{
Enum(node.op);
Node(node.expression);
}
void AstSerializerBase::Serialize(BranchStatement& node)
{

View File

@@ -86,8 +86,13 @@ namespace Nz::ShaderAst
node.expression->Visit(*this);
}
void ShaderAstValueCategory::Visit(VariableExpression& node)
void ShaderAstValueCategory::Visit(VariableExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(UnaryExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::RValue;
}
}

View File

@@ -592,6 +592,41 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node)
{
auto clone = static_unique_pointer_cast<UnaryExpression>(AstCloner::Clone(node));
const ExpressionType& exprType = GetExpressionType(MandatoryExpr(clone->expression));
if (!IsPrimitiveType(exprType))
throw AstError{ "unary expression operand type does not support unary operation" };
PrimitiveType primitiveType = std::get<PrimitiveType>(exprType);
switch (node.op)
{
case UnaryType::LogicalNot:
{
if (primitiveType != PrimitiveType::Boolean)
throw AstError{ "logical not is only supported on booleans" };
break;
}
case UnaryType::Minus:
case UnaryType::Plus:
{
if (primitiveType != PrimitiveType::Float32 && primitiveType != PrimitiveType::Int32 && primitiveType != PrimitiveType::UInt32)
throw AstError{ "plus and minus unary expressions are only supported on floating points and integers types" };
break;
}
}
clone->cachedExpressionType = primitiveType;
return clone;
}
StatementPtr SanitizeVisitor::Clone(BranchStatement& node)
{
auto clone = std::make_unique<BranchStatement>();