Shader: Minor module fixes

This commit is contained in:
Jérôme Leclercq
2022-03-09 20:05:10 +01:00
parent da40a2db28
commit 39a2992791
11 changed files with 83 additions and 62 deletions

View File

@@ -481,7 +481,7 @@ namespace Nz
std::vector<UInt32> resultIds;
UInt32 nextVarIndex = 1;
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
PreVisitor* preVisitor;
PreVisitor* previsitor;
// Output
SpirvSection header;
@@ -499,7 +499,7 @@ namespace Nz
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst::Module& module, const States& states)
{
ShaderAst::ModulePtr sanitizedModule;
ShaderAst::Statement* targetAst;
const ShaderAst::Module* targetModule;
if (!states.sanitized)
{
ShaderAst::SanitizeVisitor::Options options;
@@ -513,14 +513,12 @@ namespace Nz
options.useIdentifierAccessesForStructs = false;
sanitizedModule = ShaderAst::Sanitize(module, options);
targetAst = sanitizedModule->rootNode.get();
targetModule = sanitizedModule.get();
}
else
targetAst = module.rootNode.get();
targetModule = &module;
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
ShaderAst::StatementPtr optimizedAst;
ShaderAst::ModulePtr optimizedModule;
if (states.optimize)
{
ShaderAst::StatementPtr tempAst;
@@ -528,12 +526,14 @@ namespace Nz
ShaderAst::DependencyCheckerVisitor::Config dependencyConfig;
dependencyConfig.usedShaderStages = ShaderStageType_All;
tempAst = ShaderAst::PropagateConstants(*targetAst);
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig);
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
targetAst = optimizedAst.get();
targetModule = optimizedModule.get();
}
// Previsitor
m_context.states = &states;
State state;
@@ -544,15 +544,15 @@ namespace Nz
});
// Register all extended instruction sets
PreVisitor preVisitor(state.constantTypeCache, state.funcs);
for (const auto& importedModule : targetModule.importedModules)
importedModule.module->rootNode->Visit(preVisitor);
PreVisitor previsitor(state.constantTypeCache, state.funcs);
for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(previsitor);
targetAst->Visit(preVisitor);
targetModule->rootNode->Visit(previsitor);
m_currentState->preVisitor = &preVisitor;
m_currentState->previsitor = &previsitor;
for (const std::string& extInst : preVisitor.extInsts)
for (const std::string& extInst : previsitor.extInsts)
state.extensionInstructionSet[extInst] = AllocateResultId();
// Assign function ID (required for forward declaration)
@@ -560,23 +560,23 @@ namespace Nz
func.funcId = AllocateResultId();
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
for (const auto& importedModule : targetModule.importedModules)
for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(visitor);
targetAst->Visit(visitor);
targetModule->rootNode->Visit(visitor);
AppendHeader();
for (auto&& [varIndex, extVar] : preVisitor.extVars)
for (auto&& [varIndex, extVar] : previsitor.extVars)
{
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
}
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
for (auto&& [varId, builtin] : previsitor.builtinDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
for (auto&& [varId, location] : preVisitor.locationDecorations)
for (auto&& [varId, location] : previsitor.locationDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
@@ -699,8 +699,8 @@ namespace Nz
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
{
auto it = m_currentState->preVisitor->extVars.find(extVarIndex);
assert(it != m_currentState->preVisitor->extVars.end());
auto it = m_currentState->previsitor->extVars.find(extVarIndex);
assert(it != m_currentState->previsitor->extVars.end());
return it->second.pointerId;
}