Shader: Fix dependency check of modules
This commit is contained in:
parent
7f15c66f16
commit
18edd41048
|
|
@ -30,10 +30,10 @@ namespace Nz::ShaderAst
|
||||||
inline void MarkFunctionAsUsed(std::size_t funcIndex);
|
inline void MarkFunctionAsUsed(std::size_t funcIndex);
|
||||||
inline void MarkStructAsUsed(std::size_t structIndex);
|
inline void MarkStructAsUsed(std::size_t structIndex);
|
||||||
|
|
||||||
inline void Process(Statement& statement);
|
inline void Register(Statement& statement);
|
||||||
void Process(Statement& statement, const Config& config);
|
void Register(Statement& statement, const Config& config);
|
||||||
|
|
||||||
inline void Resolve();
|
inline void Resolve(bool allowUnknownId = false);
|
||||||
|
|
||||||
DependencyCheckerVisitor& operator=(const DependencyCheckerVisitor&) = delete;
|
DependencyCheckerVisitor& operator=(const DependencyCheckerVisitor&) = delete;
|
||||||
DependencyCheckerVisitor& operator=(DependencyCheckerVisitor&&) = delete;
|
DependencyCheckerVisitor& operator=(DependencyCheckerVisitor&&) = delete;
|
||||||
|
|
@ -53,7 +53,8 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
private:
|
private:
|
||||||
UsageSet& GetContextUsageSet();
|
UsageSet& GetContextUsageSet();
|
||||||
void Resolve(const UsageSet& usageSet);
|
void RegisterType(UsageSet& usageSet, const ExpressionType& exprType);
|
||||||
|
void Resolve(const UsageSet& usageSet, bool allowUnknownId);
|
||||||
|
|
||||||
using AstRecursiveVisitor::Visit;
|
using AstRecursiveVisitor::Visit;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,15 +22,15 @@ namespace Nz::ShaderAst
|
||||||
m_globalUsage.usedStructs.UnboundedSet(structIndex);
|
m_globalUsage.usedStructs.UnboundedSet(structIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void DependencyCheckerVisitor::Process(Statement& statement)
|
inline void DependencyCheckerVisitor::Register(Statement& statement)
|
||||||
{
|
{
|
||||||
Config defaultConfig;
|
Config defaultConfig;
|
||||||
return Process(statement, defaultConfig);
|
return Register(statement, defaultConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DependencyCheckerVisitor::Resolve()
|
inline void DependencyCheckerVisitor::Resolve(bool allowUnknownId)
|
||||||
{
|
{
|
||||||
Resolve(m_globalUsage);
|
Resolve(m_globalUsage, allowUnknownId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@ namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
DependencyCheckerVisitor dependencyVisitor;
|
DependencyCheckerVisitor dependencyVisitor;
|
||||||
for (const auto& importedModule : shaderModule.importedModules)
|
for (const auto& importedModule : shaderModule.importedModules)
|
||||||
dependencyVisitor.Process(*importedModule.module->rootNode, config);
|
dependencyVisitor.Register(*importedModule.module->rootNode, config);
|
||||||
|
|
||||||
dependencyVisitor.Process(*shaderModule.rootNode, config);
|
dependencyVisitor.Register(*shaderModule.rootNode, config);
|
||||||
dependencyVisitor.Resolve();
|
dependencyVisitor.Resolve();
|
||||||
|
|
||||||
return EliminateUnusedPass(shaderModule, dependencyVisitor.GetUsage());
|
return EliminateUnusedPass(shaderModule, dependencyVisitor.GetUsage());
|
||||||
|
|
@ -40,7 +40,7 @@ namespace Nz::ShaderAst
|
||||||
inline StatementPtr EliminateUnusedPass(Statement& ast, const DependencyCheckerVisitor::Config& config)
|
inline StatementPtr EliminateUnusedPass(Statement& ast, const DependencyCheckerVisitor::Config& config)
|
||||||
{
|
{
|
||||||
DependencyCheckerVisitor dependencyVisitor;
|
DependencyCheckerVisitor dependencyVisitor;
|
||||||
dependencyVisitor.Process(ast, config);
|
dependencyVisitor.Register(ast, config);
|
||||||
dependencyVisitor.Resolve();
|
dependencyVisitor.Resolve();
|
||||||
|
|
||||||
return EliminateUnusedPass(ast, dependencyVisitor.GetUsage());
|
return EliminateUnusedPass(ast, dependencyVisitor.GetUsage());
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
namespace Nz::ShaderAst
|
namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
void DependencyCheckerVisitor::Process(Statement& statement, const Config& config)
|
void DependencyCheckerVisitor::Register(Statement& statement, const Config& config)
|
||||||
{
|
{
|
||||||
m_config = config;
|
m_config = config;
|
||||||
statement.Visit(*this);
|
statement.Visit(*this);
|
||||||
|
|
@ -26,7 +26,23 @@ namespace Nz::ShaderAst
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet)
|
void DependencyCheckerVisitor::RegisterType(UsageSet& usageSet, const ExpressionType& exprType)
|
||||||
|
{
|
||||||
|
std::visit([&](auto&& arg)
|
||||||
|
{
|
||||||
|
using T = std::decay_t<decltype(arg)>;
|
||||||
|
|
||||||
|
if constexpr (std::is_same_v<T, AliasType>)
|
||||||
|
usageSet.usedAliases.UnboundedSet(arg.aliasIndex);
|
||||||
|
else if constexpr (std::is_same_v<T, StructType>)
|
||||||
|
usageSet.usedStructs.UnboundedSet(arg.structIndex);
|
||||||
|
else if constexpr (std::is_same_v<T, UniformType>)
|
||||||
|
usageSet.usedStructs.UnboundedSet(arg.containedType.structIndex);
|
||||||
|
|
||||||
|
}, exprType);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DependencyCheckerVisitor::Resolve(const UsageSet& usageSet, bool allowUnknownId)
|
||||||
{
|
{
|
||||||
m_resolvedUsage.usedAliases |= usageSet.usedAliases;
|
m_resolvedUsage.usedAliases |= usageSet.usedAliases;
|
||||||
m_resolvedUsage.usedFunctions |= usageSet.usedFunctions;
|
m_resolvedUsage.usedFunctions |= usageSet.usedFunctions;
|
||||||
|
|
@ -34,16 +50,40 @@ namespace Nz::ShaderAst
|
||||||
m_resolvedUsage.usedVariables |= usageSet.usedVariables;
|
m_resolvedUsage.usedVariables |= usageSet.usedVariables;
|
||||||
|
|
||||||
for (std::size_t aliasIndex = usageSet.usedAliases.FindFirst(); aliasIndex != usageSet.usedAliases.npos; aliasIndex = usageSet.usedAliases.FindNext(aliasIndex))
|
for (std::size_t aliasIndex = usageSet.usedAliases.FindFirst(); aliasIndex != usageSet.usedAliases.npos; aliasIndex = usageSet.usedAliases.FindNext(aliasIndex))
|
||||||
Resolve(Retrieve(m_aliasUsages, aliasIndex));
|
{
|
||||||
|
auto it = m_aliasUsages.find(aliasIndex);
|
||||||
|
if (it != m_aliasUsages.end())
|
||||||
|
Resolve(it->second, allowUnknownId);
|
||||||
|
else if (!allowUnknownId)
|
||||||
|
throw std::runtime_error("unknown alias #" + std::to_string(aliasIndex));
|
||||||
|
}
|
||||||
|
|
||||||
for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex))
|
for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex))
|
||||||
Resolve(Retrieve(m_functionUsages, funcIndex));
|
{
|
||||||
|
auto it = m_functionUsages.find(funcIndex);
|
||||||
|
if (it != m_functionUsages.end())
|
||||||
|
Resolve(it->second, allowUnknownId);
|
||||||
|
else if (!allowUnknownId)
|
||||||
|
throw std::runtime_error("unknown func #" + std::to_string(funcIndex));
|
||||||
|
}
|
||||||
|
|
||||||
for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex))
|
for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex))
|
||||||
Resolve(Retrieve(m_structUsages, structIndex));
|
{
|
||||||
|
auto it = m_structUsages.find(structIndex);
|
||||||
|
if (it != m_structUsages.end())
|
||||||
|
Resolve(it->second, allowUnknownId);
|
||||||
|
else if (!allowUnknownId)
|
||||||
|
throw std::runtime_error("unknown struct #" + std::to_string(structIndex));
|
||||||
|
}
|
||||||
|
|
||||||
for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex))
|
for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex))
|
||||||
Resolve(Retrieve(m_variableUsages, varIndex));
|
{
|
||||||
|
auto it = m_variableUsages.find(varIndex);
|
||||||
|
if (it != m_variableUsages.end())
|
||||||
|
Resolve(it->second, allowUnknownId);
|
||||||
|
else if (!allowUnknownId)
|
||||||
|
throw std::runtime_error("unknown var #" + std::to_string(varIndex));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DependencyCheckerVisitor::Visit(DeclareAliasStatement& node)
|
void DependencyCheckerVisitor::Visit(DeclareAliasStatement& node)
|
||||||
|
|
@ -69,12 +109,7 @@ namespace Nz::ShaderAst
|
||||||
UsageSet& usageSet = m_variableUsages[varIndex];
|
UsageSet& usageSet = m_variableUsages[varIndex];
|
||||||
|
|
||||||
const auto& exprType = externalVar.type.GetResultingValue();
|
const auto& exprType = externalVar.type.GetResultingValue();
|
||||||
|
RegisterType(usageSet, exprType);
|
||||||
if (IsUniformType(exprType))
|
|
||||||
{
|
|
||||||
const UniformType& uniformType = std::get<UniformType>(exprType);
|
|
||||||
usageSet.usedStructs.UnboundedSet(uniformType.containedType.structIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
++varIndex;
|
++varIndex;
|
||||||
}
|
}
|
||||||
|
|
@ -100,22 +135,14 @@ namespace Nz::ShaderAst
|
||||||
m_variableUsages.emplace(*parameter.varIndex, UsageSet{});
|
m_variableUsages.emplace(*parameter.varIndex, UsageSet{});
|
||||||
|
|
||||||
const auto& exprType = parameter.type.GetResultingValue();
|
const auto& exprType = parameter.type.GetResultingValue();
|
||||||
if (IsStructType(exprType))
|
RegisterType(usageSet, exprType);
|
||||||
{
|
|
||||||
std::size_t structIndex = std::get<ShaderAst::StructType>(exprType).structIndex;
|
|
||||||
usageSet.usedStructs.UnboundedSet(structIndex);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node.returnType.HasValue())
|
if (node.returnType.HasValue())
|
||||||
{
|
{
|
||||||
const auto& returnExprType = node.returnType.GetResultingValue();
|
const auto& returnExprType = node.returnType.GetResultingValue();
|
||||||
if (IsStructType(returnExprType))
|
RegisterType(usageSet, returnExprType);
|
||||||
{
|
|
||||||
std::size_t structIndex = std::get<ShaderAst::StructType>(returnExprType).structIndex;
|
|
||||||
usageSet.usedStructs.UnboundedSet(structIndex);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node.entryStage.HasValue())
|
if (node.entryStage.HasValue())
|
||||||
|
|
@ -139,11 +166,7 @@ namespace Nz::ShaderAst
|
||||||
for (const auto& structMember : node.description.members)
|
for (const auto& structMember : node.description.members)
|
||||||
{
|
{
|
||||||
const auto& memberExprType = structMember.type.GetResultingValue();
|
const auto& memberExprType = structMember.type.GetResultingValue();
|
||||||
if (IsStructType(memberExprType))
|
RegisterType(usageSet, memberExprType);
|
||||||
{
|
|
||||||
std::size_t structIndex = std::get<ShaderAst::StructType>(memberExprType).structIndex;
|
|
||||||
usageSet.usedStructs.UnboundedSet(structIndex);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AstRecursiveVisitor::Visit(node);
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
|
@ -156,11 +179,7 @@ namespace Nz::ShaderAst
|
||||||
UsageSet& usageSet = m_variableUsages[*node.varIndex];
|
UsageSet& usageSet = m_variableUsages[*node.varIndex];
|
||||||
|
|
||||||
const auto& varType = node.varType.GetResultingValue();
|
const auto& varType = node.varType.GetResultingValue();
|
||||||
if (IsStructType(varType))
|
RegisterType(usageSet, varType);
|
||||||
{
|
|
||||||
const auto& structType = std::get<StructType>(varType);
|
|
||||||
usageSet.usedStructs.UnboundedSet(structType.structIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
m_currentVariableDeclIndex = node.varIndex;
|
m_currentVariableDeclIndex = node.varIndex;
|
||||||
AstRecursiveVisitor::Visit(node);
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,8 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
if (moduleData.dependenciesVisitor)
|
if (moduleData.dependenciesVisitor)
|
||||||
{
|
{
|
||||||
moduleData.dependenciesVisitor->Resolve();
|
moduleData.dependenciesVisitor->Resolve(true); //< allow unknown identifiers since we may be referencing other modules
|
||||||
|
|
||||||
importedModule.module = EliminateUnusedPass(*importedModule.module, moduleData.dependenciesVisitor->GetUsage());
|
importedModule.module = EliminateUnusedPass(*importedModule.module, moduleData.dependenciesVisitor->GetUsage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1750,7 +1751,7 @@ namespace Nz::ShaderAst
|
||||||
if (!m_context->options.allowPartialSanitization)
|
if (!m_context->options.allowPartialSanitization)
|
||||||
{
|
{
|
||||||
moduleData.dependenciesVisitor = std::make_unique<DependencyCheckerVisitor>();
|
moduleData.dependenciesVisitor = std::make_unique<DependencyCheckerVisitor>();
|
||||||
moduleData.dependenciesVisitor->Process(*sanitizedModule->rootNode);
|
moduleData.dependenciesVisitor->Register(*sanitizedModule->rootNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
moduleData.environment = std::move(moduleEnvironment);
|
moduleData.environment = std::move(moduleEnvironment);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue