From bf7f06ac4cb7e6d841f9e8ec261d19d3034a6af4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Thu, 10 Mar 2022 12:31:00 +0100 Subject: [PATCH] Shader: Fix shader serialization --- include/Nazara/Shader/Ast/AstSerializer.hpp | 10 ++++- src/Nazara/Shader/Ast/AstSerializer.cpp | 46 +++++++++++++++++---- src/ShaderNode/Widgets/MainWindow.cpp | 2 +- tests/Engine/Shader/SerializationsTests.cpp | 2 +- 4 files changed, 48 insertions(+), 12 deletions(-) diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index 99b68f1c0..fcc673f00 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -74,6 +74,8 @@ namespace Nz::ShaderAst virtual bool IsWriting() const = 0; + virtual void SerializeModule(ModulePtr& module) = 0; + virtual void Node(ExpressionPtr& node) = 0; virtual void Node(StatementPtr& node) = 0; @@ -102,11 +104,13 @@ namespace Nz::ShaderAst inline ShaderAstSerializer(ByteStream& stream); ~ShaderAstSerializer() = default; - void Serialize(Module& shader); + void Serialize(ModulePtr& shader); private: using AstSerializerBase::Serialize; + void SerializeModule(ModulePtr& module) override; + bool IsWriting() const override; void Node(ExpressionPtr& node) override; void Node(StatementPtr& node) override; @@ -140,6 +144,8 @@ namespace Nz::ShaderAst private: using AstSerializerBase::Serialize; + void SerializeModule(ModulePtr& module) override; + bool IsWriting() const override; void Node(ExpressionPtr& node) override; void Node(StatementPtr& node) override; @@ -162,7 +168,7 @@ namespace Nz::ShaderAst ByteStream& m_stream; }; - NAZARA_SHADER_API ByteArray SerializeShader(Module& shader); + NAZARA_SHADER_API ByteArray SerializeShader(ModulePtr& shader); inline ModulePtr UnserializeShader(const void* data, std::size_t size); NAZARA_SHADER_API ModulePtr UnserializeShader(ByteStream& stream); } diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 3cdcc9934..83d3a92dd 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -376,17 +376,31 @@ namespace Nz::ShaderAst Node(node.body); } - void ShaderAstSerializer::Serialize(Module& module) + void ShaderAstSerializer::Serialize(ModulePtr& module) { m_stream << s_magicNumber << s_currentVersion; - m_stream << module.metadata->moduleId; - m_stream << module.metadata->shaderLangVersion; - Serialize(*module.rootNode); + SerializeModule(module); m_stream.FlushBits(); } + void ShaderAstSerializer::SerializeModule(ModulePtr& module) + { + m_stream << module->metadata->moduleId; + m_stream << module->metadata->shaderLangVersion; + + Container(module->importedModules); + for (auto& importedModule : module->importedModules) + { + Value(importedModule.identifier); + SerializeModule(importedModule.module); + } + + ShaderSerializerVisitor visitor(*this); + module->rootNode->Visit(visitor); + } + bool ShaderAstSerializer::IsWriting() const { return true; @@ -583,16 +597,32 @@ namespace Nz::ShaderAst if (version > s_currentVersion) throw std::runtime_error("unsupported version"); + ModulePtr module; + SerializeModule(module); + + return module; + } + + void ShaderAstUnserializer::SerializeModule(ModulePtr& module) + { std::shared_ptr metadata = std::make_shared(); m_stream >> metadata->moduleId; m_stream >> metadata->shaderLangVersion; - ModulePtr module = std::make_shared(std::move(metadata)); + std::vector importedModules; + Container(importedModules); + for (auto& importedModule : importedModules) + { + Value(const_cast(importedModule.identifier)); //< not used for writing + SerializeModule(importedModule.module); + } + + MultiStatementPtr rootNode = std::make_unique(); ShaderSerializerVisitor visitor(*this); - module->rootNode->Visit(visitor); + rootNode->Visit(visitor); - return module; + module = std::make_shared(std::move(metadata), std::move(rootNode), std::move(importedModules)); } bool ShaderAstUnserializer::IsWriting() const @@ -903,7 +933,7 @@ namespace Nz::ShaderAst } - ByteArray SerializeShader(Module& module) + ByteArray SerializeShader(ModulePtr& module) { ByteArray byteArray; ByteStream stream(&byteArray, OpenModeFlags(OpenMode::WriteOnly)); diff --git a/src/ShaderNode/Widgets/MainWindow.cpp b/src/ShaderNode/Widgets/MainWindow.cpp index 82a8dfae9..8264bc003 100644 --- a/src/ShaderNode/Widgets/MainWindow.cpp +++ b/src/ShaderNode/Widgets/MainWindow.cpp @@ -191,7 +191,7 @@ void MainWindow::OnCompile() fileName += ".shader"; Nz::File file(fileName.toStdString(), Nz::OpenMode::WriteOnly); - file.Write(Nz::ShaderAst::SerializeShader(*shaderModule)); + file.Write(Nz::ShaderAst::SerializeShader(shaderModule)); } catch (const std::exception& e) { diff --git a/tests/Engine/Shader/SerializationsTests.cpp b/tests/Engine/Shader/SerializationsTests.cpp index 79b44eb39..5b9370873 100644 --- a/tests/Engine/Shader/SerializationsTests.cpp +++ b/tests/Engine/Shader/SerializationsTests.cpp @@ -18,7 +18,7 @@ void ParseSerializeUnserialize(std::string_view sourceCode, bool sanitize) REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule)); Nz::ByteArray serializedModule; - REQUIRE_NOTHROW(serializedModule = Nz::ShaderAst::SerializeShader(*shaderModule)); + REQUIRE_NOTHROW(serializedModule = Nz::ShaderAst::SerializeShader(shaderModule)); Nz::ByteStream byteStream(&serializedModule); Nz::ShaderAst::ModulePtr unserializedShader;