From 18edd410486b56b7813f708c791020bc302db5fb Mon Sep 17 00:00:00 2001 From: SirLynix Date: Tue, 5 Apr 2022 08:35:12 +0200 Subject: [PATCH] Shader: Fix dependency check of modules --- .../Shader/Ast/DependencyCheckerVisitor.hpp | 9 +- .../Shader/Ast/DependencyCheckerVisitor.inl | 8 +- .../Shader/Ast/EliminateUnusedPassVisitor.inl | 6 +- .../Shader/Ast/DependencyCheckerVisitor.cpp | 83 ++++++++++++------- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 5 +- 5 files changed, 66 insertions(+), 45 deletions(-) diff --git a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp index 89a951abc..72a0f3f81 100644 --- a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp +++ b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp @@ -30,10 +30,10 @@ namespace Nz::ShaderAst inline void MarkFunctionAsUsed(std::size_t funcIndex); inline void MarkStructAsUsed(std::size_t structIndex); - inline void Process(Statement& statement); - void Process(Statement& statement, const Config& config); + inline void Register(Statement& statement); + void Register(Statement& statement, const Config& config); - inline void Resolve(); + inline void Resolve(bool allowUnknownId = false); DependencyCheckerVisitor& operator=(const DependencyCheckerVisitor&) = delete; DependencyCheckerVisitor& operator=(DependencyCheckerVisitor&&) = delete; @@ -53,7 +53,8 @@ namespace Nz::ShaderAst private: UsageSet& GetContextUsageSet(); - void Resolve(const UsageSet& usageSet); + void RegisterType(UsageSet& usageSet, const ExpressionType& exprType); + void Resolve(const UsageSet& usageSet, bool allowUnknownId); using AstRecursiveVisitor::Visit; diff --git a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl index 3e2a7a8c0..b420d4ea0 100644 --- a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl +++ b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl @@ -22,15 +22,15 @@ namespace Nz::ShaderAst m_globalUsage.usedStructs.UnboundedSet(structIndex); } - inline void DependencyCheckerVisitor::Process(Statement& statement) + inline void DependencyCheckerVisitor::Register(Statement& statement) { Config defaultConfig; - return Process(statement, defaultConfig); + return Register(statement, defaultConfig); } - void DependencyCheckerVisitor::Resolve() + inline void DependencyCheckerVisitor::Resolve(bool allowUnknownId) { - Resolve(m_globalUsage); + Resolve(m_globalUsage, allowUnknownId); } } diff --git a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl index d7642b8d7..775990e67 100644 --- a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl +++ b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl @@ -17,9 +17,9 @@ namespace Nz::ShaderAst { DependencyCheckerVisitor dependencyVisitor; for (const auto& importedModule : shaderModule.importedModules) - dependencyVisitor.Process(*importedModule.module->rootNode, config); + dependencyVisitor.Register(*importedModule.module->rootNode, config); - dependencyVisitor.Process(*shaderModule.rootNode, config); + dependencyVisitor.Register(*shaderModule.rootNode, config); dependencyVisitor.Resolve(); return EliminateUnusedPass(shaderModule, dependencyVisitor.GetUsage()); @@ -40,7 +40,7 @@ namespace Nz::ShaderAst inline StatementPtr EliminateUnusedPass(Statement& ast, const DependencyCheckerVisitor::Config& config) { DependencyCheckerVisitor dependencyVisitor; - dependencyVisitor.Process(ast, config); + dependencyVisitor.Register(ast, config); dependencyVisitor.Resolve(); return EliminateUnusedPass(ast, dependencyVisitor.GetUsage()); diff --git a/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp index 1385358ad..c94d97c2d 100644 --- a/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp +++ b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp @@ -7,7 +7,7 @@ namespace Nz::ShaderAst { - void DependencyCheckerVisitor::Process(Statement& statement, const Config& config) + void DependencyCheckerVisitor::Register(Statement& statement, const Config& config) { m_config = config; statement.Visit(*this); @@ -26,7 +26,23 @@ namespace Nz::ShaderAst } } - void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet) + void DependencyCheckerVisitor::RegisterType(UsageSet& usageSet, const ExpressionType& exprType) + { + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + usageSet.usedAliases.UnboundedSet(arg.aliasIndex); + else if constexpr (std::is_same_v) + usageSet.usedStructs.UnboundedSet(arg.structIndex); + else if constexpr (std::is_same_v) + usageSet.usedStructs.UnboundedSet(arg.containedType.structIndex); + + }, exprType); + } + + void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet, bool allowUnknownId) { m_resolvedUsage.usedAliases |= usageSet.usedAliases; m_resolvedUsage.usedFunctions |= usageSet.usedFunctions; @@ -34,16 +50,40 @@ namespace Nz::ShaderAst m_resolvedUsage.usedVariables |= usageSet.usedVariables; for (std::size_t aliasIndex = usageSet.usedAliases.FindFirst(); aliasIndex != usageSet.usedAliases.npos; aliasIndex = usageSet.usedAliases.FindNext(aliasIndex)) - Resolve(Retrieve(m_aliasUsages, aliasIndex)); + { + auto it = m_aliasUsages.find(aliasIndex); + if (it != m_aliasUsages.end()) + Resolve(it->second, allowUnknownId); + else if (!allowUnknownId) + throw std::runtime_error("unknown alias #" + std::to_string(aliasIndex)); + } for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex)) - Resolve(Retrieve(m_functionUsages, funcIndex)); + { + auto it = m_functionUsages.find(funcIndex); + if (it != m_functionUsages.end()) + Resolve(it->second, allowUnknownId); + else if (!allowUnknownId) + throw std::runtime_error("unknown func #" + std::to_string(funcIndex)); + } for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex)) - Resolve(Retrieve(m_structUsages, structIndex)); + { + auto it = m_structUsages.find(structIndex); + if (it != m_structUsages.end()) + Resolve(it->second, allowUnknownId); + else if (!allowUnknownId) + throw std::runtime_error("unknown struct #" + std::to_string(structIndex)); + } for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex)) - Resolve(Retrieve(m_variableUsages, varIndex)); + { + auto it = m_variableUsages.find(varIndex); + if (it != m_variableUsages.end()) + Resolve(it->second, allowUnknownId); + else if (!allowUnknownId) + throw std::runtime_error("unknown var #" + std::to_string(varIndex)); + } } void DependencyCheckerVisitor::Visit(DeclareAliasStatement& node) @@ -69,12 +109,7 @@ namespace Nz::ShaderAst UsageSet& usageSet = m_variableUsages[varIndex]; const auto& exprType = externalVar.type.GetResultingValue(); - - if (IsUniformType(exprType)) - { - const UniformType& uniformType = std::get(exprType); - usageSet.usedStructs.UnboundedSet(uniformType.containedType.structIndex); - } + RegisterType(usageSet, exprType); ++varIndex; } @@ -100,22 +135,14 @@ namespace Nz::ShaderAst m_variableUsages.emplace(*parameter.varIndex, UsageSet{}); const auto& exprType = parameter.type.GetResultingValue(); - if (IsStructType(exprType)) - { - std::size_t structIndex = std::get(exprType).structIndex; - usageSet.usedStructs.UnboundedSet(structIndex); - } + RegisterType(usageSet, exprType); } } if (node.returnType.HasValue()) { const auto& returnExprType = node.returnType.GetResultingValue(); - if (IsStructType(returnExprType)) - { - std::size_t structIndex = std::get(returnExprType).structIndex; - usageSet.usedStructs.UnboundedSet(structIndex); - } + RegisterType(usageSet, returnExprType); } if (node.entryStage.HasValue()) @@ -139,11 +166,7 @@ namespace Nz::ShaderAst for (const auto& structMember : node.description.members) { const auto& memberExprType = structMember.type.GetResultingValue(); - if (IsStructType(memberExprType)) - { - std::size_t structIndex = std::get(memberExprType).structIndex; - usageSet.usedStructs.UnboundedSet(structIndex); - } + RegisterType(usageSet, memberExprType); } AstRecursiveVisitor::Visit(node); @@ -156,11 +179,7 @@ namespace Nz::ShaderAst UsageSet& usageSet = m_variableUsages[*node.varIndex]; const auto& varType = node.varType.GetResultingValue(); - if (IsStructType(varType)) - { - const auto& structType = std::get(varType); - usageSet.usedStructs.UnboundedSet(structType.structIndex); - } + RegisterType(usageSet, varType); m_currentVariableDeclIndex = node.varIndex; AstRecursiveVisitor::Visit(node); diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 79821e575..18b70eb6e 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -241,7 +241,8 @@ namespace Nz::ShaderAst if (moduleData.dependenciesVisitor) { - moduleData.dependenciesVisitor->Resolve(); + moduleData.dependenciesVisitor->Resolve(true); //< allow unknown identifiers since we may be referencing other modules + importedModule.module = EliminateUnusedPass(*importedModule.module, moduleData.dependenciesVisitor->GetUsage()); } } @@ -1750,7 +1751,7 @@ namespace Nz::ShaderAst if (!m_context->options.allowPartialSanitization) { moduleData.dependenciesVisitor = std::make_unique(); - moduleData.dependenciesVisitor->Process(*sanitizedModule->rootNode); + moduleData.dependenciesVisitor->Register(*sanitizedModule->rootNode); } moduleData.environment = std::move(moduleEnvironment);