Shader/DependencyCheckerVisitor: Handle aliases
This commit is contained in:
parent
c04b650e7c
commit
81b1b9b473
|
|
@ -45,26 +45,33 @@ namespace Nz::ShaderAst
|
|||
|
||||
struct UsageSet
|
||||
{
|
||||
Bitset<> usedAliases;
|
||||
Bitset<> usedFunctions;
|
||||
Bitset<> usedStructs;
|
||||
Bitset<> usedVariables;
|
||||
};
|
||||
|
||||
private:
|
||||
UsageSet& GetContextUsageSet();
|
||||
void Resolve(const UsageSet& usageSet);
|
||||
|
||||
using AstRecursiveVisitor::Visit;
|
||||
|
||||
void Visit(CallFunctionExpression& node) override;
|
||||
void Visit(AliasValueExpression& node) override;
|
||||
void Visit(FunctionExpression& node) override;
|
||||
void Visit(StructTypeExpression& node) override;
|
||||
void Visit(VariableValueExpression& node) override;
|
||||
|
||||
void Visit(DeclareAliasStatement& node) override;
|
||||
void Visit(DeclareExternalStatement& node) override;
|
||||
void Visit(DeclareFunctionStatement& node) override;
|
||||
void Visit(DeclareStructStatement& node) override;
|
||||
void Visit(DeclareVariableStatement& node) override;
|
||||
|
||||
std::optional<std::size_t> m_currentAliasDeclIndex;
|
||||
std::optional<std::size_t> m_currentFunctionIndex;
|
||||
std::optional<std::size_t> m_currentVariableDeclIndex;
|
||||
std::unordered_map<std::size_t, UsageSet> m_aliasUsages;
|
||||
std::unordered_map<std::size_t, UsageSet> m_functionUsages;
|
||||
std::unordered_map<std::size_t, UsageSet> m_structUsages;
|
||||
std::unordered_map<std::size_t, UsageSet> m_variableUsages;
|
||||
|
|
|
|||
|
|
@ -13,27 +13,48 @@ namespace Nz::ShaderAst
|
|||
statement.Visit(*this);
|
||||
}
|
||||
|
||||
void DependencyCheckerVisitor::Visit(CallFunctionExpression& node)
|
||||
auto DependencyCheckerVisitor::GetContextUsageSet() -> UsageSet&
|
||||
{
|
||||
const ExpressionType* targetFuncType = GetExpressionType(*node.targetFunction);
|
||||
assert(targetFuncType);
|
||||
assert(std::holds_alternative<FunctionType>(*targetFuncType));
|
||||
|
||||
const auto& funcType = std::get<FunctionType>(*targetFuncType);
|
||||
|
||||
assert(m_currentFunctionIndex);
|
||||
if (m_currentVariableDeclIndex)
|
||||
{
|
||||
UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex);
|
||||
usageSet.usedFunctions.UnboundedSet(funcType.funcIndex);
|
||||
}
|
||||
if (m_currentAliasDeclIndex)
|
||||
return Retrieve(m_aliasUsages, *m_currentAliasDeclIndex);
|
||||
else if (m_currentVariableDeclIndex)
|
||||
return Retrieve(m_variableUsages, *m_currentVariableDeclIndex);
|
||||
else
|
||||
{
|
||||
UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex);
|
||||
usageSet.usedFunctions.UnboundedSet(funcType.funcIndex);
|
||||
assert(m_currentFunctionIndex);
|
||||
return Retrieve(m_functionUsages, *m_currentFunctionIndex);
|
||||
}
|
||||
}
|
||||
|
||||
void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet)
|
||||
{
|
||||
m_resolvedUsage.usedFunctions |= usageSet.usedFunctions;
|
||||
m_resolvedUsage.usedStructs |= usageSet.usedStructs;
|
||||
m_resolvedUsage.usedVariables |= usageSet.usedVariables;
|
||||
|
||||
for (std::size_t aliasIndex = usageSet.usedAliases.FindFirst(); aliasIndex != usageSet.usedAliases.npos; aliasIndex = usageSet.usedAliases.FindNext(aliasIndex))
|
||||
Resolve(Retrieve(m_aliasUsages, aliasIndex));
|
||||
|
||||
for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex))
|
||||
Resolve(Retrieve(m_functionUsages, funcIndex));
|
||||
|
||||
for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex))
|
||||
Resolve(Retrieve(m_structUsages, structIndex));
|
||||
|
||||
for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex))
|
||||
Resolve(Retrieve(m_variableUsages, varIndex));
|
||||
}
|
||||
|
||||
void DependencyCheckerVisitor::Visit(DeclareAliasStatement& node)
|
||||
{
|
||||
assert(node.aliasIndex);
|
||||
assert(m_aliasUsages.find(*node.aliasIndex) == m_aliasUsages.end());
|
||||
m_aliasUsages.emplace(*node.aliasIndex, UsageSet{});
|
||||
|
||||
assert(node.aliasIndex);
|
||||
m_currentAliasDeclIndex = *node.aliasIndex;
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
m_currentAliasDeclIndex = {};
|
||||
}
|
||||
|
||||
void DependencyCheckerVisitor::Visit(DeclareExternalStatement& node)
|
||||
|
|
@ -145,34 +166,27 @@ namespace Nz::ShaderAst
|
|||
m_currentVariableDeclIndex = {};
|
||||
}
|
||||
|
||||
void DependencyCheckerVisitor::Visit(AliasValueExpression& node)
|
||||
{
|
||||
UsageSet& usageSet = GetContextUsageSet();
|
||||
usageSet.usedAliases.UnboundedSet(node.aliasId);
|
||||
}
|
||||
|
||||
void DependencyCheckerVisitor::Visit(FunctionExpression& node)
|
||||
{
|
||||
UsageSet& usageSet = GetContextUsageSet();
|
||||
usageSet.usedFunctions.UnboundedSet(node.funcId);
|
||||
}
|
||||
|
||||
void DependencyCheckerVisitor::Visit(StructTypeExpression& node)
|
||||
{
|
||||
UsageSet& usageSet = GetContextUsageSet();
|
||||
usageSet.usedStructs.UnboundedSet(node.structTypeId);
|
||||
}
|
||||
|
||||
void DependencyCheckerVisitor::Visit(VariableValueExpression& node)
|
||||
{
|
||||
assert(m_currentFunctionIndex);
|
||||
if (m_currentVariableDeclIndex)
|
||||
{
|
||||
UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex);
|
||||
usageSet.usedVariables.UnboundedSet(node.variableId);
|
||||
}
|
||||
else
|
||||
{
|
||||
UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex);
|
||||
usageSet.usedVariables.UnboundedSet(node.variableId);
|
||||
}
|
||||
}
|
||||
|
||||
void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet)
|
||||
{
|
||||
m_resolvedUsage.usedFunctions |= usageSet.usedFunctions;
|
||||
m_resolvedUsage.usedStructs |= usageSet.usedStructs;
|
||||
m_resolvedUsage.usedVariables |= usageSet.usedVariables;
|
||||
|
||||
for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex))
|
||||
Resolve(Retrieve(m_functionUsages, funcIndex));
|
||||
|
||||
for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex))
|
||||
Resolve(Retrieve(m_structUsages, structIndex));
|
||||
|
||||
for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex))
|
||||
Resolve(Retrieve(m_variableUsages, varIndex));
|
||||
UsageSet& usageSet = GetContextUsageSet();
|
||||
usageSet.usedVariables.UnboundedSet(node.variableId);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue