diff --git a/include/Nazara/Renderer/ShaderNodes.hpp b/include/Nazara/Renderer/ShaderNodes.hpp index d6bafc911..659e643a3 100644 --- a/include/Nazara/Renderer/ShaderNodes.hpp +++ b/include/Nazara/Renderer/ShaderNodes.hpp @@ -34,6 +34,7 @@ namespace Nz virtual ~Node(); inline NodeType GetType() const; + inline bool IsStatement() const; virtual void Visit(ShaderVisitor& visitor) = 0; @@ -41,10 +42,24 @@ namespace Nz static inline ExpressionType GetComponentType(ExpressionType type); protected: - inline Node(NodeType type); + inline Node(NodeType type, bool isStatement); private: NodeType m_type; + bool m_isStatement; + }; + + class Expression; + + using ExpressionPtr = std::shared_ptr; + + class NAZARA_RENDERER_API Expression : public Node + { + public: + inline Expression(NodeType type); + + virtual ExpressionCategory GetExpressionCategory() const; + virtual ExpressionType GetExpressionType() const = 0; }; class Statement; @@ -54,20 +69,7 @@ namespace Nz class NAZARA_RENDERER_API Statement : public Node { public: - using Node::Node; - }; - - class Expression; - - using ExpressionPtr = std::shared_ptr; - - class NAZARA_RENDERER_API Expression : public Node - { - public: - using Node::Node; - - virtual ExpressionCategory GetExpressionCategory() const; - virtual ExpressionType GetExpressionType() const = 0; + inline Statement(NodeType type); }; struct NAZARA_RENDERER_API ExpressionStatement : public Statement diff --git a/include/Nazara/Renderer/ShaderNodes.inl b/include/Nazara/Renderer/ShaderNodes.inl index 1bcf4effa..bafad38f5 100644 --- a/include/Nazara/Renderer/ShaderNodes.inl +++ b/include/Nazara/Renderer/ShaderNodes.inl @@ -7,8 +7,9 @@ namespace Nz::ShaderNodes { - inline Node::Node(NodeType type) : - m_type(type) + inline Node::Node(NodeType type, bool isStatement) : + m_type(type), + m_isStatement(isStatement) { } @@ -17,6 +18,11 @@ namespace Nz::ShaderNodes return m_type; } + inline bool Node::IsStatement() const + { + return m_isStatement; + } + inline unsigned int Node::GetComponentCount(ExpressionType type) { switch (type) @@ -55,6 +61,19 @@ namespace Nz::ShaderNodes } } + + inline Expression::Expression(NodeType type) : + Node(type, false) + { + } + + inline Statement::Statement(NodeType type) : + Node(type, true) + { + } + + + inline ExpressionStatement::ExpressionStatement() : Statement(NodeType::ExpressionStatement) { diff --git a/include/Nazara/Renderer/ShaderSerializer.hpp b/include/Nazara/Renderer/ShaderSerializer.hpp index daeb13e73..a5e012a9c 100644 --- a/include/Nazara/Renderer/ShaderSerializer.hpp +++ b/include/Nazara/Renderer/ShaderSerializer.hpp @@ -14,8 +14,10 @@ #include #include -namespace Nz::ShaderNodes +namespace Nz { + class ShaderAst; + class NAZARA_RENDERER_API ShaderSerializerBase { public: @@ -24,28 +26,29 @@ namespace Nz::ShaderNodes ShaderSerializerBase(ShaderSerializerBase&&) = delete; ~ShaderSerializerBase() = default; - void Serialize(AssignOp& node); - void Serialize(BinaryOp& node); - void Serialize(BuiltinVariable& var); - void Serialize(Branch& node); - void Serialize(Cast& node); - void Serialize(Constant& node); - void Serialize(DeclareVariable& node); - void Serialize(ExpressionStatement& node); - void Serialize(Identifier& node); - void Serialize(IntrinsicCall& node); - void Serialize(NamedVariable& var); - void Serialize(Sample2D& node); - void Serialize(StatementBlock& node); - void Serialize(SwizzleOp& node); + void Serialize(ShaderNodes::AssignOp& node); + void Serialize(ShaderNodes::BinaryOp& node); + void Serialize(ShaderNodes::BuiltinVariable& var); + void Serialize(ShaderNodes::Branch& node); + void Serialize(ShaderNodes::Cast& node); + void Serialize(ShaderNodes::Constant& node); + void Serialize(ShaderNodes::DeclareVariable& node); + void Serialize(ShaderNodes::ExpressionStatement& node); + void Serialize(ShaderNodes::Identifier& node); + void Serialize(ShaderNodes::IntrinsicCall& node); + void Serialize(ShaderNodes::NamedVariable& var); + void Serialize(ShaderNodes::Sample2D& node); + void Serialize(ShaderNodes::StatementBlock& node); + void Serialize(ShaderNodes::SwizzleOp& node); protected: template void Container(T& container); template void Enum(T& enumVal); + template void OptVal(std::optional& optVal); virtual bool IsWriting() const = 0; - virtual void Node(NodePtr& node) = 0; + virtual void Node(ShaderNodes::NodePtr& node) = 0; template void Node(std::shared_ptr& node); virtual void Value(bool& val) = 0; @@ -57,7 +60,7 @@ namespace Nz::ShaderNodes virtual void Value(UInt32& val) = 0; inline void Value(std::size_t& val); - virtual void Variable(VariablePtr& var) = 0; + virtual void Variable(ShaderNodes::VariablePtr& var) = 0; template void Variable(std::shared_ptr& var); }; @@ -67,11 +70,12 @@ namespace Nz::ShaderNodes inline ShaderSerializer(ByteArray& byteArray); ~ShaderSerializer() = default; - void Serialize(const StatementPtr& shader); + void Serialize(const ShaderAst& shader); private: bool IsWriting() const override; - void Node(NodePtr& node) override; + void Node(const ShaderNodes::NodePtr& node); + void Node(ShaderNodes::NodePtr& node) override; void Value(bool& val) override; void Value(float& val) override; void Value(std::string& val) override; @@ -79,7 +83,7 @@ namespace Nz::ShaderNodes void Value(Vector3f& val) override; void Value(Vector4f& val) override; void Value(UInt32& val) override; - void Variable(VariablePtr& var) override; + void Variable(ShaderNodes::VariablePtr& var) override; ByteArray& m_byteArray; ByteStream m_stream; @@ -91,11 +95,11 @@ namespace Nz::ShaderNodes ShaderUnserializer(const ByteArray& byteArray); ~ShaderUnserializer() = default; - StatementPtr Unserialize(); + ShaderAst Unserialize(); private: bool IsWriting() const override; - void Node(NodePtr& node) override; + void Node(ShaderNodes::NodePtr& node) override; void Value(bool& val) override; void Value(float& val) override; void Value(std::string& val) override; @@ -103,14 +107,14 @@ namespace Nz::ShaderNodes void Value(Vector3f& val) override; void Value(Vector4f& val) override; void Value(UInt32& val) override; - void Variable(VariablePtr& var) override; + void Variable(ShaderNodes::VariablePtr& var) override; const ByteArray& m_byteArray; ByteStream m_stream; }; - NAZARA_RENDERER_API ByteArray Serialize(const StatementPtr& shader); - NAZARA_RENDERER_API StatementPtr Unserialize(const ByteArray& data); + NAZARA_RENDERER_API ByteArray SerializeShader(const ShaderAst& shader); + NAZARA_RENDERER_API ShaderAst UnserializeShader(const ByteArray& data); } #include diff --git a/include/Nazara/Renderer/ShaderSerializer.inl b/include/Nazara/Renderer/ShaderSerializer.inl index 354986bb1..a2b680b95 100644 --- a/include/Nazara/Renderer/ShaderSerializer.inl +++ b/include/Nazara/Renderer/ShaderSerializer.inl @@ -5,7 +5,7 @@ #include #include -namespace Nz::ShaderNodes +namespace Nz { template void ShaderSerializerBase::Container(T& container) @@ -36,12 +36,30 @@ namespace Nz::ShaderNodes enumVal = static_cast(value); } + template + void ShaderSerializerBase::OptVal(std::optional& optVal) + { + bool isWriting = IsWriting(); + + bool hasValue; + if (isWriting) + hasValue = optVal.has_value(); + + Value(hasValue); + + if (!isWriting && hasValue) + optVal.emplace(); + + if (optVal.has_value()) + Value(optVal.value()); + } + template void ShaderSerializerBase::Node(std::shared_ptr& node) { bool isWriting = IsWriting(); - NodePtr value; + ShaderNodes::NodePtr value; if (isWriting) value = node; @@ -55,7 +73,7 @@ namespace Nz::ShaderNodes { bool isWriting = IsWriting(); - VariablePtr value; + ShaderNodes::VariablePtr value; if (isWriting) value = var; diff --git a/src/Nazara/Renderer/GlslWriter.cpp b/src/Nazara/Renderer/GlslWriter.cpp index f32239663..8617c522e 100644 --- a/src/Nazara/Renderer/GlslWriter.cpp +++ b/src/Nazara/Renderer/GlslWriter.cpp @@ -330,9 +330,7 @@ namespace Nz Append(node.variable->name); if (node.expression) { - Append(" "); - Append("="); - Append(" "); + Append(" = "); Visit(node.expression); } diff --git a/src/Nazara/Renderer/ShaderSerializer.cpp b/src/Nazara/Renderer/ShaderSerializer.cpp index 909422f52..7d1b7ac2e 100644 --- a/src/Nazara/Renderer/ShaderSerializer.cpp +++ b/src/Nazara/Renderer/ShaderSerializer.cpp @@ -3,14 +3,18 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include #include #include -namespace Nz::ShaderNodes +namespace Nz { namespace { + constexpr UInt32 s_magicNumber = 0x4E534852; + constexpr UInt32 s_currentVersion = 1; + class ShaderSerializerVisitor : public ShaderVisitor, public ShaderVarVisitor { public: @@ -19,62 +23,62 @@ namespace Nz::ShaderNodes { } - void Visit(const AssignOp& node) override + void Visit(const ShaderNodes::AssignOp& node) override { Serialize(node); } - void Visit(const BinaryOp& node) override + void Visit(const ShaderNodes::BinaryOp& node) override { Serialize(node); } - void Visit(const Branch& node) override + void Visit(const ShaderNodes::Branch& node) override { Serialize(node); } - void Visit(const Cast& node) override + void Visit(const ShaderNodes::Cast& node) override { Serialize(node); } - void Visit(const Constant& node) override + void Visit(const ShaderNodes::Constant& node) override { Serialize(node); } - void Visit(const DeclareVariable& node) override + void Visit(const ShaderNodes::DeclareVariable& node) override { Serialize(node); } - void Visit(const ExpressionStatement& node) override + void Visit(const ShaderNodes::ExpressionStatement& node) override { Serialize(node); } - void Visit(const Identifier& node) override + void Visit(const ShaderNodes::Identifier& node) override { Serialize(node); } - void Visit(const IntrinsicCall& node) override + void Visit(const ShaderNodes::IntrinsicCall& node) override { Serialize(node); } - void Visit(const Sample2D& node) override + void Visit(const ShaderNodes::Sample2D& node) override { Serialize(node); } - void Visit(const StatementBlock& node) override + void Visit(const ShaderNodes::StatementBlock& node) override { Serialize(node); } - void Visit(const SwizzleOp& node) override + void Visit(const ShaderNodes::SwizzleOp& node) override { Serialize(node); } @@ -122,21 +126,21 @@ namespace Nz::ShaderNodes }; } - void ShaderSerializerBase::Serialize(AssignOp& node) + void ShaderSerializerBase::Serialize(ShaderNodes::AssignOp& node) { Enum(node.op); Node(node.left); Node(node.right); } - void ShaderSerializerBase::Serialize(BinaryOp& node) + void ShaderSerializerBase::Serialize(ShaderNodes::BinaryOp& node) { Enum(node.op); Node(node.left); Node(node.right); } - void ShaderSerializerBase::Serialize(Branch& node) + void ShaderSerializerBase::Serialize(ShaderNodes::Branch& node) { Container(node.condStatements); for (auto& condStatement : node.condStatements) @@ -148,64 +152,64 @@ namespace Nz::ShaderNodes Node(node.elseStatement); } - void ShaderSerializerBase::Serialize(BuiltinVariable& node) + void ShaderSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node) { Enum(node.type); Enum(node.type); } - void ShaderSerializerBase::Serialize(Cast& node) + void ShaderSerializerBase::Serialize(ShaderNodes::Cast& node) { Enum(node.exprType); for (auto& expr : node.expressions) Node(expr); } - void ShaderSerializerBase::Serialize(Constant& node) + void ShaderSerializerBase::Serialize(ShaderNodes::Constant& node) { Enum(node.exprType); switch (node.exprType) { - case ExpressionType::Boolean: + case ShaderNodes::ExpressionType::Boolean: Value(node.values.bool1); break; - case ExpressionType::Float1: + case ShaderNodes::ExpressionType::Float1: Value(node.values.vec1); break; - case ExpressionType::Float2: + case ShaderNodes::ExpressionType::Float2: Value(node.values.vec2); break; - case ExpressionType::Float3: + case ShaderNodes::ExpressionType::Float3: Value(node.values.vec3); break; - case ExpressionType::Float4: + case ShaderNodes::ExpressionType::Float4: Value(node.values.vec4); break; } } - void ShaderSerializerBase::Serialize(DeclareVariable& node) + void ShaderSerializerBase::Serialize(ShaderNodes::DeclareVariable& node) { Variable(node.variable); Node(node.expression); } - void ShaderSerializerBase::Serialize(ExpressionStatement& node) + void ShaderSerializerBase::Serialize(ShaderNodes::ExpressionStatement& node) { Node(node.expression); } - void ShaderSerializerBase::Serialize(Identifier& node) + void ShaderSerializerBase::Serialize(ShaderNodes::Identifier& node) { Variable(node.var); } - void ShaderSerializerBase::Serialize(IntrinsicCall& node) + void ShaderSerializerBase::Serialize(ShaderNodes::IntrinsicCall& node) { Enum(node.intrinsic); Container(node.parameters); @@ -213,26 +217,26 @@ namespace Nz::ShaderNodes Node(param); } - void ShaderSerializerBase::Serialize(NamedVariable& node) + void ShaderSerializerBase::Serialize(ShaderNodes::NamedVariable& node) { Value(node.name); Enum(node.type); } - void ShaderSerializerBase::Serialize(Sample2D& node) + void ShaderSerializerBase::Serialize(ShaderNodes::Sample2D& node) { Node(node.sampler); Node(node.coordinates); } - void ShaderSerializerBase::Serialize(StatementBlock& node) + void ShaderSerializerBase::Serialize(ShaderNodes::StatementBlock& node) { Container(node.statements); for (auto& statement : node.statements) Node(statement); } - void ShaderSerializerBase::Serialize(SwizzleOp& node) + void ShaderSerializerBase::Serialize(ShaderNodes::SwizzleOp& node) { Value(node.componentCount); Node(node.expression); @@ -242,13 +246,50 @@ namespace Nz::ShaderNodes } - void ShaderSerializer::Serialize(const StatementPtr& shader) + void ShaderSerializer::Serialize(const ShaderAst& shader) { - assert(shader); - m_stream << static_cast(shader->GetType()); + UInt32 magicNumber = s_magicNumber; + UInt32 version = s_currentVersion; - ShaderSerializerVisitor visitor(*this); - shader->Visit(visitor); + m_stream << s_magicNumber << s_currentVersion; + + auto SerializeInputOutput = [&](auto& inout) + { + m_stream << UInt32(inout.size()); + for (const auto& data : inout) + { + m_stream << data.name << UInt32(data.type); + + m_stream << data.locationIndex.has_value(); + if (data.locationIndex) + m_stream << UInt32(data.locationIndex.value()); + } + }; + + SerializeInputOutput(shader.GetInputs()); + SerializeInputOutput(shader.GetOutputs()); + + m_stream << UInt32(shader.GetUniformCount()); + for (const auto& uniform : shader.GetUniforms()) + { + m_stream << uniform.name << UInt32(uniform.type); + + m_stream << uniform.bindingIndex.has_value(); + if (uniform.bindingIndex) + m_stream << UInt32(uniform.bindingIndex.value()); + } + + m_stream << UInt32(shader.GetFunctionCount()); + for (const auto& func : shader.GetFunctions()) + { + m_stream << func.name << UInt32(func.returnType); + + m_stream << UInt32(func.parameters.size()); + for (const auto& param : func.parameters) + m_stream << param.name << UInt32(param.type); + + Node(func.statement); + } m_stream.FlushBits(); } @@ -258,9 +299,9 @@ namespace Nz::ShaderNodes return true; } - void ShaderSerializer::Node(NodePtr& node) + void ShaderSerializer::Node(ShaderNodes::NodePtr& node) { - NodeType nodeType = (node) ? node->GetType() : NodeType::None; + ShaderNodes::NodeType nodeType = (node) ? node->GetType() : ShaderNodes::NodeType::None; m_stream << static_cast(nodeType); if (node) @@ -270,6 +311,11 @@ namespace Nz::ShaderNodes } } + void ShaderSerializer::Node(const ShaderNodes::NodePtr& node) + { + Node(const_cast(node)); //< Yes const_cast is ugly but it won't be used for writing + } + void ShaderSerializer::Value(bool& val) { m_stream << val; @@ -305,9 +351,9 @@ namespace Nz::ShaderNodes m_stream << val; } - void ShaderSerializer::Variable(VariablePtr& var) + void ShaderSerializer::Variable(ShaderNodes::VariablePtr& var) { - VariableType nodeType = (var) ? var->GetType() : VariableType::None; + ShaderNodes::VariableType nodeType = (var) ? var->GetType() : ShaderNodes::VariableType::None; m_stream << static_cast(nodeType); if (var) @@ -317,29 +363,93 @@ namespace Nz::ShaderNodes } } - ByteArray Serialize(const StatementPtr& shader) + ShaderAst ShaderUnserializer::Unserialize() { - ByteArray byteArray; - ShaderSerializer serializer(byteArray); - serializer.Serialize(shader); + UInt32 magicNumber; + UInt32 version; + m_stream >> magicNumber; + if (magicNumber != s_magicNumber) + throw std::runtime_error("invalid shader file"); - return byteArray; - } + m_stream >> version; + if (version > s_currentVersion) + throw std::runtime_error("unsupported version"); - StatementPtr Unserialize(const ByteArray& data) - { - ShaderUnserializer unserializer(data); - return unserializer.Unserialize(); - } + ShaderAst shader; - StatementPtr ShaderUnserializer::Unserialize() - { - NodePtr statement; - Node(statement); - if (!statement || statement->GetType() != NodeType::StatementBlock) - throw std::runtime_error("Invalid shader"); + UInt32 inputCount; + m_stream >> inputCount; + for (UInt32 i = 0; i < inputCount; ++i) + { + std::string inputName; + ShaderNodes::ExpressionType inputType; + std::optional location; - return std::static_pointer_cast(statement); + Value(inputName); + Enum(inputType); + OptVal(location); + + shader.AddInput(std::move(inputName), inputType, location); + } + + UInt32 outputCount; + m_stream >> outputCount; + for (UInt32 i = 0; i < outputCount; ++i) + { + std::string outputName; + ShaderNodes::ExpressionType outputType; + std::optional location; + + Value(outputName); + Enum(outputType); + OptVal(location); + + shader.AddOutput(std::move(outputName), outputType, location); + } + + UInt32 uniformCount; + m_stream >> uniformCount; + for (UInt32 i = 0; i < uniformCount; ++i) + { + std::string name; + ShaderNodes::ExpressionType type; + std::optional binding; + + Value(name); + Enum(type); + OptVal(binding); + + shader.AddUniform(std::move(name), type, binding); + } + + UInt32 funcCount; + m_stream >> funcCount; + for (UInt32 i = 0; i < funcCount; ++i) + { + std::string name; + ShaderNodes::ExpressionType retType; + std::vector parameters; + + Value(name); + Enum(retType); + Container(parameters); + for (auto& param : parameters) + { + Value(param.name); + Enum(param.type); + } + + ShaderNodes::NodePtr node; + Node(node); + if (!node || !node->IsStatement()) + throw std::runtime_error("functions can only have statements"); + + ShaderNodes::StatementPtr statement = std::static_pointer_cast(node); + + shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), retType); + } + + return shader; } bool ShaderUnserializer::IsWriting() const @@ -347,17 +457,17 @@ namespace Nz::ShaderNodes return false; } - void ShaderUnserializer::Node(NodePtr& node) + void ShaderUnserializer::Node(ShaderNodes::NodePtr& node) { Int32 nodeTypeInt; m_stream >> nodeTypeInt; - NodeType nodeType = static_cast(nodeTypeInt); + ShaderNodes::NodeType nodeType = static_cast(nodeTypeInt); -#define HandleType(Type) case NodeType:: Type : node = std::make_shared(); break +#define HandleType(Type) case ShaderNodes::NodeType:: Type : node = std::make_shared(); break switch (nodeType) { - case NodeType::None: break; + case ShaderNodes::NodeType::None: break; HandleType(AssignOp); HandleType(BinaryOp); @@ -417,17 +527,17 @@ namespace Nz::ShaderNodes m_stream >> val; } - void ShaderUnserializer::Variable(VariablePtr& var) + void ShaderUnserializer::Variable(ShaderNodes::VariablePtr& var) { Int32 nodeTypeInt; m_stream >> nodeTypeInt; - VariableType nodeType = static_cast(nodeTypeInt); + ShaderNodes::VariableType nodeType = static_cast(nodeTypeInt); -#define HandleType(Type) case VariableType:: Type : var = std::make_shared(); break +#define HandleType(Type) case ShaderNodes::VariableType:: Type : var = std::make_shared(); break switch (nodeType) { - case VariableType::None: break; + case ShaderNodes::VariableType::None: break; HandleType(BuiltinVariable); HandleType(InputVariable); @@ -443,5 +553,21 @@ namespace Nz::ShaderNodes var->Visit(visitor); } } + + + ByteArray SerializeShader(const ShaderAst& shader) + { + ByteArray byteArray; + ShaderSerializer serializer(byteArray); + serializer.Serialize(shader); + + return byteArray; + } + + ShaderAst UnserializeShader(const ByteArray& data) + { + ShaderUnserializer unserializer(data); + return unserializer.Unserialize(); + } } diff --git a/src/ShaderNode/Widgets/MainWindow.cpp b/src/ShaderNode/Widgets/MainWindow.cpp index c59a59254..7f39eafc3 100644 --- a/src/ShaderNode/Widgets/MainWindow.cpp +++ b/src/ShaderNode/Widgets/MainWindow.cpp @@ -103,9 +103,6 @@ void MainWindow::OnCompileToGLSL() { Nz::ShaderNodes::StatementPtr shaderAst = m_shaderGraph.ToAst(); - Nz::File file("shader.shader", Nz::OpenMode_WriteOnly); - file.Write(Nz::ShaderNodes::Serialize(shaderAst)); - //TODO: Put in another function auto GetExpressionFromInOut = [&] (InOutType type) { @@ -145,6 +142,9 @@ void MainWindow::OnCompileToGLSL() shader.AddFunction("main", shaderAst); + Nz::File file("shader.shader", Nz::OpenMode_WriteOnly); + file.Write(Nz::SerializeShader(shader)); + Nz::GlslWriter writer; Nz::String glsl = writer.Generate(shader);