diff --git a/include/Nazara/Shader/Ast/AstOptimizer.hpp b/include/Nazara/Shader/Ast/AstOptimizer.hpp index 967d0cec9..d81eced76 100644 --- a/include/Nazara/Shader/Ast/AstOptimizer.hpp +++ b/include/Nazara/Shader/Ast/AstOptimizer.hpp @@ -31,13 +31,18 @@ namespace Nz::ShaderAst protected: ExpressionPtr Clone(BinaryExpression& node) override; + ExpressionPtr Clone(CastExpression& node) override; ExpressionPtr Clone(ConditionalExpression& node) override; ExpressionPtr Clone(UnaryExpression& node) override; StatementPtr Clone(BranchStatement& node) override; StatementPtr Clone(ConditionalStatement& node) override; - template ExpressionPtr PropagateConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs); - template ExpressionPtr PropagateConstant(std::unique_ptr&& operand); + template ExpressionPtr PropagateBinaryConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs); + template ExpressionPtr PropagateSingleValueCast(std::unique_ptr&& operand); + template ExpressionPtr PropagateUnaryConstant(std::unique_ptr&& operand); + template ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2); + template ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3); + template ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4); private: std::optional m_enabledOptions; diff --git a/include/Nazara/Shader/Ast/Attribute.hpp b/include/Nazara/Shader/Ast/Attribute.hpp index fb7bc091d..5bd1cf589 100644 --- a/include/Nazara/Shader/Ast/Attribute.hpp +++ b/include/Nazara/Shader/Ast/Attribute.hpp @@ -9,6 +9,7 @@ #include #include +#include namespace Nz::ShaderAst { diff --git a/include/Nazara/Shader/Ast/ConstantValue.hpp b/include/Nazara/Shader/Ast/ConstantValue.hpp index 0f8ed1fc4..3064603cf 100644 --- a/include/Nazara/Shader/Ast/ConstantValue.hpp +++ b/include/Nazara/Shader/Ast/ConstantValue.hpp @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include namespace Nz::ShaderAst @@ -27,6 +29,8 @@ namespace Nz::ShaderAst Vector3i32, Vector4i32 >; + + NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant); } #endif diff --git a/include/Nazara/Shader/Ast/ExpressionType.hpp b/include/Nazara/Shader/Ast/ExpressionType.hpp index aa6cd809e..e7a949e5b 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.hpp +++ b/include/Nazara/Shader/Ast/ExpressionType.hpp @@ -9,8 +9,8 @@ #include #include -#include #include +#include #include #include #include diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 37d80a9bd..4bfc7b51a 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -126,8 +126,6 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; - ExpressionType GetExpressionType() const; - ShaderAst::ConstantValue value; }; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 498dfa861..f80e595ff 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -39,6 +40,7 @@ namespace Nz::ShaderBuilder struct Cast { + inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, std::array expressions) const; inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, std::vector expressions) const; }; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 194040b0e..dba57a872 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -58,6 +58,15 @@ namespace Nz::ShaderBuilder return branchNode; } + inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::array expressions) const + { + auto castNode = std::make_unique(); + castNode->expressions = std::move(expressions); + castNode->targetType = std::move(targetType); + + return castNode; + } + inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::vector expressions) const { auto castNode = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstOptimizer.cpp b/src/Nazara/Shader/Ast/AstOptimizer.cpp index e5d869a21..c58ee0f66 100644 --- a/src/Nazara/Shader/Ast/AstOptimizer.cpp +++ b/src/Nazara/Shader/Ast/AstOptimizer.cpp @@ -227,7 +227,27 @@ namespace Nz::ShaderAst { using Op = BinarySubtraction; }; - + + /*************************************************************************************************/ + + template + struct CastConstantBase + { + std::unique_ptr operator()(const Args&... args) + { + return ShaderBuilder::Constant(T(args...)); + } + }; + + template + struct CastConstant; + + template + struct CastConstantPropagation + { + using Op = CastConstant; + }; + /*************************************************************************************************/ template @@ -319,13 +339,13 @@ namespace Nz::ShaderAst EnableOptimisation(BinaryCompGt, bool, bool); EnableOptimisation(BinaryCompGt, double, double); EnableOptimisation(BinaryCompGt, float, float); - EnableOptimisation(BinaryCompGt, Nz::Int32, Nz::Int32); - EnableOptimisation(BinaryCompGt, Nz::Vector2f, Nz::Vector2f); - EnableOptimisation(BinaryCompGt, Nz::Vector3f, Nz::Vector3f); - EnableOptimisation(BinaryCompGt, Nz::Vector4f, Nz::Vector4f); - EnableOptimisation(BinaryCompGt, Nz::Vector2i32, Nz::Vector2i32); - EnableOptimisation(BinaryCompGt, Nz::Vector3i32, Nz::Vector3i32); - EnableOptimisation(BinaryCompGt, Nz::Vector4i32, Nz::Vector4i32); + EnableOptimisation(BinaryCompGt, Int32, Int32); + EnableOptimisation(BinaryCompGt, Vector2f, Vector2f); + EnableOptimisation(BinaryCompGt, Vector3f, Vector3f); + EnableOptimisation(BinaryCompGt, Vector4f, Vector4f); + EnableOptimisation(BinaryCompGt, Vector2i32, Vector2i32); + EnableOptimisation(BinaryCompGt, Vector3i32, Vector3i32); + EnableOptimisation(BinaryCompGt, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompLe, bool, bool); EnableOptimisation(BinaryCompLe, double, double); @@ -442,6 +462,48 @@ namespace Nz::ShaderAst EnableOptimisation(BinarySubtraction, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(BinarySubtraction, Nz::Vector4i32, Nz::Vector4i32); + // Cast + + EnableOptimisation(CastConstant, bool, bool); + EnableOptimisation(CastConstant, bool, Int32); + EnableOptimisation(CastConstant, bool, UInt32); + + EnableOptimisation(CastConstant, double, double); + EnableOptimisation(CastConstant, double, float); + EnableOptimisation(CastConstant, double, Int32); + EnableOptimisation(CastConstant, double, UInt32); + + EnableOptimisation(CastConstant, float, double); + EnableOptimisation(CastConstant, float, float); + EnableOptimisation(CastConstant, float, Int32); + EnableOptimisation(CastConstant, float, UInt32); + + EnableOptimisation(CastConstant, Int32, double); + EnableOptimisation(CastConstant, Int32, float); + EnableOptimisation(CastConstant, Int32, Int32); + EnableOptimisation(CastConstant, Int32, UInt32); + + EnableOptimisation(CastConstant, UInt32, double); + EnableOptimisation(CastConstant, UInt32, float); + EnableOptimisation(CastConstant, UInt32, Int32); + EnableOptimisation(CastConstant, UInt32, UInt32); + + //EnableOptimisation(CastConstant, Vector2d, double, double); + //EnableOptimisation(CastConstant, Vector3d, double, double, double); + //EnableOptimisation(CastConstant, Vector4d, double, double, double, double); + + EnableOptimisation(CastConstant, Vector2f, float, float); + EnableOptimisation(CastConstant, Vector3f, float, float, float); + EnableOptimisation(CastConstant, Vector4f, float, float, float, float); + + EnableOptimisation(CastConstant, Vector2i32, Int32, Int32); + EnableOptimisation(CastConstant, Vector3i32, Int32, Int32, Int32); + EnableOptimisation(CastConstant, Vector4i32, Int32, Int32, Int32, Int32); + + //EnableOptimisation(CastConstant, Vector2ui32, UInt32, UInt32); + //EnableOptimisation(CastConstant, Vector3ui32, UInt32, UInt32, UInt32); + //EnableOptimisation(CastConstant, Vector4ui32, UInt32, UInt32, UInt32, UInt32); + // Unary EnableOptimisation(UnaryLogicalNot, bool); @@ -496,42 +558,43 @@ namespace Nz::ShaderAst switch (node.op) { case BinaryType::Add: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; case BinaryType::Subtract: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; case BinaryType::Multiply: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; case BinaryType::Divide: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; case BinaryType::CompEq: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; case BinaryType::CompGe: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; case BinaryType::CompGt: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; case BinaryType::CompLe: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; case BinaryType::CompLt: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; case BinaryType::CompNe: - optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; } @@ -545,6 +608,123 @@ namespace Nz::ShaderAst return binary; } + ExpressionPtr AstOptimizer::Clone(CastExpression& node) + { + std::array expressions; + + std::size_t expressionCount = 0; + for (const auto& expression : node.expressions) + { + if (!expression) + break; + + expressions[expressionCount] = CloneExpression(expression); + expressionCount++; + } + + ExpressionPtr optimized; + if (IsPrimitiveType(node.targetType)) + { + if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantExpression) + { + auto constantExpr = static_unique_pointer_cast(std::move(expressions.front())); + + switch (std::get(node.targetType)) + { + case PrimitiveType::Boolean: optimized = PropagateSingleValueCast(std::move(constantExpr)); break; + case PrimitiveType::Float32: optimized = PropagateSingleValueCast(std::move(constantExpr)); break; + case PrimitiveType::Int32: optimized = PropagateSingleValueCast(std::move(constantExpr)); break; + case PrimitiveType::UInt32: optimized = PropagateSingleValueCast(std::move(constantExpr)); break; + } + } + } + else if (IsVectorType(node.targetType)) + { + const auto& vecType = std::get(node.targetType); + + // Decompose vector into values (cast(vec3, float) => cast(float, float, float, float)) + std::vector constantValues; + for (std::size_t i = 0; i < expressionCount; ++i) + { + if (expressions[i]->GetType() != NodeType::ConstantExpression) + { + constantValues.clear(); + break; + } + + const auto& constantExpr = static_cast(*expressions[i]); + + if (!constantValues.empty() && GetExpressionType(constantValues.front()) != GetExpressionType(constantExpr.value)) + { + // Unhandled case, all cast parameters are expected to be of the same type + constantValues.clear(); + break; + } + + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) + constantValues.push_back(arg); + else if constexpr (std::is_same_v || std::is_same_v) + { + constantValues.push_back(arg.x); + constantValues.push_back(arg.y); + } + else if constexpr (std::is_same_v || std::is_same_v) + { + constantValues.push_back(arg.x); + constantValues.push_back(arg.y); + constantValues.push_back(arg.z); + } + else if constexpr (std::is_same_v || std::is_same_v) + { + constantValues.push_back(arg.x); + constantValues.push_back(arg.y); + constantValues.push_back(arg.z); + constantValues.push_back(arg.w); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, constantExpr.value); + } + + if (!constantValues.empty()) + { + assert(constantValues.size() == vecType.componentCount); + + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + switch (vecType.componentCount) + { + case 2: + optimized = PropagateVec2Cast(std::get(constantValues[0]), std::get(constantValues[1])); + break; + + case 3: + optimized = PropagateVec3Cast(std::get(constantValues[0]), std::get(constantValues[1]), std::get(constantValues[2])); + break; + + case 4: + optimized = PropagateVec4Cast(std::get(constantValues[0]), std::get(constantValues[1]), std::get(constantValues[2]), std::get(constantValues[3])); + break; + } + }, constantValues.front()); + } + } + + if (optimized) + return optimized; + + auto cast = ShaderBuilder::Cast(node.targetType, std::move(expressions)); + cast->cachedExpressionType = node.cachedExpressionType; + + return cast; + } + StatementPtr AstOptimizer::Clone(BranchStatement& node) { std::vector statements; @@ -626,15 +806,15 @@ namespace Nz::ShaderAst switch (node.op) { case UnaryType::LogicalNot: - optimized = PropagateConstant(std::move(constantExpr)); + optimized = PropagateUnaryConstant(std::move(constantExpr)); break; case UnaryType::Minus: - optimized = PropagateConstant(std::move(constantExpr)); + optimized = PropagateUnaryConstant(std::move(constantExpr)); break; case UnaryType::Plus: - optimized = PropagateConstant(std::move(constantExpr)); + optimized = PropagateUnaryConstant(std::move(constantExpr)); break; } @@ -660,7 +840,7 @@ namespace Nz::ShaderAst } template - ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs) + ExpressionPtr AstOptimizer::PropagateBinaryConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs) { std::unique_ptr optimized; std::visit([&](auto&& arg1) @@ -683,13 +863,34 @@ namespace Nz::ShaderAst }, lhs->value); if (optimized) - optimized->cachedExpressionType = optimized->GetExpressionType(); + optimized->cachedExpressionType = GetExpressionType(optimized->value); + + return optimized; + } + + template + ExpressionPtr AstOptimizer::PropagateSingleValueCast(std::unique_ptr&& operand) + { + std::unique_ptr optimized; + + std::visit([&](auto&& arg) + { + using T = std::decay_t; + using CCType = CastConstantPropagation; + + if constexpr (is_complete_v) + { + using Op = typename CCType::Op; + if constexpr (is_complete_v) + optimized = Op{}(arg); + } + }, operand->value); return optimized; } template - ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr&& operand) + ExpressionPtr AstOptimizer::PropagateUnaryConstant(std::unique_ptr&& operand) { std::unique_ptr optimized; std::visit([&](auto&& arg) @@ -706,7 +907,58 @@ namespace Nz::ShaderAst }, operand->value); if (optimized) - optimized->cachedExpressionType = optimized->GetExpressionType(); + optimized->cachedExpressionType = GetExpressionType(optimized->value); + + return optimized; + } + + template + ExpressionPtr AstOptimizer::PropagateVec2Cast(TargetType v1, TargetType v2) + { + std::unique_ptr optimized; + + using CCType = CastConstantPropagation, TargetType, TargetType>; + + if constexpr (is_complete_v) + { + using Op = typename CCType::Op; + if constexpr (is_complete_v) + optimized = Op{}(v1, v2); + } + + return optimized; + } + + template + ExpressionPtr AstOptimizer::PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3) + { + std::unique_ptr optimized; + + using CCType = CastConstantPropagation, TargetType, TargetType, TargetType>; + + if constexpr (is_complete_v) + { + using Op = typename CCType::Op; + if constexpr (is_complete_v) + optimized = Op{}(v1, v2, v3); + } + + return optimized; + } + + template + ExpressionPtr AstOptimizer::PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4) + { + std::unique_ptr optimized; + + using CCType = CastConstantPropagation, TargetType, TargetType, TargetType, TargetType>; + + if constexpr (is_complete_v) + { + using Op = typename CCType::Op; + if constexpr (is_complete_v) + optimized = Op{}(v1, v2, v3, v4); + } return optimized; } diff --git a/src/Nazara/Shader/Ast/ConstantValue.cpp b/src/Nazara/Shader/Ast/ConstantValue.cpp new file mode 100644 index 000000000..b9398212b --- /dev/null +++ b/src/Nazara/Shader/Ast/ConstantValue.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + ExpressionType GetExpressionType(const ConstantValue& constant) + { + return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + return PrimitiveType::Boolean; + else if constexpr (std::is_same_v) + return PrimitiveType::Float32; + else if constexpr (std::is_same_v) + return PrimitiveType::Int32; + else if constexpr (std::is_same_v) + return PrimitiveType::UInt32; + else if constexpr (std::is_same_v) + return VectorType{ 2, PrimitiveType::Float32 }; + else if constexpr (std::is_same_v) + return VectorType{ 3, PrimitiveType::Float32 }; + else if constexpr (std::is_same_v) + return VectorType{ 4, PrimitiveType::Float32 }; + else if constexpr (std::is_same_v) + return VectorType{ 2, PrimitiveType::Int32 }; + else if constexpr (std::is_same_v) + return VectorType{ 3, PrimitiveType::Int32 }; + else if constexpr (std::is_same_v) + return VectorType{ 4, PrimitiveType::Int32 }; + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, constant); + } +} diff --git a/src/Nazara/Shader/Ast/Nodes.cpp b/src/Nazara/Shader/Ast/Nodes.cpp index b334d76ec..aed845fc5 100644 --- a/src/Nazara/Shader/Ast/Nodes.cpp +++ b/src/Nazara/Shader/Ast/Nodes.cpp @@ -29,35 +29,4 @@ namespace Nz::ShaderAst } #include - - ExpressionType ConstantExpression::GetExpressionType() const - { - return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType - { - using T = std::decay_t; - - if constexpr (std::is_same_v) - return PrimitiveType::Boolean; - else if constexpr (std::is_same_v) - return PrimitiveType::Float32; - else if constexpr (std::is_same_v) - return PrimitiveType::Int32; - else if constexpr (std::is_same_v) - return PrimitiveType::UInt32; - else if constexpr (std::is_same_v) - return VectorType{ 2, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 3, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 4, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 2, PrimitiveType::Int32 }; - else if constexpr (std::is_same_v) - return VectorType{ 3, PrimitiveType::Int32 }; - else if constexpr (std::is_same_v) - return VectorType{ 4, PrimitiveType::Int32 }; - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, value); - } } diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 54fb5f8eb..5e96ddb0b 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -413,7 +413,7 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); - clone->cachedExpressionType = clone->GetExpressionType(); + clone->cachedExpressionType = GetExpressionType(clone->value); return clone; }