Renderer/ShaderAst: Add serialization

This commit is contained in:
Lynix
2020-06-17 20:09:21 +02:00
parent 0ff10bf1e2
commit 736ca1c409
7 changed files with 285 additions and 118 deletions

View File

@@ -330,9 +330,7 @@ namespace Nz
Append(node.variable->name);
if (node.expression)
{
Append(" ");
Append("=");
Append(" ");
Append(" = ");
Visit(node.expression);
}

View File

@@ -3,14 +3,18 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Renderer/ShaderSerializer.hpp>
#include <Nazara/Renderer/ShaderAst.hpp>
#include <Nazara/Renderer/ShaderVarVisitor.hpp>
#include <Nazara/Renderer/ShaderVisitor.hpp>
#include <Nazara/Renderer/Debug.hpp>
namespace Nz::ShaderNodes
namespace Nz
{
namespace
{
constexpr UInt32 s_magicNumber = 0x4E534852;
constexpr UInt32 s_currentVersion = 1;
class ShaderSerializerVisitor : public ShaderVisitor, public ShaderVarVisitor
{
public:
@@ -19,62 +23,62 @@ namespace Nz::ShaderNodes
{
}
void Visit(const AssignOp& node) override
void Visit(const ShaderNodes::AssignOp& node) override
{
Serialize(node);
}
void Visit(const BinaryOp& node) override
void Visit(const ShaderNodes::BinaryOp& node) override
{
Serialize(node);
}
void Visit(const Branch& node) override
void Visit(const ShaderNodes::Branch& node) override
{
Serialize(node);
}
void Visit(const Cast& node) override
void Visit(const ShaderNodes::Cast& node) override
{
Serialize(node);
}
void Visit(const Constant& node) override
void Visit(const ShaderNodes::Constant& node) override
{
Serialize(node);
}
void Visit(const DeclareVariable& node) override
void Visit(const ShaderNodes::DeclareVariable& node) override
{
Serialize(node);
}
void Visit(const ExpressionStatement& node) override
void Visit(const ShaderNodes::ExpressionStatement& node) override
{
Serialize(node);
}
void Visit(const Identifier& node) override
void Visit(const ShaderNodes::Identifier& node) override
{
Serialize(node);
}
void Visit(const IntrinsicCall& node) override
void Visit(const ShaderNodes::IntrinsicCall& node) override
{
Serialize(node);
}
void Visit(const Sample2D& node) override
void Visit(const ShaderNodes::Sample2D& node) override
{
Serialize(node);
}
void Visit(const StatementBlock& node) override
void Visit(const ShaderNodes::StatementBlock& node) override
{
Serialize(node);
}
void Visit(const SwizzleOp& node) override
void Visit(const ShaderNodes::SwizzleOp& node) override
{
Serialize(node);
}
@@ -122,21 +126,21 @@ namespace Nz::ShaderNodes
};
}
void ShaderSerializerBase::Serialize(AssignOp& node)
void ShaderSerializerBase::Serialize(ShaderNodes::AssignOp& node)
{
Enum(node.op);
Node(node.left);
Node(node.right);
}
void ShaderSerializerBase::Serialize(BinaryOp& node)
void ShaderSerializerBase::Serialize(ShaderNodes::BinaryOp& node)
{
Enum(node.op);
Node(node.left);
Node(node.right);
}
void ShaderSerializerBase::Serialize(Branch& node)
void ShaderSerializerBase::Serialize(ShaderNodes::Branch& node)
{
Container(node.condStatements);
for (auto& condStatement : node.condStatements)
@@ -148,64 +152,64 @@ namespace Nz::ShaderNodes
Node(node.elseStatement);
}
void ShaderSerializerBase::Serialize(BuiltinVariable& node)
void ShaderSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node)
{
Enum(node.type);
Enum(node.type);
}
void ShaderSerializerBase::Serialize(Cast& node)
void ShaderSerializerBase::Serialize(ShaderNodes::Cast& node)
{
Enum(node.exprType);
for (auto& expr : node.expressions)
Node(expr);
}
void ShaderSerializerBase::Serialize(Constant& node)
void ShaderSerializerBase::Serialize(ShaderNodes::Constant& node)
{
Enum(node.exprType);
switch (node.exprType)
{
case ExpressionType::Boolean:
case ShaderNodes::ExpressionType::Boolean:
Value(node.values.bool1);
break;
case ExpressionType::Float1:
case ShaderNodes::ExpressionType::Float1:
Value(node.values.vec1);
break;
case ExpressionType::Float2:
case ShaderNodes::ExpressionType::Float2:
Value(node.values.vec2);
break;
case ExpressionType::Float3:
case ShaderNodes::ExpressionType::Float3:
Value(node.values.vec3);
break;
case ExpressionType::Float4:
case ShaderNodes::ExpressionType::Float4:
Value(node.values.vec4);
break;
}
}
void ShaderSerializerBase::Serialize(DeclareVariable& node)
void ShaderSerializerBase::Serialize(ShaderNodes::DeclareVariable& node)
{
Variable(node.variable);
Node(node.expression);
}
void ShaderSerializerBase::Serialize(ExpressionStatement& node)
void ShaderSerializerBase::Serialize(ShaderNodes::ExpressionStatement& node)
{
Node(node.expression);
}
void ShaderSerializerBase::Serialize(Identifier& node)
void ShaderSerializerBase::Serialize(ShaderNodes::Identifier& node)
{
Variable(node.var);
}
void ShaderSerializerBase::Serialize(IntrinsicCall& node)
void ShaderSerializerBase::Serialize(ShaderNodes::IntrinsicCall& node)
{
Enum(node.intrinsic);
Container(node.parameters);
@@ -213,26 +217,26 @@ namespace Nz::ShaderNodes
Node(param);
}
void ShaderSerializerBase::Serialize(NamedVariable& node)
void ShaderSerializerBase::Serialize(ShaderNodes::NamedVariable& node)
{
Value(node.name);
Enum(node.type);
}
void ShaderSerializerBase::Serialize(Sample2D& node)
void ShaderSerializerBase::Serialize(ShaderNodes::Sample2D& node)
{
Node(node.sampler);
Node(node.coordinates);
}
void ShaderSerializerBase::Serialize(StatementBlock& node)
void ShaderSerializerBase::Serialize(ShaderNodes::StatementBlock& node)
{
Container(node.statements);
for (auto& statement : node.statements)
Node(statement);
}
void ShaderSerializerBase::Serialize(SwizzleOp& node)
void ShaderSerializerBase::Serialize(ShaderNodes::SwizzleOp& node)
{
Value(node.componentCount);
Node(node.expression);
@@ -242,13 +246,50 @@ namespace Nz::ShaderNodes
}
void ShaderSerializer::Serialize(const StatementPtr& shader)
void ShaderSerializer::Serialize(const ShaderAst& shader)
{
assert(shader);
m_stream << static_cast<Int32>(shader->GetType());
UInt32 magicNumber = s_magicNumber;
UInt32 version = s_currentVersion;
ShaderSerializerVisitor visitor(*this);
shader->Visit(visitor);
m_stream << s_magicNumber << s_currentVersion;
auto SerializeInputOutput = [&](auto& inout)
{
m_stream << UInt32(inout.size());
for (const auto& data : inout)
{
m_stream << data.name << UInt32(data.type);
m_stream << data.locationIndex.has_value();
if (data.locationIndex)
m_stream << UInt32(data.locationIndex.value());
}
};
SerializeInputOutput(shader.GetInputs());
SerializeInputOutput(shader.GetOutputs());
m_stream << UInt32(shader.GetUniformCount());
for (const auto& uniform : shader.GetUniforms())
{
m_stream << uniform.name << UInt32(uniform.type);
m_stream << uniform.bindingIndex.has_value();
if (uniform.bindingIndex)
m_stream << UInt32(uniform.bindingIndex.value());
}
m_stream << UInt32(shader.GetFunctionCount());
for (const auto& func : shader.GetFunctions())
{
m_stream << func.name << UInt32(func.returnType);
m_stream << UInt32(func.parameters.size());
for (const auto& param : func.parameters)
m_stream << param.name << UInt32(param.type);
Node(func.statement);
}
m_stream.FlushBits();
}
@@ -258,9 +299,9 @@ namespace Nz::ShaderNodes
return true;
}
void ShaderSerializer::Node(NodePtr& node)
void ShaderSerializer::Node(ShaderNodes::NodePtr& node)
{
NodeType nodeType = (node) ? node->GetType() : NodeType::None;
ShaderNodes::NodeType nodeType = (node) ? node->GetType() : ShaderNodes::NodeType::None;
m_stream << static_cast<Int32>(nodeType);
if (node)
@@ -270,6 +311,11 @@ namespace Nz::ShaderNodes
}
}
void ShaderSerializer::Node(const ShaderNodes::NodePtr& node)
{
Node(const_cast<ShaderNodes::NodePtr&>(node)); //< Yes const_cast is ugly but it won't be used for writing
}
void ShaderSerializer::Value(bool& val)
{
m_stream << val;
@@ -305,9 +351,9 @@ namespace Nz::ShaderNodes
m_stream << val;
}
void ShaderSerializer::Variable(VariablePtr& var)
void ShaderSerializer::Variable(ShaderNodes::VariablePtr& var)
{
VariableType nodeType = (var) ? var->GetType() : VariableType::None;
ShaderNodes::VariableType nodeType = (var) ? var->GetType() : ShaderNodes::VariableType::None;
m_stream << static_cast<Int32>(nodeType);
if (var)
@@ -317,29 +363,93 @@ namespace Nz::ShaderNodes
}
}
ByteArray Serialize(const StatementPtr& shader)
ShaderAst ShaderUnserializer::Unserialize()
{
ByteArray byteArray;
ShaderSerializer serializer(byteArray);
serializer.Serialize(shader);
UInt32 magicNumber;
UInt32 version;
m_stream >> magicNumber;
if (magicNumber != s_magicNumber)
throw std::runtime_error("invalid shader file");
return byteArray;
}
m_stream >> version;
if (version > s_currentVersion)
throw std::runtime_error("unsupported version");
StatementPtr Unserialize(const ByteArray& data)
{
ShaderUnserializer unserializer(data);
return unserializer.Unserialize();
}
ShaderAst shader;
StatementPtr ShaderUnserializer::Unserialize()
{
NodePtr statement;
Node(statement);
if (!statement || statement->GetType() != NodeType::StatementBlock)
throw std::runtime_error("Invalid shader");
UInt32 inputCount;
m_stream >> inputCount;
for (UInt32 i = 0; i < inputCount; ++i)
{
std::string inputName;
ShaderNodes::ExpressionType inputType;
std::optional<std::size_t> location;
return std::static_pointer_cast<Statement>(statement);
Value(inputName);
Enum(inputType);
OptVal(location);
shader.AddInput(std::move(inputName), inputType, location);
}
UInt32 outputCount;
m_stream >> outputCount;
for (UInt32 i = 0; i < outputCount; ++i)
{
std::string outputName;
ShaderNodes::ExpressionType outputType;
std::optional<std::size_t> location;
Value(outputName);
Enum(outputType);
OptVal(location);
shader.AddOutput(std::move(outputName), outputType, location);
}
UInt32 uniformCount;
m_stream >> uniformCount;
for (UInt32 i = 0; i < uniformCount; ++i)
{
std::string name;
ShaderNodes::ExpressionType type;
std::optional<std::size_t> binding;
Value(name);
Enum(type);
OptVal(binding);
shader.AddUniform(std::move(name), type, binding);
}
UInt32 funcCount;
m_stream >> funcCount;
for (UInt32 i = 0; i < funcCount; ++i)
{
std::string name;
ShaderNodes::ExpressionType retType;
std::vector<ShaderAst::FunctionParameter> parameters;
Value(name);
Enum(retType);
Container(parameters);
for (auto& param : parameters)
{
Value(param.name);
Enum(param.type);
}
ShaderNodes::NodePtr node;
Node(node);
if (!node || !node->IsStatement())
throw std::runtime_error("functions can only have statements");
ShaderNodes::StatementPtr statement = std::static_pointer_cast<ShaderNodes::Statement>(node);
shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), retType);
}
return shader;
}
bool ShaderUnserializer::IsWriting() const
@@ -347,17 +457,17 @@ namespace Nz::ShaderNodes
return false;
}
void ShaderUnserializer::Node(NodePtr& node)
void ShaderUnserializer::Node(ShaderNodes::NodePtr& node)
{
Int32 nodeTypeInt;
m_stream >> nodeTypeInt;
NodeType nodeType = static_cast<NodeType>(nodeTypeInt);
ShaderNodes::NodeType nodeType = static_cast<ShaderNodes::NodeType>(nodeTypeInt);
#define HandleType(Type) case NodeType:: Type : node = std::make_shared<Type>(); break
#define HandleType(Type) case ShaderNodes::NodeType:: Type : node = std::make_shared<ShaderNodes:: Type>(); break
switch (nodeType)
{
case NodeType::None: break;
case ShaderNodes::NodeType::None: break;
HandleType(AssignOp);
HandleType(BinaryOp);
@@ -417,17 +527,17 @@ namespace Nz::ShaderNodes
m_stream >> val;
}
void ShaderUnserializer::Variable(VariablePtr& var)
void ShaderUnserializer::Variable(ShaderNodes::VariablePtr& var)
{
Int32 nodeTypeInt;
m_stream >> nodeTypeInt;
VariableType nodeType = static_cast<VariableType>(nodeTypeInt);
ShaderNodes::VariableType nodeType = static_cast<ShaderNodes:: VariableType>(nodeTypeInt);
#define HandleType(Type) case VariableType:: Type : var = std::make_shared<Type>(); break
#define HandleType(Type) case ShaderNodes::VariableType:: Type : var = std::make_shared<ShaderNodes::Type>(); break
switch (nodeType)
{
case VariableType::None: break;
case ShaderNodes::VariableType::None: break;
HandleType(BuiltinVariable);
HandleType(InputVariable);
@@ -443,5 +553,21 @@ namespace Nz::ShaderNodes
var->Visit(visitor);
}
}
ByteArray SerializeShader(const ShaderAst& shader)
{
ByteArray byteArray;
ShaderSerializer serializer(byteArray);
serializer.Serialize(shader);
return byteArray;
}
ShaderAst UnserializeShader(const ByteArray& data)
{
ShaderUnserializer unserializer(data);
return unserializer.Unserialize();
}
}

View File

@@ -103,9 +103,6 @@ void MainWindow::OnCompileToGLSL()
{
Nz::ShaderNodes::StatementPtr shaderAst = m_shaderGraph.ToAst();
Nz::File file("shader.shader", Nz::OpenMode_WriteOnly);
file.Write(Nz::ShaderNodes::Serialize(shaderAst));
//TODO: Put in another function
auto GetExpressionFromInOut = [&] (InOutType type)
{
@@ -145,6 +142,9 @@ void MainWindow::OnCompileToGLSL()
shader.AddFunction("main", shaderAst);
Nz::File file("shader.shader", Nz::OpenMode_WriteOnly);
file.Write(Nz::SerializeShader(shader));
Nz::GlslWriter writer;
Nz::String glsl = writer.Generate(shader);