Shader: Minor module fixes

This commit is contained in:
Jérôme Leclercq
2022-03-09 20:05:10 +01:00
parent da40a2db28
commit 39a2992791
11 changed files with 83 additions and 62 deletions

View File

@@ -170,7 +170,7 @@ int main()
assert(Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(output))); assert(Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(output)));
Nz::ShaderWriter::States states; Nz::ShaderWriter::States states;
states.optimize = false; states.optimize = true;
auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states); auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states);
if (!fragVertShader) if (!fragVertShader)

View File

@@ -45,9 +45,6 @@ namespace Nz::ShaderAst
if (!Compare(lhs.identifier, rhs.identifier)) if (!Compare(lhs.identifier, rhs.identifier))
return false; return false;
if (!Compare(lhs.dependencies, rhs.dependencies))
return false;
if (!Compare(*lhs.module, *rhs.module)) if (!Compare(*lhs.module, *rhs.module))
return false; return false;

View File

@@ -22,11 +22,12 @@ namespace Nz::ShaderAst
class Module class Module
{ {
public: public:
struct ImportedModule;
struct Metadata; struct Metadata;
inline Module(UInt32 shaderLangVersion, const Uuid& moduleId = Uuid::Generate()); inline Module(UInt32 shaderLangVersion, const Uuid& moduleId = Uuid::Generate());
inline Module(std::shared_ptr<const Metadata> metadata); inline Module(std::shared_ptr<const Metadata> metadata, std::vector<ImportedModule> importedModules = {});
inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode); inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode, std::vector<ImportedModule> importedModules = {});
Module(const Module&) = default; Module(const Module&) = default;
Module(Module&&) noexcept = default; Module(Module&&) noexcept = default;
~Module() = default; ~Module() = default;
@@ -37,7 +38,6 @@ namespace Nz::ShaderAst
struct ImportedModule struct ImportedModule
{ {
std::string identifier; std::string identifier;
std::vector<Uuid> dependencies;
ModulePtr module; ModulePtr module;
}; };

View File

@@ -18,13 +18,14 @@ namespace Nz::ShaderAst
rootNode = ShaderBuilder::MultiStatement(); rootNode = ShaderBuilder::MultiStatement();
} }
inline Module::Module(std::shared_ptr<const Metadata> metadata) : inline Module::Module(std::shared_ptr<const Metadata> metadata, std::vector<ImportedModule> importedModules) :
Module(std::move(metadata), ShaderBuilder::MultiStatement()) Module(std::move(metadata), ShaderBuilder::MultiStatement(), std::move(importedModules))
{ {
} }
inline Module::Module(std::shared_ptr<const Metadata> Metadata, MultiStatementPtr RootNode) : inline Module::Module(std::shared_ptr<const Metadata> Metadata, MultiStatementPtr RootNode, std::vector<ImportedModule> ImportedModules) :
metadata(std::move(Metadata)), metadata(std::move(Metadata)),
importedModules(std::move(ImportedModules)),
rootNode(std::move(RootNode)) rootNode(std::move(RootNode))
{ {
} }

View File

@@ -739,14 +739,14 @@ namespace Nz::ShaderAst
{ {
auto rootnode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode)); auto rootnode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode));
return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode)); return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode), shaderModule.importedModules);
} }
ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule, const Options& options) ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule, const Options& options)
{ {
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, options)); auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, options));
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode)); return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
} }
ExpressionPtr AstConstantPropagationVisitor::Clone(BinaryExpression& node) ExpressionPtr AstConstantPropagationVisitor::Clone(BinaryExpression& node)

View File

@@ -27,7 +27,7 @@ namespace Nz::ShaderAst
{ {
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, usageSet)); auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, usageSet));
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode)); return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
} }
StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet) StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet)

View File

@@ -163,8 +163,7 @@ namespace Nz::ShaderAst
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error) ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
{ {
ModulePtr clone = std::make_shared<Module>(module.metadata); ModulePtr clone = std::make_shared<Module>(module.metadata, module.importedModules);
clone->importedModules = module.importedModules;
Context currentContext; Context currentContext;
currentContext.options = options; currentContext.options = options;

View File

@@ -170,18 +170,16 @@ namespace Nz
}); });
ShaderAst::ModulePtr sanitizedModule; ShaderAst::ModulePtr sanitizedModule;
ShaderAst::Statement* targetAst; const ShaderAst::Module* targetModule;
if (!states.sanitized) if (!states.sanitized)
{ {
sanitizedModule = Sanitize(module, states.optionValues); sanitizedModule = Sanitize(module, states.optionValues);
targetAst = sanitizedModule->rootNode.get(); targetModule = sanitizedModule.get();
} }
else else
targetAst = module.rootNode.get(); targetModule = &module;
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module; ShaderAst::ModulePtr optimizedModule;
ShaderAst::StatementPtr optimizedAst;
if (states.optimize) if (states.optimize)
{ {
ShaderAst::StatementPtr tempAst; ShaderAst::StatementPtr tempAst;
@@ -190,27 +188,31 @@ namespace Nz
if (shaderStage) if (shaderStage)
dependencyConfig.usedShaderStages = *shaderStage; dependencyConfig.usedShaderStages = *shaderStage;
tempAst = ShaderAst::PropagateConstants(*targetAst); optimizedModule = ShaderAst::PropagateConstants(*targetModule);
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig); optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
targetAst = optimizedAst.get(); targetModule = optimizedModule.get();
} }
// Previsitor
state.previsitor.selectedStage = shaderStage; state.previsitor.selectedStage = shaderStage;
targetAst->Visit(state.previsitor);
for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(state.previsitor);
targetModule->rootNode->Visit(state.previsitor);
// Code generation
AppendHeader(); AppendHeader();
for (const auto& importedModule : targetModule.importedModules) for (const auto& importedModule : targetModule->importedModules)
{ {
m_currentState->moduleSuffix = importedModule.identifier; m_currentState->moduleSuffix = importedModule.identifier;
importedModule.module->rootNode->Visit(state.previsitor);
importedModule.module->rootNode->Visit(*this); importedModule.module->rootNode->Visit(*this);
} }
m_currentState->moduleSuffix = {}; m_currentState->moduleSuffix = {};
targetAst->Visit(*this); targetModule->rootNode->Visit(*this);
return state.stream.str(); return state.stream.str();
} }

View File

@@ -481,7 +481,7 @@ namespace Nz
std::vector<UInt32> resultIds; std::vector<UInt32> resultIds;
UInt32 nextVarIndex = 1; UInt32 nextVarIndex = 1;
SpirvConstantCache constantTypeCache; //< init after nextVarIndex SpirvConstantCache constantTypeCache; //< init after nextVarIndex
PreVisitor* preVisitor; PreVisitor* previsitor;
// Output // Output
SpirvSection header; SpirvSection header;
@@ -499,7 +499,7 @@ namespace Nz
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst::Module& module, const States& states) std::vector<UInt32> SpirvWriter::Generate(const ShaderAst::Module& module, const States& states)
{ {
ShaderAst::ModulePtr sanitizedModule; ShaderAst::ModulePtr sanitizedModule;
ShaderAst::Statement* targetAst; const ShaderAst::Module* targetModule;
if (!states.sanitized) if (!states.sanitized)
{ {
ShaderAst::SanitizeVisitor::Options options; ShaderAst::SanitizeVisitor::Options options;
@@ -513,14 +513,12 @@ namespace Nz
options.useIdentifierAccessesForStructs = false; options.useIdentifierAccessesForStructs = false;
sanitizedModule = ShaderAst::Sanitize(module, options); sanitizedModule = ShaderAst::Sanitize(module, options);
targetAst = sanitizedModule->rootNode.get(); targetModule = sanitizedModule.get();
} }
else else
targetAst = module.rootNode.get(); targetModule = &module;
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module; ShaderAst::ModulePtr optimizedModule;
ShaderAst::StatementPtr optimizedAst;
if (states.optimize) if (states.optimize)
{ {
ShaderAst::StatementPtr tempAst; ShaderAst::StatementPtr tempAst;
@@ -528,12 +526,14 @@ namespace Nz
ShaderAst::DependencyCheckerVisitor::Config dependencyConfig; ShaderAst::DependencyCheckerVisitor::Config dependencyConfig;
dependencyConfig.usedShaderStages = ShaderStageType_All; dependencyConfig.usedShaderStages = ShaderStageType_All;
tempAst = ShaderAst::PropagateConstants(*targetAst); optimizedModule = ShaderAst::PropagateConstants(*targetModule);
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig); optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
targetAst = optimizedAst.get(); targetModule = optimizedModule.get();
} }
// Previsitor
m_context.states = &states; m_context.states = &states;
State state; State state;
@@ -544,15 +544,15 @@ namespace Nz
}); });
// Register all extended instruction sets // Register all extended instruction sets
PreVisitor preVisitor(state.constantTypeCache, state.funcs); PreVisitor previsitor(state.constantTypeCache, state.funcs);
for (const auto& importedModule : targetModule.importedModules) for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(preVisitor); importedModule.module->rootNode->Visit(previsitor);
targetAst->Visit(preVisitor); targetModule->rootNode->Visit(previsitor);
m_currentState->preVisitor = &preVisitor; m_currentState->previsitor = &previsitor;
for (const std::string& extInst : preVisitor.extInsts) for (const std::string& extInst : previsitor.extInsts)
state.extensionInstructionSet[extInst] = AllocateResultId(); state.extensionInstructionSet[extInst] = AllocateResultId();
// Assign function ID (required for forward declaration) // Assign function ID (required for forward declaration)
@@ -560,23 +560,23 @@ namespace Nz
func.funcId = AllocateResultId(); func.funcId = AllocateResultId();
SpirvAstVisitor visitor(*this, state.instructions, state.funcs); SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
for (const auto& importedModule : targetModule.importedModules) for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(visitor); importedModule.module->rootNode->Visit(visitor);
targetAst->Visit(visitor); targetModule->rootNode->Visit(visitor);
AppendHeader(); AppendHeader();
for (auto&& [varIndex, extVar] : preVisitor.extVars) for (auto&& [varIndex, extVar] : previsitor.extVars)
{ {
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex); state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet); state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
} }
for (auto&& [varId, builtin] : preVisitor.builtinDecorations) for (auto&& [varId, builtin] : previsitor.builtinDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin); state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
for (auto&& [varId, location] : preVisitor.locationDecorations) for (auto&& [varId, location] : previsitor.locationDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location); state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo); m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
@@ -699,8 +699,8 @@ namespace Nz
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
{ {
auto it = m_currentState->preVisitor->extVars.find(extVarIndex); auto it = m_currentState->previsitor->extVars.find(extVarIndex);
assert(it != m_currentState->preVisitor->extVars.end()); assert(it != m_currentState->previsitor->extVars.end());
return it->second.pointerId; return it->second.pointerId;
} }

View File

@@ -6,6 +6,7 @@
#include <Nazara/Shader/SpirvPrinter.hpp> #include <Nazara/Shader/SpirvPrinter.hpp>
#include <Nazara/Shader/SpirvWriter.hpp> #include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Ast/AstReflect.hpp> #include <Nazara/Shader/Ast/AstReflect.hpp>
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
#include <catch2/catch.hpp> #include <catch2/catch.hpp>
#include <glslang/Public/ShaderLang.h> #include <glslang/Public/ShaderLang.h>
#include <spirv-tools/libspirv.hpp> #include <spirv-tools/libspirv.hpp>
@@ -121,12 +122,19 @@ namespace
}; };
} }
void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput) void ExpectGLSL(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput)
{ {
expectedOutput = Nz::Trim(expectedOutput); expectedOutput = Nz::Trim(expectedOutput);
SECTION("Generating GLSL") SECTION("Generating GLSL")
{ {
Nz::ShaderAst::ModulePtr sanitizedModule;
WHEN("Sanitizing a second time")
{
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
}
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
// Retrieve entry-point to get shader type // Retrieve entry-point to get shader type
std::optional<Nz::ShaderStageType> entryShaderStage; std::optional<Nz::ShaderStageType> entryShaderStage;
@@ -140,7 +148,7 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
}; };
Nz::ShaderAst::AstReflect reflectVisitor; Nz::ShaderAst::AstReflect reflectVisitor;
reflectVisitor.Reflect(*shader.rootNode, callbacks); reflectVisitor.Reflect(*targetModule.rootNode, callbacks);
{ {
INFO("no entry point found"); INFO("no entry point found");
@@ -148,7 +156,7 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
} }
Nz::GlslWriter writer; Nz::GlslWriter writer;
std::string output = writer.Generate(entryShaderStage, shader); std::string output = writer.Generate(entryShaderStage, targetModule);
WHEN("Validating expected code") WHEN("Validating expected code")
{ {
@@ -188,14 +196,21 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
} }
} }
void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput) void ExpectNZSL(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput)
{ {
expectedOutput = Nz::Trim(expectedOutput); expectedOutput = Nz::Trim(expectedOutput);
SECTION("Generating NZSL") SECTION("Generating NZSL")
{ {
Nz::ShaderAst::ModulePtr sanitizedModule;
WHEN("Sanitizing a second time")
{
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
}
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
Nz::LangWriter writer; Nz::LangWriter writer;
std::string output = writer.Generate(shader); std::string output = writer.Generate(targetModule);
WHEN("Validating expected code") WHEN("Validating expected code")
{ {
@@ -211,12 +226,19 @@ void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
} }
} }
void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter) void ExpectSPIRV(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput, bool outputParameter)
{ {
expectedOutput = Nz::Trim(expectedOutput); expectedOutput = Nz::Trim(expectedOutput);
SECTION("Generating SPIRV") SECTION("Generating SPIRV")
{ {
Nz::ShaderAst::ModulePtr sanitizedModule;
WHEN("Sanitizing a second time")
{
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
}
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
Nz::SpirvWriter writer; Nz::SpirvWriter writer;
Nz::SpirvPrinter printer; Nz::SpirvPrinter printer;
@@ -224,7 +246,7 @@ void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput,
settings.printHeader = false; settings.printHeader = false;
settings.printParameters = outputParameter; settings.printParameters = outputParameter;
auto spirv = writer.Generate(shader); auto spirv = writer.Generate(targetModule);
std::string output = printer.Print(spirv.data(), spirv.size(), settings); std::string output = printer.Print(spirv.data(), spirv.size(), settings);
WHEN("Validating expected code") WHEN("Validating expected code")

View File

@@ -6,8 +6,8 @@
#include <Nazara/Shader/Ast/Module.hpp> #include <Nazara/Shader/Ast/Module.hpp>
#include <string> #include <string>
void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput); void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput); void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false); void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
#endif #endif