From a155aa598e2bded63e30262bf4c46e8318cb209f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Sat, 26 Feb 2022 16:08:16 +0100 Subject: [PATCH] Shader: Move DependencyChecker to a public class --- .../Shader/Ast/DependencyCheckerVisitor.hpp | 66 ++++++ .../Shader/Ast/DependencyCheckerVisitor.inl | 21 ++ .../Shader/Ast/DependencyCheckerVisitor.cpp | 176 +++++++++++++++ .../Shader/Ast/EliminateUnusedPassVisitor.cpp | 203 +----------------- 4 files changed, 268 insertions(+), 198 deletions(-) create mode 100644 include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp create mode 100644 include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl create mode 100644 src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp diff --git a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp new file mode 100644 index 000000000..c76a536ed --- /dev/null +++ b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp @@ -0,0 +1,66 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADER_AST_DEPENDENCYCHECKERVISITOR_HPP +#define NAZARA_SHADER_AST_DEPENDENCYCHECKERVISITOR_HPP + +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_SHADER_API DependencyCheckerVisitor : public AstRecursiveVisitor + { + public: + struct UsageSet; + + DependencyCheckerVisitor() = default; + DependencyCheckerVisitor(const DependencyCheckerVisitor&) = delete; + DependencyCheckerVisitor(DependencyCheckerVisitor&&) = delete; + ~DependencyCheckerVisitor() = default; + + inline const UsageSet& GetUsage() const; + + inline void Resolve(); + + using AstRecursiveVisitor::Visit; + + void Visit(CallFunctionExpression& node) override; + void Visit(VariableExpression& node) override; + + void Visit(DeclareExternalStatement& node) override; + void Visit(DeclareFunctionStatement& node) override; + void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; + + DependencyCheckerVisitor& operator=(const DependencyCheckerVisitor&) = delete; + DependencyCheckerVisitor& operator=(DependencyCheckerVisitor&&) = delete; + + struct UsageSet + { + Bitset<> usedFunctions; + Bitset<> usedStructs; + Bitset<> usedVariables; + }; + + private: + void Resolve(const UsageSet& usageSet); + + std::optional m_currentFunctionIndex; + std::optional m_currentVariableDeclIndex; + std::unordered_map m_functionUsages; + std::unordered_map m_structUsages; + std::unordered_map m_variableUsages; + UsageSet m_globalUsage; + UsageSet m_resolvedUsage; + }; +} + +#include + +#endif // NAZARA_SHADER_AST_DEPENDENCYCHECKERVISITOR_HPP diff --git a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl new file mode 100644 index 000000000..30f4674e4 --- /dev/null +++ b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl @@ -0,0 +1,21 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + inline auto DependencyCheckerVisitor::GetUsage() const -> const UsageSet& + { + return m_resolvedUsage; + } + + void DependencyCheckerVisitor::Resolve() + { + Resolve(m_globalUsage); + } +} + +#include diff --git a/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp new file mode 100644 index 000000000..21e9f56aa --- /dev/null +++ b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp @@ -0,0 +1,176 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + namespace + { + template T& Retrieve(std::unordered_map& map, std::size_t id) + { + auto it = map.find(id); + assert(it != map.end()); + return it->second; + } + + template const T& Retrieve(const std::unordered_map& map, std::size_t id) + { + auto it = map.find(id); + assert(it != map.end()); + return it->second; + } + } + + void DependencyCheckerVisitor::Visit(CallFunctionExpression& node) + { + const auto& targetFuncType = GetExpressionType(node); + assert(std::holds_alternative(targetFuncType)); + + const auto& funcType = std::get(targetFuncType); + + assert(m_currentFunctionIndex); + UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex); + usageSet.usedFunctions.UnboundedSet(funcType.funcIndex); + } + + void DependencyCheckerVisitor::Visit(DeclareExternalStatement& node) + { + for (const auto& externalVar : node.externalVars) + { + assert(externalVar.varIndex); + std::size_t varIndex = *externalVar.varIndex; + + assert(m_variableUsages.find(varIndex) == m_variableUsages.end()); + UsageSet& usageSet = m_variableUsages[varIndex]; + + const auto& exprType = externalVar.type.GetResultingValue(); + + if (IsUniformType(exprType)) + { + const UniformType& uniformType = std::get(exprType); + usageSet.usedStructs.UnboundedSet(uniformType.containedType.structIndex); + } + + ++varIndex; + } + + AstRecursiveVisitor::Visit(node); + } + + void DependencyCheckerVisitor::Visit(DeclareFunctionStatement& node) + { + assert(node.funcIndex); + assert(m_functionUsages.find(*node.funcIndex) == m_functionUsages.end()); + UsageSet& usageSet = m_functionUsages[*node.funcIndex]; + + // Register struct used in parameters or return type + if (!node.parameters.empty()) + { + assert(node.varIndex); + std::size_t parameterVarIndex = *node.varIndex; + for (auto& parameter : node.parameters) + { + // Since parameters must always be defined, their type isn't a dependency of parameter variables + assert(m_variableUsages.find(parameterVarIndex) == m_variableUsages.end()); + m_variableUsages.emplace(parameterVarIndex, UsageSet{}); + + const auto& exprType = parameter.type.GetResultingValue(); + if (IsStructType(exprType)) + { + std::size_t structIndex = std::get(exprType).structIndex; + usageSet.usedStructs.UnboundedSet(structIndex); + } + + ++parameterVarIndex; + } + } + + if (node.returnType.HasValue()) + { + const auto& returnExprType = node.returnType.GetResultingValue(); + if (IsStructType(returnExprType)) + { + std::size_t structIndex = std::get(returnExprType).structIndex; + usageSet.usedStructs.UnboundedSet(structIndex); + } + } + + if (node.entryStage.HasValue()) + m_globalUsage.usedFunctions.UnboundedSet(*node.funcIndex); + + m_currentFunctionIndex = node.funcIndex; + AstRecursiveVisitor::Visit(node); + m_currentFunctionIndex = {}; + } + + void DependencyCheckerVisitor::Visit(DeclareStructStatement& node) + { + assert(node.structIndex); + assert(m_structUsages.find(*node.structIndex) == m_structUsages.end()); + UsageSet& usageSet = m_structUsages[*node.structIndex]; + + for (const auto& structMember : node.description.members) + { + const auto& memberExprType = structMember.type.GetResultingValue(); + if (IsStructType(memberExprType)) + { + std::size_t structIndex = std::get(memberExprType).structIndex; + usageSet.usedStructs.UnboundedSet(structIndex); + } + } + + AstRecursiveVisitor::Visit(node); + } + + void DependencyCheckerVisitor::Visit(DeclareVariableStatement& node) + { + assert(node.varIndex); + assert(m_variableUsages.find(*node.varIndex) == m_variableUsages.end()); + UsageSet& usageSet = m_variableUsages[*node.varIndex]; + + const auto& varType = node.varType.GetResultingValue(); + if (IsStructType(varType)) + { + const auto& structType = std::get(varType); + usageSet.usedStructs.UnboundedSet(structType.structIndex); + } + + m_currentVariableDeclIndex = node.varIndex; + AstRecursiveVisitor::Visit(node); + m_currentVariableDeclIndex = {}; + } + + void DependencyCheckerVisitor::Visit(VariableExpression& node) + { + assert(m_currentFunctionIndex); + if (m_currentVariableDeclIndex) + { + UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex); + usageSet.usedVariables.UnboundedSet(node.variableId); + } + else + { + UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex); + usageSet.usedVariables.UnboundedSet(node.variableId); + } + } + + void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet) + { + m_resolvedUsage.usedFunctions |= usageSet.usedFunctions; + m_resolvedUsage.usedStructs |= usageSet.usedStructs; + m_resolvedUsage.usedVariables |= usageSet.usedVariables; + + for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex)) + Resolve(Retrieve(m_functionUsages, funcIndex)); + + for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex)) + Resolve(Retrieve(m_structUsages, structIndex)); + + for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex)) + Resolve(Retrieve(m_variableUsages, varIndex)); + } +} diff --git a/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp b/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp index 9fccbceac..3ddfbd8cb 100644 --- a/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp +++ b/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp @@ -6,208 +6,15 @@ #include #include #include +#include #include #include namespace Nz::ShaderAst { - namespace - { - template T& Retrieve(std::unordered_map& map, std::size_t id) - { - auto it = map.find(id); - assert(it != map.end()); - return it->second; - } - - template const T& Retrieve(const std::unordered_map& map, std::size_t id) - { - auto it = map.find(id); - assert(it != map.end()); - return it->second; - } - - struct UsageChecker : AstRecursiveVisitor - { - struct UsageSet; - - void Resolve() - { - Resolve(globalUsage); - } - - void Resolve(const UsageSet& usageSet) - { - resolvedUsage.usedFunctions |= usageSet.usedFunctions; - resolvedUsage.usedStructs |= usageSet.usedStructs; - resolvedUsage.usedVariables |= usageSet.usedVariables; - - for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex)) - Resolve(Retrieve(functionUsages, funcIndex)); - - for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex)) - Resolve(Retrieve(structUsages, structIndex)); - - for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex)) - Resolve(Retrieve(variableUsages, varIndex)); - } - - using AstRecursiveVisitor::Visit; - - void Visit(CallFunctionExpression& node) override - { - const auto& targetFuncType = GetExpressionType(node); - assert(std::holds_alternative(targetFuncType)); - - const auto& funcType = std::get(targetFuncType); - - assert(currentFunctionIndex); - UsageSet& usageSet = Retrieve(functionUsages, *currentFunctionIndex); - usageSet.usedFunctions.UnboundedSet(funcType.funcIndex); - } - - void Visit(DeclareExternalStatement& node) override - { - for (const auto& externalVar : node.externalVars) - { - assert(externalVar.varIndex); - std::size_t varIndex = *externalVar.varIndex; - - assert(variableUsages.find(varIndex) == variableUsages.end()); - UsageSet& usageSet = variableUsages[varIndex]; - - const auto& exprType = externalVar.type.GetResultingValue(); - - if (IsUniformType(exprType)) - { - const UniformType& uniformType = std::get(exprType); - usageSet.usedStructs.UnboundedSet(uniformType.containedType.structIndex); - } - - ++varIndex; - } - - AstRecursiveVisitor::Visit(node); - } - - void Visit(DeclareFunctionStatement& node) override - { - assert(node.funcIndex); - assert(functionUsages.find(*node.funcIndex) == functionUsages.end()); - UsageSet& usageSet = functionUsages[*node.funcIndex]; - - // Register struct used in parameters or return type - if (!node.parameters.empty()) - { - assert(node.varIndex); - std::size_t parameterVarIndex = *node.varIndex; - for (auto& parameter : node.parameters) - { - // Since parameters must always be defined, their type isn't a dependency of parameter variables - assert(variableUsages.find(parameterVarIndex) == variableUsages.end()); - variableUsages.emplace(parameterVarIndex, UsageSet{}); - - const auto& exprType = parameter.type.GetResultingValue(); - if (IsStructType(exprType)) - { - std::size_t structIndex = std::get(exprType).structIndex; - usageSet.usedStructs.UnboundedSet(structIndex); - } - - ++parameterVarIndex; - } - } - - if (node.returnType.HasValue()) - { - const auto& returnExprType = node.returnType.GetResultingValue(); - if (IsStructType(returnExprType)) - { - std::size_t structIndex = std::get(returnExprType).structIndex; - usageSet.usedStructs.UnboundedSet(structIndex); - } - } - - if (node.entryStage.HasValue()) - globalUsage.usedFunctions.UnboundedSet(*node.funcIndex); - - currentFunctionIndex = node.funcIndex; - AstRecursiveVisitor::Visit(node); - currentFunctionIndex = {}; - } - - void Visit(DeclareStructStatement& node) override - { - assert(node.structIndex); - assert(structUsages.find(*node.structIndex) == structUsages.end()); - UsageSet& usageSet = structUsages[*node.structIndex]; - - for (const auto& structMember : node.description.members) - { - const auto& memberExprType = structMember.type.GetResultingValue(); - if (IsStructType(memberExprType)) - { - std::size_t structIndex = std::get(memberExprType).structIndex; - usageSet.usedStructs.UnboundedSet(structIndex); - } - } - - AstRecursiveVisitor::Visit(node); - } - - void Visit(DeclareVariableStatement& node) override - { - assert(node.varIndex); - assert(variableUsages.find(*node.varIndex) == variableUsages.end()); - UsageSet& usageSet = variableUsages[*node.varIndex]; - - const auto& varType = node.varType.GetResultingValue(); - if (IsStructType(varType)) - { - const auto& structType = std::get(varType); - usageSet.usedStructs.UnboundedSet(structType.structIndex); - } - - currentVariableDeclIndex = node.varIndex; - AstRecursiveVisitor::Visit(node); - currentVariableDeclIndex = {}; - } - - void Visit(VariableExpression& node) override - { - assert(currentFunctionIndex); - if (currentVariableDeclIndex) - { - UsageSet& usageSet = Retrieve(variableUsages, *currentVariableDeclIndex); - usageSet.usedVariables.UnboundedSet(node.variableId); - } - else - { - UsageSet& usageSet = Retrieve(functionUsages, *currentFunctionIndex); - usageSet.usedVariables.UnboundedSet(node.variableId); - } - } - - struct UsageSet - { - Bitset<> usedFunctions; - Bitset<> usedStructs; - Bitset<> usedVariables; - }; - - std::optional currentFunctionIndex; - std::optional currentVariableDeclIndex; - std::unordered_map functionUsages; - std::unordered_map structUsages; - std::unordered_map variableUsages; - UsageSet globalUsage; - UsageSet resolvedUsage; - }; - } - struct EliminateUnusedPassVisitor::Context { - UsageChecker usageChecker; + DependencyCheckerVisitor usageChecker; }; StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement) @@ -291,18 +98,18 @@ namespace Nz::ShaderAst bool EliminateUnusedPassVisitor::IsFunctionUsed(std::size_t varIndex) const { assert(m_context); - return m_context->usageChecker.resolvedUsage.usedFunctions.UnboundedTest(varIndex); + return m_context->usageChecker.GetUsage().usedFunctions.UnboundedTest(varIndex); } bool EliminateUnusedPassVisitor::IsStructUsed(std::size_t varIndex) const { assert(m_context); - return m_context->usageChecker.resolvedUsage.usedStructs.UnboundedTest(varIndex); + return m_context->usageChecker.GetUsage().usedStructs.UnboundedTest(varIndex); } bool EliminateUnusedPassVisitor::IsVariableUsed(std::size_t varIndex) const { assert(m_context); - return m_context->usageChecker.resolvedUsage.usedVariables.UnboundedTest(varIndex); + return m_context->usageChecker.GetUsage().usedVariables.UnboundedTest(varIndex); } }