Shader/DependencyCheckerVisitor: Handle aliases
This commit is contained in:
parent
c04b650e7c
commit
81b1b9b473
|
|
@ -45,26 +45,33 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
struct UsageSet
|
struct UsageSet
|
||||||
{
|
{
|
||||||
|
Bitset<> usedAliases;
|
||||||
Bitset<> usedFunctions;
|
Bitset<> usedFunctions;
|
||||||
Bitset<> usedStructs;
|
Bitset<> usedStructs;
|
||||||
Bitset<> usedVariables;
|
Bitset<> usedVariables;
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
UsageSet& GetContextUsageSet();
|
||||||
void Resolve(const UsageSet& usageSet);
|
void Resolve(const UsageSet& usageSet);
|
||||||
|
|
||||||
using AstRecursiveVisitor::Visit;
|
using AstRecursiveVisitor::Visit;
|
||||||
|
|
||||||
void Visit(CallFunctionExpression& node) override;
|
void Visit(AliasValueExpression& node) override;
|
||||||
|
void Visit(FunctionExpression& node) override;
|
||||||
|
void Visit(StructTypeExpression& node) override;
|
||||||
void Visit(VariableValueExpression& node) override;
|
void Visit(VariableValueExpression& node) override;
|
||||||
|
|
||||||
|
void Visit(DeclareAliasStatement& node) override;
|
||||||
void Visit(DeclareExternalStatement& node) override;
|
void Visit(DeclareExternalStatement& node) override;
|
||||||
void Visit(DeclareFunctionStatement& node) override;
|
void Visit(DeclareFunctionStatement& node) override;
|
||||||
void Visit(DeclareStructStatement& node) override;
|
void Visit(DeclareStructStatement& node) override;
|
||||||
void Visit(DeclareVariableStatement& node) override;
|
void Visit(DeclareVariableStatement& node) override;
|
||||||
|
|
||||||
|
std::optional<std::size_t> m_currentAliasDeclIndex;
|
||||||
std::optional<std::size_t> m_currentFunctionIndex;
|
std::optional<std::size_t> m_currentFunctionIndex;
|
||||||
std::optional<std::size_t> m_currentVariableDeclIndex;
|
std::optional<std::size_t> m_currentVariableDeclIndex;
|
||||||
|
std::unordered_map<std::size_t, UsageSet> m_aliasUsages;
|
||||||
std::unordered_map<std::size_t, UsageSet> m_functionUsages;
|
std::unordered_map<std::size_t, UsageSet> m_functionUsages;
|
||||||
std::unordered_map<std::size_t, UsageSet> m_structUsages;
|
std::unordered_map<std::size_t, UsageSet> m_structUsages;
|
||||||
std::unordered_map<std::size_t, UsageSet> m_variableUsages;
|
std::unordered_map<std::size_t, UsageSet> m_variableUsages;
|
||||||
|
|
|
||||||
|
|
@ -13,27 +13,48 @@ namespace Nz::ShaderAst
|
||||||
statement.Visit(*this);
|
statement.Visit(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DependencyCheckerVisitor::Visit(CallFunctionExpression& node)
|
auto DependencyCheckerVisitor::GetContextUsageSet() -> UsageSet&
|
||||||
{
|
{
|
||||||
const ExpressionType* targetFuncType = GetExpressionType(*node.targetFunction);
|
if (m_currentAliasDeclIndex)
|
||||||
assert(targetFuncType);
|
return Retrieve(m_aliasUsages, *m_currentAliasDeclIndex);
|
||||||
assert(std::holds_alternative<FunctionType>(*targetFuncType));
|
else if (m_currentVariableDeclIndex)
|
||||||
|
return Retrieve(m_variableUsages, *m_currentVariableDeclIndex);
|
||||||
const auto& funcType = std::get<FunctionType>(*targetFuncType);
|
|
||||||
|
|
||||||
assert(m_currentFunctionIndex);
|
|
||||||
if (m_currentVariableDeclIndex)
|
|
||||||
{
|
|
||||||
UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex);
|
|
||||||
usageSet.usedFunctions.UnboundedSet(funcType.funcIndex);
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex);
|
assert(m_currentFunctionIndex);
|
||||||
usageSet.usedFunctions.UnboundedSet(funcType.funcIndex);
|
return Retrieve(m_functionUsages, *m_currentFunctionIndex);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet)
|
||||||
|
{
|
||||||
|
m_resolvedUsage.usedFunctions |= usageSet.usedFunctions;
|
||||||
|
m_resolvedUsage.usedStructs |= usageSet.usedStructs;
|
||||||
|
m_resolvedUsage.usedVariables |= usageSet.usedVariables;
|
||||||
|
|
||||||
|
for (std::size_t aliasIndex = usageSet.usedAliases.FindFirst(); aliasIndex != usageSet.usedAliases.npos; aliasIndex = usageSet.usedAliases.FindNext(aliasIndex))
|
||||||
|
Resolve(Retrieve(m_aliasUsages, aliasIndex));
|
||||||
|
|
||||||
|
for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex))
|
||||||
|
Resolve(Retrieve(m_functionUsages, funcIndex));
|
||||||
|
|
||||||
|
for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex))
|
||||||
|
Resolve(Retrieve(m_structUsages, structIndex));
|
||||||
|
|
||||||
|
for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex))
|
||||||
|
Resolve(Retrieve(m_variableUsages, varIndex));
|
||||||
|
}
|
||||||
|
|
||||||
|
void DependencyCheckerVisitor::Visit(DeclareAliasStatement& node)
|
||||||
|
{
|
||||||
|
assert(node.aliasIndex);
|
||||||
|
assert(m_aliasUsages.find(*node.aliasIndex) == m_aliasUsages.end());
|
||||||
|
m_aliasUsages.emplace(*node.aliasIndex, UsageSet{});
|
||||||
|
|
||||||
|
assert(node.aliasIndex);
|
||||||
|
m_currentAliasDeclIndex = *node.aliasIndex;
|
||||||
AstRecursiveVisitor::Visit(node);
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
m_currentAliasDeclIndex = {};
|
||||||
}
|
}
|
||||||
|
|
||||||
void DependencyCheckerVisitor::Visit(DeclareExternalStatement& node)
|
void DependencyCheckerVisitor::Visit(DeclareExternalStatement& node)
|
||||||
|
|
@ -145,34 +166,27 @@ namespace Nz::ShaderAst
|
||||||
m_currentVariableDeclIndex = {};
|
m_currentVariableDeclIndex = {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DependencyCheckerVisitor::Visit(AliasValueExpression& node)
|
||||||
|
{
|
||||||
|
UsageSet& usageSet = GetContextUsageSet();
|
||||||
|
usageSet.usedAliases.UnboundedSet(node.aliasId);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DependencyCheckerVisitor::Visit(FunctionExpression& node)
|
||||||
|
{
|
||||||
|
UsageSet& usageSet = GetContextUsageSet();
|
||||||
|
usageSet.usedFunctions.UnboundedSet(node.funcId);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DependencyCheckerVisitor::Visit(StructTypeExpression& node)
|
||||||
|
{
|
||||||
|
UsageSet& usageSet = GetContextUsageSet();
|
||||||
|
usageSet.usedStructs.UnboundedSet(node.structTypeId);
|
||||||
|
}
|
||||||
|
|
||||||
void DependencyCheckerVisitor::Visit(VariableValueExpression& node)
|
void DependencyCheckerVisitor::Visit(VariableValueExpression& node)
|
||||||
{
|
{
|
||||||
assert(m_currentFunctionIndex);
|
UsageSet& usageSet = GetContextUsageSet();
|
||||||
if (m_currentVariableDeclIndex)
|
usageSet.usedVariables.UnboundedSet(node.variableId);
|
||||||
{
|
|
||||||
UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex);
|
|
||||||
usageSet.usedVariables.UnboundedSet(node.variableId);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex);
|
|
||||||
usageSet.usedVariables.UnboundedSet(node.variableId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet)
|
|
||||||
{
|
|
||||||
m_resolvedUsage.usedFunctions |= usageSet.usedFunctions;
|
|
||||||
m_resolvedUsage.usedStructs |= usageSet.usedStructs;
|
|
||||||
m_resolvedUsage.usedVariables |= usageSet.usedVariables;
|
|
||||||
|
|
||||||
for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex))
|
|
||||||
Resolve(Retrieve(m_functionUsages, funcIndex));
|
|
||||||
|
|
||||||
for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex))
|
|
||||||
Resolve(Retrieve(m_structUsages, structIndex));
|
|
||||||
|
|
||||||
for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex))
|
|
||||||
Resolve(Retrieve(m_variableUsages, varIndex));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue