diff --git a/include/Nazara/Shader/Ast/AstCompare.hpp b/include/Nazara/Shader/Ast/AstCompare.hpp index 4d3094028..5ec128ff9 100644 --- a/include/Nazara/Shader/Ast/AstCompare.hpp +++ b/include/Nazara/Shader/Ast/AstCompare.hpp @@ -10,17 +10,20 @@ #include #include #include -#include +#include #include namespace Nz::ShaderAst { inline bool Compare(const Expression& lhs, const Expression& rhs); + inline bool Compare(const Module& lhs, const Module& rhs); + inline bool Compare(const Module::Metadata& lhs, const Module::Metadata& rhs); inline bool Compare(const Statement& lhs, const Statement& rhs); template bool Compare(const T& lhs, const T& rhs); template bool Compare(const std::array& lhs, const std::array& rhs); template bool Compare(const std::vector& lhs, const std::vector& rhs); + template bool Compare(const std::unique_ptr& lhs, const std::unique_ptr& rhs); template bool Compare(const ExpressionValue& lhs, const ExpressionValue& rhs); inline bool Compare(const BranchStatement::ConditionalStatement& lhs, const BranchStatement::ConditionalStatement& rhs); inline bool Compare(const DeclareExternalStatement::ExternalVar& lhs, const DeclareExternalStatement::ExternalVar& rhs); diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index 95e6d40e8..86e2a73b3 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -26,6 +26,28 @@ namespace Nz::ShaderAst return true; } + bool Compare(const Module& lhs, const Module& rhs) + { + if (!Compare(*lhs.metadata, *rhs.metadata)) + return false; + + if (!Compare(*lhs.rootNode, *rhs.rootNode)) + return false; + + return true; + } + + bool Compare(const Module::Metadata& lhs, const Module::Metadata& rhs) + { + if (!Compare(lhs.moduleId, rhs.moduleId)) + return false; + + if (!Compare(lhs.shaderLangVersion, rhs.shaderLangVersion)) + return false; + + return true; + } + inline bool Compare(const Statement& lhs, const Statement& rhs) { if (lhs.GetType() != rhs.GetType()) @@ -77,6 +99,17 @@ namespace Nz::ShaderAst return true; } + template + bool Compare(const std::unique_ptr& lhs, const std::unique_ptr& rhs) + { + if (lhs == nullptr) + return rhs == nullptr; + else if (rhs == nullptr) + return false; + + return Compare(*lhs, *rhs); + } + template bool Compare(const ExpressionValue& lhs, const ExpressionValue& rhs) { diff --git a/tests/Engine/Shader/Serializations.cpp b/tests/Engine/Shader/Serializations.cpp index 7f01064fb..79b44eb39 100644 --- a/tests/Engine/Shader/Serializations.cpp +++ b/tests/Engine/Shader/Serializations.cpp @@ -11,20 +11,20 @@ void ParseSerializeUnserialize(std::string_view sourceCode, bool sanitize) { - Nz::ShaderAst::StatementPtr shader; - REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); + Nz::ShaderAst::ModulePtr shaderModule; + REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode)); if (sanitize) - REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader)); + REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule)); - Nz::ByteArray serializedShader; - REQUIRE_NOTHROW(serializedShader = Nz::ShaderAst::SerializeShader(shader)); + Nz::ByteArray serializedModule; + REQUIRE_NOTHROW(serializedModule = Nz::ShaderAst::SerializeShader(*shaderModule)); - Nz::ByteStream byteStream(&serializedShader); - Nz::ShaderAst::StatementPtr unserializedShader; + Nz::ByteStream byteStream(&serializedModule); + Nz::ShaderAst::ModulePtr unserializedShader; REQUIRE_NOTHROW(unserializedShader = Nz::ShaderAst::UnserializeShader(byteStream)); - CHECK(Nz::ShaderAst::Compare(*shader, *unserializedShader)); + CHECK(Nz::ShaderAst::Compare(*shaderModule, *unserializedShader)); } void ParseSerializeUnserialize(std::string_view sourceCode) @@ -38,6 +38,9 @@ TEST_CASE("serialization", "[Shader]") WHEN("serializing and unserializing a simple shader") { ParseSerializeUnserialize(R"( +[nzsl_version("1.0")] +module; + struct Data { value: f32 @@ -67,6 +70,9 @@ fn main() -> Output WHEN("serializing and unserializing branches") { ParseSerializeUnserialize(R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: f32 @@ -96,6 +102,9 @@ fn main() WHEN("serializing and unserializing consts") { ParseSerializeUnserialize(R"( +[nzsl_version("1.0")] +module; + option UseInt: bool = false; [cond(UseInt)] @@ -135,6 +144,9 @@ fn main() WHEN("serializing and unserializing loops") { ParseSerializeUnserialize(R"( +[nzsl_version("1.0")] +module; + struct inputStruct { value: array[f32, 10] @@ -174,6 +186,9 @@ fn main() WHEN("serializing and unserializing swizzles") { ParseSerializeUnserialize(R"( +[nzsl_version("1.0")] +module; + [entry(frag)] fn main() {