Shader: Minor module fixes
This commit is contained in:
parent
da40a2db28
commit
39a2992791
|
|
@ -170,7 +170,7 @@ int main()
|
|||
assert(Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(output)));
|
||||
|
||||
Nz::ShaderWriter::States states;
|
||||
states.optimize = false;
|
||||
states.optimize = true;
|
||||
|
||||
auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states);
|
||||
if (!fragVertShader)
|
||||
|
|
|
|||
|
|
@ -45,9 +45,6 @@ namespace Nz::ShaderAst
|
|||
if (!Compare(lhs.identifier, rhs.identifier))
|
||||
return false;
|
||||
|
||||
if (!Compare(lhs.dependencies, rhs.dependencies))
|
||||
return false;
|
||||
|
||||
if (!Compare(*lhs.module, *rhs.module))
|
||||
return false;
|
||||
|
||||
|
|
|
|||
|
|
@ -22,11 +22,12 @@ namespace Nz::ShaderAst
|
|||
class Module
|
||||
{
|
||||
public:
|
||||
struct ImportedModule;
|
||||
struct Metadata;
|
||||
|
||||
inline Module(UInt32 shaderLangVersion, const Uuid& moduleId = Uuid::Generate());
|
||||
inline Module(std::shared_ptr<const Metadata> metadata);
|
||||
inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode);
|
||||
inline Module(std::shared_ptr<const Metadata> metadata, std::vector<ImportedModule> importedModules = {});
|
||||
inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode, std::vector<ImportedModule> importedModules = {});
|
||||
Module(const Module&) = default;
|
||||
Module(Module&&) noexcept = default;
|
||||
~Module() = default;
|
||||
|
|
@ -37,7 +38,6 @@ namespace Nz::ShaderAst
|
|||
struct ImportedModule
|
||||
{
|
||||
std::string identifier;
|
||||
std::vector<Uuid> dependencies;
|
||||
ModulePtr module;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -18,13 +18,14 @@ namespace Nz::ShaderAst
|
|||
rootNode = ShaderBuilder::MultiStatement();
|
||||
}
|
||||
|
||||
inline Module::Module(std::shared_ptr<const Metadata> metadata) :
|
||||
Module(std::move(metadata), ShaderBuilder::MultiStatement())
|
||||
inline Module::Module(std::shared_ptr<const Metadata> metadata, std::vector<ImportedModule> importedModules) :
|
||||
Module(std::move(metadata), ShaderBuilder::MultiStatement(), std::move(importedModules))
|
||||
{
|
||||
}
|
||||
|
||||
inline Module::Module(std::shared_ptr<const Metadata> Metadata, MultiStatementPtr RootNode) :
|
||||
inline Module::Module(std::shared_ptr<const Metadata> Metadata, MultiStatementPtr RootNode, std::vector<ImportedModule> ImportedModules) :
|
||||
metadata(std::move(Metadata)),
|
||||
importedModules(std::move(ImportedModules)),
|
||||
rootNode(std::move(RootNode))
|
||||
{
|
||||
}
|
||||
|
|
|
|||
|
|
@ -739,14 +739,14 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
auto rootnode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode));
|
||||
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode));
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode), shaderModule.importedModules);
|
||||
}
|
||||
|
||||
ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule, const Options& options)
|
||||
{
|
||||
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, options));
|
||||
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode));
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
|
||||
}
|
||||
|
||||
ExpressionPtr AstConstantPropagationVisitor::Clone(BinaryExpression& node)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, usageSet));
|
||||
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode));
|
||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
|
||||
}
|
||||
|
||||
StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet)
|
||||
|
|
|
|||
|
|
@ -163,8 +163,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
|
||||
{
|
||||
ModulePtr clone = std::make_shared<Module>(module.metadata);
|
||||
clone->importedModules = module.importedModules;
|
||||
ModulePtr clone = std::make_shared<Module>(module.metadata, module.importedModules);
|
||||
|
||||
Context currentContext;
|
||||
currentContext.options = options;
|
||||
|
|
|
|||
|
|
@ -170,18 +170,16 @@ namespace Nz
|
|||
});
|
||||
|
||||
ShaderAst::ModulePtr sanitizedModule;
|
||||
ShaderAst::Statement* targetAst;
|
||||
const ShaderAst::Module* targetModule;
|
||||
if (!states.sanitized)
|
||||
{
|
||||
sanitizedModule = Sanitize(module, states.optionValues);
|
||||
targetAst = sanitizedModule->rootNode.get();
|
||||
targetModule = sanitizedModule.get();
|
||||
}
|
||||
else
|
||||
targetAst = module.rootNode.get();
|
||||
targetModule = &module;
|
||||
|
||||
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
|
||||
|
||||
ShaderAst::StatementPtr optimizedAst;
|
||||
ShaderAst::ModulePtr optimizedModule;
|
||||
if (states.optimize)
|
||||
{
|
||||
ShaderAst::StatementPtr tempAst;
|
||||
|
|
@ -190,27 +188,31 @@ namespace Nz
|
|||
if (shaderStage)
|
||||
dependencyConfig.usedShaderStages = *shaderStage;
|
||||
|
||||
tempAst = ShaderAst::PropagateConstants(*targetAst);
|
||||
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig);
|
||||
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
|
||||
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
|
||||
|
||||
targetAst = optimizedAst.get();
|
||||
targetModule = optimizedModule.get();
|
||||
}
|
||||
|
||||
// Previsitor
|
||||
state.previsitor.selectedStage = shaderStage;
|
||||
targetAst->Visit(state.previsitor);
|
||||
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
importedModule.module->rootNode->Visit(state.previsitor);
|
||||
|
||||
targetModule->rootNode->Visit(state.previsitor);
|
||||
|
||||
// Code generation
|
||||
AppendHeader();
|
||||
|
||||
for (const auto& importedModule : targetModule.importedModules)
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
{
|
||||
m_currentState->moduleSuffix = importedModule.identifier;
|
||||
|
||||
importedModule.module->rootNode->Visit(state.previsitor);
|
||||
importedModule.module->rootNode->Visit(*this);
|
||||
}
|
||||
|
||||
m_currentState->moduleSuffix = {};
|
||||
targetAst->Visit(*this);
|
||||
targetModule->rootNode->Visit(*this);
|
||||
|
||||
return state.stream.str();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -481,7 +481,7 @@ namespace Nz
|
|||
std::vector<UInt32> resultIds;
|
||||
UInt32 nextVarIndex = 1;
|
||||
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
|
||||
PreVisitor* preVisitor;
|
||||
PreVisitor* previsitor;
|
||||
|
||||
// Output
|
||||
SpirvSection header;
|
||||
|
|
@ -499,7 +499,7 @@ namespace Nz
|
|||
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst::Module& module, const States& states)
|
||||
{
|
||||
ShaderAst::ModulePtr sanitizedModule;
|
||||
ShaderAst::Statement* targetAst;
|
||||
const ShaderAst::Module* targetModule;
|
||||
if (!states.sanitized)
|
||||
{
|
||||
ShaderAst::SanitizeVisitor::Options options;
|
||||
|
|
@ -513,14 +513,12 @@ namespace Nz
|
|||
options.useIdentifierAccessesForStructs = false;
|
||||
|
||||
sanitizedModule = ShaderAst::Sanitize(module, options);
|
||||
targetAst = sanitizedModule->rootNode.get();
|
||||
targetModule = sanitizedModule.get();
|
||||
}
|
||||
else
|
||||
targetAst = module.rootNode.get();
|
||||
targetModule = &module;
|
||||
|
||||
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
|
||||
|
||||
ShaderAst::StatementPtr optimizedAst;
|
||||
ShaderAst::ModulePtr optimizedModule;
|
||||
if (states.optimize)
|
||||
{
|
||||
ShaderAst::StatementPtr tempAst;
|
||||
|
|
@ -528,12 +526,14 @@ namespace Nz
|
|||
ShaderAst::DependencyCheckerVisitor::Config dependencyConfig;
|
||||
dependencyConfig.usedShaderStages = ShaderStageType_All;
|
||||
|
||||
tempAst = ShaderAst::PropagateConstants(*targetAst);
|
||||
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig);
|
||||
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
|
||||
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
|
||||
|
||||
targetAst = optimizedAst.get();
|
||||
targetModule = optimizedModule.get();
|
||||
}
|
||||
|
||||
// Previsitor
|
||||
|
||||
m_context.states = &states;
|
||||
|
||||
State state;
|
||||
|
|
@ -544,15 +544,15 @@ namespace Nz
|
|||
});
|
||||
|
||||
// Register all extended instruction sets
|
||||
PreVisitor preVisitor(state.constantTypeCache, state.funcs);
|
||||
for (const auto& importedModule : targetModule.importedModules)
|
||||
importedModule.module->rootNode->Visit(preVisitor);
|
||||
PreVisitor previsitor(state.constantTypeCache, state.funcs);
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
importedModule.module->rootNode->Visit(previsitor);
|
||||
|
||||
targetAst->Visit(preVisitor);
|
||||
targetModule->rootNode->Visit(previsitor);
|
||||
|
||||
m_currentState->preVisitor = &preVisitor;
|
||||
m_currentState->previsitor = &previsitor;
|
||||
|
||||
for (const std::string& extInst : preVisitor.extInsts)
|
||||
for (const std::string& extInst : previsitor.extInsts)
|
||||
state.extensionInstructionSet[extInst] = AllocateResultId();
|
||||
|
||||
// Assign function ID (required for forward declaration)
|
||||
|
|
@ -560,23 +560,23 @@ namespace Nz
|
|||
func.funcId = AllocateResultId();
|
||||
|
||||
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
|
||||
for (const auto& importedModule : targetModule.importedModules)
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
importedModule.module->rootNode->Visit(visitor);
|
||||
|
||||
targetAst->Visit(visitor);
|
||||
targetModule->rootNode->Visit(visitor);
|
||||
|
||||
AppendHeader();
|
||||
|
||||
for (auto&& [varIndex, extVar] : preVisitor.extVars)
|
||||
for (auto&& [varIndex, extVar] : previsitor.extVars)
|
||||
{
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
|
||||
}
|
||||
|
||||
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
|
||||
for (auto&& [varId, builtin] : previsitor.builtinDecorations)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
|
||||
|
||||
for (auto&& [varId, location] : preVisitor.locationDecorations)
|
||||
for (auto&& [varId, location] : previsitor.locationDecorations)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
|
||||
|
||||
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
|
||||
|
|
@ -699,8 +699,8 @@ namespace Nz
|
|||
|
||||
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
|
||||
{
|
||||
auto it = m_currentState->preVisitor->extVars.find(extVarIndex);
|
||||
assert(it != m_currentState->preVisitor->extVars.end());
|
||||
auto it = m_currentState->previsitor->extVars.find(extVarIndex);
|
||||
assert(it != m_currentState->previsitor->extVars.end());
|
||||
|
||||
return it->second.pointerId;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <Nazara/Shader/SpirvPrinter.hpp>
|
||||
#include <Nazara/Shader/SpirvWriter.hpp>
|
||||
#include <Nazara/Shader/Ast/AstReflect.hpp>
|
||||
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
|
||||
#include <catch2/catch.hpp>
|
||||
#include <glslang/Public/ShaderLang.h>
|
||||
#include <spirv-tools/libspirv.hpp>
|
||||
|
|
@ -121,12 +122,19 @@ namespace
|
|||
};
|
||||
}
|
||||
|
||||
void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
||||
void ExpectGLSL(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput)
|
||||
{
|
||||
expectedOutput = Nz::Trim(expectedOutput);
|
||||
|
||||
SECTION("Generating GLSL")
|
||||
{
|
||||
Nz::ShaderAst::ModulePtr sanitizedModule;
|
||||
WHEN("Sanitizing a second time")
|
||||
{
|
||||
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
|
||||
}
|
||||
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
|
||||
|
||||
// Retrieve entry-point to get shader type
|
||||
std::optional<Nz::ShaderStageType> entryShaderStage;
|
||||
|
||||
|
|
@ -140,7 +148,7 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
|||
};
|
||||
|
||||
Nz::ShaderAst::AstReflect reflectVisitor;
|
||||
reflectVisitor.Reflect(*shader.rootNode, callbacks);
|
||||
reflectVisitor.Reflect(*targetModule.rootNode, callbacks);
|
||||
|
||||
{
|
||||
INFO("no entry point found");
|
||||
|
|
@ -148,7 +156,7 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
|||
}
|
||||
|
||||
Nz::GlslWriter writer;
|
||||
std::string output = writer.Generate(entryShaderStage, shader);
|
||||
std::string output = writer.Generate(entryShaderStage, targetModule);
|
||||
|
||||
WHEN("Validating expected code")
|
||||
{
|
||||
|
|
@ -188,14 +196,21 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
|||
}
|
||||
}
|
||||
|
||||
void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
||||
void ExpectNZSL(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput)
|
||||
{
|
||||
expectedOutput = Nz::Trim(expectedOutput);
|
||||
|
||||
SECTION("Generating NZSL")
|
||||
{
|
||||
Nz::ShaderAst::ModulePtr sanitizedModule;
|
||||
WHEN("Sanitizing a second time")
|
||||
{
|
||||
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
|
||||
}
|
||||
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
|
||||
|
||||
Nz::LangWriter writer;
|
||||
std::string output = writer.Generate(shader);
|
||||
std::string output = writer.Generate(targetModule);
|
||||
|
||||
WHEN("Validating expected code")
|
||||
{
|
||||
|
|
@ -211,12 +226,19 @@ void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
|||
}
|
||||
}
|
||||
|
||||
void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter)
|
||||
void ExpectSPIRV(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput, bool outputParameter)
|
||||
{
|
||||
expectedOutput = Nz::Trim(expectedOutput);
|
||||
|
||||
SECTION("Generating SPIRV")
|
||||
{
|
||||
Nz::ShaderAst::ModulePtr sanitizedModule;
|
||||
WHEN("Sanitizing a second time")
|
||||
{
|
||||
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
|
||||
}
|
||||
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
|
||||
|
||||
Nz::SpirvWriter writer;
|
||||
Nz::SpirvPrinter printer;
|
||||
|
||||
|
|
@ -224,7 +246,7 @@ void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput,
|
|||
settings.printHeader = false;
|
||||
settings.printParameters = outputParameter;
|
||||
|
||||
auto spirv = writer.Generate(shader);
|
||||
auto spirv = writer.Generate(targetModule);
|
||||
std::string output = printer.Print(spirv.data(), spirv.size(), settings);
|
||||
|
||||
WHEN("Validating expected code")
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@
|
|||
#include <Nazara/Shader/Ast/Module.hpp>
|
||||
#include <string>
|
||||
|
||||
void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||
void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||
void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
|
||||
void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||
void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||
void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
|
||||
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue