Shader/DependencyCheckerVisitor: Handle aliases

This commit is contained in:
SirLynix 2022-04-02 02:04:42 +02:00
parent c04b650e7c
commit 81b1b9b473
2 changed files with 64 additions and 43 deletions

View File

@ -45,26 +45,33 @@ namespace Nz::ShaderAst
struct UsageSet
{
Bitset<> usedAliases;
Bitset<> usedFunctions;
Bitset<> usedStructs;
Bitset<> usedVariables;
};
private:
UsageSet& GetContextUsageSet();
void Resolve(const UsageSet& usageSet);
using AstRecursiveVisitor::Visit;
void Visit(CallFunctionExpression& node) override;
void Visit(AliasValueExpression& node) override;
void Visit(FunctionExpression& node) override;
void Visit(StructTypeExpression& node) override;
void Visit(VariableValueExpression& node) override;
void Visit(DeclareAliasStatement& node) override;
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;
std::optional<std::size_t> m_currentAliasDeclIndex;
std::optional<std::size_t> m_currentFunctionIndex;
std::optional<std::size_t> m_currentVariableDeclIndex;
std::unordered_map<std::size_t, UsageSet> m_aliasUsages;
std::unordered_map<std::size_t, UsageSet> m_functionUsages;
std::unordered_map<std::size_t, UsageSet> m_structUsages;
std::unordered_map<std::size_t, UsageSet> m_variableUsages;

View File

@ -13,27 +13,48 @@ namespace Nz::ShaderAst
statement.Visit(*this);
}
void DependencyCheckerVisitor::Visit(CallFunctionExpression& node)
auto DependencyCheckerVisitor::GetContextUsageSet() -> UsageSet&
{
const ExpressionType* targetFuncType = GetExpressionType(*node.targetFunction);
assert(targetFuncType);
assert(std::holds_alternative<FunctionType>(*targetFuncType));
const auto& funcType = std::get<FunctionType>(*targetFuncType);
assert(m_currentFunctionIndex);
if (m_currentVariableDeclIndex)
{
UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex);
usageSet.usedFunctions.UnboundedSet(funcType.funcIndex);
}
if (m_currentAliasDeclIndex)
return Retrieve(m_aliasUsages, *m_currentAliasDeclIndex);
else if (m_currentVariableDeclIndex)
return Retrieve(m_variableUsages, *m_currentVariableDeclIndex);
else
{
UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex);
usageSet.usedFunctions.UnboundedSet(funcType.funcIndex);
assert(m_currentFunctionIndex);
return Retrieve(m_functionUsages, *m_currentFunctionIndex);
}
}
void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet)
{
m_resolvedUsage.usedFunctions |= usageSet.usedFunctions;
m_resolvedUsage.usedStructs |= usageSet.usedStructs;
m_resolvedUsage.usedVariables |= usageSet.usedVariables;
for (std::size_t aliasIndex = usageSet.usedAliases.FindFirst(); aliasIndex != usageSet.usedAliases.npos; aliasIndex = usageSet.usedAliases.FindNext(aliasIndex))
Resolve(Retrieve(m_aliasUsages, aliasIndex));
for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex))
Resolve(Retrieve(m_functionUsages, funcIndex));
for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex))
Resolve(Retrieve(m_structUsages, structIndex));
for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex))
Resolve(Retrieve(m_variableUsages, varIndex));
}
void DependencyCheckerVisitor::Visit(DeclareAliasStatement& node)
{
assert(node.aliasIndex);
assert(m_aliasUsages.find(*node.aliasIndex) == m_aliasUsages.end());
m_aliasUsages.emplace(*node.aliasIndex, UsageSet{});
assert(node.aliasIndex);
m_currentAliasDeclIndex = *node.aliasIndex;
AstRecursiveVisitor::Visit(node);
m_currentAliasDeclIndex = {};
}
void DependencyCheckerVisitor::Visit(DeclareExternalStatement& node)
@ -145,34 +166,27 @@ namespace Nz::ShaderAst
m_currentVariableDeclIndex = {};
}
void DependencyCheckerVisitor::Visit(AliasValueExpression& node)
{
UsageSet& usageSet = GetContextUsageSet();
usageSet.usedAliases.UnboundedSet(node.aliasId);
}
void DependencyCheckerVisitor::Visit(FunctionExpression& node)
{
UsageSet& usageSet = GetContextUsageSet();
usageSet.usedFunctions.UnboundedSet(node.funcId);
}
void DependencyCheckerVisitor::Visit(StructTypeExpression& node)
{
UsageSet& usageSet = GetContextUsageSet();
usageSet.usedStructs.UnboundedSet(node.structTypeId);
}
void DependencyCheckerVisitor::Visit(VariableValueExpression& node)
{
assert(m_currentFunctionIndex);
if (m_currentVariableDeclIndex)
{
UsageSet& usageSet = Retrieve(m_variableUsages, *m_currentVariableDeclIndex);
usageSet.usedVariables.UnboundedSet(node.variableId);
}
else
{
UsageSet& usageSet = Retrieve(m_functionUsages, *m_currentFunctionIndex);
usageSet.usedVariables.UnboundedSet(node.variableId);
}
}
void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet)
{
m_resolvedUsage.usedFunctions |= usageSet.usedFunctions;
m_resolvedUsage.usedStructs |= usageSet.usedStructs;
m_resolvedUsage.usedVariables |= usageSet.usedVariables;
for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex))
Resolve(Retrieve(m_functionUsages, funcIndex));
for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex))
Resolve(Retrieve(m_structUsages, structIndex));
for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex))
Resolve(Retrieve(m_variableUsages, varIndex));
UsageSet& usageSet = GetContextUsageSet();
usageSet.usedVariables.UnboundedSet(node.variableId);
}
}