From 3a7f5c263092d34e17948bf66787937f2fce0536 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Sat, 17 Apr 2021 14:43:29 +0200 Subject: [PATCH] Shader: Fix optimization --- include/Nazara/Shader/Ast/Nodes.hpp | 2 + src/Nazara/Shader/Ast/AstOptimizer.cpp | 121 +++++++++++----------- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 28 +---- 3 files changed, 66 insertions(+), 85 deletions(-) diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 8ac4942b3..5fc836bef 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -126,6 +126,8 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; + ExpressionType GetExpressionType() const; + ShaderAst::ConstantValue value; }; diff --git a/src/Nazara/Shader/Ast/AstOptimizer.cpp b/src/Nazara/Shader/Ast/AstOptimizer.cpp index 3c00fbd2b..53aff83a9 100644 --- a/src/Nazara/Shader/Ast/AstOptimizer.cpp +++ b/src/Nazara/Shader/Ast/AstOptimizer.cpp @@ -41,7 +41,7 @@ namespace Nz::ShaderAst template struct CompEqBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs == rhs); } @@ -60,7 +60,7 @@ namespace Nz::ShaderAst template struct CompGeBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs >= rhs); } @@ -79,7 +79,7 @@ namespace Nz::ShaderAst template struct CompGtBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs > rhs); } @@ -98,7 +98,7 @@ namespace Nz::ShaderAst template struct CompLeBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs <= rhs); } @@ -117,7 +117,7 @@ namespace Nz::ShaderAst template struct CompLtBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs < rhs); } @@ -136,7 +136,7 @@ namespace Nz::ShaderAst template struct CompNeBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs != rhs); } @@ -155,7 +155,7 @@ namespace Nz::ShaderAst template struct AdditionBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs + rhs); } @@ -174,7 +174,7 @@ namespace Nz::ShaderAst template struct DivisionBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs / rhs); } @@ -193,7 +193,7 @@ namespace Nz::ShaderAst template struct MultiplicationBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs * rhs); } @@ -212,7 +212,7 @@ namespace Nz::ShaderAst template struct SubtractionBase { - ExpressionPtr operator()(const T1& lhs, const T2& rhs) + std::unique_ptr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs - rhs); } @@ -382,17 +382,18 @@ namespace Nz::ShaderAst StatementPtr AstOptimizer::Optimise(StatementPtr& statement) { + m_enabledOptions.reset(); return CloneStatement(statement); } StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions) { - m_enabledConditions = enabledConditions; + m_enabledOptions = enabledConditions; return CloneStatement(statement); } - void AstOptimizer::Visit(BinaryExpression& node) + ExpressionPtr AstOptimizer::Clone(BinaryExpression& node) { auto lhs = CloneExpression(node.left); auto rhs = CloneExpression(node.right); @@ -402,44 +403,60 @@ namespace Nz::ShaderAst auto lhsConstant = static_unique_pointer_cast(std::move(lhs)); auto rhsConstant = static_unique_pointer_cast(std::move(rhs)); + ExpressionPtr optimized; switch (node.op) { case BinaryType::Add: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; case BinaryType::Subtract: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); case BinaryType::Multiply: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; case BinaryType::Divide: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; case BinaryType::CompEq: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; case BinaryType::CompGe: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; case BinaryType::CompGt: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; case BinaryType::CompLe: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; case BinaryType::CompLt: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; case BinaryType::CompNe: - return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + optimized = PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; } + + if (optimized) + return optimized; } - AstCloner::Visit(node); + auto binary = ShaderBuilder::Binary(node.op, std::move(lhs), std::move(rhs)); + binary->cachedExpressionType = node.cachedExpressionType; + + return binary; } - void AstOptimizer::Visit(BranchStatement& node) + StatementPtr AstOptimizer::Clone(BranchStatement& node) { std::vector statements; StatementPtr elseStatement; @@ -465,8 +482,7 @@ namespace Nz::ShaderAst if (statements.empty()) { // First condition is true, dismiss the branch - condStatement.statement->Visit(*this); - return; + return AstCloner::Clone(condStatement.statement); } else { @@ -487,54 +503,43 @@ namespace Nz::ShaderAst { // All conditions have been removed, replace by else statement or no-op if (node.elseStatement) - { - node.elseStatement->Visit(*this); - return; - } + return AstCloner::Clone(node.elseStatement); else - return PushStatement(ShaderBuilder::NoOp()); + return ShaderBuilder::NoOp(); } if (!elseStatement) elseStatement = CloneStatement(node.elseStatement); - PushStatement(ShaderBuilder::Branch(std::move(statements), std::move(elseStatement))); + return ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)); } - void AstOptimizer::Visit(ConditionalExpression& node) + ExpressionPtr AstOptimizer::Clone(ConditionalExpression& node) { - return AstCloner::Visit(node); + if (!m_enabledOptions) + return AstCloner::Clone(node); - /*if (!m_shaderAst) - return ShaderAstCloner::Visit(node); - - std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName); - assert(conditionIndex != InvalidCondition); - - if (TestBit(m_enabledConditions, conditionIndex)) - Visit(node.truePath); + if (TestBit(*m_enabledOptions, node.optionIndex)) + return AstCloner::Clone(node.truePath); else - Visit(node.falsePath);*/ + return AstCloner::Clone(node.falsePath); } - void AstOptimizer::Visit(ConditionalStatement& node) + StatementPtr AstOptimizer::Clone(ConditionalStatement& node) { - return AstCloner::Visit(node); + if (!m_enabledOptions) + return AstCloner::Clone(node); - /*if (!m_shaderAst) - return ShaderAstCloner::Visit(node); - - std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName); - assert(conditionIndex != InvalidCondition); - - if (TestBit(m_enabledConditions, conditionIndex)) - Visit(node.statement);*/ + if (TestBit(*m_enabledOptions, node.optionIndex)) + return AstCloner::Clone(node); + else + return ShaderBuilder::NoOp(); } template - void AstOptimizer::PropagateConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs) + ExpressionPtr AstOptimizer::PropagateConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs) { - ExpressionPtr optimized; + std::unique_ptr optimized; std::visit([&](auto&& arg1) { using T1 = std::decay_t; @@ -555,8 +560,8 @@ namespace Nz::ShaderAst }, lhs->value); if (optimized) - PushExpression(std::move(optimized)); - else - PushExpression(ShaderBuilder::Binary(Type, std::move(lhs), std::move(rhs))); + optimized->cachedExpressionType = optimized->GetExpressionType(); + + return optimized; } } diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index b05194a82..3698a0485 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -358,33 +358,7 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); - clone->cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType - { - using T = std::decay_t; - - if constexpr (std::is_same_v) - return PrimitiveType::Boolean; - else if constexpr (std::is_same_v) - return PrimitiveType::Float32; - else if constexpr (std::is_same_v) - return PrimitiveType::Int32; - else if constexpr (std::is_same_v) - return PrimitiveType::UInt32; - else if constexpr (std::is_same_v) - return VectorType{ 2, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 3, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 4, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 2, PrimitiveType::Int32 }; - else if constexpr (std::is_same_v) - return VectorType{ 3, PrimitiveType::Int32 }; - else if constexpr (std::is_same_v) - return VectorType{ 4, PrimitiveType::Int32 }; - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, clone->value); + clone->cachedExpressionType = clone->GetExpressionType(); return clone; }