Shader: Minor module fixes

This commit is contained in:
Jérôme Leclercq 2022-03-09 20:05:10 +01:00
parent da40a2db28
commit 39a2992791
11 changed files with 83 additions and 62 deletions

View File

@ -170,7 +170,7 @@ int main()
assert(Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(output)));
Nz::ShaderWriter::States states;
states.optimize = false;
states.optimize = true;
auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states);
if (!fragVertShader)

View File

@ -45,9 +45,6 @@ namespace Nz::ShaderAst
if (!Compare(lhs.identifier, rhs.identifier))
return false;
if (!Compare(lhs.dependencies, rhs.dependencies))
return false;
if (!Compare(*lhs.module, *rhs.module))
return false;

View File

@ -22,11 +22,12 @@ namespace Nz::ShaderAst
class Module
{
public:
struct ImportedModule;
struct Metadata;
inline Module(UInt32 shaderLangVersion, const Uuid& moduleId = Uuid::Generate());
inline Module(std::shared_ptr<const Metadata> metadata);
inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode);
inline Module(std::shared_ptr<const Metadata> metadata, std::vector<ImportedModule> importedModules = {});
inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode, std::vector<ImportedModule> importedModules = {});
Module(const Module&) = default;
Module(Module&&) noexcept = default;
~Module() = default;
@ -37,7 +38,6 @@ namespace Nz::ShaderAst
struct ImportedModule
{
std::string identifier;
std::vector<Uuid> dependencies;
ModulePtr module;
};

View File

@ -18,13 +18,14 @@ namespace Nz::ShaderAst
rootNode = ShaderBuilder::MultiStatement();
}
inline Module::Module(std::shared_ptr<const Metadata> metadata) :
Module(std::move(metadata), ShaderBuilder::MultiStatement())
inline Module::Module(std::shared_ptr<const Metadata> metadata, std::vector<ImportedModule> importedModules) :
Module(std::move(metadata), ShaderBuilder::MultiStatement(), std::move(importedModules))
{
}
inline Module::Module(std::shared_ptr<const Metadata> Metadata, MultiStatementPtr RootNode) :
inline Module::Module(std::shared_ptr<const Metadata> Metadata, MultiStatementPtr RootNode, std::vector<ImportedModule> ImportedModules) :
metadata(std::move(Metadata)),
importedModules(std::move(ImportedModules)),
rootNode(std::move(RootNode))
{
}

View File

@ -739,14 +739,14 @@ namespace Nz::ShaderAst
{
auto rootnode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode));
return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode));
return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode), shaderModule.importedModules);
}
ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule, const Options& options)
{
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, options));
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode));
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
}
ExpressionPtr AstConstantPropagationVisitor::Clone(BinaryExpression& node)

View File

@ -27,7 +27,7 @@ namespace Nz::ShaderAst
{
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, usageSet));
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode));
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
}
StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet)

View File

@ -163,8 +163,7 @@ namespace Nz::ShaderAst
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
{
ModulePtr clone = std::make_shared<Module>(module.metadata);
clone->importedModules = module.importedModules;
ModulePtr clone = std::make_shared<Module>(module.metadata, module.importedModules);
Context currentContext;
currentContext.options = options;

View File

@ -170,18 +170,16 @@ namespace Nz
});
ShaderAst::ModulePtr sanitizedModule;
ShaderAst::Statement* targetAst;
const ShaderAst::Module* targetModule;
if (!states.sanitized)
{
sanitizedModule = Sanitize(module, states.optionValues);
targetAst = sanitizedModule->rootNode.get();
targetModule = sanitizedModule.get();
}
else
targetAst = module.rootNode.get();
targetModule = &module;
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
ShaderAst::StatementPtr optimizedAst;
ShaderAst::ModulePtr optimizedModule;
if (states.optimize)
{
ShaderAst::StatementPtr tempAst;
@ -190,27 +188,31 @@ namespace Nz
if (shaderStage)
dependencyConfig.usedShaderStages = *shaderStage;
tempAst = ShaderAst::PropagateConstants(*targetAst);
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig);
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
targetAst = optimizedAst.get();
targetModule = optimizedModule.get();
}
// Previsitor
state.previsitor.selectedStage = shaderStage;
targetAst->Visit(state.previsitor);
for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(state.previsitor);
targetModule->rootNode->Visit(state.previsitor);
// Code generation
AppendHeader();
for (const auto& importedModule : targetModule.importedModules)
for (const auto& importedModule : targetModule->importedModules)
{
m_currentState->moduleSuffix = importedModule.identifier;
importedModule.module->rootNode->Visit(state.previsitor);
importedModule.module->rootNode->Visit(*this);
}
m_currentState->moduleSuffix = {};
targetAst->Visit(*this);
targetModule->rootNode->Visit(*this);
return state.stream.str();
}

View File

@ -481,7 +481,7 @@ namespace Nz
std::vector<UInt32> resultIds;
UInt32 nextVarIndex = 1;
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
PreVisitor* preVisitor;
PreVisitor* previsitor;
// Output
SpirvSection header;
@ -499,7 +499,7 @@ namespace Nz
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst::Module& module, const States& states)
{
ShaderAst::ModulePtr sanitizedModule;
ShaderAst::Statement* targetAst;
const ShaderAst::Module* targetModule;
if (!states.sanitized)
{
ShaderAst::SanitizeVisitor::Options options;
@ -513,14 +513,12 @@ namespace Nz
options.useIdentifierAccessesForStructs = false;
sanitizedModule = ShaderAst::Sanitize(module, options);
targetAst = sanitizedModule->rootNode.get();
targetModule = sanitizedModule.get();
}
else
targetAst = module.rootNode.get();
targetModule = &module;
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
ShaderAst::StatementPtr optimizedAst;
ShaderAst::ModulePtr optimizedModule;
if (states.optimize)
{
ShaderAst::StatementPtr tempAst;
@ -528,12 +526,14 @@ namespace Nz
ShaderAst::DependencyCheckerVisitor::Config dependencyConfig;
dependencyConfig.usedShaderStages = ShaderStageType_All;
tempAst = ShaderAst::PropagateConstants(*targetAst);
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig);
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
targetAst = optimizedAst.get();
targetModule = optimizedModule.get();
}
// Previsitor
m_context.states = &states;
State state;
@ -544,15 +544,15 @@ namespace Nz
});
// Register all extended instruction sets
PreVisitor preVisitor(state.constantTypeCache, state.funcs);
for (const auto& importedModule : targetModule.importedModules)
importedModule.module->rootNode->Visit(preVisitor);
PreVisitor previsitor(state.constantTypeCache, state.funcs);
for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(previsitor);
targetAst->Visit(preVisitor);
targetModule->rootNode->Visit(previsitor);
m_currentState->preVisitor = &preVisitor;
m_currentState->previsitor = &previsitor;
for (const std::string& extInst : preVisitor.extInsts)
for (const std::string& extInst : previsitor.extInsts)
state.extensionInstructionSet[extInst] = AllocateResultId();
// Assign function ID (required for forward declaration)
@ -560,23 +560,23 @@ namespace Nz
func.funcId = AllocateResultId();
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
for (const auto& importedModule : targetModule.importedModules)
for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(visitor);
targetAst->Visit(visitor);
targetModule->rootNode->Visit(visitor);
AppendHeader();
for (auto&& [varIndex, extVar] : preVisitor.extVars)
for (auto&& [varIndex, extVar] : previsitor.extVars)
{
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
}
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
for (auto&& [varId, builtin] : previsitor.builtinDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
for (auto&& [varId, location] : preVisitor.locationDecorations)
for (auto&& [varId, location] : previsitor.locationDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
@ -699,8 +699,8 @@ namespace Nz
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
{
auto it = m_currentState->preVisitor->extVars.find(extVarIndex);
assert(it != m_currentState->preVisitor->extVars.end());
auto it = m_currentState->previsitor->extVars.find(extVarIndex);
assert(it != m_currentState->previsitor->extVars.end());
return it->second.pointerId;
}

View File

@ -6,6 +6,7 @@
#include <Nazara/Shader/SpirvPrinter.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Ast/AstReflect.hpp>
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
#include <catch2/catch.hpp>
#include <glslang/Public/ShaderLang.h>
#include <spirv-tools/libspirv.hpp>
@ -121,12 +122,19 @@ namespace
};
}
void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
void ExpectGLSL(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput)
{
expectedOutput = Nz::Trim(expectedOutput);
SECTION("Generating GLSL")
{
Nz::ShaderAst::ModulePtr sanitizedModule;
WHEN("Sanitizing a second time")
{
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
}
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
// Retrieve entry-point to get shader type
std::optional<Nz::ShaderStageType> entryShaderStage;
@ -140,7 +148,7 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
};
Nz::ShaderAst::AstReflect reflectVisitor;
reflectVisitor.Reflect(*shader.rootNode, callbacks);
reflectVisitor.Reflect(*targetModule.rootNode, callbacks);
{
INFO("no entry point found");
@ -148,7 +156,7 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
}
Nz::GlslWriter writer;
std::string output = writer.Generate(entryShaderStage, shader);
std::string output = writer.Generate(entryShaderStage, targetModule);
WHEN("Validating expected code")
{
@ -188,14 +196,21 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
}
}
void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
void ExpectNZSL(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput)
{
expectedOutput = Nz::Trim(expectedOutput);
SECTION("Generating NZSL")
{
Nz::ShaderAst::ModulePtr sanitizedModule;
WHEN("Sanitizing a second time")
{
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
}
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
Nz::LangWriter writer;
std::string output = writer.Generate(shader);
std::string output = writer.Generate(targetModule);
WHEN("Validating expected code")
{
@ -211,12 +226,19 @@ void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
}
}
void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter)
void ExpectSPIRV(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput, bool outputParameter)
{
expectedOutput = Nz::Trim(expectedOutput);
SECTION("Generating SPIRV")
{
Nz::ShaderAst::ModulePtr sanitizedModule;
WHEN("Sanitizing a second time")
{
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
}
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
Nz::SpirvWriter writer;
Nz::SpirvPrinter printer;
@ -224,7 +246,7 @@ void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput,
settings.printHeader = false;
settings.printParameters = outputParameter;
auto spirv = writer.Generate(shader);
auto spirv = writer.Generate(targetModule);
std::string output = printer.Print(spirv.data(), spirv.size(), settings);
WHEN("Validating expected code")

View File

@ -6,8 +6,8 @@
#include <Nazara/Shader/Ast/Module.hpp>
#include <string>
void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
#endif