diff --git a/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp b/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp index 151fa7fca..0bcc99e3b 100644 --- a/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp @@ -10,9 +10,7 @@ #include #include #include -#include -#include -#include +#include namespace Nz::ShaderAst { @@ -28,6 +26,8 @@ namespace Nz::ShaderAst inline ExpressionPtr Process(Expression& expression); inline ExpressionPtr Process(Expression& expression, const Options& options); + ModulePtr Process(const Module& shaderModule); + ModulePtr Process(const Module& shaderModule, const Options& options); inline StatementPtr Process(Statement& statement); inline StatementPtr Process(Statement& statement, const Options& options); @@ -65,6 +65,8 @@ namespace Nz::ShaderAst inline ExpressionPtr PropagateConstants(Expression& expr); inline ExpressionPtr PropagateConstants(Expression& expr, const AstConstantPropagationVisitor::Options& options); + inline ModulePtr PropagateConstants(const Module& shaderModule); + inline ModulePtr PropagateConstants(const Module& shaderModule, const AstConstantPropagationVisitor::Options& options); inline StatementPtr PropagateConstants(Statement& ast); inline StatementPtr PropagateConstants(Statement& ast, const AstConstantPropagationVisitor::Options& options); } diff --git a/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.inl b/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.inl index 97c9ba014..f24278bc3 100644 --- a/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.inl +++ b/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.inl @@ -43,6 +43,18 @@ namespace Nz::ShaderAst return optimize.Process(ast, options); } + inline ModulePtr PropagateConstants(const Module& shaderModule) + { + AstConstantPropagationVisitor optimize; + return optimize.Process(shaderModule); + } + + inline ModulePtr PropagateConstants(const Module& shaderModule, const AstConstantPropagationVisitor::Options& options) + { + AstConstantPropagationVisitor optimize; + return optimize.Process(shaderModule, options); + } + inline StatementPtr PropagateConstants(Statement& ast) { AstConstantPropagationVisitor optimize; diff --git a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp index 5b5119eb4..db497d9b7 100644 --- a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp +++ b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp @@ -12,6 +12,7 @@ #include #include #include +#include namespace Nz::ShaderAst { @@ -23,6 +24,7 @@ namespace Nz::ShaderAst EliminateUnusedPassVisitor(EliminateUnusedPassVisitor&&) = delete; ~EliminateUnusedPassVisitor() = default; + ModulePtr Process(const Module& shaderModule, const DependencyCheckerVisitor::UsageSet& usageSet); StatementPtr Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet); EliminateUnusedPassVisitor& operator=(const EliminateUnusedPassVisitor&) = delete; @@ -43,6 +45,10 @@ namespace Nz::ShaderAst Context* m_context; }; + inline ModulePtr EliminateUnusedPass(const Module& shaderModule); + inline ModulePtr EliminateUnusedPass(const Module& shaderModule, const DependencyCheckerVisitor::Config& config); + inline ModulePtr EliminateUnusedPass(const Module& shaderModule, const DependencyCheckerVisitor::UsageSet& usageSet); + inline StatementPtr EliminateUnusedPass(Statement& ast); inline StatementPtr EliminateUnusedPass(Statement& ast, const DependencyCheckerVisitor::Config& config); inline StatementPtr EliminateUnusedPass(Statement& ast, const DependencyCheckerVisitor::UsageSet& usageSet); diff --git a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl index 0a227ff8f..61eb50cb1 100644 --- a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl +++ b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl @@ -7,6 +7,27 @@ namespace Nz::ShaderAst { + inline ModulePtr EliminateUnusedPass(const Module& shaderModule) + { + DependencyCheckerVisitor::Config defaultConfig; + return EliminateUnusedPass(shaderModule, defaultConfig); + } + + inline ModulePtr EliminateUnusedPass(const Module& shaderModule, const DependencyCheckerVisitor::Config& config) + { + DependencyCheckerVisitor dependencyVisitor; + dependencyVisitor.Process(*shaderModule.rootNode, config); + dependencyVisitor.Resolve(); + + return EliminateUnusedPass(shaderModule, dependencyVisitor.GetUsage()); + } + + ModulePtr EliminateUnusedPass(const Module& shaderModule, const DependencyCheckerVisitor::UsageSet& usageSet) + { + EliminateUnusedPassVisitor visitor; + return visitor.Process(shaderModule, usageSet); + } + inline StatementPtr EliminateUnusedPass(Statement& ast) { DependencyCheckerVisitor::Config defaultConfig; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index c9a6ff9e9..3f5736622 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -31,8 +31,8 @@ namespace Nz::ShaderAst SanitizeVisitor(SanitizeVisitor&&) = delete; ~SanitizeVisitor() = default; - inline ModulePtr Sanitize(Module& module, std::string* error = nullptr); - ModulePtr Sanitize(Module& module, const Options& options, std::string* error = nullptr); + inline ModulePtr Sanitize(const Module& module, std::string* error = nullptr); + ModulePtr Sanitize(const Module& module, const Options& options, std::string* error = nullptr); SanitizeVisitor& operator=(const SanitizeVisitor&) = delete; SanitizeVisitor& operator=(SanitizeVisitor&&) = delete; @@ -178,8 +178,8 @@ namespace Nz::ShaderAst Context* m_context; }; - inline ModulePtr Sanitize(Module& module, std::string* error = nullptr); - inline ModulePtr Sanitize(Module& module, const SanitizeVisitor::Options& options, std::string* error = nullptr); + inline ModulePtr Sanitize(const Module& module, std::string* error = nullptr); + inline ModulePtr Sanitize(const Module& module, const SanitizeVisitor::Options& options, std::string* error = nullptr); } #include diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.inl b/include/Nazara/Shader/Ast/SanitizeVisitor.inl index 3bb54f683..10dbf7eb6 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.inl +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.inl @@ -7,18 +7,18 @@ namespace Nz::ShaderAst { - inline ModulePtr SanitizeVisitor::Sanitize(Module& module, std::string* error) + inline ModulePtr SanitizeVisitor::Sanitize(const Module& module, std::string* error) { return Sanitize(module, {}, error); } - inline ModulePtr Sanitize(Module& module, std::string* error) + inline ModulePtr Sanitize(const Module& module, std::string* error) { SanitizeVisitor sanitizer; return sanitizer.Sanitize(module, error); } - inline ModulePtr Sanitize(Module& module, const SanitizeVisitor::Options& options, std::string* error) + inline ModulePtr Sanitize(const Module& module, const SanitizeVisitor::Options& options, std::string* error) { SanitizeVisitor sanitizer; return sanitizer.Sanitize(module, options, error); diff --git a/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp b/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp index 44a670b41..b9dff56b4 100644 --- a/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp @@ -735,6 +735,24 @@ namespace Nz::ShaderAst #undef EnableOptimisation } + ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule) + { + ModulePtr clone = std::make_shared(); + clone->metadata = shaderModule.metadata; + clone->rootNode = static_unique_pointer_cast(Process(*shaderModule.rootNode)); + + return clone; + } + + ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule, const Options& options) + { + ModulePtr clone = std::make_shared(); + clone->metadata = shaderModule.metadata; + clone->rootNode = static_unique_pointer_cast(Process(*shaderModule.rootNode, options)); + + return clone; + } + ExpressionPtr AstConstantPropagationVisitor::Clone(BinaryExpression& node) { auto lhs = CloneExpression(node.left); diff --git a/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp b/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp index 7704b9938..b48161e0a 100644 --- a/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp +++ b/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp @@ -5,17 +5,33 @@ #include #include #include -#include -#include #include namespace Nz::ShaderAst { + namespace + { + template + std::unique_ptr static_unique_pointer_cast(std::unique_ptr&& ptr) + { + return std::unique_ptr(static_cast(ptr.release())); + } + } + struct EliminateUnusedPassVisitor::Context { const DependencyCheckerVisitor::UsageSet& usageSet; }; + ModulePtr EliminateUnusedPassVisitor::Process(const Module& shaderModule, const DependencyCheckerVisitor::UsageSet& usageSet) + { + ModulePtr clone = std::make_shared(); + clone->metadata = shaderModule.metadata; + clone->rootNode = static_unique_pointer_cast(Process(*shaderModule.rootNode, usageSet)); + + return clone; + } + StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet) { Context context{ diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 550f45b48..a469b9b33 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -112,7 +112,7 @@ namespace Nz::ShaderAst std::vector* currentStatementList = nullptr; }; - ModulePtr SanitizeVisitor::Sanitize(Module& module, const Options& options, std::string* error) + ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error) { ModulePtr clone = std::make_shared(); clone->shaderLangVersion = module.shaderLangVersion; diff --git a/src/ShaderNode/ShaderGraph.cpp b/src/ShaderNode/ShaderGraph.cpp index 6e7b3940e..70b8603ac 100644 --- a/src/ShaderNode/ShaderGraph.cpp +++ b/src/ShaderNode/ShaderGraph.cpp @@ -457,13 +457,18 @@ QJsonObject ShaderGraph::Save() return sceneJson; } -Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const +Nz::ShaderAst::ModulePtr ShaderGraph::ToModule() const { - std::vector statements; + Nz::ShaderAst::ModulePtr shaderModule = std::make_shared(); + + std::shared_ptr moduleMetada = std::make_shared(); + moduleMetada->shaderLangVersion = 100; + + shaderModule->metadata = std::move(moduleMetada); // Declare all options for (const auto& option : m_options) - statements.push_back(Nz::ShaderBuilder::DeclareOption(option.name, Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::PrimitiveType::Boolean })); + shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareOption(option.name, Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::PrimitiveType::Boolean })); // Declare all structures for (const auto& structInfo : m_structs) @@ -479,7 +484,7 @@ Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const structMember.type = ToShaderExpressionType(memberInfo.type); } - statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc))); + shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc), false)); } // External block @@ -509,7 +514,7 @@ Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const } if (!external->externalVars.empty()) - statements.push_back(std::move(external)); + shaderModule->rootNode->statements.push_back(std::move(external)); // Inputs / outputs if (!m_inputs.empty()) @@ -525,7 +530,7 @@ Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const structMember.locationIndex = input.locationIndex; } - statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc))); + shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc), false)); } if (!m_outputs.empty()) @@ -549,13 +554,13 @@ Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const position.type = Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::VectorType{ 4, Nz::ShaderAst::PrimitiveType::Float32 } }; } - statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc))); + shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc), false)); } // Functions - statements.push_back(ToFunction()); + shaderModule->rootNode->statements.push_back(ToFunction()); - return Nz::ShaderBuilder::MultiStatement(std::move(statements)); + return shaderModule; } Nz::ShaderAst::ExpressionType ShaderGraph::ToShaderExpressionType(const std::variant& type) const diff --git a/src/ShaderNode/ShaderGraph.hpp b/src/ShaderNode/ShaderGraph.hpp index edc470cf9..d236ddbc0 100644 --- a/src/ShaderNode/ShaderGraph.hpp +++ b/src/ShaderNode/ShaderGraph.hpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include #include #include @@ -67,7 +67,7 @@ class ShaderGraph void Load(const QJsonObject& data); QJsonObject Save(); - Nz::ShaderAst::StatementPtr ToAst() const; + Nz::ShaderAst::ModulePtr ToModule() const; Nz::ShaderAst::ExpressionType ToShaderExpressionType(const std::variant& type) const; void UpdateBuffer(std::size_t bufferIndex, std::string name, BufferType bufferType, std::size_t structIndex, std::size_t setIndex, std::size_t bindingIndex); diff --git a/src/ShaderNode/Widgets/CodeOutputWidget.cpp b/src/ShaderNode/Widgets/CodeOutputWidget.cpp index 04a788b3f..889a7440e 100644 --- a/src/ShaderNode/Widgets/CodeOutputWidget.cpp +++ b/src/ShaderNode/Widgets/CodeOutputWidget.cpp @@ -61,18 +61,18 @@ void CodeOutputWidget::Refresh() for (std::size_t i = 0; i < m_shaderGraph.GetOptionCount(); ++i) states.optionValues[i] = m_shaderGraph.IsOptionEnabled(i); - Nz::ShaderAst::StatementPtr shaderAst = m_shaderGraph.ToAst(); + Nz::ShaderAst::ModulePtr shaderModule = m_shaderGraph.ToModule(); if (m_optimisationCheckbox->isChecked()) { Nz::ShaderAst::SanitizeVisitor::Options sanitizeOptions; sanitizeOptions.optionValues = states.optionValues; - shaderAst = Nz::ShaderAst::Sanitize(*shaderAst, sanitizeOptions); + shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOptions); Nz::ShaderAst::AstConstantPropagationVisitor optimiser; - shaderAst = Nz::ShaderAst::PropagateConstants(*shaderAst); - shaderAst = Nz::ShaderAst::EliminateUnusedPass(*shaderAst); + shaderModule = Nz::ShaderAst::PropagateConstants(*shaderModule); + shaderModule = Nz::ShaderAst::EliminateUnusedPass(*shaderModule); } std::string output; @@ -89,21 +89,21 @@ void CodeOutputWidget::Refresh() bindingMapping.emplace(Nz::UInt64(texture.setIndex) << 32 | Nz::UInt64(texture.bindingIndex), bindingMapping.size()); Nz::GlslWriter writer; - output = writer.Generate(ShaderGraph::ToShaderStageType(m_shaderGraph.GetType()), *shaderAst, bindingMapping, states); + output = writer.Generate(ShaderGraph::ToShaderStageType(m_shaderGraph.GetType()), *shaderModule, bindingMapping, states); break; } case OutputLanguage::NZSL: { Nz::LangWriter writer; - output = writer.Generate(*shaderAst, states); + output = writer.Generate(*shaderModule, states); break; } case OutputLanguage::SpirV: { Nz::SpirvWriter writer; - std::vector spirv = writer.Generate(*shaderAst, states); + std::vector spirv = writer.Generate(*shaderModule, states); Nz::SpirvPrinter printer; output = printer.Print(spirv.data(), spirv.size()); diff --git a/src/ShaderNode/Widgets/MainWindow.cpp b/src/ShaderNode/Widgets/MainWindow.cpp index 851356cda..82a8dfae9 100644 --- a/src/ShaderNode/Widgets/MainWindow.cpp +++ b/src/ShaderNode/Widgets/MainWindow.cpp @@ -181,7 +181,7 @@ void MainWindow::OnCompile() { try { - auto shader = m_shaderGraph.ToAst(); + auto shaderModule = m_shaderGraph.ToModule(); QString fileName = QFileDialog::getSaveFileName(nullptr, tr("Save shader"), QString(), tr("Shader Files (*.shader)")); if (fileName.isEmpty()) @@ -191,7 +191,7 @@ void MainWindow::OnCompile() fileName += ".shader"; Nz::File file(fileName.toStdString(), Nz::OpenMode::WriteOnly); - file.Write(Nz::ShaderAst::SerializeShader(shader)); + file.Write(Nz::ShaderAst::SerializeShader(*shaderModule)); } catch (const std::exception& e) { diff --git a/tests/Engine/Shader/AccessMemberTest.cpp b/tests/Engine/Shader/AccessMemberTest.cpp index dbec148ec..e87c99ea7 100644 --- a/tests/Engine/Shader/AccessMemberTest.cpp +++ b/tests/Engine/Shader/AccessMemberTest.cpp @@ -11,6 +11,9 @@ TEST_CASE("structure member access", "[Shader]") SECTION("Nested member loading") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + struct innerStruct { field: vec3[f32] @@ -27,9 +30,7 @@ external } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); - REQUIRE(shader->GetType() == Nz::ShaderAst::NodeType::MultiStatement); - Nz::ShaderAst::MultiStatement& multiStatement = static_cast(*shader); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); SECTION("Nested AccessMember") { @@ -40,16 +41,16 @@ external auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(secondAccess), { 2u }); auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::PrimitiveType::Float32 }, std::move(swizzle)); - multiStatement.statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl))); + shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl))); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { float result = ubo.s.field.z; } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(vert)] fn main() { @@ -57,7 +58,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -77,16 +78,16 @@ OpFunctionEnd)"); auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(access), { 2u }); auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::PrimitiveType::Float32 }, std::move(swizzle)); - multiStatement.statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl))); + shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl))); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { float result = ubo.s.field.z; } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(vert)] fn main() { @@ -94,7 +95,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable diff --git a/tests/Engine/Shader/Branch.cpp b/tests/Engine/Shader/Branch.cpp index cab7aa33e..d7d87a8d9 100644 --- a/tests/Engine/Shader/Branch.cpp +++ b/tests/Engine/Shader/Branch.cpp @@ -11,6 +11,9 @@ TEST_CASE("branching", "[Shader]") WHEN("using a simple branch") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: f32 @@ -32,9 +35,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { float value; @@ -50,7 +53,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -67,7 +70,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -90,6 +93,9 @@ OpFunctionEnd)"); WHEN("discarding in a branch") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: f32 @@ -108,9 +114,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { if (data.value > (0.000000)) @@ -121,7 +127,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -133,7 +139,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpAccessChain @@ -154,6 +160,9 @@ OpFunctionEnd)"); WHEN("using a complex branch") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: f32 @@ -179,9 +188,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { float value; @@ -205,7 +214,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -230,7 +239,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable diff --git a/tests/Engine/Shader/Const.cpp b/tests/Engine/Shader/Const.cpp index df99b2e31..451dc432d 100644 --- a/tests/Engine/Shader/Const.cpp +++ b/tests/Engine/Shader/Const.cpp @@ -8,10 +8,10 @@ #include #include -void ExpectOutput(Nz::ShaderAst::Statement& shader, const Nz::ShaderAst::SanitizeVisitor::Options& options, std::string_view expectedOptimizedResult) +void ExpectOutput(Nz::ShaderAst::Module& shaderModule, const Nz::ShaderAst::SanitizeVisitor::Options& options, std::string_view expectedOptimizedResult) { - Nz::ShaderAst::StatementPtr sanitizedShader; - REQUIRE_NOTHROW(sanitizedShader = Nz::ShaderAst::Sanitize(shader, options)); + Nz::ShaderAst::ModulePtr sanitizedShader; + REQUIRE_NOTHROW(sanitizedShader = Nz::ShaderAst::Sanitize(shaderModule, options)); ExpectNZSL(*sanitizedShader, expectedOptimizedResult); } @@ -21,6 +21,9 @@ TEST_CASE("const", "[Shader]") WHEN("using const if") { std::string_view sourceCode = R"( +[nzsl_version("1.0")] +module; + option UseInt: bool = false; [cond(UseInt)] @@ -56,8 +59,8 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader; - REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); + Nz::ShaderAst::ModulePtr shaderModule; + REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode)); Nz::ShaderAst::SanitizeVisitor::Options options; @@ -65,7 +68,7 @@ fn main() { options.optionValues[0] = true; - ExpectOutput(*shader, options, R"( + ExpectOutput(*shaderModule, options, R"( struct inputStruct { value: i32 @@ -89,7 +92,7 @@ fn main() { options.optionValues[0] = false; - ExpectOutput(*shader, options, R"( + ExpectOutput(*shaderModule, options, R"( struct inputStruct { value: f32 @@ -113,6 +116,9 @@ fn main() WHEN("using [unroll] attribute on numerical for") { std::string_view sourceCode = R"( +[nzsl_version("1.0")] +module; + const LightCount = 3; [layout(std140)] @@ -145,10 +151,10 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader; - REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); + Nz::ShaderAst::ModulePtr shaderModule; + REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode)); - ExpectOutput(*shader, {}, R"( + ExpectOutput(*shaderModule, {}, R"( [entry(frag)] fn main() { @@ -170,6 +176,9 @@ fn main() WHEN("using [unroll] attribute on for-each") { std::string_view sourceCode = R"( +[nzsl_version("1.0")] +module; + const LightCount = 3; [layout(std140)] @@ -202,10 +211,10 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader; - REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); + Nz::ShaderAst::ModulePtr shaderModule; + REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode)); - ExpectOutput(*shader, {}, R"( + ExpectOutput(*shaderModule, {}, R"( [entry(frag)] fn main() { diff --git a/tests/Engine/Shader/Loops.cpp b/tests/Engine/Shader/Loops.cpp index 2f9c019fe..233d35ddc 100644 --- a/tests/Engine/Shader/Loops.cpp +++ b/tests/Engine/Shader/Loops.cpp @@ -11,6 +11,9 @@ TEST_CASE("loops", "[Shader]") WHEN("using a while") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: f32 @@ -34,9 +37,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { float value = 0.000000; @@ -50,7 +53,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -65,7 +68,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -94,6 +97,9 @@ OpFunctionEnd)"); WHEN("using a for range") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -105,10 +111,10 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { int x = 0; @@ -123,7 +129,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -136,7 +142,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -169,6 +175,9 @@ OpFunctionEnd)"); WHEN("using a for range with step") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -180,10 +189,10 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { int x = 0; @@ -199,7 +208,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -212,7 +221,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -248,6 +257,9 @@ OpFunctionEnd)"); WHEN("using a for-each") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: array[f32, 10] @@ -269,10 +281,10 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { float x = 0.000000; @@ -287,7 +299,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -300,7 +312,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable diff --git a/tests/Engine/Shader/Optimizations.cpp b/tests/Engine/Shader/Optimizations.cpp index 96d97f341..d157845f6 100644 --- a/tests/Engine/Shader/Optimizations.cpp +++ b/tests/Engine/Shader/Optimizations.cpp @@ -9,24 +9,33 @@ #include #include +template +std::unique_ptr static_unique_pointer_cast(std::unique_ptr&& ptr) +{ + return std::unique_ptr(Nz::SafeCast(ptr.release())); +} + void PropagateConstantAndExpect(std::string_view sourceCode, std::string_view expectedOptimizedResult) { - Nz::ShaderAst::StatementPtr shader; - REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); - REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader)); - REQUIRE_NOTHROW(shader = Nz::ShaderAst::PropagateConstants(*shader)); + Nz::ShaderAst::ModulePtr shaderModule; + REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode)); + REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule)); + REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::PropagateConstants(*shaderModule)); - ExpectNZSL(*shader, expectedOptimizedResult); + ExpectNZSL(*shaderModule, expectedOptimizedResult); } void EliminateUnusedAndExpect(std::string_view sourceCode, std::string_view expectedOptimizedResult) { - Nz::ShaderAst::StatementPtr shader; - REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); - REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader)); - REQUIRE_NOTHROW(shader = Nz::ShaderAst::EliminateUnusedPass(*shader)); + Nz::ShaderAst::DependencyCheckerVisitor::Config depConfig; + depConfig.usedShaderStages = Nz::ShaderStageType_All; - ExpectNZSL(*shader, expectedOptimizedResult); + Nz::ShaderAst::ModulePtr shaderModule; + REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode)); + REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule)); + REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::EliminateUnusedPass(*shaderModule, depConfig)); + + ExpectNZSL(*shaderModule, expectedOptimizedResult); } TEST_CASE("optimizations", "[Shader]") @@ -34,6 +43,9 @@ TEST_CASE("optimizations", "[Shader]") WHEN("propagating constants") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -51,6 +63,9 @@ fn main() WHEN("propagating vector constants") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -67,6 +82,9 @@ fn main() WHEN("eliminating simple branch") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -85,6 +103,9 @@ fn main() WHEN("eliminating multiple branches") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -116,6 +137,9 @@ fn main() WHEN("eliminating multiple split branches") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -158,6 +182,9 @@ fn main() WHEN("optimizing out scalar swizzle") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -175,6 +202,9 @@ fn main() WHEN("optimizing out scalar swizzle to vector") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -192,6 +222,9 @@ fn main() WHEN("optimizing out vector swizzle") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -209,6 +242,9 @@ fn main() WHEN("optimizing out vector swizzle with repetition") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -226,6 +262,9 @@ fn main() WHEN("optimizing out complex swizzle") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -243,6 +282,9 @@ fn main() WHEN("optimizing out complex swizzle on unknown value") { PropagateConstantAndExpect(R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: vec4[f32] @@ -270,6 +312,9 @@ fn main() WHEN("eliminating unused code") { EliminateUnusedAndExpect(R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: vec4[f32] @@ -305,6 +350,9 @@ fn main() -> Output output.value = data.value; return output; })", R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: vec4[f32] diff --git a/tests/Engine/Shader/Sanitizations.cpp b/tests/Engine/Shader/Sanitizations.cpp index e0e565500..abe810302 100644 --- a/tests/Engine/Shader/Sanitizations.cpp +++ b/tests/Engine/Shader/Sanitizations.cpp @@ -12,6 +12,9 @@ TEST_CASE("sanitizing", "[Shader]") WHEN("splitting branches") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: f32 @@ -37,14 +40,14 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::SanitizeVisitor::Options options; options.splitMultipleBranches = true; - REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader, options)); + REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, options)); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -82,6 +85,9 @@ fn main() WHEN("reducing for-each to while") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: array[f32, 10] @@ -103,14 +109,14 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::SanitizeVisitor::Options options; options.reduceLoopsToWhile = true; - REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader, options)); + REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, options)); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -131,6 +137,9 @@ fn main() WHEN("removing matrix casts") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + fn testMat2ToMat2(input: mat2[f32]) -> mat2[f32] { return mat2[f32](input); @@ -177,14 +186,14 @@ fn testMat4ToMat4(input: mat4[f32]) -> mat4[f32] } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::SanitizeVisitor::Options options; options.removeMatrixCast = true; - REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader, options)); + REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, options)); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( fn testMat2ToMat2(input: mat2[f32]) -> mat2[f32] { return input; diff --git a/tests/Engine/Shader/ShaderUtils.cpp b/tests/Engine/Shader/ShaderUtils.cpp index 27f20cb46..e17626331 100644 --- a/tests/Engine/Shader/ShaderUtils.cpp +++ b/tests/Engine/Shader/ShaderUtils.cpp @@ -121,7 +121,7 @@ namespace }; } -void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput) +void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput) { expectedOutput = Nz::Trim(expectedOutput); @@ -140,7 +140,7 @@ void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutpu }; Nz::ShaderAst::AstReflect reflectVisitor; - reflectVisitor.Reflect(shader, callbacks); + reflectVisitor.Reflect(*shader.rootNode, callbacks); INFO("no entry point found"); REQUIRE(entryShaderStage.has_value()); @@ -186,7 +186,7 @@ void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutpu } } -void ExpectNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput) +void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput) { expectedOutput = Nz::Trim(expectedOutput); @@ -209,7 +209,7 @@ void ExpectNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutpu } } -void ExpectSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput, bool outputParameter) +void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter) { expectedOutput = Nz::Trim(expectedOutput); diff --git a/tests/Engine/Shader/ShaderUtils.hpp b/tests/Engine/Shader/ShaderUtils.hpp index 023c5d8ad..416396cee 100644 --- a/tests/Engine/Shader/ShaderUtils.hpp +++ b/tests/Engine/Shader/ShaderUtils.hpp @@ -3,11 +3,11 @@ #ifndef NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP #define NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP -#include +#include #include -void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput); -void ExpectNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput); -void ExpectSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput, bool outputParameter = false); +void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput); +void ExpectNZSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput); +void ExpectSPIRV(Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false); #endif diff --git a/tests/Engine/Shader/Swizzle.cpp b/tests/Engine/Shader/Swizzle.cpp index 6922e9424..b65867fc3 100644 --- a/tests/Engine/Shader/Swizzle.cpp +++ b/tests/Engine/Shader/Swizzle.cpp @@ -13,6 +13,9 @@ TEST_CASE("swizzle", "[Shader]") WHEN("reading") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -21,9 +24,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { vec4 vec = vec4(0.000000, 1.000000, 2.000000, 3.000000); @@ -31,7 +34,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -40,7 +43,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -57,6 +60,9 @@ OpFunctionEnd)"); WHEN("writing") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -65,9 +71,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { vec4 vec = vec4(0.000000, 0.000000, 0.000000, 0.000000); @@ -75,7 +81,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -84,7 +90,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -104,6 +110,9 @@ OpFunctionEnd)"); GIVEN("a variable") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -112,9 +121,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { float value = 42.000000; @@ -122,7 +131,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -131,7 +140,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -147,6 +156,9 @@ OpFunctionEnd)"); GIVEN("a function value") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -155,9 +167,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { float cachedResult = max(2.000000, 1.000000); @@ -167,7 +179,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -176,7 +188,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -197,6 +209,9 @@ OpFunctionEnd)"); WHEN("reading") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -205,9 +220,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { vec4 vec = vec4(0.000000, 1.000000, 2.000000, 3.000000); @@ -215,7 +230,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -224,7 +239,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable @@ -244,6 +259,9 @@ OpFunctionEnd)"); WHEN("writing") { std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() { @@ -253,9 +271,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shaderModule, R"( void main() { vec4 vec = vec4(0.000000, 1.000000, 2.000000, 3.000000); @@ -264,7 +282,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shaderModule, R"( [entry(frag)] fn main() { @@ -274,7 +292,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel OpVariable