Shader: Minor module fixes
This commit is contained in:
@@ -739,14 +739,14 @@ namespace Nz::ShaderAst
|
||||
{
|
||||
auto rootnode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode));
|
||||
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode));
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode), shaderModule.importedModules);
|
||||
}
|
||||
|
||||
ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule, const Options& options)
|
||||
{
|
||||
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, options));
|
||||
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode));
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
|
||||
}
|
||||
|
||||
ExpressionPtr AstConstantPropagationVisitor::Clone(BinaryExpression& node)
|
||||
|
||||
@@ -27,7 +27,7 @@ namespace Nz::ShaderAst
|
||||
{
|
||||
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, usageSet));
|
||||
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode));
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
|
||||
}
|
||||
|
||||
StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet)
|
||||
|
||||
@@ -163,8 +163,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
|
||||
{
|
||||
ModulePtr clone = std::make_shared<Module>(module.metadata);
|
||||
clone->importedModules = module.importedModules;
|
||||
ModulePtr clone = std::make_shared<Module>(module.metadata, module.importedModules);
|
||||
|
||||
Context currentContext;
|
||||
currentContext.options = options;
|
||||
|
||||
@@ -170,18 +170,16 @@ namespace Nz
|
||||
});
|
||||
|
||||
ShaderAst::ModulePtr sanitizedModule;
|
||||
ShaderAst::Statement* targetAst;
|
||||
const ShaderAst::Module* targetModule;
|
||||
if (!states.sanitized)
|
||||
{
|
||||
sanitizedModule = Sanitize(module, states.optionValues);
|
||||
targetAst = sanitizedModule->rootNode.get();
|
||||
targetModule = sanitizedModule.get();
|
||||
}
|
||||
else
|
||||
targetAst = module.rootNode.get();
|
||||
targetModule = &module;
|
||||
|
||||
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
|
||||
|
||||
ShaderAst::StatementPtr optimizedAst;
|
||||
ShaderAst::ModulePtr optimizedModule;
|
||||
if (states.optimize)
|
||||
{
|
||||
ShaderAst::StatementPtr tempAst;
|
||||
@@ -190,27 +188,31 @@ namespace Nz
|
||||
if (shaderStage)
|
||||
dependencyConfig.usedShaderStages = *shaderStage;
|
||||
|
||||
tempAst = ShaderAst::PropagateConstants(*targetAst);
|
||||
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig);
|
||||
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
|
||||
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
|
||||
|
||||
targetAst = optimizedAst.get();
|
||||
targetModule = optimizedModule.get();
|
||||
}
|
||||
|
||||
// Previsitor
|
||||
state.previsitor.selectedStage = shaderStage;
|
||||
targetAst->Visit(state.previsitor);
|
||||
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
importedModule.module->rootNode->Visit(state.previsitor);
|
||||
|
||||
targetModule->rootNode->Visit(state.previsitor);
|
||||
|
||||
// Code generation
|
||||
AppendHeader();
|
||||
|
||||
for (const auto& importedModule : targetModule.importedModules)
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
{
|
||||
m_currentState->moduleSuffix = importedModule.identifier;
|
||||
|
||||
importedModule.module->rootNode->Visit(state.previsitor);
|
||||
importedModule.module->rootNode->Visit(*this);
|
||||
}
|
||||
|
||||
m_currentState->moduleSuffix = {};
|
||||
targetAst->Visit(*this);
|
||||
targetModule->rootNode->Visit(*this);
|
||||
|
||||
return state.stream.str();
|
||||
}
|
||||
|
||||
@@ -481,7 +481,7 @@ namespace Nz
|
||||
std::vector<UInt32> resultIds;
|
||||
UInt32 nextVarIndex = 1;
|
||||
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
|
||||
PreVisitor* preVisitor;
|
||||
PreVisitor* previsitor;
|
||||
|
||||
// Output
|
||||
SpirvSection header;
|
||||
@@ -499,7 +499,7 @@ namespace Nz
|
||||
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst::Module& module, const States& states)
|
||||
{
|
||||
ShaderAst::ModulePtr sanitizedModule;
|
||||
ShaderAst::Statement* targetAst;
|
||||
const ShaderAst::Module* targetModule;
|
||||
if (!states.sanitized)
|
||||
{
|
||||
ShaderAst::SanitizeVisitor::Options options;
|
||||
@@ -513,14 +513,12 @@ namespace Nz
|
||||
options.useIdentifierAccessesForStructs = false;
|
||||
|
||||
sanitizedModule = ShaderAst::Sanitize(module, options);
|
||||
targetAst = sanitizedModule->rootNode.get();
|
||||
targetModule = sanitizedModule.get();
|
||||
}
|
||||
else
|
||||
targetAst = module.rootNode.get();
|
||||
targetModule = &module;
|
||||
|
||||
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
|
||||
|
||||
ShaderAst::StatementPtr optimizedAst;
|
||||
ShaderAst::ModulePtr optimizedModule;
|
||||
if (states.optimize)
|
||||
{
|
||||
ShaderAst::StatementPtr tempAst;
|
||||
@@ -528,12 +526,14 @@ namespace Nz
|
||||
ShaderAst::DependencyCheckerVisitor::Config dependencyConfig;
|
||||
dependencyConfig.usedShaderStages = ShaderStageType_All;
|
||||
|
||||
tempAst = ShaderAst::PropagateConstants(*targetAst);
|
||||
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig);
|
||||
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
|
||||
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
|
||||
|
||||
targetAst = optimizedAst.get();
|
||||
targetModule = optimizedModule.get();
|
||||
}
|
||||
|
||||
// Previsitor
|
||||
|
||||
m_context.states = &states;
|
||||
|
||||
State state;
|
||||
@@ -544,15 +544,15 @@ namespace Nz
|
||||
});
|
||||
|
||||
// Register all extended instruction sets
|
||||
PreVisitor preVisitor(state.constantTypeCache, state.funcs);
|
||||
for (const auto& importedModule : targetModule.importedModules)
|
||||
importedModule.module->rootNode->Visit(preVisitor);
|
||||
PreVisitor previsitor(state.constantTypeCache, state.funcs);
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
importedModule.module->rootNode->Visit(previsitor);
|
||||
|
||||
targetAst->Visit(preVisitor);
|
||||
targetModule->rootNode->Visit(previsitor);
|
||||
|
||||
m_currentState->preVisitor = &preVisitor;
|
||||
m_currentState->previsitor = &previsitor;
|
||||
|
||||
for (const std::string& extInst : preVisitor.extInsts)
|
||||
for (const std::string& extInst : previsitor.extInsts)
|
||||
state.extensionInstructionSet[extInst] = AllocateResultId();
|
||||
|
||||
// Assign function ID (required for forward declaration)
|
||||
@@ -560,23 +560,23 @@ namespace Nz
|
||||
func.funcId = AllocateResultId();
|
||||
|
||||
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
|
||||
for (const auto& importedModule : targetModule.importedModules)
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
importedModule.module->rootNode->Visit(visitor);
|
||||
|
||||
targetAst->Visit(visitor);
|
||||
targetModule->rootNode->Visit(visitor);
|
||||
|
||||
AppendHeader();
|
||||
|
||||
for (auto&& [varIndex, extVar] : preVisitor.extVars)
|
||||
for (auto&& [varIndex, extVar] : previsitor.extVars)
|
||||
{
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
|
||||
}
|
||||
|
||||
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
|
||||
for (auto&& [varId, builtin] : previsitor.builtinDecorations)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
|
||||
|
||||
for (auto&& [varId, location] : preVisitor.locationDecorations)
|
||||
for (auto&& [varId, location] : previsitor.locationDecorations)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
|
||||
|
||||
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
|
||||
@@ -699,8 +699,8 @@ namespace Nz
|
||||
|
||||
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
|
||||
{
|
||||
auto it = m_currentState->preVisitor->extVars.find(extVarIndex);
|
||||
assert(it != m_currentState->preVisitor->extVars.end());
|
||||
auto it = m_currentState->previsitor->extVars.find(extVarIndex);
|
||||
assert(it != m_currentState->previsitor->extVars.end());
|
||||
|
||||
return it->second.pointerId;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user