Shader/SanitizeVisitor: Fix an issue when double-sanitizing
... with differents parameters (like reducing loops to while, which introduces new variables which would take over existing var indices)
This commit is contained in:
parent
bf7f06ac4c
commit
98bd04e35a
|
|
@ -32,12 +32,33 @@ namespace Nz::ShaderAst
|
|||
struct Callbacks
|
||||
{
|
||||
std::function<void(ShaderStageType stageType, const std::string& functionName)> onEntryPointDeclaration;
|
||||
std::function<void(const std::string& optionName, const ExpressionValue<ExpressionType>& optionType)> onOptionDeclaration;
|
||||
|
||||
std::function<void(const DeclareAliasStatement& aliasDecl)> onAliasDeclaration;
|
||||
std::function<void(const DeclareConstStatement& constDecl)> onConstDeclaration;
|
||||
std::function<void(const DeclareExternalStatement& extDecl)> onExternalDeclaration;
|
||||
std::function<void(const DeclareFunctionStatement& funcDecl)> onFunctionDeclaration;
|
||||
std::function<void(const DeclareOptionStatement& optionDecl)> onOptionDeclaration;
|
||||
std::function<void(const DeclareStructStatement& structDecl)> onStructDeclaration;
|
||||
std::function<void(const DeclareVariableStatement& variableDecl)> onVariableDeclaration;
|
||||
|
||||
std::function<void(const std::string& name, std::size_t aliasIndex)> onAliasIndex;
|
||||
std::function<void(const std::string& name, std::size_t constIndex)> onConstIndex;
|
||||
std::function<void(const std::string& name, std::size_t funcIndex)> onFunctionIndex;
|
||||
std::function<void(const std::string& name, std::size_t optIndex)> onOptionIndex;
|
||||
std::function<void(const std::string& name, std::size_t structIndex)> onStructIndex;
|
||||
std::function<void(const std::string& name, std::size_t varIndex)> onVariableIndex;
|
||||
};
|
||||
|
||||
private:
|
||||
void Visit(DeclareAliasStatement& node) override;
|
||||
void Visit(DeclareConstStatement& node) override;
|
||||
void Visit(DeclareExternalStatement& node) override;
|
||||
void Visit(DeclareFunctionStatement& node) override;
|
||||
void Visit(DeclareOptionStatement& node) override;
|
||||
void Visit(DeclareStructStatement& node) override;
|
||||
void Visit(DeclareVariableStatement& node) override;
|
||||
void Visit(ForStatement& node) override;
|
||||
void Visit(ForEachStatement& node) override;
|
||||
|
||||
const Callbacks* m_callbacks;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -124,6 +124,7 @@ namespace Nz::ShaderAst
|
|||
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
|
||||
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
|
||||
|
||||
void PreregisterIndices(const Module& module);
|
||||
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
|
||||
|
||||
void RegisterBuiltin();
|
||||
|
|
|
|||
|
|
@ -32,12 +32,12 @@ namespace Nz
|
|||
supportedStageType |= stageType;
|
||||
};
|
||||
|
||||
callbacks.onOptionDeclaration = [&](const std::string& optionName, const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& optionType)
|
||||
callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option)
|
||||
{
|
||||
//TODO: Check optionType
|
||||
|
||||
m_optionIndexByName[optionName] = Option{
|
||||
CRC32(optionName)
|
||||
m_optionIndexByName[option.optName] = Option{
|
||||
CRC32(option.optName)
|
||||
};
|
||||
|
||||
optionCount++;
|
||||
|
|
|
|||
|
|
@ -14,9 +14,55 @@ namespace Nz::ShaderAst
|
|||
statement.Visit(*this);
|
||||
}
|
||||
|
||||
void AstReflect::Visit(DeclareAliasStatement& node)
|
||||
{
|
||||
assert(m_callbacks);
|
||||
if (m_callbacks->onAliasDeclaration)
|
||||
m_callbacks->onAliasDeclaration(node);
|
||||
|
||||
if (m_callbacks->onAliasIndex && node.aliasIndex)
|
||||
m_callbacks->onAliasIndex(node.name, *node.aliasIndex);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstReflect::Visit(DeclareConstStatement& node)
|
||||
{
|
||||
assert(m_callbacks);
|
||||
if (m_callbacks->onConstDeclaration)
|
||||
m_callbacks->onConstDeclaration(node);
|
||||
|
||||
if (m_callbacks->onConstIndex && node.constIndex)
|
||||
m_callbacks->onConstIndex(node.name, *node.constIndex);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstReflect::Visit(DeclareExternalStatement& node)
|
||||
{
|
||||
assert(m_callbacks);
|
||||
if (m_callbacks->onExternalDeclaration)
|
||||
m_callbacks->onExternalDeclaration(node);
|
||||
|
||||
if (m_callbacks->onVariableIndex)
|
||||
{
|
||||
for (const auto& extVar : node.externalVars)
|
||||
{
|
||||
if (extVar.varIndex)
|
||||
m_callbacks->onVariableIndex(extVar.name, *extVar.varIndex);
|
||||
}
|
||||
}
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstReflect::Visit(DeclareFunctionStatement& node)
|
||||
{
|
||||
assert(m_callbacks);
|
||||
|
||||
if (m_callbacks->onFunctionDeclaration)
|
||||
m_callbacks->onFunctionDeclaration(node);
|
||||
|
||||
if (m_callbacks->onEntryPointDeclaration)
|
||||
{
|
||||
if (!node.entryStage.HasValue())
|
||||
|
|
@ -24,12 +70,70 @@ namespace Nz::ShaderAst
|
|||
|
||||
m_callbacks->onEntryPointDeclaration(node.entryStage.GetResultingValue(), node.name);
|
||||
}
|
||||
|
||||
if (m_callbacks->onVariableIndex)
|
||||
{
|
||||
for (const auto& parameter : node.parameters)
|
||||
{
|
||||
if (parameter.varIndex)
|
||||
m_callbacks->onVariableIndex(parameter.name, *parameter.varIndex);
|
||||
}
|
||||
}
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstReflect::Visit(DeclareOptionStatement& node)
|
||||
{
|
||||
assert(m_callbacks);
|
||||
if (m_callbacks->onOptionDeclaration)
|
||||
m_callbacks->onOptionDeclaration(node.optName, node.optType);
|
||||
m_callbacks->onOptionDeclaration(node);
|
||||
|
||||
if (m_callbacks->onOptionIndex && node.optIndex)
|
||||
m_callbacks->onOptionIndex(node.optName, *node.optIndex);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstReflect::Visit(DeclareStructStatement& node)
|
||||
{
|
||||
assert(m_callbacks);
|
||||
if (m_callbacks->onStructDeclaration)
|
||||
m_callbacks->onStructDeclaration(node);
|
||||
|
||||
if (m_callbacks->onStructIndex && node.structIndex)
|
||||
m_callbacks->onStructIndex(node.description.name, *node.structIndex);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstReflect::Visit(DeclareVariableStatement& node)
|
||||
{
|
||||
assert(m_callbacks);
|
||||
if (m_callbacks->onVariableDeclaration)
|
||||
m_callbacks->onVariableDeclaration(node);
|
||||
|
||||
if (m_callbacks->onVariableIndex && node.varIndex)
|
||||
m_callbacks->onVariableIndex(node.varName, *node.varIndex);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstReflect::Visit(ForStatement& node)
|
||||
{
|
||||
assert(m_callbacks);
|
||||
if (m_callbacks->onVariableIndex && node.varIndex)
|
||||
m_callbacks->onVariableIndex(node.varName, *node.varIndex);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstReflect::Visit(ForEachStatement& node)
|
||||
{
|
||||
assert(m_callbacks);
|
||||
if (m_callbacks->onVariableIndex && node.varIndex)
|
||||
m_callbacks->onVariableIndex(node.varName, *node.varIndex);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp>
|
||||
#include <Nazara/Shader/Ast/AstExportVisitor.hpp>
|
||||
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
|
||||
#include <Nazara/Shader/Ast/AstReflect.hpp>
|
||||
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
||||
#include <Nazara/Shader/Ast/DependencyCheckerVisitor.hpp>
|
||||
#include <Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp>
|
||||
|
|
@ -53,6 +54,20 @@ namespace Nz::ShaderAst
|
|||
Bitset<UInt64> preregisteredIndices;
|
||||
std::unordered_map<std::size_t, T> values;
|
||||
|
||||
void PreregisterIndex(std::size_t index)
|
||||
{
|
||||
if (index < availableIndices.GetSize())
|
||||
{
|
||||
if (!availableIndices.Test(index))
|
||||
throw AstError{ "cannot preregister used index " + std::to_string(index) + " as its already used" };
|
||||
}
|
||||
else if (index >= availableIndices.GetSize())
|
||||
availableIndices.Resize(index + 1, true);
|
||||
|
||||
availableIndices.Set(index, false);
|
||||
preregisteredIndices.UnboundedSet(index);
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
std::size_t Register(U&& data, std::optional<std::size_t> index = {})
|
||||
{
|
||||
|
|
@ -172,6 +187,8 @@ namespace Nz::ShaderAst
|
|||
m_context = ¤tContext;
|
||||
CallOnExit resetContext([&] { m_context = nullptr; });
|
||||
|
||||
PreregisterIndices(module);
|
||||
|
||||
// Register builtin env
|
||||
m_context->builtinEnv = std::make_shared<Environment>();
|
||||
m_context->currentEnv = m_context->builtinEnv;
|
||||
|
|
@ -906,6 +923,7 @@ namespace Nz::ShaderAst
|
|||
auto& cloneParam = clone->parameters.emplace_back();
|
||||
cloneParam.name = parameter.name;
|
||||
cloneParam.type = ResolveType(parameter.type);
|
||||
cloneParam.varIndex = parameter.varIndex;
|
||||
}
|
||||
|
||||
if (node.returnType.HasValue())
|
||||
|
|
@ -1170,6 +1188,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
// Counter variable
|
||||
auto counterVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(fromExpr));
|
||||
counterVariable->varIndex = node.varIndex;
|
||||
Validate(*counterVariable);
|
||||
|
||||
std::size_t counterVarIndex = counterVariable->varIndex.value();
|
||||
|
|
@ -1240,7 +1259,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
PushScope();
|
||||
{
|
||||
clone->varIndex = RegisterVariable(node.varName, fromExprType);
|
||||
clone->varIndex = RegisterVariable(node.varName, fromExprType, node.varIndex);
|
||||
clone->statement = CloneStatement(node.statement);
|
||||
}
|
||||
PopScope();
|
||||
|
|
@ -1337,6 +1356,7 @@ namespace Nz::ShaderAst
|
|||
Validate(*accessIndex);
|
||||
|
||||
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
|
||||
elementVariable->varIndex = node.varIndex; //< Preserve var index
|
||||
Validate(*elementVariable);
|
||||
body->statements.emplace_back(std::move(elementVariable));
|
||||
|
||||
|
|
@ -1365,7 +1385,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
PushScope();
|
||||
{
|
||||
clone->varIndex = RegisterVariable(node.varName, innerType);
|
||||
clone->varIndex = RegisterVariable(node.varName, innerType, node.varIndex);
|
||||
clone->statement = CloneStatement(node.statement);
|
||||
}
|
||||
PopScope();
|
||||
|
|
@ -1825,6 +1845,27 @@ namespace Nz::ShaderAst
|
|||
return static_unique_pointer_cast<T>(ShaderAst::PropagateConstants(node, optimizerOptions));
|
||||
}
|
||||
|
||||
void SanitizeVisitor::PreregisterIndices(const Module& module)
|
||||
{
|
||||
// If AST has been sanitized before and is sanitized again but with differents options that may introduce new variables (for example reduceLoopsToWhile)
|
||||
// we have to make sure we won't override variable indices. This is done by visiting the AST a first time and preregistering all indices.
|
||||
// TODO: Only do this is the AST has been already sanitized, maybe using a flag stored in the module?
|
||||
|
||||
AstReflect::Callbacks registerCallbacks;
|
||||
registerCallbacks.onAliasIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->aliases.PreregisterIndex(index); };
|
||||
registerCallbacks.onConstIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->constantValues.PreregisterIndex(index); };
|
||||
registerCallbacks.onFunctionIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->functions.PreregisterIndex(index); };
|
||||
registerCallbacks.onOptionIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->constantValues.PreregisterIndex(index); };
|
||||
registerCallbacks.onStructIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->structs.PreregisterIndex(index); };
|
||||
registerCallbacks.onVariableIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->variableTypes.PreregisterIndex(index); };
|
||||
|
||||
AstReflect reflectVisitor;
|
||||
for (const auto& importedModule : module.importedModules)
|
||||
reflectVisitor.Reflect(*importedModule.module->rootNode, registerCallbacks);
|
||||
|
||||
reflectVisitor.Reflect(*module.rootNode, registerCallbacks);
|
||||
}
|
||||
|
||||
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
|
||||
{
|
||||
auto& funcData = m_context->functions.Retrieve(funcIndex);
|
||||
|
|
|
|||
Loading…
Reference in New Issue