Shader/SanitizeVisitor: Fix an issue when double-sanitizing

... with differents parameters (like reducing loops to while, which introduces new variables which would take over existing var indices)
This commit is contained in:
Jérôme Leclercq 2022-03-10 12:44:47 +01:00
parent bf7f06ac4c
commit 98bd04e35a
5 changed files with 174 additions and 7 deletions

View File

@ -32,12 +32,33 @@ namespace Nz::ShaderAst
struct Callbacks
{
std::function<void(ShaderStageType stageType, const std::string& functionName)> onEntryPointDeclaration;
std::function<void(const std::string& optionName, const ExpressionValue<ExpressionType>& optionType)> onOptionDeclaration;
std::function<void(const DeclareAliasStatement& aliasDecl)> onAliasDeclaration;
std::function<void(const DeclareConstStatement& constDecl)> onConstDeclaration;
std::function<void(const DeclareExternalStatement& extDecl)> onExternalDeclaration;
std::function<void(const DeclareFunctionStatement& funcDecl)> onFunctionDeclaration;
std::function<void(const DeclareOptionStatement& optionDecl)> onOptionDeclaration;
std::function<void(const DeclareStructStatement& structDecl)> onStructDeclaration;
std::function<void(const DeclareVariableStatement& variableDecl)> onVariableDeclaration;
std::function<void(const std::string& name, std::size_t aliasIndex)> onAliasIndex;
std::function<void(const std::string& name, std::size_t constIndex)> onConstIndex;
std::function<void(const std::string& name, std::size_t funcIndex)> onFunctionIndex;
std::function<void(const std::string& name, std::size_t optIndex)> onOptionIndex;
std::function<void(const std::string& name, std::size_t structIndex)> onStructIndex;
std::function<void(const std::string& name, std::size_t varIndex)> onVariableIndex;
};
private:
void Visit(DeclareAliasStatement& node) override;
void Visit(DeclareConstStatement& node) override;
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareOptionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;
void Visit(ForStatement& node) override;
void Visit(ForEachStatement& node) override;
const Callbacks* m_callbacks;
};

View File

@ -124,6 +124,7 @@ namespace Nz::ShaderAst
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
void PreregisterIndices(const Module& module);
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
void RegisterBuiltin();

View File

@ -32,12 +32,12 @@ namespace Nz
supportedStageType |= stageType;
};
callbacks.onOptionDeclaration = [&](const std::string& optionName, const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& optionType)
callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option)
{
//TODO: Check optionType
m_optionIndexByName[optionName] = Option{
CRC32(optionName)
m_optionIndexByName[option.optName] = Option{
CRC32(option.optName)
};
optionCount++;

View File

@ -14,9 +14,55 @@ namespace Nz::ShaderAst
statement.Visit(*this);
}
void AstReflect::Visit(DeclareAliasStatement& node)
{
assert(m_callbacks);
if (m_callbacks->onAliasDeclaration)
m_callbacks->onAliasDeclaration(node);
if (m_callbacks->onAliasIndex && node.aliasIndex)
m_callbacks->onAliasIndex(node.name, *node.aliasIndex);
AstRecursiveVisitor::Visit(node);
}
void AstReflect::Visit(DeclareConstStatement& node)
{
assert(m_callbacks);
if (m_callbacks->onConstDeclaration)
m_callbacks->onConstDeclaration(node);
if (m_callbacks->onConstIndex && node.constIndex)
m_callbacks->onConstIndex(node.name, *node.constIndex);
AstRecursiveVisitor::Visit(node);
}
void AstReflect::Visit(DeclareExternalStatement& node)
{
assert(m_callbacks);
if (m_callbacks->onExternalDeclaration)
m_callbacks->onExternalDeclaration(node);
if (m_callbacks->onVariableIndex)
{
for (const auto& extVar : node.externalVars)
{
if (extVar.varIndex)
m_callbacks->onVariableIndex(extVar.name, *extVar.varIndex);
}
}
AstRecursiveVisitor::Visit(node);
}
void AstReflect::Visit(DeclareFunctionStatement& node)
{
assert(m_callbacks);
if (m_callbacks->onFunctionDeclaration)
m_callbacks->onFunctionDeclaration(node);
if (m_callbacks->onEntryPointDeclaration)
{
if (!node.entryStage.HasValue())
@ -24,12 +70,70 @@ namespace Nz::ShaderAst
m_callbacks->onEntryPointDeclaration(node.entryStage.GetResultingValue(), node.name);
}
if (m_callbacks->onVariableIndex)
{
for (const auto& parameter : node.parameters)
{
if (parameter.varIndex)
m_callbacks->onVariableIndex(parameter.name, *parameter.varIndex);
}
}
AstRecursiveVisitor::Visit(node);
}
void AstReflect::Visit(DeclareOptionStatement& node)
{
assert(m_callbacks);
if (m_callbacks->onOptionDeclaration)
m_callbacks->onOptionDeclaration(node.optName, node.optType);
m_callbacks->onOptionDeclaration(node);
if (m_callbacks->onOptionIndex && node.optIndex)
m_callbacks->onOptionIndex(node.optName, *node.optIndex);
AstRecursiveVisitor::Visit(node);
}
void AstReflect::Visit(DeclareStructStatement& node)
{
assert(m_callbacks);
if (m_callbacks->onStructDeclaration)
m_callbacks->onStructDeclaration(node);
if (m_callbacks->onStructIndex && node.structIndex)
m_callbacks->onStructIndex(node.description.name, *node.structIndex);
AstRecursiveVisitor::Visit(node);
}
void AstReflect::Visit(DeclareVariableStatement& node)
{
assert(m_callbacks);
if (m_callbacks->onVariableDeclaration)
m_callbacks->onVariableDeclaration(node);
if (m_callbacks->onVariableIndex && node.varIndex)
m_callbacks->onVariableIndex(node.varName, *node.varIndex);
AstRecursiveVisitor::Visit(node);
}
void AstReflect::Visit(ForStatement& node)
{
assert(m_callbacks);
if (m_callbacks->onVariableIndex && node.varIndex)
m_callbacks->onVariableIndex(node.varName, *node.varIndex);
AstRecursiveVisitor::Visit(node);
}
void AstReflect::Visit(ForEachStatement& node)
{
assert(m_callbacks);
if (m_callbacks->onVariableIndex && node.varIndex)
m_callbacks->onVariableIndex(node.varName, *node.varIndex);
AstRecursiveVisitor::Visit(node);
}
}

View File

@ -12,6 +12,7 @@
#include <Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp>
#include <Nazara/Shader/Ast/AstExportVisitor.hpp>
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
#include <Nazara/Shader/Ast/AstReflect.hpp>
#include <Nazara/Shader/Ast/AstUtils.hpp>
#include <Nazara/Shader/Ast/DependencyCheckerVisitor.hpp>
#include <Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp>
@ -53,6 +54,20 @@ namespace Nz::ShaderAst
Bitset<UInt64> preregisteredIndices;
std::unordered_map<std::size_t, T> values;
void PreregisterIndex(std::size_t index)
{
if (index < availableIndices.GetSize())
{
if (!availableIndices.Test(index))
throw AstError{ "cannot preregister used index " + std::to_string(index) + " as its already used" };
}
else if (index >= availableIndices.GetSize())
availableIndices.Resize(index + 1, true);
availableIndices.Set(index, false);
preregisteredIndices.UnboundedSet(index);
}
template<typename U>
std::size_t Register(U&& data, std::optional<std::size_t> index = {})
{
@ -172,6 +187,8 @@ namespace Nz::ShaderAst
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
PreregisterIndices(module);
// Register builtin env
m_context->builtinEnv = std::make_shared<Environment>();
m_context->currentEnv = m_context->builtinEnv;
@ -906,6 +923,7 @@ namespace Nz::ShaderAst
auto& cloneParam = clone->parameters.emplace_back();
cloneParam.name = parameter.name;
cloneParam.type = ResolveType(parameter.type);
cloneParam.varIndex = parameter.varIndex;
}
if (node.returnType.HasValue())
@ -1170,6 +1188,7 @@ namespace Nz::ShaderAst
// Counter variable
auto counterVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(fromExpr));
counterVariable->varIndex = node.varIndex;
Validate(*counterVariable);
std::size_t counterVarIndex = counterVariable->varIndex.value();
@ -1240,7 +1259,7 @@ namespace Nz::ShaderAst
PushScope();
{
clone->varIndex = RegisterVariable(node.varName, fromExprType);
clone->varIndex = RegisterVariable(node.varName, fromExprType, node.varIndex);
clone->statement = CloneStatement(node.statement);
}
PopScope();
@ -1337,6 +1356,7 @@ namespace Nz::ShaderAst
Validate(*accessIndex);
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
elementVariable->varIndex = node.varIndex; //< Preserve var index
Validate(*elementVariable);
body->statements.emplace_back(std::move(elementVariable));
@ -1365,7 +1385,7 @@ namespace Nz::ShaderAst
PushScope();
{
clone->varIndex = RegisterVariable(node.varName, innerType);
clone->varIndex = RegisterVariable(node.varName, innerType, node.varIndex);
clone->statement = CloneStatement(node.statement);
}
PopScope();
@ -1825,6 +1845,27 @@ namespace Nz::ShaderAst
return static_unique_pointer_cast<T>(ShaderAst::PropagateConstants(node, optimizerOptions));
}
void SanitizeVisitor::PreregisterIndices(const Module& module)
{
// If AST has been sanitized before and is sanitized again but with differents options that may introduce new variables (for example reduceLoopsToWhile)
// we have to make sure we won't override variable indices. This is done by visiting the AST a first time and preregistering all indices.
// TODO: Only do this is the AST has been already sanitized, maybe using a flag stored in the module?
AstReflect::Callbacks registerCallbacks;
registerCallbacks.onAliasIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->aliases.PreregisterIndex(index); };
registerCallbacks.onConstIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->constantValues.PreregisterIndex(index); };
registerCallbacks.onFunctionIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->functions.PreregisterIndex(index); };
registerCallbacks.onOptionIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->constantValues.PreregisterIndex(index); };
registerCallbacks.onStructIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->structs.PreregisterIndex(index); };
registerCallbacks.onVariableIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->variableTypes.PreregisterIndex(index); };
AstReflect reflectVisitor;
for (const auto& importedModule : module.importedModules)
reflectVisitor.Reflect(*importedModule.module->rootNode, registerCallbacks);
reflectVisitor.Reflect(*module.rootNode, registerCallbacks);
}
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
{
auto& funcData = m_context->functions.Retrieve(funcIndex);