Renderer/ShaderNodes: Add support for accessing struct fields

This commit is contained in:
Jérôme Leclercq 2020-07-19 21:05:46 +02:00
parent 1d2fb88198
commit 086f76fb97
15 changed files with 221 additions and 62 deletions

View File

@ -61,6 +61,7 @@ namespace Nz
using ShaderVarVisitor::Visit; using ShaderVarVisitor::Visit;
using ShaderVisitor::Visit; using ShaderVisitor::Visit;
void Visit(const ShaderNodes::AccessMember& node) override;
void Visit(const ShaderNodes::AssignOp& node) override; void Visit(const ShaderNodes::AssignOp& node) override;
void Visit(const ShaderNodes::Branch& node) override; void Visit(const ShaderNodes::Branch& node) override;
void Visit(const ShaderNodes::BinaryOp& node) override; void Visit(const ShaderNodes::BinaryOp& node) override;
@ -85,6 +86,7 @@ namespace Nz
struct Context struct Context
{ {
const ShaderAst* shader = nullptr;
const ShaderAst::Function* currentFunction = nullptr; const ShaderAst::Function* currentFunction = nullptr;
}; };

View File

@ -44,6 +44,7 @@ namespace Nz::ShaderBuilder
template<typename... Args> std::shared_ptr<T> operator()(Args&&... args) const; template<typename... Args> std::shared_ptr<T> operator()(Args&&... args) const;
}; };
constexpr GenBuilder<ShaderNodes::AccessMember> AccessMember;
constexpr BinOpBuilder<ShaderNodes::BinaryType::Add> Add; constexpr BinOpBuilder<ShaderNodes::BinaryType::Add> Add;
constexpr AssignOpBuilder<ShaderNodes::AssignType::Simple> Assign; constexpr AssignOpBuilder<ShaderNodes::AssignType::Simple> Assign;
constexpr BuiltinBuilder Builtin; constexpr BuiltinBuilder Builtin;

View File

@ -65,6 +65,7 @@ namespace Nz::ShaderNodes
{ {
None = -1, None = -1,
AccessMember,
AssignOp, AssignOp,
BinaryOp, BinaryOp,
Branch, Branch,

View File

@ -8,7 +8,7 @@
#define NAZARA_SHADER_EXPRESSIONTYPE_HPP #define NAZARA_SHADER_EXPRESSIONTYPE_HPP
#include <Nazara/Prerequisites.hpp> #include <Nazara/Prerequisites.hpp>
#include <Nazara/Renderer/ShaderNodes.hpp> #include <Nazara/Renderer/ShaderEnums.hpp>
#include <string> #include <string>
#include <variant> #include <variant>

View File

@ -13,6 +13,7 @@
#include <Nazara/Math/Vector4.hpp> #include <Nazara/Math/Vector4.hpp>
#include <Nazara/Renderer/Config.hpp> #include <Nazara/Renderer/Config.hpp>
#include <Nazara/Renderer/ShaderEnums.hpp> #include <Nazara/Renderer/ShaderEnums.hpp>
#include <Nazara/Renderer/ShaderExpressionType.hpp>
#include <Nazara/Renderer/ShaderVariables.hpp> #include <Nazara/Renderer/ShaderVariables.hpp>
#include <array> #include <array>
#include <optional> #include <optional>
@ -59,7 +60,7 @@ namespace Nz
inline Expression(NodeType type); inline Expression(NodeType type);
virtual ExpressionCategory GetExpressionCategory() const; virtual ExpressionCategory GetExpressionCategory() const;
virtual BasicType GetExpressionType() const = 0; virtual ShaderExpressionType GetExpressionType() const = 0;
}; };
class Statement; class Statement;
@ -125,7 +126,7 @@ namespace Nz
inline Identifier(); inline Identifier();
ExpressionCategory GetExpressionCategory() const override; ExpressionCategory GetExpressionCategory() const override;
BasicType GetExpressionType() const override; ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderVisitor& visitor) override; void Visit(ShaderVisitor& visitor) override;
VariablePtr var; VariablePtr var;
@ -133,13 +134,28 @@ namespace Nz
static inline std::shared_ptr<Identifier> Build(VariablePtr variable); static inline std::shared_ptr<Identifier> Build(VariablePtr variable);
}; };
struct NAZARA_RENDERER_API AccessMember : public Expression
{
inline AccessMember();
ExpressionCategory GetExpressionCategory() const override;
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderVisitor& visitor) override;
std::size_t memberIndex;
ExpressionPtr structExpr;
ShaderExpressionType exprType; //< FIXME: Use ShaderAst to automate
static inline std::shared_ptr<AccessMember> Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType);
};
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
struct NAZARA_RENDERER_API AssignOp : public Expression struct NAZARA_RENDERER_API AssignOp : public Expression
{ {
inline AssignOp(); inline AssignOp();
BasicType GetExpressionType() const override; ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderVisitor& visitor) override; void Visit(ShaderVisitor& visitor) override;
AssignType op; AssignType op;
@ -153,7 +169,7 @@ namespace Nz
{ {
inline BinaryOp(); inline BinaryOp();
BasicType GetExpressionType() const override; ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderVisitor& visitor) override; void Visit(ShaderVisitor& visitor) override;
BinaryType op; BinaryType op;
@ -187,7 +203,7 @@ namespace Nz
{ {
inline Cast(); inline Cast();
BasicType GetExpressionType() const override; ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderVisitor& visitor) override; void Visit(ShaderVisitor& visitor) override;
BasicType exprType; BasicType exprType;
@ -201,7 +217,7 @@ namespace Nz
{ {
inline Constant(); inline Constant();
BasicType GetExpressionType() const override; ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderVisitor& visitor) override; void Visit(ShaderVisitor& visitor) override;
BasicType exprType; BasicType exprType;
@ -227,7 +243,7 @@ namespace Nz
inline SwizzleOp(); inline SwizzleOp();
ExpressionCategory GetExpressionCategory() const override; ExpressionCategory GetExpressionCategory() const override;
BasicType GetExpressionType() const override; ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderVisitor& visitor) override; void Visit(ShaderVisitor& visitor) override;
std::array<SwizzleComponent, 4> components; std::array<SwizzleComponent, 4> components;
@ -243,7 +259,7 @@ namespace Nz
{ {
inline Sample2D(); inline Sample2D();
BasicType GetExpressionType() const override; ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderVisitor& visitor) override; void Visit(ShaderVisitor& visitor) override;
ExpressionPtr sampler; ExpressionPtr sampler;
@ -258,7 +274,7 @@ namespace Nz
{ {
inline IntrinsicCall(); inline IntrinsicCall();
BasicType GetExpressionType() const override; ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderVisitor& visitor) override; void Visit(ShaderVisitor& visitor) override;
IntrinsicType intrinsic; IntrinsicType intrinsic;

View File

@ -146,6 +146,22 @@ namespace Nz::ShaderNodes
} }
inline AccessMember::AccessMember() :
Expression(NodeType::AccessMember)
{
}
inline std::shared_ptr<AccessMember> AccessMember::Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType)
{
auto node = std::make_shared<AccessMember>();
node->exprType = std::move(exprType);
node->memberIndex = memberIndex;
node->structExpr = std::move(structExpr);
return node;
}
inline AssignOp::AssignOp() : inline AssignOp::AssignOp() :
Expression(NodeType::AssignOp) Expression(NodeType::AssignOp)
{ {

View File

@ -25,6 +25,7 @@ namespace Nz
ShaderSerializerBase(ShaderSerializerBase&&) = delete; ShaderSerializerBase(ShaderSerializerBase&&) = delete;
~ShaderSerializerBase() = default; ~ShaderSerializerBase() = default;
void Serialize(ShaderNodes::AccessMember& node);
void Serialize(ShaderNodes::AssignOp& node); void Serialize(ShaderNodes::AssignOp& node);
void Serialize(ShaderNodes::BinaryOp& node); void Serialize(ShaderNodes::BinaryOp& node);
void Serialize(ShaderNodes::BuiltinVariable& var); void Serialize(ShaderNodes::BuiltinVariable& var);
@ -51,6 +52,8 @@ namespace Nz
virtual void Node(ShaderNodes::NodePtr& node) = 0; virtual void Node(ShaderNodes::NodePtr& node) = 0;
template<typename T> void Node(std::shared_ptr<T>& node); template<typename T> void Node(std::shared_ptr<T>& node);
virtual void Type(ShaderExpressionType& type) = 0;
virtual void Value(bool& val) = 0; virtual void Value(bool& val) = 0;
virtual void Value(float& val) = 0; virtual void Value(float& val) = 0;
virtual void Value(std::string& val) = 0; virtual void Value(std::string& val) = 0;
@ -78,6 +81,7 @@ namespace Nz
bool IsWriting() const override; bool IsWriting() const override;
void Node(const ShaderNodes::NodePtr& node); void Node(const ShaderNodes::NodePtr& node);
void Node(ShaderNodes::NodePtr& node) override; void Node(ShaderNodes::NodePtr& node) override;
void Type(ShaderExpressionType& type) override;
void Value(bool& val) override; void Value(bool& val) override;
void Value(float& val) override; void Value(float& val) override;
void Value(std::string& val) override; void Value(std::string& val) override;
@ -103,7 +107,7 @@ namespace Nz
private: private:
bool IsWriting() const override; bool IsWriting() const override;
void Node(ShaderNodes::NodePtr& node) override; void Node(ShaderNodes::NodePtr& node) override;
void Type(ShaderExpressionType& type); void Type(ShaderExpressionType& type) override;
void Value(bool& val) override; void Value(bool& val) override;
void Value(float& val) override; void Value(float& val) override;
void Value(std::string& val) override; void Value(std::string& val) override;

View File

@ -33,6 +33,7 @@ namespace Nz
void TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right); void TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right);
using ShaderVisitor::Visit; using ShaderVisitor::Visit;
void Visit(const ShaderNodes::AccessMember& node) override;
void Visit(const ShaderNodes::AssignOp& node) override; void Visit(const ShaderNodes::AssignOp& node) override;
void Visit(const ShaderNodes::BinaryOp& node) override; void Visit(const ShaderNodes::BinaryOp& node) override;
void Visit(const ShaderNodes::Branch& node) override; void Visit(const ShaderNodes::Branch& node) override;

View File

@ -12,7 +12,7 @@
#include <Nazara/Math/Vector3.hpp> #include <Nazara/Math/Vector3.hpp>
#include <Nazara/Math/Vector4.hpp> #include <Nazara/Math/Vector4.hpp>
#include <Nazara/Renderer/Config.hpp> #include <Nazara/Renderer/Config.hpp>
#include <Nazara/Renderer/ShaderEnums.hpp> #include <Nazara/Renderer/ShaderExpressionType.hpp>
#include <array> #include <array>
#include <optional> #include <optional>
#include <string> #include <string>
@ -34,7 +34,7 @@ namespace Nz
virtual VariableType GetType() const = 0; virtual VariableType GetType() const = 0;
virtual void Visit(ShaderVarVisitor& visitor) = 0; virtual void Visit(ShaderVarVisitor& visitor) = 0;
BasicType type; ShaderExpressionType type;
}; };
struct BuiltinVariable; struct BuiltinVariable;
@ -48,7 +48,7 @@ namespace Nz
VariableType GetType() const override; VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override; void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<BuiltinVariable> Build(BuiltinEntry entry, BasicType varType); static inline std::shared_ptr<BuiltinVariable> Build(BuiltinEntry entry, ShaderExpressionType varType);
}; };
struct NamedVariable; struct NamedVariable;
@ -69,7 +69,7 @@ namespace Nz
VariableType GetType() const override; VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override; void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<InputVariable> Build(std::string varName, BasicType varType); static inline std::shared_ptr<InputVariable> Build(std::string varName, ShaderExpressionType varType);
}; };
struct LocalVariable; struct LocalVariable;
@ -81,7 +81,7 @@ namespace Nz
VariableType GetType() const override; VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override; void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<LocalVariable> Build(std::string varName, BasicType varType); static inline std::shared_ptr<LocalVariable> Build(std::string varName, ShaderExpressionType varType);
}; };
struct OutputVariable; struct OutputVariable;
@ -93,7 +93,7 @@ namespace Nz
VariableType GetType() const override; VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override; void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<OutputVariable> Build(std::string varName, BasicType varType); static inline std::shared_ptr<OutputVariable> Build(std::string varName, ShaderExpressionType varType);
}; };
struct ParameterVariable; struct ParameterVariable;
@ -105,7 +105,7 @@ namespace Nz
VariableType GetType() const override; VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override; void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<ParameterVariable> Build(std::string varName, BasicType varType); static inline std::shared_ptr<ParameterVariable> Build(std::string varName, ShaderExpressionType varType);
}; };
struct UniformVariable; struct UniformVariable;
@ -117,7 +117,7 @@ namespace Nz
VariableType GetType() const override; VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override; void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<UniformVariable> Build(std::string varName, BasicType varType); static inline std::shared_ptr<UniformVariable> Build(std::string varName, ShaderExpressionType varType);
}; };
} }
} }

View File

@ -7,7 +7,7 @@
namespace Nz::ShaderNodes namespace Nz::ShaderNodes
{ {
inline std::shared_ptr<BuiltinVariable> BuiltinVariable::Build(BuiltinEntry variable, BasicType varType) inline std::shared_ptr<BuiltinVariable> BuiltinVariable::Build(BuiltinEntry variable, ShaderExpressionType varType)
{ {
auto node = std::make_shared<BuiltinVariable>(); auto node = std::make_shared<BuiltinVariable>();
node->entry = variable; node->entry = variable;
@ -16,7 +16,7 @@ namespace Nz::ShaderNodes
return node; return node;
} }
inline std::shared_ptr<InputVariable> InputVariable::Build(std::string varName, BasicType varType) inline std::shared_ptr<InputVariable> InputVariable::Build(std::string varName, ShaderExpressionType varType)
{ {
auto node = std::make_shared<InputVariable>(); auto node = std::make_shared<InputVariable>();
node->name = std::move(varName); node->name = std::move(varName);
@ -25,7 +25,7 @@ namespace Nz::ShaderNodes
return node; return node;
} }
inline std::shared_ptr<LocalVariable> LocalVariable::Build(std::string varName, BasicType varType) inline std::shared_ptr<LocalVariable> LocalVariable::Build(std::string varName, ShaderExpressionType varType)
{ {
auto node = std::make_shared<LocalVariable>(); auto node = std::make_shared<LocalVariable>();
node->name = std::move(varName); node->name = std::move(varName);
@ -34,7 +34,7 @@ namespace Nz::ShaderNodes
return node; return node;
} }
inline std::shared_ptr<OutputVariable> OutputVariable::Build(std::string varName, BasicType varType) inline std::shared_ptr<OutputVariable> OutputVariable::Build(std::string varName, ShaderExpressionType varType)
{ {
auto node = std::make_shared<OutputVariable>(); auto node = std::make_shared<OutputVariable>();
node->name = std::move(varName); node->name = std::move(varName);
@ -43,7 +43,7 @@ namespace Nz::ShaderNodes
return node; return node;
} }
inline std::shared_ptr<ParameterVariable> ParameterVariable::Build(std::string varName, BasicType varType) inline std::shared_ptr<ParameterVariable> ParameterVariable::Build(std::string varName, ShaderExpressionType varType)
{ {
auto node = std::make_shared<ParameterVariable>(); auto node = std::make_shared<ParameterVariable>();
node->name = std::move(varName); node->name = std::move(varName);
@ -52,7 +52,7 @@ namespace Nz::ShaderNodes
return node; return node;
} }
inline std::shared_ptr<UniformVariable> UniformVariable::Build(std::string varName, BasicType varType) inline std::shared_ptr<UniformVariable> UniformVariable::Build(std::string varName, ShaderExpressionType varType)
{ {
auto node = std::make_shared<UniformVariable>(); auto node = std::make_shared<UniformVariable>();
node->name = std::move(varName); node->name = std::move(varName);

View File

@ -27,6 +27,7 @@ namespace Nz
bool IsConditionEnabled(const std::string& name) const; bool IsConditionEnabled(const std::string& name) const;
virtual void Visit(const ShaderNodes::AccessMember& node) = 0;
virtual void Visit(const ShaderNodes::AssignOp& node) = 0; virtual void Visit(const ShaderNodes::AssignOp& node) = 0;
virtual void Visit(const ShaderNodes::BinaryOp& node) = 0; virtual void Visit(const ShaderNodes::BinaryOp& node) = 0;
virtual void Visit(const ShaderNodes::Branch& node) = 0; virtual void Visit(const ShaderNodes::Branch& node) = 0;

View File

@ -21,6 +21,8 @@ namespace Nz
if (!ValidateShader(shader, &error)) if (!ValidateShader(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error); throw std::runtime_error("Invalid shader AST: " + error);
m_context.shader = &shader;
State state; State state;
m_currentState = &state; m_currentState = &state;
CallOnExit onExit([this]() CallOnExit onExit([this]()
@ -294,7 +296,30 @@ namespace Nz
AppendLine(); AppendLine();
AppendLine("}"); AppendLine("}");
} }
void GlslWriter::Visit(const ShaderNodes::AccessMember& node)
{
Append("(");
Visit(node.structExpr);
Append(")");
const ShaderExpressionType& exprType = node.structExpr->GetExpressionType();
assert(std::holds_alternative<std::string>(exprType));
const std::string& structName = std::get<std::string>(exprType);
const auto& structs = m_context.shader->GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
assert(it != structs.end());
const ShaderAst::Struct& s = *it;
assert(node.memberIndex < s.members.size());
const auto& member = s.members[node.memberIndex];
Append(".");
Append(member.name);
}
void GlslWriter::Visit(const ShaderNodes::AssignOp& node) void GlslWriter::Visit(const ShaderNodes::AssignOp& node)
{ {
Visit(node.left); Visit(node.left);
@ -374,9 +399,7 @@ namespace Nz
Append(node.exprType); Append(node.exprType);
Append("("); Append("(");
unsigned int i = 0; for (std::size_t i = 0; node.expressions[i]; ++i)
unsigned int requiredComponents = ShaderNodes::Node::GetComponentCount(node.exprType);
while (requiredComponents > 0)
{ {
if (i != 0) if (i != 0)
m_currentState->stream << ", "; m_currentState->stream << ", ";
@ -385,7 +408,6 @@ namespace Nz
NazaraAssert(exprPtr, "Invalid expression"); NazaraAssert(exprPtr, "Invalid expression");
Visit(exprPtr); Visit(exprPtr);
requiredComponents -= ShaderNodes::Node::GetComponentCount(exprPtr->GetExpressionType());
} }
Append(")"); Append(")");

View File

@ -47,7 +47,7 @@ namespace Nz::ShaderNodes
return ExpressionCategory::LValue; return ExpressionCategory::LValue;
} }
BasicType Identifier::GetExpressionType() const ShaderExpressionType Identifier::GetExpressionType() const
{ {
assert(var); assert(var);
return var->type; return var->type;
@ -58,8 +58,22 @@ namespace Nz::ShaderNodes
visitor.Visit(*this); visitor.Visit(*this);
} }
ExpressionCategory ShaderNodes::AccessMember::GetExpressionCategory() const
{
return ExpressionCategory::LValue;
}
BasicType AssignOp::GetExpressionType() const ShaderExpressionType AccessMember::GetExpressionType() const
{
return exprType;
}
void AccessMember::Visit(ShaderVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType AssignOp::GetExpressionType() const
{ {
return left->GetExpressionType(); return left->GetExpressionType();
} }
@ -70,31 +84,39 @@ namespace Nz::ShaderNodes
} }
BasicType BinaryOp::GetExpressionType() const ShaderExpressionType BinaryOp::GetExpressionType() const
{ {
ShaderNodes::BasicType exprType = ShaderNodes::BasicType::Void; std::optional<ShaderExpressionType> exprType;
switch (op) switch (op)
{ {
case ShaderNodes::BinaryType::Add: case BinaryType::Add:
case ShaderNodes::BinaryType::Substract: case BinaryType::Substract:
exprType = left->GetExpressionType(); exprType = left->GetExpressionType();
break; break;
case ShaderNodes::BinaryType::Divide: case BinaryType::Divide:
case ShaderNodes::BinaryType::Multiply: case BinaryType::Multiply:
//FIXME {
exprType = static_cast<BasicType>(std::max(UnderlyingCast(left->GetExpressionType()), UnderlyingCast(right->GetExpressionType()))); const ShaderExpressionType& leftExprType = left->GetExpressionType();
break; assert(std::holds_alternative<BasicType>(leftExprType));
case ShaderNodes::BinaryType::Equality: const ShaderExpressionType& rightExprType = right->GetExpressionType();
assert(std::holds_alternative<BasicType>(rightExprType));
//FIXME
exprType = static_cast<BasicType>(std::max(UnderlyingCast(std::get<BasicType>(leftExprType)), UnderlyingCast(std::get<BasicType>(rightExprType))));
break;
}
case BinaryType::Equality:
exprType = BasicType::Boolean; exprType = BasicType::Boolean;
break; break;
} }
NazaraAssert(exprType != ShaderNodes::BasicType::Void, "Unhandled builtin"); NazaraAssert(exprType.has_value(), "Unhandled builtin");
return exprType; return *exprType;
} }
void BinaryOp::Visit(ShaderVisitor& visitor) void BinaryOp::Visit(ShaderVisitor& visitor)
@ -109,7 +131,7 @@ namespace Nz::ShaderNodes
} }
BasicType Constant::GetExpressionType() const ShaderExpressionType Constant::GetExpressionType() const
{ {
return exprType; return exprType;
} }
@ -119,7 +141,7 @@ namespace Nz::ShaderNodes
visitor.Visit(*this); visitor.Visit(*this);
} }
BasicType Cast::GetExpressionType() const ShaderExpressionType Cast::GetExpressionType() const
{ {
return exprType; return exprType;
} }
@ -135,9 +157,12 @@ namespace Nz::ShaderNodes
return ExpressionCategory::LValue; return ExpressionCategory::LValue;
} }
BasicType SwizzleOp::GetExpressionType() const ShaderExpressionType SwizzleOp::GetExpressionType() const
{ {
return static_cast<BasicType>(UnderlyingCast(GetComponentType(expression->GetExpressionType())) + componentCount - 1); const ShaderExpressionType& exprType = expression->GetExpressionType();
assert(std::holds_alternative<BasicType>(exprType));
return static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + componentCount - 1);
} }
void SwizzleOp::Visit(ShaderVisitor& visitor) void SwizzleOp::Visit(ShaderVisitor& visitor)
@ -146,7 +171,7 @@ namespace Nz::ShaderNodes
} }
BasicType Sample2D::GetExpressionType() const ShaderExpressionType Sample2D::GetExpressionType() const
{ {
return BasicType::Float4; return BasicType::Float4;
} }
@ -157,7 +182,7 @@ namespace Nz::ShaderNodes
} }
BasicType IntrinsicCall::GetExpressionType() const ShaderExpressionType IntrinsicCall::GetExpressionType() const
{ {
switch (intrinsic) switch (intrinsic)
{ {

View File

@ -22,6 +22,11 @@ namespace Nz
{ {
} }
void Visit(const ShaderNodes::AccessMember& node) override
{
Serialize(node);
}
void Visit(const ShaderNodes::AssignOp& node) override void Visit(const ShaderNodes::AssignOp& node) override
{ {
Serialize(node); Serialize(node);
@ -125,6 +130,13 @@ namespace Nz
}; };
} }
void ShaderSerializerBase::Serialize(ShaderNodes::AccessMember& node)
{
Value(node.memberIndex);
Node(node.structExpr);
Type(node.exprType);
}
void ShaderSerializerBase::Serialize(ShaderNodes::AssignOp& node) void ShaderSerializerBase::Serialize(ShaderNodes::AssignOp& node)
{ {
Enum(node.op); Enum(node.op);
@ -153,8 +165,8 @@ namespace Nz
void ShaderSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node) void ShaderSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node)
{ {
Enum(node.type); Enum(node.entry);
Enum(node.type); Type(node.type);
} }
void ShaderSerializerBase::Serialize(ShaderNodes::Cast& node) void ShaderSerializerBase::Serialize(ShaderNodes::Cast& node)
@ -219,7 +231,7 @@ namespace Nz
void ShaderSerializerBase::Serialize(ShaderNodes::NamedVariable& node) void ShaderSerializerBase::Serialize(ShaderNodes::NamedVariable& node)
{ {
Value(node.name); Value(node.name);
Enum(node.type); Type(node.type);
} }
void ShaderSerializerBase::Serialize(ShaderNodes::Sample2D& node) void ShaderSerializerBase::Serialize(ShaderNodes::Sample2D& node)
@ -348,6 +360,26 @@ namespace Nz
} }
} }
void ShaderSerializer::Type(ShaderExpressionType& type)
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
{
m_stream << UInt8(0);
m_stream << UInt32(arg);
}
else if constexpr (std::is_same_v<T, std::string>)
{
m_stream << UInt8(1);
m_stream << arg;
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
}
void ShaderSerializer::Node(const ShaderNodes::NodePtr& node) void ShaderSerializer::Node(const ShaderNodes::NodePtr& node)
{ {
Node(const_cast<ShaderNodes::NodePtr&>(node)); //< Yes const_cast is ugly but it won't be used for writing Node(const_cast<ShaderNodes::NodePtr&>(node)); //< Yes const_cast is ugly but it won't be used for writing

View File

@ -21,7 +21,7 @@ namespace Nz
struct Local struct Local
{ {
std::string name; std::string name;
ShaderNodes::BasicType type; ShaderExpressionType type;
}; };
const ShaderAst::Function* currentFunction; const ShaderAst::Function* currentFunction;
@ -83,6 +83,28 @@ namespace Nz
throw AstError{ "Left expression type must match right expression type" }; throw AstError{ "Left expression type must match right expression type" };
} }
void ShaderValidator::Visit(const ShaderNodes::AccessMember& node)
{
const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType();
if (!std::holds_alternative<std::string>(exprType))
throw AstError{ "expression is not a structure" };
const std::string& structName = std::get<std::string>(exprType);
const auto& structs = m_shader.GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
if (it == structs.end())
throw AstError{ "invalid structure" };
const ShaderAst::Struct& s = *it;
if (node.memberIndex >= s.members.size())
throw AstError{ "member index out of bounds" };
const auto& member = s.members[node.memberIndex];
if (member.type != node.exprType)
throw AstError{ "member type does not match node type" };
}
void ShaderValidator::Visit(const ShaderNodes::AssignOp& node) void ShaderValidator::Visit(const ShaderNodes::AssignOp& node)
{ {
MandatoryNode(node.left); MandatoryNode(node.left);
@ -101,8 +123,16 @@ namespace Nz
MandatoryNode(node.left); MandatoryNode(node.left);
MandatoryNode(node.right); MandatoryNode(node.right);
ShaderNodes::BasicType leftType = node.left->GetExpressionType(); const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType();
ShaderNodes::BasicType rightType = node.right->GetExpressionType(); if (!std::holds_alternative<ShaderNodes::BasicType>(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
switch (node.op) switch (node.op)
{ {
@ -179,7 +209,11 @@ namespace Nz
if (!exprPtr) if (!exprPtr)
break; break;
componentCount += node.GetComponentCount(exprPtr->GetExpressionType()); const ShaderExpressionType& exprType = exprPtr->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
throw AstError{ "incompatible type" };
componentCount += node.GetComponentCount(std::get<ShaderNodes::BasicType>(exprType));
Visit(exprPtr); Visit(exprPtr);
} }
@ -318,7 +352,7 @@ namespace Nz
for (auto& param : node.parameters) for (auto& param : node.parameters)
MandatoryNode(param); MandatoryNode(param);
ShaderNodes::BasicType type = node.parameters.front()->GetExpressionType(); ShaderExpressionType type = node.parameters.front()->GetExpressionType();
for (std::size_t i = 1; i < node.parameters.size(); ++i) for (std::size_t i = 1; i < node.parameters.size(); ++i)
{ {
if (type != node.parameters[i]->GetExpressionType()) if (type != node.parameters[i]->GetExpressionType())
@ -333,7 +367,7 @@ namespace Nz
{ {
case ShaderNodes::IntrinsicType::CrossProduct: case ShaderNodes::IntrinsicType::CrossProduct:
{ {
if (node.parameters[0]->GetExpressionType() != ShaderNodes::BasicType::Float3) if (node.parameters[0]->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float3 })
throw AstError{ "CrossProduct only works with Float3 expressions" }; throw AstError{ "CrossProduct only works with Float3 expressions" };
break; break;
@ -349,10 +383,10 @@ namespace Nz
void ShaderValidator::Visit(const ShaderNodes::Sample2D& node) void ShaderValidator::Visit(const ShaderNodes::Sample2D& node)
{ {
if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderNodes::BasicType::Sampler2D) if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D })
throw AstError{ "Sampler must be a Sampler2D" }; throw AstError{ "Sampler must be a Sampler2D" };
if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderNodes::BasicType::Float2) if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float2 })
throw AstError{ "Coordinates must be a Float2" }; throw AstError{ "Coordinates must be a Float2" };
Visit(node.sampler); Visit(node.sampler);
@ -378,7 +412,11 @@ namespace Nz
if (node.componentCount > 4) if (node.componentCount > 4)
throw AstError{ "Cannot swizzle more than four elements" }; throw AstError{ "Cannot swizzle more than four elements" };
switch (MandatoryExpr(node.expression)->GetExpressionType()) const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
throw AstError{ "Cannot swizzle this type" };
switch (std::get<ShaderNodes::BasicType>(exprType))
{ {
case ShaderNodes::BasicType::Float1: case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2: case ShaderNodes::BasicType::Float2: