Shader: Minor module fixes
This commit is contained in:
@@ -170,7 +170,7 @@ int main()
|
|||||||
assert(Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(output)));
|
assert(Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(output)));
|
||||||
|
|
||||||
Nz::ShaderWriter::States states;
|
Nz::ShaderWriter::States states;
|
||||||
states.optimize = false;
|
states.optimize = true;
|
||||||
|
|
||||||
auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states);
|
auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states);
|
||||||
if (!fragVertShader)
|
if (!fragVertShader)
|
||||||
|
|||||||
@@ -45,9 +45,6 @@ namespace Nz::ShaderAst
|
|||||||
if (!Compare(lhs.identifier, rhs.identifier))
|
if (!Compare(lhs.identifier, rhs.identifier))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (!Compare(lhs.dependencies, rhs.dependencies))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (!Compare(*lhs.module, *rhs.module))
|
if (!Compare(*lhs.module, *rhs.module))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
|||||||
@@ -22,11 +22,12 @@ namespace Nz::ShaderAst
|
|||||||
class Module
|
class Module
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
struct ImportedModule;
|
||||||
struct Metadata;
|
struct Metadata;
|
||||||
|
|
||||||
inline Module(UInt32 shaderLangVersion, const Uuid& moduleId = Uuid::Generate());
|
inline Module(UInt32 shaderLangVersion, const Uuid& moduleId = Uuid::Generate());
|
||||||
inline Module(std::shared_ptr<const Metadata> metadata);
|
inline Module(std::shared_ptr<const Metadata> metadata, std::vector<ImportedModule> importedModules = {});
|
||||||
inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode);
|
inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode, std::vector<ImportedModule> importedModules = {});
|
||||||
Module(const Module&) = default;
|
Module(const Module&) = default;
|
||||||
Module(Module&&) noexcept = default;
|
Module(Module&&) noexcept = default;
|
||||||
~Module() = default;
|
~Module() = default;
|
||||||
@@ -37,7 +38,6 @@ namespace Nz::ShaderAst
|
|||||||
struct ImportedModule
|
struct ImportedModule
|
||||||
{
|
{
|
||||||
std::string identifier;
|
std::string identifier;
|
||||||
std::vector<Uuid> dependencies;
|
|
||||||
ModulePtr module;
|
ModulePtr module;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -18,13 +18,14 @@ namespace Nz::ShaderAst
|
|||||||
rootNode = ShaderBuilder::MultiStatement();
|
rootNode = ShaderBuilder::MultiStatement();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Module::Module(std::shared_ptr<const Metadata> metadata) :
|
inline Module::Module(std::shared_ptr<const Metadata> metadata, std::vector<ImportedModule> importedModules) :
|
||||||
Module(std::move(metadata), ShaderBuilder::MultiStatement())
|
Module(std::move(metadata), ShaderBuilder::MultiStatement(), std::move(importedModules))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Module::Module(std::shared_ptr<const Metadata> Metadata, MultiStatementPtr RootNode) :
|
inline Module::Module(std::shared_ptr<const Metadata> Metadata, MultiStatementPtr RootNode, std::vector<ImportedModule> ImportedModules) :
|
||||||
metadata(std::move(Metadata)),
|
metadata(std::move(Metadata)),
|
||||||
|
importedModules(std::move(ImportedModules)),
|
||||||
rootNode(std::move(RootNode))
|
rootNode(std::move(RootNode))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -739,14 +739,14 @@ namespace Nz::ShaderAst
|
|||||||
{
|
{
|
||||||
auto rootnode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode));
|
auto rootnode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode));
|
||||||
|
|
||||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode));
|
return std::make_shared<Module>(shaderModule.metadata, std::move(rootnode), shaderModule.importedModules);
|
||||||
}
|
}
|
||||||
|
|
||||||
ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule, const Options& options)
|
ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule, const Options& options)
|
||||||
{
|
{
|
||||||
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, options));
|
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, options));
|
||||||
|
|
||||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode));
|
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
|
||||||
}
|
}
|
||||||
|
|
||||||
ExpressionPtr AstConstantPropagationVisitor::Clone(BinaryExpression& node)
|
ExpressionPtr AstConstantPropagationVisitor::Clone(BinaryExpression& node)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ namespace Nz::ShaderAst
|
|||||||
{
|
{
|
||||||
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, usageSet));
|
auto rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, usageSet));
|
||||||
|
|
||||||
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode));
|
return std::make_shared<Module>(shaderModule.metadata, std::move(rootNode), shaderModule.importedModules);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet)
|
StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet)
|
||||||
|
|||||||
@@ -163,8 +163,7 @@ namespace Nz::ShaderAst
|
|||||||
|
|
||||||
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
|
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
|
||||||
{
|
{
|
||||||
ModulePtr clone = std::make_shared<Module>(module.metadata);
|
ModulePtr clone = std::make_shared<Module>(module.metadata, module.importedModules);
|
||||||
clone->importedModules = module.importedModules;
|
|
||||||
|
|
||||||
Context currentContext;
|
Context currentContext;
|
||||||
currentContext.options = options;
|
currentContext.options = options;
|
||||||
|
|||||||
@@ -170,18 +170,16 @@ namespace Nz
|
|||||||
});
|
});
|
||||||
|
|
||||||
ShaderAst::ModulePtr sanitizedModule;
|
ShaderAst::ModulePtr sanitizedModule;
|
||||||
ShaderAst::Statement* targetAst;
|
const ShaderAst::Module* targetModule;
|
||||||
if (!states.sanitized)
|
if (!states.sanitized)
|
||||||
{
|
{
|
||||||
sanitizedModule = Sanitize(module, states.optionValues);
|
sanitizedModule = Sanitize(module, states.optionValues);
|
||||||
targetAst = sanitizedModule->rootNode.get();
|
targetModule = sanitizedModule.get();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
targetAst = module.rootNode.get();
|
targetModule = &module;
|
||||||
|
|
||||||
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
|
ShaderAst::ModulePtr optimizedModule;
|
||||||
|
|
||||||
ShaderAst::StatementPtr optimizedAst;
|
|
||||||
if (states.optimize)
|
if (states.optimize)
|
||||||
{
|
{
|
||||||
ShaderAst::StatementPtr tempAst;
|
ShaderAst::StatementPtr tempAst;
|
||||||
@@ -190,27 +188,31 @@ namespace Nz
|
|||||||
if (shaderStage)
|
if (shaderStage)
|
||||||
dependencyConfig.usedShaderStages = *shaderStage;
|
dependencyConfig.usedShaderStages = *shaderStage;
|
||||||
|
|
||||||
tempAst = ShaderAst::PropagateConstants(*targetAst);
|
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
|
||||||
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig);
|
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
|
||||||
|
|
||||||
targetAst = optimizedAst.get();
|
targetModule = optimizedModule.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Previsitor
|
||||||
state.previsitor.selectedStage = shaderStage;
|
state.previsitor.selectedStage = shaderStage;
|
||||||
targetAst->Visit(state.previsitor);
|
|
||||||
|
|
||||||
|
for (const auto& importedModule : targetModule->importedModules)
|
||||||
|
importedModule.module->rootNode->Visit(state.previsitor);
|
||||||
|
|
||||||
|
targetModule->rootNode->Visit(state.previsitor);
|
||||||
|
|
||||||
|
// Code generation
|
||||||
AppendHeader();
|
AppendHeader();
|
||||||
|
|
||||||
for (const auto& importedModule : targetModule.importedModules)
|
for (const auto& importedModule : targetModule->importedModules)
|
||||||
{
|
{
|
||||||
m_currentState->moduleSuffix = importedModule.identifier;
|
m_currentState->moduleSuffix = importedModule.identifier;
|
||||||
|
|
||||||
importedModule.module->rootNode->Visit(state.previsitor);
|
|
||||||
importedModule.module->rootNode->Visit(*this);
|
importedModule.module->rootNode->Visit(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
m_currentState->moduleSuffix = {};
|
m_currentState->moduleSuffix = {};
|
||||||
targetAst->Visit(*this);
|
targetModule->rootNode->Visit(*this);
|
||||||
|
|
||||||
return state.stream.str();
|
return state.stream.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -481,7 +481,7 @@ namespace Nz
|
|||||||
std::vector<UInt32> resultIds;
|
std::vector<UInt32> resultIds;
|
||||||
UInt32 nextVarIndex = 1;
|
UInt32 nextVarIndex = 1;
|
||||||
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
|
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
|
||||||
PreVisitor* preVisitor;
|
PreVisitor* previsitor;
|
||||||
|
|
||||||
// Output
|
// Output
|
||||||
SpirvSection header;
|
SpirvSection header;
|
||||||
@@ -499,7 +499,7 @@ namespace Nz
|
|||||||
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst::Module& module, const States& states)
|
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst::Module& module, const States& states)
|
||||||
{
|
{
|
||||||
ShaderAst::ModulePtr sanitizedModule;
|
ShaderAst::ModulePtr sanitizedModule;
|
||||||
ShaderAst::Statement* targetAst;
|
const ShaderAst::Module* targetModule;
|
||||||
if (!states.sanitized)
|
if (!states.sanitized)
|
||||||
{
|
{
|
||||||
ShaderAst::SanitizeVisitor::Options options;
|
ShaderAst::SanitizeVisitor::Options options;
|
||||||
@@ -513,14 +513,12 @@ namespace Nz
|
|||||||
options.useIdentifierAccessesForStructs = false;
|
options.useIdentifierAccessesForStructs = false;
|
||||||
|
|
||||||
sanitizedModule = ShaderAst::Sanitize(module, options);
|
sanitizedModule = ShaderAst::Sanitize(module, options);
|
||||||
targetAst = sanitizedModule->rootNode.get();
|
targetModule = sanitizedModule.get();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
targetAst = module.rootNode.get();
|
targetModule = &module;
|
||||||
|
|
||||||
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
|
ShaderAst::ModulePtr optimizedModule;
|
||||||
|
|
||||||
ShaderAst::StatementPtr optimizedAst;
|
|
||||||
if (states.optimize)
|
if (states.optimize)
|
||||||
{
|
{
|
||||||
ShaderAst::StatementPtr tempAst;
|
ShaderAst::StatementPtr tempAst;
|
||||||
@@ -528,12 +526,14 @@ namespace Nz
|
|||||||
ShaderAst::DependencyCheckerVisitor::Config dependencyConfig;
|
ShaderAst::DependencyCheckerVisitor::Config dependencyConfig;
|
||||||
dependencyConfig.usedShaderStages = ShaderStageType_All;
|
dependencyConfig.usedShaderStages = ShaderStageType_All;
|
||||||
|
|
||||||
tempAst = ShaderAst::PropagateConstants(*targetAst);
|
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
|
||||||
optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst, dependencyConfig);
|
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
|
||||||
|
|
||||||
targetAst = optimizedAst.get();
|
targetModule = optimizedModule.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Previsitor
|
||||||
|
|
||||||
m_context.states = &states;
|
m_context.states = &states;
|
||||||
|
|
||||||
State state;
|
State state;
|
||||||
@@ -544,15 +544,15 @@ namespace Nz
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Register all extended instruction sets
|
// Register all extended instruction sets
|
||||||
PreVisitor preVisitor(state.constantTypeCache, state.funcs);
|
PreVisitor previsitor(state.constantTypeCache, state.funcs);
|
||||||
for (const auto& importedModule : targetModule.importedModules)
|
for (const auto& importedModule : targetModule->importedModules)
|
||||||
importedModule.module->rootNode->Visit(preVisitor);
|
importedModule.module->rootNode->Visit(previsitor);
|
||||||
|
|
||||||
targetAst->Visit(preVisitor);
|
targetModule->rootNode->Visit(previsitor);
|
||||||
|
|
||||||
m_currentState->preVisitor = &preVisitor;
|
m_currentState->previsitor = &previsitor;
|
||||||
|
|
||||||
for (const std::string& extInst : preVisitor.extInsts)
|
for (const std::string& extInst : previsitor.extInsts)
|
||||||
state.extensionInstructionSet[extInst] = AllocateResultId();
|
state.extensionInstructionSet[extInst] = AllocateResultId();
|
||||||
|
|
||||||
// Assign function ID (required for forward declaration)
|
// Assign function ID (required for forward declaration)
|
||||||
@@ -560,23 +560,23 @@ namespace Nz
|
|||||||
func.funcId = AllocateResultId();
|
func.funcId = AllocateResultId();
|
||||||
|
|
||||||
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
|
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
|
||||||
for (const auto& importedModule : targetModule.importedModules)
|
for (const auto& importedModule : targetModule->importedModules)
|
||||||
importedModule.module->rootNode->Visit(visitor);
|
importedModule.module->rootNode->Visit(visitor);
|
||||||
|
|
||||||
targetAst->Visit(visitor);
|
targetModule->rootNode->Visit(visitor);
|
||||||
|
|
||||||
AppendHeader();
|
AppendHeader();
|
||||||
|
|
||||||
for (auto&& [varIndex, extVar] : preVisitor.extVars)
|
for (auto&& [varIndex, extVar] : previsitor.extVars)
|
||||||
{
|
{
|
||||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
|
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
|
||||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
|
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
|
for (auto&& [varId, builtin] : previsitor.builtinDecorations)
|
||||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
|
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
|
||||||
|
|
||||||
for (auto&& [varId, location] : preVisitor.locationDecorations)
|
for (auto&& [varId, location] : previsitor.locationDecorations)
|
||||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
|
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
|
||||||
|
|
||||||
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
|
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
|
||||||
@@ -699,8 +699,8 @@ namespace Nz
|
|||||||
|
|
||||||
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
|
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
|
||||||
{
|
{
|
||||||
auto it = m_currentState->preVisitor->extVars.find(extVarIndex);
|
auto it = m_currentState->previsitor->extVars.find(extVarIndex);
|
||||||
assert(it != m_currentState->preVisitor->extVars.end());
|
assert(it != m_currentState->previsitor->extVars.end());
|
||||||
|
|
||||||
return it->second.pointerId;
|
return it->second.pointerId;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
#include <Nazara/Shader/SpirvPrinter.hpp>
|
#include <Nazara/Shader/SpirvPrinter.hpp>
|
||||||
#include <Nazara/Shader/SpirvWriter.hpp>
|
#include <Nazara/Shader/SpirvWriter.hpp>
|
||||||
#include <Nazara/Shader/Ast/AstReflect.hpp>
|
#include <Nazara/Shader/Ast/AstReflect.hpp>
|
||||||
|
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
|
||||||
#include <catch2/catch.hpp>
|
#include <catch2/catch.hpp>
|
||||||
#include <glslang/Public/ShaderLang.h>
|
#include <glslang/Public/ShaderLang.h>
|
||||||
#include <spirv-tools/libspirv.hpp>
|
#include <spirv-tools/libspirv.hpp>
|
||||||
@@ -121,12 +122,19 @@ namespace
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
void ExpectGLSL(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput)
|
||||||
{
|
{
|
||||||
expectedOutput = Nz::Trim(expectedOutput);
|
expectedOutput = Nz::Trim(expectedOutput);
|
||||||
|
|
||||||
SECTION("Generating GLSL")
|
SECTION("Generating GLSL")
|
||||||
{
|
{
|
||||||
|
Nz::ShaderAst::ModulePtr sanitizedModule;
|
||||||
|
WHEN("Sanitizing a second time")
|
||||||
|
{
|
||||||
|
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
|
||||||
|
}
|
||||||
|
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
|
||||||
|
|
||||||
// Retrieve entry-point to get shader type
|
// Retrieve entry-point to get shader type
|
||||||
std::optional<Nz::ShaderStageType> entryShaderStage;
|
std::optional<Nz::ShaderStageType> entryShaderStage;
|
||||||
|
|
||||||
@@ -140,7 +148,7 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
|||||||
};
|
};
|
||||||
|
|
||||||
Nz::ShaderAst::AstReflect reflectVisitor;
|
Nz::ShaderAst::AstReflect reflectVisitor;
|
||||||
reflectVisitor.Reflect(*shader.rootNode, callbacks);
|
reflectVisitor.Reflect(*targetModule.rootNode, callbacks);
|
||||||
|
|
||||||
{
|
{
|
||||||
INFO("no entry point found");
|
INFO("no entry point found");
|
||||||
@@ -148,7 +156,7 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
|||||||
}
|
}
|
||||||
|
|
||||||
Nz::GlslWriter writer;
|
Nz::GlslWriter writer;
|
||||||
std::string output = writer.Generate(entryShaderStage, shader);
|
std::string output = writer.Generate(entryShaderStage, targetModule);
|
||||||
|
|
||||||
WHEN("Validating expected code")
|
WHEN("Validating expected code")
|
||||||
{
|
{
|
||||||
@@ -188,14 +196,21 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
void ExpectNZSL(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput)
|
||||||
{
|
{
|
||||||
expectedOutput = Nz::Trim(expectedOutput);
|
expectedOutput = Nz::Trim(expectedOutput);
|
||||||
|
|
||||||
SECTION("Generating NZSL")
|
SECTION("Generating NZSL")
|
||||||
{
|
{
|
||||||
|
Nz::ShaderAst::ModulePtr sanitizedModule;
|
||||||
|
WHEN("Sanitizing a second time")
|
||||||
|
{
|
||||||
|
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
|
||||||
|
}
|
||||||
|
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
|
||||||
|
|
||||||
Nz::LangWriter writer;
|
Nz::LangWriter writer;
|
||||||
std::string output = writer.Generate(shader);
|
std::string output = writer.Generate(targetModule);
|
||||||
|
|
||||||
WHEN("Validating expected code")
|
WHEN("Validating expected code")
|
||||||
{
|
{
|
||||||
@@ -211,12 +226,19 @@ void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter)
|
void ExpectSPIRV(const Nz::ShaderAst::Module& shaderModule, std::string_view expectedOutput, bool outputParameter)
|
||||||
{
|
{
|
||||||
expectedOutput = Nz::Trim(expectedOutput);
|
expectedOutput = Nz::Trim(expectedOutput);
|
||||||
|
|
||||||
SECTION("Generating SPIRV")
|
SECTION("Generating SPIRV")
|
||||||
{
|
{
|
||||||
|
Nz::ShaderAst::ModulePtr sanitizedModule;
|
||||||
|
WHEN("Sanitizing a second time")
|
||||||
|
{
|
||||||
|
CHECK_NOTHROW(sanitizedModule = Nz::ShaderAst::Sanitize(shaderModule));
|
||||||
|
}
|
||||||
|
const Nz::ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : shaderModule;
|
||||||
|
|
||||||
Nz::SpirvWriter writer;
|
Nz::SpirvWriter writer;
|
||||||
Nz::SpirvPrinter printer;
|
Nz::SpirvPrinter printer;
|
||||||
|
|
||||||
@@ -224,7 +246,7 @@ void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput,
|
|||||||
settings.printHeader = false;
|
settings.printHeader = false;
|
||||||
settings.printParameters = outputParameter;
|
settings.printParameters = outputParameter;
|
||||||
|
|
||||||
auto spirv = writer.Generate(shader);
|
auto spirv = writer.Generate(targetModule);
|
||||||
std::string output = printer.Print(spirv.data(), spirv.size(), settings);
|
std::string output = printer.Print(spirv.data(), spirv.size(), settings);
|
||||||
|
|
||||||
WHEN("Validating expected code")
|
WHEN("Validating expected code")
|
||||||
|
|||||||
@@ -6,8 +6,8 @@
|
|||||||
#include <Nazara/Shader/Ast/Module.hpp>
|
#include <Nazara/Shader/Ast/Module.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||||
void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||||
void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
|
void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Reference in New Issue
Block a user