Shader/SanitizeVisitor: Fix an issue when double-sanitizing
... with differents parameters (like reducing loops to while, which introduces new variables which would take over existing var indices)
This commit is contained in:
@@ -32,12 +32,33 @@ namespace Nz::ShaderAst
|
|||||||
struct Callbacks
|
struct Callbacks
|
||||||
{
|
{
|
||||||
std::function<void(ShaderStageType stageType, const std::string& functionName)> onEntryPointDeclaration;
|
std::function<void(ShaderStageType stageType, const std::string& functionName)> onEntryPointDeclaration;
|
||||||
std::function<void(const std::string& optionName, const ExpressionValue<ExpressionType>& optionType)> onOptionDeclaration;
|
|
||||||
|
std::function<void(const DeclareAliasStatement& aliasDecl)> onAliasDeclaration;
|
||||||
|
std::function<void(const DeclareConstStatement& constDecl)> onConstDeclaration;
|
||||||
|
std::function<void(const DeclareExternalStatement& extDecl)> onExternalDeclaration;
|
||||||
|
std::function<void(const DeclareFunctionStatement& funcDecl)> onFunctionDeclaration;
|
||||||
|
std::function<void(const DeclareOptionStatement& optionDecl)> onOptionDeclaration;
|
||||||
|
std::function<void(const DeclareStructStatement& structDecl)> onStructDeclaration;
|
||||||
|
std::function<void(const DeclareVariableStatement& variableDecl)> onVariableDeclaration;
|
||||||
|
|
||||||
|
std::function<void(const std::string& name, std::size_t aliasIndex)> onAliasIndex;
|
||||||
|
std::function<void(const std::string& name, std::size_t constIndex)> onConstIndex;
|
||||||
|
std::function<void(const std::string& name, std::size_t funcIndex)> onFunctionIndex;
|
||||||
|
std::function<void(const std::string& name, std::size_t optIndex)> onOptionIndex;
|
||||||
|
std::function<void(const std::string& name, std::size_t structIndex)> onStructIndex;
|
||||||
|
std::function<void(const std::string& name, std::size_t varIndex)> onVariableIndex;
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void Visit(DeclareAliasStatement& node) override;
|
||||||
|
void Visit(DeclareConstStatement& node) override;
|
||||||
|
void Visit(DeclareExternalStatement& node) override;
|
||||||
void Visit(DeclareFunctionStatement& node) override;
|
void Visit(DeclareFunctionStatement& node) override;
|
||||||
void Visit(DeclareOptionStatement& node) override;
|
void Visit(DeclareOptionStatement& node) override;
|
||||||
|
void Visit(DeclareStructStatement& node) override;
|
||||||
|
void Visit(DeclareVariableStatement& node) override;
|
||||||
|
void Visit(ForStatement& node) override;
|
||||||
|
void Visit(ForEachStatement& node) override;
|
||||||
|
|
||||||
const Callbacks* m_callbacks;
|
const Callbacks* m_callbacks;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ namespace Nz::ShaderAst
|
|||||||
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
|
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
|
||||||
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
|
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
|
||||||
|
|
||||||
|
void PreregisterIndices(const Module& module);
|
||||||
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
|
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
|
||||||
|
|
||||||
void RegisterBuiltin();
|
void RegisterBuiltin();
|
||||||
|
|||||||
@@ -32,12 +32,12 @@ namespace Nz
|
|||||||
supportedStageType |= stageType;
|
supportedStageType |= stageType;
|
||||||
};
|
};
|
||||||
|
|
||||||
callbacks.onOptionDeclaration = [&](const std::string& optionName, const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& optionType)
|
callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option)
|
||||||
{
|
{
|
||||||
//TODO: Check optionType
|
//TODO: Check optionType
|
||||||
|
|
||||||
m_optionIndexByName[optionName] = Option{
|
m_optionIndexByName[option.optName] = Option{
|
||||||
CRC32(optionName)
|
CRC32(option.optName)
|
||||||
};
|
};
|
||||||
|
|
||||||
optionCount++;
|
optionCount++;
|
||||||
|
|||||||
@@ -14,9 +14,55 @@ namespace Nz::ShaderAst
|
|||||||
statement.Visit(*this);
|
statement.Visit(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AstReflect::Visit(DeclareAliasStatement& node)
|
||||||
|
{
|
||||||
|
assert(m_callbacks);
|
||||||
|
if (m_callbacks->onAliasDeclaration)
|
||||||
|
m_callbacks->onAliasDeclaration(node);
|
||||||
|
|
||||||
|
if (m_callbacks->onAliasIndex && node.aliasIndex)
|
||||||
|
m_callbacks->onAliasIndex(node.name, *node.aliasIndex);
|
||||||
|
|
||||||
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AstReflect::Visit(DeclareConstStatement& node)
|
||||||
|
{
|
||||||
|
assert(m_callbacks);
|
||||||
|
if (m_callbacks->onConstDeclaration)
|
||||||
|
m_callbacks->onConstDeclaration(node);
|
||||||
|
|
||||||
|
if (m_callbacks->onConstIndex && node.constIndex)
|
||||||
|
m_callbacks->onConstIndex(node.name, *node.constIndex);
|
||||||
|
|
||||||
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AstReflect::Visit(DeclareExternalStatement& node)
|
||||||
|
{
|
||||||
|
assert(m_callbacks);
|
||||||
|
if (m_callbacks->onExternalDeclaration)
|
||||||
|
m_callbacks->onExternalDeclaration(node);
|
||||||
|
|
||||||
|
if (m_callbacks->onVariableIndex)
|
||||||
|
{
|
||||||
|
for (const auto& extVar : node.externalVars)
|
||||||
|
{
|
||||||
|
if (extVar.varIndex)
|
||||||
|
m_callbacks->onVariableIndex(extVar.name, *extVar.varIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
}
|
||||||
|
|
||||||
void AstReflect::Visit(DeclareFunctionStatement& node)
|
void AstReflect::Visit(DeclareFunctionStatement& node)
|
||||||
{
|
{
|
||||||
assert(m_callbacks);
|
assert(m_callbacks);
|
||||||
|
|
||||||
|
if (m_callbacks->onFunctionDeclaration)
|
||||||
|
m_callbacks->onFunctionDeclaration(node);
|
||||||
|
|
||||||
if (m_callbacks->onEntryPointDeclaration)
|
if (m_callbacks->onEntryPointDeclaration)
|
||||||
{
|
{
|
||||||
if (!node.entryStage.HasValue())
|
if (!node.entryStage.HasValue())
|
||||||
@@ -24,12 +70,70 @@ namespace Nz::ShaderAst
|
|||||||
|
|
||||||
m_callbacks->onEntryPointDeclaration(node.entryStage.GetResultingValue(), node.name);
|
m_callbacks->onEntryPointDeclaration(node.entryStage.GetResultingValue(), node.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (m_callbacks->onVariableIndex)
|
||||||
|
{
|
||||||
|
for (const auto& parameter : node.parameters)
|
||||||
|
{
|
||||||
|
if (parameter.varIndex)
|
||||||
|
m_callbacks->onVariableIndex(parameter.name, *parameter.varIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AstRecursiveVisitor::Visit(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AstReflect::Visit(DeclareOptionStatement& node)
|
void AstReflect::Visit(DeclareOptionStatement& node)
|
||||||
{
|
{
|
||||||
assert(m_callbacks);
|
assert(m_callbacks);
|
||||||
if (m_callbacks->onOptionDeclaration)
|
if (m_callbacks->onOptionDeclaration)
|
||||||
m_callbacks->onOptionDeclaration(node.optName, node.optType);
|
m_callbacks->onOptionDeclaration(node);
|
||||||
|
|
||||||
|
if (m_callbacks->onOptionIndex && node.optIndex)
|
||||||
|
m_callbacks->onOptionIndex(node.optName, *node.optIndex);
|
||||||
|
|
||||||
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AstReflect::Visit(DeclareStructStatement& node)
|
||||||
|
{
|
||||||
|
assert(m_callbacks);
|
||||||
|
if (m_callbacks->onStructDeclaration)
|
||||||
|
m_callbacks->onStructDeclaration(node);
|
||||||
|
|
||||||
|
if (m_callbacks->onStructIndex && node.structIndex)
|
||||||
|
m_callbacks->onStructIndex(node.description.name, *node.structIndex);
|
||||||
|
|
||||||
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AstReflect::Visit(DeclareVariableStatement& node)
|
||||||
|
{
|
||||||
|
assert(m_callbacks);
|
||||||
|
if (m_callbacks->onVariableDeclaration)
|
||||||
|
m_callbacks->onVariableDeclaration(node);
|
||||||
|
|
||||||
|
if (m_callbacks->onVariableIndex && node.varIndex)
|
||||||
|
m_callbacks->onVariableIndex(node.varName, *node.varIndex);
|
||||||
|
|
||||||
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AstReflect::Visit(ForStatement& node)
|
||||||
|
{
|
||||||
|
assert(m_callbacks);
|
||||||
|
if (m_callbacks->onVariableIndex && node.varIndex)
|
||||||
|
m_callbacks->onVariableIndex(node.varName, *node.varIndex);
|
||||||
|
|
||||||
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AstReflect::Visit(ForEachStatement& node)
|
||||||
|
{
|
||||||
|
assert(m_callbacks);
|
||||||
|
if (m_callbacks->onVariableIndex && node.varIndex)
|
||||||
|
m_callbacks->onVariableIndex(node.varName, *node.varIndex);
|
||||||
|
|
||||||
|
AstRecursiveVisitor::Visit(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
#include <Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp>
|
#include <Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp>
|
||||||
#include <Nazara/Shader/Ast/AstExportVisitor.hpp>
|
#include <Nazara/Shader/Ast/AstExportVisitor.hpp>
|
||||||
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
|
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
|
||||||
|
#include <Nazara/Shader/Ast/AstReflect.hpp>
|
||||||
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
||||||
#include <Nazara/Shader/Ast/DependencyCheckerVisitor.hpp>
|
#include <Nazara/Shader/Ast/DependencyCheckerVisitor.hpp>
|
||||||
#include <Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp>
|
#include <Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp>
|
||||||
@@ -53,6 +54,20 @@ namespace Nz::ShaderAst
|
|||||||
Bitset<UInt64> preregisteredIndices;
|
Bitset<UInt64> preregisteredIndices;
|
||||||
std::unordered_map<std::size_t, T> values;
|
std::unordered_map<std::size_t, T> values;
|
||||||
|
|
||||||
|
void PreregisterIndex(std::size_t index)
|
||||||
|
{
|
||||||
|
if (index < availableIndices.GetSize())
|
||||||
|
{
|
||||||
|
if (!availableIndices.Test(index))
|
||||||
|
throw AstError{ "cannot preregister used index " + std::to_string(index) + " as its already used" };
|
||||||
|
}
|
||||||
|
else if (index >= availableIndices.GetSize())
|
||||||
|
availableIndices.Resize(index + 1, true);
|
||||||
|
|
||||||
|
availableIndices.Set(index, false);
|
||||||
|
preregisteredIndices.UnboundedSet(index);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename U>
|
template<typename U>
|
||||||
std::size_t Register(U&& data, std::optional<std::size_t> index = {})
|
std::size_t Register(U&& data, std::optional<std::size_t> index = {})
|
||||||
{
|
{
|
||||||
@@ -172,6 +187,8 @@ namespace Nz::ShaderAst
|
|||||||
m_context = ¤tContext;
|
m_context = ¤tContext;
|
||||||
CallOnExit resetContext([&] { m_context = nullptr; });
|
CallOnExit resetContext([&] { m_context = nullptr; });
|
||||||
|
|
||||||
|
PreregisterIndices(module);
|
||||||
|
|
||||||
// Register builtin env
|
// Register builtin env
|
||||||
m_context->builtinEnv = std::make_shared<Environment>();
|
m_context->builtinEnv = std::make_shared<Environment>();
|
||||||
m_context->currentEnv = m_context->builtinEnv;
|
m_context->currentEnv = m_context->builtinEnv;
|
||||||
@@ -906,6 +923,7 @@ namespace Nz::ShaderAst
|
|||||||
auto& cloneParam = clone->parameters.emplace_back();
|
auto& cloneParam = clone->parameters.emplace_back();
|
||||||
cloneParam.name = parameter.name;
|
cloneParam.name = parameter.name;
|
||||||
cloneParam.type = ResolveType(parameter.type);
|
cloneParam.type = ResolveType(parameter.type);
|
||||||
|
cloneParam.varIndex = parameter.varIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node.returnType.HasValue())
|
if (node.returnType.HasValue())
|
||||||
@@ -1170,6 +1188,7 @@ namespace Nz::ShaderAst
|
|||||||
|
|
||||||
// Counter variable
|
// Counter variable
|
||||||
auto counterVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(fromExpr));
|
auto counterVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(fromExpr));
|
||||||
|
counterVariable->varIndex = node.varIndex;
|
||||||
Validate(*counterVariable);
|
Validate(*counterVariable);
|
||||||
|
|
||||||
std::size_t counterVarIndex = counterVariable->varIndex.value();
|
std::size_t counterVarIndex = counterVariable->varIndex.value();
|
||||||
@@ -1240,7 +1259,7 @@ namespace Nz::ShaderAst
|
|||||||
|
|
||||||
PushScope();
|
PushScope();
|
||||||
{
|
{
|
||||||
clone->varIndex = RegisterVariable(node.varName, fromExprType);
|
clone->varIndex = RegisterVariable(node.varName, fromExprType, node.varIndex);
|
||||||
clone->statement = CloneStatement(node.statement);
|
clone->statement = CloneStatement(node.statement);
|
||||||
}
|
}
|
||||||
PopScope();
|
PopScope();
|
||||||
@@ -1337,6 +1356,7 @@ namespace Nz::ShaderAst
|
|||||||
Validate(*accessIndex);
|
Validate(*accessIndex);
|
||||||
|
|
||||||
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
|
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
|
||||||
|
elementVariable->varIndex = node.varIndex; //< Preserve var index
|
||||||
Validate(*elementVariable);
|
Validate(*elementVariable);
|
||||||
body->statements.emplace_back(std::move(elementVariable));
|
body->statements.emplace_back(std::move(elementVariable));
|
||||||
|
|
||||||
@@ -1365,7 +1385,7 @@ namespace Nz::ShaderAst
|
|||||||
|
|
||||||
PushScope();
|
PushScope();
|
||||||
{
|
{
|
||||||
clone->varIndex = RegisterVariable(node.varName, innerType);
|
clone->varIndex = RegisterVariable(node.varName, innerType, node.varIndex);
|
||||||
clone->statement = CloneStatement(node.statement);
|
clone->statement = CloneStatement(node.statement);
|
||||||
}
|
}
|
||||||
PopScope();
|
PopScope();
|
||||||
@@ -1825,6 +1845,27 @@ namespace Nz::ShaderAst
|
|||||||
return static_unique_pointer_cast<T>(ShaderAst::PropagateConstants(node, optimizerOptions));
|
return static_unique_pointer_cast<T>(ShaderAst::PropagateConstants(node, optimizerOptions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SanitizeVisitor::PreregisterIndices(const Module& module)
|
||||||
|
{
|
||||||
|
// If AST has been sanitized before and is sanitized again but with differents options that may introduce new variables (for example reduceLoopsToWhile)
|
||||||
|
// we have to make sure we won't override variable indices. This is done by visiting the AST a first time and preregistering all indices.
|
||||||
|
// TODO: Only do this is the AST has been already sanitized, maybe using a flag stored in the module?
|
||||||
|
|
||||||
|
AstReflect::Callbacks registerCallbacks;
|
||||||
|
registerCallbacks.onAliasIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->aliases.PreregisterIndex(index); };
|
||||||
|
registerCallbacks.onConstIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->constantValues.PreregisterIndex(index); };
|
||||||
|
registerCallbacks.onFunctionIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->functions.PreregisterIndex(index); };
|
||||||
|
registerCallbacks.onOptionIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->constantValues.PreregisterIndex(index); };
|
||||||
|
registerCallbacks.onStructIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->structs.PreregisterIndex(index); };
|
||||||
|
registerCallbacks.onVariableIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->variableTypes.PreregisterIndex(index); };
|
||||||
|
|
||||||
|
AstReflect reflectVisitor;
|
||||||
|
for (const auto& importedModule : module.importedModules)
|
||||||
|
reflectVisitor.Reflect(*importedModule.module->rootNode, registerCallbacks);
|
||||||
|
|
||||||
|
reflectVisitor.Reflect(*module.rootNode, registerCallbacks);
|
||||||
|
}
|
||||||
|
|
||||||
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
|
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
|
||||||
{
|
{
|
||||||
auto& funcData = m_context->functions.Retrieve(funcIndex);
|
auto& funcData = m_context->functions.Retrieve(funcIndex);
|
||||||
|
|||||||
Reference in New Issue
Block a user