Fix after rebase

This commit is contained in:
Lynix 2022-03-06 19:48:46 +01:00 committed by Jérôme Leclercq
parent a7acf32886
commit 8dcce73738
3 changed files with 60 additions and 9 deletions

View File

@ -10,17 +10,20 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/Ast/Attribute.hpp>
#include <Nazara/Shader/Ast/Nodes.hpp>
#include <Nazara/Shader/Ast/Module.hpp>
#include <vector>
namespace Nz::ShaderAst
{
inline bool Compare(const Expression& lhs, const Expression& rhs);
inline bool Compare(const Module& lhs, const Module& rhs);
inline bool Compare(const Module::Metadata& lhs, const Module::Metadata& rhs);
inline bool Compare(const Statement& lhs, const Statement& rhs);
template<typename T> bool Compare(const T& lhs, const T& rhs);
template<typename T, std::size_t S> bool Compare(const std::array<T, S>& lhs, const std::array<T, S>& rhs);
template<typename T> bool Compare(const std::vector<T>& lhs, const std::vector<T>& rhs);
template<typename T> bool Compare(const std::unique_ptr<T>& lhs, const std::unique_ptr<T>& rhs);
template<typename T> bool Compare(const ExpressionValue<T>& lhs, const ExpressionValue<T>& rhs);
inline bool Compare(const BranchStatement::ConditionalStatement& lhs, const BranchStatement::ConditionalStatement& rhs);
inline bool Compare(const DeclareExternalStatement::ExternalVar& lhs, const DeclareExternalStatement::ExternalVar& rhs);

View File

@ -26,6 +26,28 @@ namespace Nz::ShaderAst
return true;
}
bool Compare(const Module& lhs, const Module& rhs)
{
if (!Compare(*lhs.metadata, *rhs.metadata))
return false;
if (!Compare(*lhs.rootNode, *rhs.rootNode))
return false;
return true;
}
bool Compare(const Module::Metadata& lhs, const Module::Metadata& rhs)
{
if (!Compare(lhs.moduleId, rhs.moduleId))
return false;
if (!Compare(lhs.shaderLangVersion, rhs.shaderLangVersion))
return false;
return true;
}
inline bool Compare(const Statement& lhs, const Statement& rhs)
{
if (lhs.GetType() != rhs.GetType())
@ -77,6 +99,17 @@ namespace Nz::ShaderAst
return true;
}
template<typename T>
bool Compare(const std::unique_ptr<T>& lhs, const std::unique_ptr<T>& rhs)
{
if (lhs == nullptr)
return rhs == nullptr;
else if (rhs == nullptr)
return false;
return Compare(*lhs, *rhs);
}
template<typename T>
bool Compare(const ExpressionValue<T>& lhs, const ExpressionValue<T>& rhs)
{

View File

@ -11,20 +11,20 @@
void ParseSerializeUnserialize(std::string_view sourceCode, bool sanitize)
{
Nz::ShaderAst::StatementPtr shader;
REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode));
Nz::ShaderAst::ModulePtr shaderModule;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
if (sanitize)
REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader));
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule));
Nz::ByteArray serializedShader;
REQUIRE_NOTHROW(serializedShader = Nz::ShaderAst::SerializeShader(shader));
Nz::ByteArray serializedModule;
REQUIRE_NOTHROW(serializedModule = Nz::ShaderAst::SerializeShader(*shaderModule));
Nz::ByteStream byteStream(&serializedShader);
Nz::ShaderAst::StatementPtr unserializedShader;
Nz::ByteStream byteStream(&serializedModule);
Nz::ShaderAst::ModulePtr unserializedShader;
REQUIRE_NOTHROW(unserializedShader = Nz::ShaderAst::UnserializeShader(byteStream));
CHECK(Nz::ShaderAst::Compare(*shader, *unserializedShader));
CHECK(Nz::ShaderAst::Compare(*shaderModule, *unserializedShader));
}
void ParseSerializeUnserialize(std::string_view sourceCode)
@ -38,6 +38,9 @@ TEST_CASE("serialization", "[Shader]")
WHEN("serializing and unserializing a simple shader")
{
ParseSerializeUnserialize(R"(
[nzsl_version("1.0")]
module;
struct Data
{
value: f32
@ -67,6 +70,9 @@ fn main() -> Output
WHEN("serializing and unserializing branches")
{
ParseSerializeUnserialize(R"(
[nzsl_version("1.0")]
module;
struct inputStruct
{
value: f32
@ -96,6 +102,9 @@ fn main()
WHEN("serializing and unserializing consts")
{
ParseSerializeUnserialize(R"(
[nzsl_version("1.0")]
module;
option UseInt: bool = false;
[cond(UseInt)]
@ -135,6 +144,9 @@ fn main()
WHEN("serializing and unserializing loops")
{
ParseSerializeUnserialize(R"(
[nzsl_version("1.0")]
module;
struct inputStruct
{
value: array[f32, 10]
@ -174,6 +186,9 @@ fn main()
WHEN("serializing and unserializing swizzles")
{
ParseSerializeUnserialize(R"(
[nzsl_version("1.0")]
module;
[entry(frag)]
fn main()
{