Shader: Add constant cast optimization
This commit is contained in:
parent
e716b44aa3
commit
a002d5c210
|
|
@ -31,13 +31,18 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
ExpressionPtr Clone(BinaryExpression& node) override;
|
ExpressionPtr Clone(BinaryExpression& node) override;
|
||||||
|
ExpressionPtr Clone(CastExpression& node) override;
|
||||||
ExpressionPtr Clone(ConditionalExpression& node) override;
|
ExpressionPtr Clone(ConditionalExpression& node) override;
|
||||||
ExpressionPtr Clone(UnaryExpression& node) override;
|
ExpressionPtr Clone(UnaryExpression& node) override;
|
||||||
StatementPtr Clone(BranchStatement& node) override;
|
StatementPtr Clone(BranchStatement& node) override;
|
||||||
StatementPtr Clone(ConditionalStatement& node) override;
|
StatementPtr Clone(ConditionalStatement& node) override;
|
||||||
|
|
||||||
template<BinaryType Type> ExpressionPtr PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs);
|
template<BinaryType Type> ExpressionPtr PropagateBinaryConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs);
|
||||||
template<UnaryType Type> ExpressionPtr PropagateConstant(std::unique_ptr<ConstantExpression>&& operand);
|
template<typename TargetType> ExpressionPtr PropagateSingleValueCast(std::unique_ptr<ConstantExpression>&& operand);
|
||||||
|
template<UnaryType Type> ExpressionPtr PropagateUnaryConstant(std::unique_ptr<ConstantExpression>&& operand);
|
||||||
|
template<typename TargetType> ExpressionPtr PropagateVec2Cast(TargetType v1, TargetType v2);
|
||||||
|
template<typename TargetType> ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3);
|
||||||
|
template<typename TargetType> ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::optional<UInt64> m_enabledOptions;
|
std::optional<UInt64> m_enabledOptions;
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
#include <Nazara/Prerequisites.hpp>
|
#include <Nazara/Prerequisites.hpp>
|
||||||
#include <Nazara/Shader/Ast/Enums.hpp>
|
#include <Nazara/Shader/Ast/Enums.hpp>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
namespace Nz::ShaderAst
|
namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@
|
||||||
#include <Nazara/Math/Vector2.hpp>
|
#include <Nazara/Math/Vector2.hpp>
|
||||||
#include <Nazara/Math/Vector3.hpp>
|
#include <Nazara/Math/Vector3.hpp>
|
||||||
#include <Nazara/Math/Vector4.hpp>
|
#include <Nazara/Math/Vector4.hpp>
|
||||||
|
#include <Nazara/Shader/Config.hpp>
|
||||||
|
#include <Nazara/Shader/Ast/ExpressionType.hpp>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
namespace Nz::ShaderAst
|
namespace Nz::ShaderAst
|
||||||
|
|
@ -27,6 +29,8 @@ namespace Nz::ShaderAst
|
||||||
Vector3i32,
|
Vector3i32,
|
||||||
Vector4i32
|
Vector4i32
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@
|
||||||
|
|
||||||
#include <Nazara/Prerequisites.hpp>
|
#include <Nazara/Prerequisites.hpp>
|
||||||
#include <Nazara/Utility/Enums.hpp>
|
#include <Nazara/Utility/Enums.hpp>
|
||||||
#include <Nazara/Shader/Ast/Enums.hpp>
|
|
||||||
#include <Nazara/Shader/Ast/Attribute.hpp>
|
#include <Nazara/Shader/Ast/Attribute.hpp>
|
||||||
|
#include <Nazara/Shader/Ast/Enums.hpp>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
|
||||||
|
|
@ -126,8 +126,6 @@ namespace Nz::ShaderAst
|
||||||
NodeType GetType() const override;
|
NodeType GetType() const override;
|
||||||
void Visit(AstExpressionVisitor& visitor) override;
|
void Visit(AstExpressionVisitor& visitor) override;
|
||||||
|
|
||||||
ExpressionType GetExpressionType() const;
|
|
||||||
|
|
||||||
ShaderAst::ConstantValue value;
|
ShaderAst::ConstantValue value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
#include <Nazara/Prerequisites.hpp>
|
#include <Nazara/Prerequisites.hpp>
|
||||||
#include <Nazara/Shader/Ast/Nodes.hpp>
|
#include <Nazara/Shader/Ast/Nodes.hpp>
|
||||||
|
#include <array>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
|
|
@ -39,6 +40,7 @@ namespace Nz::ShaderBuilder
|
||||||
|
|
||||||
struct Cast
|
struct Cast
|
||||||
{
|
{
|
||||||
|
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::array<ShaderAst::ExpressionPtr, 4> expressions) const;
|
||||||
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::vector<ShaderAst::ExpressionPtr> expressions) const;
|
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::vector<ShaderAst::ExpressionPtr> expressions) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,15 @@ namespace Nz::ShaderBuilder
|
||||||
return branchNode;
|
return branchNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::unique_ptr<ShaderAst::CastExpression> Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::array<ShaderAst::ExpressionPtr, 4> expressions) const
|
||||||
|
{
|
||||||
|
auto castNode = std::make_unique<ShaderAst::CastExpression>();
|
||||||
|
castNode->expressions = std::move(expressions);
|
||||||
|
castNode->targetType = std::move(targetType);
|
||||||
|
|
||||||
|
return castNode;
|
||||||
|
}
|
||||||
|
|
||||||
inline std::unique_ptr<ShaderAst::CastExpression> Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::vector<ShaderAst::ExpressionPtr> expressions) const
|
inline std::unique_ptr<ShaderAst::CastExpression> Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::vector<ShaderAst::ExpressionPtr> expressions) const
|
||||||
{
|
{
|
||||||
auto castNode = std::make_unique<ShaderAst::CastExpression>();
|
auto castNode = std::make_unique<ShaderAst::CastExpression>();
|
||||||
|
|
|
||||||
|
|
@ -227,7 +227,27 @@ namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
using Op = BinarySubtraction<T1, T2>;
|
using Op = BinarySubtraction<T1, T2>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*************************************************************************************************/
|
||||||
|
|
||||||
|
template<typename T, typename... Args>
|
||||||
|
struct CastConstantBase
|
||||||
|
{
|
||||||
|
std::unique_ptr<ConstantExpression> operator()(const Args&... args)
|
||||||
|
{
|
||||||
|
return ShaderBuilder::Constant(T(args...));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T, typename... Args>
|
||||||
|
struct CastConstant;
|
||||||
|
|
||||||
|
template<typename T, typename... Args>
|
||||||
|
struct CastConstantPropagation
|
||||||
|
{
|
||||||
|
using Op = CastConstant<T, Args...>;
|
||||||
|
};
|
||||||
|
|
||||||
/*************************************************************************************************/
|
/*************************************************************************************************/
|
||||||
|
|
||||||
template<UnaryType Type, typename T>
|
template<UnaryType Type, typename T>
|
||||||
|
|
@ -319,13 +339,13 @@ namespace Nz::ShaderAst
|
||||||
EnableOptimisation(BinaryCompGt, bool, bool);
|
EnableOptimisation(BinaryCompGt, bool, bool);
|
||||||
EnableOptimisation(BinaryCompGt, double, double);
|
EnableOptimisation(BinaryCompGt, double, double);
|
||||||
EnableOptimisation(BinaryCompGt, float, float);
|
EnableOptimisation(BinaryCompGt, float, float);
|
||||||
EnableOptimisation(BinaryCompGt, Nz::Int32, Nz::Int32);
|
EnableOptimisation(BinaryCompGt, Int32, Int32);
|
||||||
EnableOptimisation(BinaryCompGt, Nz::Vector2f, Nz::Vector2f);
|
EnableOptimisation(BinaryCompGt, Vector2f, Vector2f);
|
||||||
EnableOptimisation(BinaryCompGt, Nz::Vector3f, Nz::Vector3f);
|
EnableOptimisation(BinaryCompGt, Vector3f, Vector3f);
|
||||||
EnableOptimisation(BinaryCompGt, Nz::Vector4f, Nz::Vector4f);
|
EnableOptimisation(BinaryCompGt, Vector4f, Vector4f);
|
||||||
EnableOptimisation(BinaryCompGt, Nz::Vector2i32, Nz::Vector2i32);
|
EnableOptimisation(BinaryCompGt, Vector2i32, Vector2i32);
|
||||||
EnableOptimisation(BinaryCompGt, Nz::Vector3i32, Nz::Vector3i32);
|
EnableOptimisation(BinaryCompGt, Vector3i32, Vector3i32);
|
||||||
EnableOptimisation(BinaryCompGt, Nz::Vector4i32, Nz::Vector4i32);
|
EnableOptimisation(BinaryCompGt, Vector4i32, Vector4i32);
|
||||||
|
|
||||||
EnableOptimisation(BinaryCompLe, bool, bool);
|
EnableOptimisation(BinaryCompLe, bool, bool);
|
||||||
EnableOptimisation(BinaryCompLe, double, double);
|
EnableOptimisation(BinaryCompLe, double, double);
|
||||||
|
|
@ -442,6 +462,48 @@ namespace Nz::ShaderAst
|
||||||
EnableOptimisation(BinarySubtraction, Nz::Vector3i32, Nz::Vector3i32);
|
EnableOptimisation(BinarySubtraction, Nz::Vector3i32, Nz::Vector3i32);
|
||||||
EnableOptimisation(BinarySubtraction, Nz::Vector4i32, Nz::Vector4i32);
|
EnableOptimisation(BinarySubtraction, Nz::Vector4i32, Nz::Vector4i32);
|
||||||
|
|
||||||
|
// Cast
|
||||||
|
|
||||||
|
EnableOptimisation(CastConstant, bool, bool);
|
||||||
|
EnableOptimisation(CastConstant, bool, Int32);
|
||||||
|
EnableOptimisation(CastConstant, bool, UInt32);
|
||||||
|
|
||||||
|
EnableOptimisation(CastConstant, double, double);
|
||||||
|
EnableOptimisation(CastConstant, double, float);
|
||||||
|
EnableOptimisation(CastConstant, double, Int32);
|
||||||
|
EnableOptimisation(CastConstant, double, UInt32);
|
||||||
|
|
||||||
|
EnableOptimisation(CastConstant, float, double);
|
||||||
|
EnableOptimisation(CastConstant, float, float);
|
||||||
|
EnableOptimisation(CastConstant, float, Int32);
|
||||||
|
EnableOptimisation(CastConstant, float, UInt32);
|
||||||
|
|
||||||
|
EnableOptimisation(CastConstant, Int32, double);
|
||||||
|
EnableOptimisation(CastConstant, Int32, float);
|
||||||
|
EnableOptimisation(CastConstant, Int32, Int32);
|
||||||
|
EnableOptimisation(CastConstant, Int32, UInt32);
|
||||||
|
|
||||||
|
EnableOptimisation(CastConstant, UInt32, double);
|
||||||
|
EnableOptimisation(CastConstant, UInt32, float);
|
||||||
|
EnableOptimisation(CastConstant, UInt32, Int32);
|
||||||
|
EnableOptimisation(CastConstant, UInt32, UInt32);
|
||||||
|
|
||||||
|
//EnableOptimisation(CastConstant, Vector2d, double, double);
|
||||||
|
//EnableOptimisation(CastConstant, Vector3d, double, double, double);
|
||||||
|
//EnableOptimisation(CastConstant, Vector4d, double, double, double, double);
|
||||||
|
|
||||||
|
EnableOptimisation(CastConstant, Vector2f, float, float);
|
||||||
|
EnableOptimisation(CastConstant, Vector3f, float, float, float);
|
||||||
|
EnableOptimisation(CastConstant, Vector4f, float, float, float, float);
|
||||||
|
|
||||||
|
EnableOptimisation(CastConstant, Vector2i32, Int32, Int32);
|
||||||
|
EnableOptimisation(CastConstant, Vector3i32, Int32, Int32, Int32);
|
||||||
|
EnableOptimisation(CastConstant, Vector4i32, Int32, Int32, Int32, Int32);
|
||||||
|
|
||||||
|
//EnableOptimisation(CastConstant, Vector2ui32, UInt32, UInt32);
|
||||||
|
//EnableOptimisation(CastConstant, Vector3ui32, UInt32, UInt32, UInt32);
|
||||||
|
//EnableOptimisation(CastConstant, Vector4ui32, UInt32, UInt32, UInt32, UInt32);
|
||||||
|
|
||||||
// Unary
|
// Unary
|
||||||
|
|
||||||
EnableOptimisation(UnaryLogicalNot, bool);
|
EnableOptimisation(UnaryLogicalNot, bool);
|
||||||
|
|
@ -496,42 +558,43 @@ namespace Nz::ShaderAst
|
||||||
switch (node.op)
|
switch (node.op)
|
||||||
{
|
{
|
||||||
case BinaryType::Add:
|
case BinaryType::Add:
|
||||||
optimized = PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BinaryType::Subtract:
|
case BinaryType::Subtract:
|
||||||
optimized = PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
|
break;
|
||||||
|
|
||||||
case BinaryType::Multiply:
|
case BinaryType::Multiply:
|
||||||
optimized = PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BinaryType::Divide:
|
case BinaryType::Divide:
|
||||||
optimized = PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BinaryType::CompEq:
|
case BinaryType::CompEq:
|
||||||
optimized = PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BinaryType::CompGe:
|
case BinaryType::CompGe:
|
||||||
optimized = PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BinaryType::CompGt:
|
case BinaryType::CompGt:
|
||||||
optimized = PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BinaryType::CompLe:
|
case BinaryType::CompLe:
|
||||||
optimized = PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BinaryType::CompLt:
|
case BinaryType::CompLt:
|
||||||
optimized = PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BinaryType::CompNe:
|
case BinaryType::CompNe:
|
||||||
optimized = PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
|
optimized = PropagateBinaryConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -545,6 +608,123 @@ namespace Nz::ShaderAst
|
||||||
return binary;
|
return binary;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ExpressionPtr AstOptimizer::Clone(CastExpression& node)
|
||||||
|
{
|
||||||
|
std::array<ExpressionPtr, 4> expressions;
|
||||||
|
|
||||||
|
std::size_t expressionCount = 0;
|
||||||
|
for (const auto& expression : node.expressions)
|
||||||
|
{
|
||||||
|
if (!expression)
|
||||||
|
break;
|
||||||
|
|
||||||
|
expressions[expressionCount] = CloneExpression(expression);
|
||||||
|
expressionCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExpressionPtr optimized;
|
||||||
|
if (IsPrimitiveType(node.targetType))
|
||||||
|
{
|
||||||
|
if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantExpression)
|
||||||
|
{
|
||||||
|
auto constantExpr = static_unique_pointer_cast<ConstantExpression>(std::move(expressions.front()));
|
||||||
|
|
||||||
|
switch (std::get<PrimitiveType>(node.targetType))
|
||||||
|
{
|
||||||
|
case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(std::move(constantExpr)); break;
|
||||||
|
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(std::move(constantExpr)); break;
|
||||||
|
case PrimitiveType::Int32: optimized = PropagateSingleValueCast<Int32>(std::move(constantExpr)); break;
|
||||||
|
case PrimitiveType::UInt32: optimized = PropagateSingleValueCast<UInt32>(std::move(constantExpr)); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (IsVectorType(node.targetType))
|
||||||
|
{
|
||||||
|
const auto& vecType = std::get<VectorType>(node.targetType);
|
||||||
|
|
||||||
|
// Decompose vector into values (cast(vec3, float) => cast(float, float, float, float))
|
||||||
|
std::vector<ConstantValue> constantValues;
|
||||||
|
for (std::size_t i = 0; i < expressionCount; ++i)
|
||||||
|
{
|
||||||
|
if (expressions[i]->GetType() != NodeType::ConstantExpression)
|
||||||
|
{
|
||||||
|
constantValues.clear();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& constantExpr = static_cast<ConstantExpression&>(*expressions[i]);
|
||||||
|
|
||||||
|
if (!constantValues.empty() && GetExpressionType(constantValues.front()) != GetExpressionType(constantExpr.value))
|
||||||
|
{
|
||||||
|
// Unhandled case, all cast parameters are expected to be of the same type
|
||||||
|
constantValues.clear();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::visit([&](auto&& arg)
|
||||||
|
{
|
||||||
|
using T = std::decay_t<decltype(arg)>;
|
||||||
|
|
||||||
|
if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
|
||||||
|
constantValues.push_back(arg);
|
||||||
|
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i32>)
|
||||||
|
{
|
||||||
|
constantValues.push_back(arg.x);
|
||||||
|
constantValues.push_back(arg.y);
|
||||||
|
}
|
||||||
|
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i32>)
|
||||||
|
{
|
||||||
|
constantValues.push_back(arg.x);
|
||||||
|
constantValues.push_back(arg.y);
|
||||||
|
constantValues.push_back(arg.z);
|
||||||
|
}
|
||||||
|
else if constexpr (std::is_same_v<T, Vector4f> || std::is_same_v<T, Vector4i32>)
|
||||||
|
{
|
||||||
|
constantValues.push_back(arg.x);
|
||||||
|
constantValues.push_back(arg.y);
|
||||||
|
constantValues.push_back(arg.z);
|
||||||
|
constantValues.push_back(arg.w);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||||
|
}, constantExpr.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!constantValues.empty())
|
||||||
|
{
|
||||||
|
assert(constantValues.size() == vecType.componentCount);
|
||||||
|
|
||||||
|
std::visit([&](auto&& arg)
|
||||||
|
{
|
||||||
|
using T = std::decay_t<decltype(arg)>;
|
||||||
|
|
||||||
|
switch (vecType.componentCount)
|
||||||
|
{
|
||||||
|
case 2:
|
||||||
|
optimized = PropagateVec2Cast(std::get<T>(constantValues[0]), std::get<T>(constantValues[1]));
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 3:
|
||||||
|
optimized = PropagateVec3Cast(std::get<T>(constantValues[0]), std::get<T>(constantValues[1]), std::get<T>(constantValues[2]));
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 4:
|
||||||
|
optimized = PropagateVec4Cast(std::get<T>(constantValues[0]), std::get<T>(constantValues[1]), std::get<T>(constantValues[2]), std::get<T>(constantValues[3]));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}, constantValues.front());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (optimized)
|
||||||
|
return optimized;
|
||||||
|
|
||||||
|
auto cast = ShaderBuilder::Cast(node.targetType, std::move(expressions));
|
||||||
|
cast->cachedExpressionType = node.cachedExpressionType;
|
||||||
|
|
||||||
|
return cast;
|
||||||
|
}
|
||||||
|
|
||||||
StatementPtr AstOptimizer::Clone(BranchStatement& node)
|
StatementPtr AstOptimizer::Clone(BranchStatement& node)
|
||||||
{
|
{
|
||||||
std::vector<BranchStatement::ConditionalStatement> statements;
|
std::vector<BranchStatement::ConditionalStatement> statements;
|
||||||
|
|
@ -626,15 +806,15 @@ namespace Nz::ShaderAst
|
||||||
switch (node.op)
|
switch (node.op)
|
||||||
{
|
{
|
||||||
case UnaryType::LogicalNot:
|
case UnaryType::LogicalNot:
|
||||||
optimized = PropagateConstant<UnaryType::LogicalNot>(std::move(constantExpr));
|
optimized = PropagateUnaryConstant<UnaryType::LogicalNot>(std::move(constantExpr));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case UnaryType::Minus:
|
case UnaryType::Minus:
|
||||||
optimized = PropagateConstant<UnaryType::Minus>(std::move(constantExpr));
|
optimized = PropagateUnaryConstant<UnaryType::Minus>(std::move(constantExpr));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case UnaryType::Plus:
|
case UnaryType::Plus:
|
||||||
optimized = PropagateConstant<UnaryType::Plus>(std::move(constantExpr));
|
optimized = PropagateUnaryConstant<UnaryType::Plus>(std::move(constantExpr));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -660,7 +840,7 @@ namespace Nz::ShaderAst
|
||||||
}
|
}
|
||||||
|
|
||||||
template<BinaryType Type>
|
template<BinaryType Type>
|
||||||
ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
|
ExpressionPtr AstOptimizer::PropagateBinaryConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
|
||||||
{
|
{
|
||||||
std::unique_ptr<ConstantExpression> optimized;
|
std::unique_ptr<ConstantExpression> optimized;
|
||||||
std::visit([&](auto&& arg1)
|
std::visit([&](auto&& arg1)
|
||||||
|
|
@ -683,13 +863,34 @@ namespace Nz::ShaderAst
|
||||||
}, lhs->value);
|
}, lhs->value);
|
||||||
|
|
||||||
if (optimized)
|
if (optimized)
|
||||||
optimized->cachedExpressionType = optimized->GetExpressionType();
|
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
||||||
|
|
||||||
|
return optimized;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TargetType>
|
||||||
|
ExpressionPtr AstOptimizer::PropagateSingleValueCast(std::unique_ptr<ConstantExpression>&& operand)
|
||||||
|
{
|
||||||
|
std::unique_ptr<ConstantExpression> optimized;
|
||||||
|
|
||||||
|
std::visit([&](auto&& arg)
|
||||||
|
{
|
||||||
|
using T = std::decay_t<decltype(arg)>;
|
||||||
|
using CCType = CastConstantPropagation<TargetType, T>;
|
||||||
|
|
||||||
|
if constexpr (is_complete_v<CCType>)
|
||||||
|
{
|
||||||
|
using Op = typename CCType::Op;
|
||||||
|
if constexpr (is_complete_v<Op>)
|
||||||
|
optimized = Op{}(arg);
|
||||||
|
}
|
||||||
|
}, operand->value);
|
||||||
|
|
||||||
return optimized;
|
return optimized;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<UnaryType Type>
|
template<UnaryType Type>
|
||||||
ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& operand)
|
ExpressionPtr AstOptimizer::PropagateUnaryConstant(std::unique_ptr<ConstantExpression>&& operand)
|
||||||
{
|
{
|
||||||
std::unique_ptr<ConstantExpression> optimized;
|
std::unique_ptr<ConstantExpression> optimized;
|
||||||
std::visit([&](auto&& arg)
|
std::visit([&](auto&& arg)
|
||||||
|
|
@ -706,7 +907,58 @@ namespace Nz::ShaderAst
|
||||||
}, operand->value);
|
}, operand->value);
|
||||||
|
|
||||||
if (optimized)
|
if (optimized)
|
||||||
optimized->cachedExpressionType = optimized->GetExpressionType();
|
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
||||||
|
|
||||||
|
return optimized;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TargetType>
|
||||||
|
ExpressionPtr AstOptimizer::PropagateVec2Cast(TargetType v1, TargetType v2)
|
||||||
|
{
|
||||||
|
std::unique_ptr<ConstantExpression> optimized;
|
||||||
|
|
||||||
|
using CCType = CastConstantPropagation<Vector2<TargetType>, TargetType, TargetType>;
|
||||||
|
|
||||||
|
if constexpr (is_complete_v<CCType>)
|
||||||
|
{
|
||||||
|
using Op = typename CCType::Op;
|
||||||
|
if constexpr (is_complete_v<Op>)
|
||||||
|
optimized = Op{}(v1, v2);
|
||||||
|
}
|
||||||
|
|
||||||
|
return optimized;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TargetType>
|
||||||
|
ExpressionPtr AstOptimizer::PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3)
|
||||||
|
{
|
||||||
|
std::unique_ptr<ConstantExpression> optimized;
|
||||||
|
|
||||||
|
using CCType = CastConstantPropagation<Vector3<TargetType>, TargetType, TargetType, TargetType>;
|
||||||
|
|
||||||
|
if constexpr (is_complete_v<CCType>)
|
||||||
|
{
|
||||||
|
using Op = typename CCType::Op;
|
||||||
|
if constexpr (is_complete_v<Op>)
|
||||||
|
optimized = Op{}(v1, v2, v3);
|
||||||
|
}
|
||||||
|
|
||||||
|
return optimized;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TargetType>
|
||||||
|
ExpressionPtr AstOptimizer::PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4)
|
||||||
|
{
|
||||||
|
std::unique_ptr<ConstantExpression> optimized;
|
||||||
|
|
||||||
|
using CCType = CastConstantPropagation<Vector3<TargetType>, TargetType, TargetType, TargetType, TargetType>;
|
||||||
|
|
||||||
|
if constexpr (is_complete_v<CCType>)
|
||||||
|
{
|
||||||
|
using Op = typename CCType::Op;
|
||||||
|
if constexpr (is_complete_v<Op>)
|
||||||
|
optimized = Op{}(v1, v2, v3, v4);
|
||||||
|
}
|
||||||
|
|
||||||
return optimized;
|
return optimized;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
// Copyright (C) 2020 Jérôme Leclercq
|
||||||
|
// This file is part of the "Nazara Engine - Shader generator"
|
||||||
|
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||||
|
|
||||||
|
#include <Nazara/Shader/Ast/ConstantValue.hpp>
|
||||||
|
#include <Nazara/Shader/Debug.hpp>
|
||||||
|
|
||||||
|
namespace Nz::ShaderAst
|
||||||
|
{
|
||||||
|
ExpressionType GetExpressionType(const ConstantValue& constant)
|
||||||
|
{
|
||||||
|
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
||||||
|
{
|
||||||
|
using T = std::decay_t<decltype(arg)>;
|
||||||
|
|
||||||
|
if constexpr (std::is_same_v<T, bool>)
|
||||||
|
return PrimitiveType::Boolean;
|
||||||
|
else if constexpr (std::is_same_v<T, float>)
|
||||||
|
return PrimitiveType::Float32;
|
||||||
|
else if constexpr (std::is_same_v<T, Int32>)
|
||||||
|
return PrimitiveType::Int32;
|
||||||
|
else if constexpr (std::is_same_v<T, UInt32>)
|
||||||
|
return PrimitiveType::UInt32;
|
||||||
|
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||||
|
return VectorType{ 2, PrimitiveType::Float32 };
|
||||||
|
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||||
|
return VectorType{ 3, PrimitiveType::Float32 };
|
||||||
|
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||||
|
return VectorType{ 4, PrimitiveType::Float32 };
|
||||||
|
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||||
|
return VectorType{ 2, PrimitiveType::Int32 };
|
||||||
|
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||||
|
return VectorType{ 3, PrimitiveType::Int32 };
|
||||||
|
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||||
|
return VectorType{ 4, PrimitiveType::Int32 };
|
||||||
|
else
|
||||||
|
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||||
|
}, constant);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -29,35 +29,4 @@ namespace Nz::ShaderAst
|
||||||
}
|
}
|
||||||
|
|
||||||
#include <Nazara/Shader/Ast/AstNodeList.hpp>
|
#include <Nazara/Shader/Ast/AstNodeList.hpp>
|
||||||
|
|
||||||
ExpressionType ConstantExpression::GetExpressionType() const
|
|
||||||
{
|
|
||||||
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
|
||||||
{
|
|
||||||
using T = std::decay_t<decltype(arg)>;
|
|
||||||
|
|
||||||
if constexpr (std::is_same_v<T, bool>)
|
|
||||||
return PrimitiveType::Boolean;
|
|
||||||
else if constexpr (std::is_same_v<T, float>)
|
|
||||||
return PrimitiveType::Float32;
|
|
||||||
else if constexpr (std::is_same_v<T, Int32>)
|
|
||||||
return PrimitiveType::Int32;
|
|
||||||
else if constexpr (std::is_same_v<T, UInt32>)
|
|
||||||
return PrimitiveType::UInt32;
|
|
||||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
|
||||||
return VectorType{ 2, PrimitiveType::Float32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
|
||||||
return VectorType{ 3, PrimitiveType::Float32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
|
||||||
return VectorType{ 4, PrimitiveType::Float32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
|
||||||
return VectorType{ 2, PrimitiveType::Int32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
|
||||||
return VectorType{ 3, PrimitiveType::Int32 };
|
|
||||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
|
||||||
return VectorType{ 4, PrimitiveType::Int32 };
|
|
||||||
else
|
|
||||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
|
||||||
}, value);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -413,7 +413,7 @@ namespace Nz::ShaderAst
|
||||||
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
|
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
|
||||||
{
|
{
|
||||||
auto clone = static_unique_pointer_cast<ConstantExpression>(AstCloner::Clone(node));
|
auto clone = static_unique_pointer_cast<ConstantExpression>(AstCloner::Clone(node));
|
||||||
clone->cachedExpressionType = clone->GetExpressionType();
|
clone->cachedExpressionType = GetExpressionType(clone->value);
|
||||||
|
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue