diff --git a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp index c76a536ed..614f4ff2f 100644 --- a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp +++ b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp @@ -17,6 +17,7 @@ namespace Nz::ShaderAst class NAZARA_SHADER_API DependencyCheckerVisitor : public AstRecursiveVisitor { public: + struct Config; struct UsageSet; DependencyCheckerVisitor() = default; @@ -26,21 +27,18 @@ namespace Nz::ShaderAst inline const UsageSet& GetUsage() const; + void Process(Statement& statement, const Config& config = {}); + inline void Resolve(); - using AstRecursiveVisitor::Visit; - - void Visit(CallFunctionExpression& node) override; - void Visit(VariableExpression& node) override; - - void Visit(DeclareExternalStatement& node) override; - void Visit(DeclareFunctionStatement& node) override; - void Visit(DeclareStructStatement& node) override; - void Visit(DeclareVariableStatement& node) override; - DependencyCheckerVisitor& operator=(const DependencyCheckerVisitor&) = delete; DependencyCheckerVisitor& operator=(DependencyCheckerVisitor&&) = delete; + struct Config + { + ShaderStageTypeFlags usedShaderStages = ShaderStageType_All; + }; + struct UsageSet { Bitset<> usedFunctions; @@ -51,11 +49,22 @@ namespace Nz::ShaderAst private: void Resolve(const UsageSet& usageSet); + using AstRecursiveVisitor::Visit; + + void Visit(CallFunctionExpression& node) override; + void Visit(VariableExpression& node) override; + + void Visit(DeclareExternalStatement& node) override; + void Visit(DeclareFunctionStatement& node) override; + void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; + std::optional m_currentFunctionIndex; std::optional m_currentVariableDeclIndex; std::unordered_map m_functionUsages; std::unordered_map m_structUsages; std::unordered_map m_variableUsages; + Config m_config; UsageSet m_globalUsage; UsageSet m_resolvedUsage; }; diff --git a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp index 854ee69df..0bcf2d721 100644 --- a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp +++ b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp @@ -11,29 +11,23 @@ #include #include #include +#include namespace Nz::ShaderAst { class NAZARA_SHADER_API EliminateUnusedPassVisitor : AstCloner { public: - struct Config; - EliminateUnusedPassVisitor() = default; EliminateUnusedPassVisitor(const EliminateUnusedPassVisitor&) = delete; EliminateUnusedPassVisitor(EliminateUnusedPassVisitor&&) = delete; ~EliminateUnusedPassVisitor() = default; - StatementPtr Process(Statement& statement, const Config& config = {}); + StatementPtr Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet); EliminateUnusedPassVisitor& operator=(const EliminateUnusedPassVisitor&) = delete; EliminateUnusedPassVisitor& operator=(EliminateUnusedPassVisitor&&) = delete; - struct Config - { - ShaderStageTypeFlags usedShaderStages = ShaderStageType_All; - }; - private: using AstCloner::Clone; StatementPtr Clone(DeclareExternalStatement& node) override; @@ -49,7 +43,8 @@ namespace Nz::ShaderAst Context* m_context; }; - inline StatementPtr EliminateUnusedPass(Statement& ast, const EliminateUnusedPassVisitor::Config& config = {}); + inline StatementPtr EliminateUnusedPass(Statement& ast, const DependencyCheckerVisitor::Config& config = {}); + inline StatementPtr EliminateUnusedPass(Statement& ast, const DependencyCheckerVisitor::UsageSet& usageSet); } #include diff --git a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl index db8124116..29792e99c 100644 --- a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl +++ b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl @@ -7,10 +7,19 @@ namespace Nz::ShaderAst { - inline StatementPtr EliminateUnusedPass(Statement& ast, const EliminateUnusedPassVisitor::Config& config) + inline StatementPtr EliminateUnusedPass(Statement& ast, const DependencyCheckerVisitor::Config& config) + { + DependencyCheckerVisitor dependencyVisitor; + dependencyVisitor.Process(ast, config); + dependencyVisitor.Resolve(); + + return EliminateUnusedPass(ast, dependencyVisitor.GetUsage()); + } + + StatementPtr EliminateUnusedPass(Statement& ast, const DependencyCheckerVisitor::UsageSet& usageSet) { EliminateUnusedPassVisitor visitor; - return visitor.Process(ast, config); + return visitor.Process(ast, usageSet); } } diff --git a/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp index 21e9f56aa..0d2175d8d 100644 --- a/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp +++ b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp @@ -24,6 +24,12 @@ namespace Nz::ShaderAst } } + void DependencyCheckerVisitor::Process(Statement& statement, const Config& config) + { + m_config = config; + statement.Visit(*this); + } + void DependencyCheckerVisitor::Visit(CallFunctionExpression& node) { const auto& targetFuncType = GetExpressionType(node); @@ -32,8 +38,18 @@ namespace Nz::ShaderAst const auto& funcType = std::get(targetFuncType); assert(m_currentFunctionIndex); - UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex); - usageSet.usedFunctions.UnboundedSet(funcType.funcIndex); + if (m_currentVariableDeclIndex) + { + UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex); + usageSet.usedFunctions.UnboundedSet(funcType.funcIndex); + } + else + { + UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex); + usageSet.usedFunctions.UnboundedSet(funcType.funcIndex); + } + + AstRecursiveVisitor::Visit(node); } void DependencyCheckerVisitor::Visit(DeclareExternalStatement& node) @@ -99,7 +115,11 @@ namespace Nz::ShaderAst } if (node.entryStage.HasValue()) - m_globalUsage.usedFunctions.UnboundedSet(*node.funcIndex); + { + ShaderStageType shaderStage = node.entryStage.GetResultingValue(); + if (m_config.usedShaderStages & shaderStage) + m_globalUsage.usedFunctions.UnboundedSet(*node.funcIndex); + } m_currentFunctionIndex = node.funcIndex; AstRecursiveVisitor::Visit(node); diff --git a/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp b/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp index 961176082..7704b9938 100644 --- a/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp +++ b/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include @@ -14,14 +13,14 @@ namespace Nz::ShaderAst { struct EliminateUnusedPassVisitor::Context { - DependencyCheckerVisitor usageChecker; + const DependencyCheckerVisitor::UsageSet& usageSet; }; - StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const Config& config) + StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet) { - Context context(config); - statement.Visit(context.usageChecker); - context.usageChecker.Resolve(); + Context context{ + usageSet + }; m_context = &context; CallOnExit onExit([this]() @@ -98,18 +97,18 @@ namespace Nz::ShaderAst bool EliminateUnusedPassVisitor::IsFunctionUsed(std::size_t varIndex) const { assert(m_context); - return m_context->usageChecker.GetUsage().usedFunctions.UnboundedTest(varIndex); + return m_context->usageSet.usedFunctions.UnboundedTest(varIndex); } bool EliminateUnusedPassVisitor::IsStructUsed(std::size_t varIndex) const { assert(m_context); - return m_context->usageChecker.GetUsage().usedStructs.UnboundedTest(varIndex); + return m_context->usageSet.usedStructs.UnboundedTest(varIndex); } bool EliminateUnusedPassVisitor::IsVariableUsed(std::size_t varIndex) const { assert(m_context); - return m_context->usageChecker.GetUsage().usedVariables.UnboundedTest(varIndex); + return m_context->usageSet.usedVariables.UnboundedTest(varIndex); } } diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 6a60c6991..a00c2056c 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -180,12 +180,12 @@ namespace Nz { ShaderAst::StatementPtr tempAst; - ShaderAst::EliminateUnusedPassVisitor::Config eliminateUnunsedConfig; + ShaderAst::DependencyCheckerVisitor::Config dependencyConfig; if (shaderStage) - eliminateUnunsedConfig.usedShaderStages = *shaderStage; + dependencyConfig.usedShaderStages = *shaderStage; tempAst = ShaderAst::PropagateConstants(*targetAst); - optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, eliminateUnunsedConfig); + optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig); targetAst = optimizedAst.get(); }