Shader: Fix index remapping when importing a text shader in a precompiled shader

This commit is contained in:
SirLynix 2022-05-12 23:08:21 +02:00
parent 6469ab5fde
commit 5544d336ab
6 changed files with 57 additions and 28 deletions

View File

@ -42,12 +42,12 @@ namespace Nz::ShaderAst
std::function<void(const DeclareStructStatement& structDecl)> onStructDeclaration;
std::function<void(const DeclareVariableStatement& variableDecl)> onVariableDeclaration;
std::function<void(const std::string& name, std::size_t aliasIndex, const ShaderLang::SourceLocation& sourceLocation)> onAliasIndex;
std::function<void(const std::string& name, std::size_t constIndex, const ShaderLang::SourceLocation& sourceLocation)> onConstIndex;
std::function<void(const std::string& name, std::size_t funcIndex, const ShaderLang::SourceLocation& sourceLocation)> onFunctionIndex;
std::function<void(const std::string& name, std::size_t optIndex, const ShaderLang::SourceLocation& sourceLocation)> onOptionIndex;
std::function<void(const std::string& name, std::size_t aliasIndex, const ShaderLang::SourceLocation& sourceLocation)> onAliasIndex;
std::function<void(const std::string& name, std::size_t constIndex, const ShaderLang::SourceLocation& sourceLocation)> onConstIndex;
std::function<void(const std::string& name, std::size_t funcIndex, const ShaderLang::SourceLocation& sourceLocation)> onFunctionIndex;
std::function<void(const std::string& name, std::size_t optIndex, const ShaderLang::SourceLocation& sourceLocation)> onOptionIndex;
std::function<void(const std::string& name, std::size_t structIndex, const ShaderLang::SourceLocation& sourceLocation)> onStructIndex;
std::function<void(const std::string& name, std::size_t varIndex, const ShaderLang::SourceLocation& sourceLocation)> onVariableIndex;
std::function<void(const std::string& name, std::size_t varIndex, const ShaderLang::SourceLocation& sourceLocation)> onVariableIndex;
};
private:

View File

@ -17,19 +17,19 @@ namespace Nz::ShaderAst
class NAZARA_SHADER_API IndexRemapperVisitor : public AstCloner
{
public:
struct Callbacks;
struct Options;
IndexRemapperVisitor() = default;
IndexRemapperVisitor(const IndexRemapperVisitor&) = delete;
IndexRemapperVisitor(IndexRemapperVisitor&&) = delete;
~IndexRemapperVisitor() = default;
StatementPtr Clone(Statement& statement, const Callbacks& callbacks);
StatementPtr Clone(Statement& statement, const Options& options);
IndexRemapperVisitor& operator=(const IndexRemapperVisitor&) = delete;
IndexRemapperVisitor& operator=(IndexRemapperVisitor&&) = delete;
struct Callbacks
struct Options
{
std::function<std::size_t(std::size_t previousIndex)> aliasIndexGenerator;
std::function<std::size_t(std::size_t previousIndex)> constIndexGenerator;
@ -37,6 +37,7 @@ namespace Nz::ShaderAst
std::function<std::size_t(std::size_t previousIndex) > structIndexGenerator;
//std::function<std::size_t()> typeIndexGenerator;
std::function<std::size_t(std::size_t previousIndex)> varIndexGenerator;
bool forceIndexGeneration = false;
};
private:
@ -60,7 +61,7 @@ namespace Nz::ShaderAst
Context* m_context;
};
inline StatementPtr RemapIndices(Statement& statement, const IndexRemapperVisitor::Callbacks& callbacks);
inline StatementPtr RemapIndices(Statement& statement, const IndexRemapperVisitor::Options& options);
}
#include <Nazara/Shader/Ast/IndexRemapperVisitor.inl>

View File

@ -7,10 +7,10 @@
namespace Nz::ShaderAst
{
StatementPtr RemapIndices(Statement& statement, const IndexRemapperVisitor::Callbacks& callbacks)
StatementPtr RemapIndices(Statement& statement, const IndexRemapperVisitor::Options& options)
{
IndexRemapperVisitor visitor;
return visitor.Clone(statement, callbacks);
return visitor.Clone(statement, options);
}
}

View File

@ -63,6 +63,9 @@ namespace Nz::ShaderAst
if (m_callbacks->onFunctionDeclaration)
m_callbacks->onFunctionDeclaration(node);
if (node.funcIndex && m_callbacks->onFunctionIndex)
m_callbacks->onFunctionIndex(node.name, *node.funcIndex, node.sourceLocation);
if (m_callbacks->onEntryPointDeclaration)
{
if (!node.entryStage.HasValue())

View File

@ -20,7 +20,7 @@ namespace Nz::ShaderAst
struct IndexRemapperVisitor::Context
{
const IndexRemapperVisitor::Callbacks* callbacks;
const IndexRemapperVisitor::Options* options;
std::unordered_map<std::size_t, std::size_t> newAliasIndices;
std::unordered_map<std::size_t, std::size_t> newConstIndices;
std::unordered_map<std::size_t, std::size_t> newFuncIndices;
@ -28,17 +28,17 @@ namespace Nz::ShaderAst
std::unordered_map<std::size_t, std::size_t> newVarIndices;
};
StatementPtr IndexRemapperVisitor::Clone(Statement& statement, const Callbacks& callbacks)
StatementPtr IndexRemapperVisitor::Clone(Statement& statement, const Options& options)
{
assert(callbacks.aliasIndexGenerator);
assert(callbacks.constIndexGenerator);
assert(callbacks.funcIndexGenerator);
assert(callbacks.structIndexGenerator);
//assert(callbacks.typeIndexGenerator);
assert(callbacks.varIndexGenerator);
assert(options.aliasIndexGenerator);
assert(options.constIndexGenerator);
assert(options.funcIndexGenerator);
assert(options.structIndexGenerator);
//assert(options.typeIndexGenerator);
assert(options.varIndexGenerator);
Context context;
context.callbacks = &callbacks;
context.options = &options;
m_context = &context;
return AstCloner::Clone(statement);
@ -50,10 +50,12 @@ namespace Nz::ShaderAst
if (clone->aliasIndex)
{
std::size_t newAliasIndex = m_context->callbacks->aliasIndexGenerator(*clone->aliasIndex);
std::size_t newAliasIndex = m_context->options->aliasIndexGenerator(*clone->aliasIndex);
UniqueInsert(m_context->newAliasIndices, *clone->aliasIndex, newAliasIndex);
clone->aliasIndex = newAliasIndex;
}
else if (m_context->options->forceIndexGeneration)
clone->aliasIndex = m_context->options->aliasIndexGenerator(std::numeric_limits<std::size_t>::max());
return clone;
}
@ -64,10 +66,12 @@ namespace Nz::ShaderAst
if (clone->constIndex)
{
std::size_t newConstIndex = m_context->callbacks->constIndexGenerator(*clone->constIndex);
std::size_t newConstIndex = m_context->options->constIndexGenerator(*clone->constIndex);
UniqueInsert(m_context->newConstIndices, *clone->constIndex, newConstIndex);
clone->constIndex = newConstIndex;
}
else if (m_context->options->forceIndexGeneration)
clone->constIndex = m_context->options->constIndexGenerator(std::numeric_limits<std::size_t>::max());
return clone;
}
@ -80,10 +84,12 @@ namespace Nz::ShaderAst
{
if (extVar.varIndex)
{
std::size_t newVarIndex = m_context->callbacks->varIndexGenerator(*extVar.varIndex);
std::size_t newVarIndex = m_context->options->varIndexGenerator(*extVar.varIndex);
UniqueInsert(m_context->newVarIndices, *extVar.varIndex, newVarIndex);
extVar.varIndex = newVarIndex;
}
else if (m_context->options->forceIndexGeneration)
extVar.varIndex = m_context->options->varIndexGenerator(std::numeric_limits<std::size_t>::max());
}
return clone;
@ -95,17 +101,21 @@ namespace Nz::ShaderAst
if (clone->funcIndex)
{
std::size_t newFuncIndex = m_context->callbacks->funcIndexGenerator(*clone->funcIndex);
std::size_t newFuncIndex = m_context->options->funcIndexGenerator(*clone->funcIndex);
UniqueInsert(m_context->newFuncIndices, *clone->funcIndex, newFuncIndex);
clone->funcIndex = newFuncIndex;
}
else if (m_context->options->forceIndexGeneration)
clone->funcIndex = m_context->options->funcIndexGenerator(std::numeric_limits<std::size_t>::max());
if (!clone->parameters.empty())
{
for (auto& parameter : node.parameters)
for (auto& parameter : clone->parameters)
{
if (parameter.varIndex)
parameter.varIndex = Retrieve(m_context->newVarIndices, *parameter.varIndex);
else if (m_context->options->forceIndexGeneration)
parameter.varIndex = m_context->options->varIndexGenerator(std::numeric_limits<std::size_t>::max());
HandleType(parameter.type);
}
@ -123,10 +133,12 @@ namespace Nz::ShaderAst
if (clone->structIndex)
{
std::size_t newStructIndex = m_context->callbacks->structIndexGenerator(*clone->structIndex);
std::size_t newStructIndex = m_context->options->structIndexGenerator(*clone->structIndex);
UniqueInsert(m_context->newStructIndices, *clone->structIndex, newStructIndex);
clone->structIndex = newStructIndex;
}
else if (m_context->options->forceIndexGeneration)
clone->structIndex = m_context->options->structIndexGenerator(std::numeric_limits<std::size_t>::max());
for (auto& structMember : clone->description.members)
HandleType(structMember.type);
@ -140,10 +152,12 @@ namespace Nz::ShaderAst
if (clone->varIndex)
{
std::size_t newVarIndex = m_context->callbacks->varIndexGenerator(*clone->varIndex);
std::size_t newVarIndex = m_context->options->varIndexGenerator(*clone->varIndex);
UniqueInsert(m_context->newConstIndices, *clone->varIndex, newVarIndex);
clone->varIndex = newVarIndex;
}
else if (m_context->options->forceIndexGeneration)
clone->varIndex = m_context->options->varIndexGenerator(std::numeric_limits<std::size_t>::max());
HandleType(node.varType);
@ -156,6 +170,8 @@ namespace Nz::ShaderAst
if (clone->aliasId)
clone->aliasId = Retrieve(m_context->newAliasIndices, clone->aliasId);
else if (m_context->options->forceIndexGeneration)
clone->aliasId = m_context->options->aliasIndexGenerator(std::numeric_limits<std::size_t>::max());
return clone;
}
@ -166,6 +182,8 @@ namespace Nz::ShaderAst
if (clone->constantId)
clone->constantId = Retrieve(m_context->newConstIndices, clone->constantId);
else if (m_context->options->forceIndexGeneration)
clone->constantId = m_context->options->constIndexGenerator(std::numeric_limits<std::size_t>::max());
return clone;
}
@ -176,6 +194,8 @@ namespace Nz::ShaderAst
if (clone->funcId)
clone->funcId = Retrieve(m_context->newFuncIndices, clone->funcId);
else if (m_context->options->forceIndexGeneration)
clone->funcId = m_context->options->funcIndexGenerator(std::numeric_limits<std::size_t>::max());
return clone;
}
@ -186,6 +206,8 @@ namespace Nz::ShaderAst
if (clone->structTypeId)
clone->structTypeId = Retrieve(m_context->newStructIndices, clone->structTypeId);
else if (m_context->options->forceIndexGeneration)
clone->structTypeId = m_context->options->structIndexGenerator(std::numeric_limits<std::size_t>::max());
return clone;
}
@ -196,6 +218,8 @@ namespace Nz::ShaderAst
if (clone->variableId)
clone->variableId = Retrieve(m_context->newVarIndices, clone->variableId);
else if (m_context->options->forceIndexGeneration)
clone->variableId = m_context->options->varIndexGenerator(std::numeric_limits<std::size_t>::max());
return clone;
}

View File

@ -1780,12 +1780,13 @@ namespace Nz::ShaderAst
ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata);
// Remap already used indices
IndexRemapperVisitor::Callbacks indexCallbacks;
IndexRemapperVisitor::Options indexCallbacks;
indexCallbacks.aliasIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->aliases.RegisterNewIndex(true); };
indexCallbacks.constIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->constantValues.RegisterNewIndex(true); };
indexCallbacks.funcIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->functions.RegisterNewIndex(true); };
indexCallbacks.structIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->structs.RegisterNewIndex(true); };
indexCallbacks.varIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->variableTypes.RegisterNewIndex(true); };
indexCallbacks.forceIndexGeneration = true;
sanitizedModule->rootNode = StaticUniquePointerCast<MultiStatement>(RemapIndices(*targetModule->rootNode, indexCallbacks));