diff --git a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp index 34ef0139f..89a951abc 100644 --- a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp +++ b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp @@ -45,26 +45,33 @@ namespace Nz::ShaderAst struct UsageSet { + Bitset<> usedAliases; Bitset<> usedFunctions; Bitset<> usedStructs; Bitset<> usedVariables; }; private: + UsageSet& GetContextUsageSet(); void Resolve(const UsageSet& usageSet); using AstRecursiveVisitor::Visit; - void Visit(CallFunctionExpression& node) override; + void Visit(AliasValueExpression& node) override; + void Visit(FunctionExpression& node) override; + void Visit(StructTypeExpression& node) override; void Visit(VariableValueExpression& node) override; + void Visit(DeclareAliasStatement& node) override; void Visit(DeclareExternalStatement& node) override; void Visit(DeclareFunctionStatement& node) override; void Visit(DeclareStructStatement& node) override; void Visit(DeclareVariableStatement& node) override; + std::optional m_currentAliasDeclIndex; std::optional m_currentFunctionIndex; std::optional m_currentVariableDeclIndex; + std::unordered_map m_aliasUsages; std::unordered_map m_functionUsages; std::unordered_map m_structUsages; std::unordered_map m_variableUsages; diff --git a/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp index 3000b01fc..1cf31712e 100644 --- a/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp +++ b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp @@ -13,27 +13,48 @@ namespace Nz::ShaderAst statement.Visit(*this); } - void DependencyCheckerVisitor::Visit(CallFunctionExpression& node) + auto DependencyCheckerVisitor::GetContextUsageSet() -> UsageSet& { - const ExpressionType* targetFuncType = GetExpressionType(*node.targetFunction); - assert(targetFuncType); - assert(std::holds_alternative(*targetFuncType)); - - const auto& funcType = std::get(*targetFuncType); - - assert(m_currentFunctionIndex); - if (m_currentVariableDeclIndex) - { - UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex); - usageSet.usedFunctions.UnboundedSet(funcType.funcIndex); - } + if (m_currentAliasDeclIndex) + return Retrieve(m_aliasUsages, *m_currentAliasDeclIndex); + else if (m_currentVariableDeclIndex) + return Retrieve(m_variableUsages, *m_currentVariableDeclIndex); else { - UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex); - usageSet.usedFunctions.UnboundedSet(funcType.funcIndex); + assert(m_currentFunctionIndex); + return Retrieve(m_functionUsages, *m_currentFunctionIndex); } + } + void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet) + { + m_resolvedUsage.usedFunctions |= usageSet.usedFunctions; + m_resolvedUsage.usedStructs |= usageSet.usedStructs; + m_resolvedUsage.usedVariables |= usageSet.usedVariables; + + for (std::size_t aliasIndex = usageSet.usedAliases.FindFirst(); aliasIndex != usageSet.usedAliases.npos; aliasIndex = usageSet.usedAliases.FindNext(aliasIndex)) + Resolve(Retrieve(m_aliasUsages, aliasIndex)); + + for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex)) + Resolve(Retrieve(m_functionUsages, funcIndex)); + + for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex)) + Resolve(Retrieve(m_structUsages, structIndex)); + + for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex)) + Resolve(Retrieve(m_variableUsages, varIndex)); + } + + void DependencyCheckerVisitor::Visit(DeclareAliasStatement& node) + { + assert(node.aliasIndex); + assert(m_aliasUsages.find(*node.aliasIndex) == m_aliasUsages.end()); + m_aliasUsages.emplace(*node.aliasIndex, UsageSet{}); + + assert(node.aliasIndex); + m_currentAliasDeclIndex = *node.aliasIndex; AstRecursiveVisitor::Visit(node); + m_currentAliasDeclIndex = {}; } void DependencyCheckerVisitor::Visit(DeclareExternalStatement& node) @@ -145,34 +166,27 @@ namespace Nz::ShaderAst m_currentVariableDeclIndex = {}; } + void DependencyCheckerVisitor::Visit(AliasValueExpression& node) + { + UsageSet& usageSet = GetContextUsageSet(); + usageSet.usedAliases.UnboundedSet(node.aliasId); + } + + void DependencyCheckerVisitor::Visit(FunctionExpression& node) + { + UsageSet& usageSet = GetContextUsageSet(); + usageSet.usedFunctions.UnboundedSet(node.funcId); + } + + void DependencyCheckerVisitor::Visit(StructTypeExpression& node) + { + UsageSet& usageSet = GetContextUsageSet(); + usageSet.usedStructs.UnboundedSet(node.structTypeId); + } + void DependencyCheckerVisitor::Visit(VariableValueExpression& node) { - assert(m_currentFunctionIndex); - if (m_currentVariableDeclIndex) - { - UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex); - usageSet.usedVariables.UnboundedSet(node.variableId); - } - else - { - UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex); - usageSet.usedVariables.UnboundedSet(node.variableId); - } - } - - void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet) - { - m_resolvedUsage.usedFunctions |= usageSet.usedFunctions; - m_resolvedUsage.usedStructs |= usageSet.usedStructs; - m_resolvedUsage.usedVariables |= usageSet.usedVariables; - - for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex)) - Resolve(Retrieve(m_functionUsages, funcIndex)); - - for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex)) - Resolve(Retrieve(m_structUsages, structIndex)); - - for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex)) - Resolve(Retrieve(m_variableUsages, varIndex)); + UsageSet& usageSet = GetContextUsageSet(); + usageSet.usedVariables.UnboundedSet(node.variableId); } }