diff --git a/include/Nazara/Renderer/GlslWriter.hpp b/include/Nazara/Renderer/GlslWriter.hpp index 94ca21676..ba2444aec 100644 --- a/include/Nazara/Renderer/GlslWriter.hpp +++ b/include/Nazara/Renderer/GlslWriter.hpp @@ -61,6 +61,7 @@ namespace Nz using ShaderVarVisitor::Visit; using ShaderVisitor::Visit; + void Visit(const ShaderNodes::AccessMember& node) override; void Visit(const ShaderNodes::AssignOp& node) override; void Visit(const ShaderNodes::Branch& node) override; void Visit(const ShaderNodes::BinaryOp& node) override; @@ -85,6 +86,7 @@ namespace Nz struct Context { + const ShaderAst* shader = nullptr; const ShaderAst::Function* currentFunction = nullptr; }; diff --git a/include/Nazara/Renderer/ShaderBuilder.hpp b/include/Nazara/Renderer/ShaderBuilder.hpp index 7d8502151..8f544e018 100644 --- a/include/Nazara/Renderer/ShaderBuilder.hpp +++ b/include/Nazara/Renderer/ShaderBuilder.hpp @@ -44,6 +44,7 @@ namespace Nz::ShaderBuilder template std::shared_ptr operator()(Args&&... args) const; }; + constexpr GenBuilder AccessMember; constexpr BinOpBuilder Add; constexpr AssignOpBuilder Assign; constexpr BuiltinBuilder Builtin; diff --git a/include/Nazara/Renderer/ShaderEnums.hpp b/include/Nazara/Renderer/ShaderEnums.hpp index a9db1abe8..6edaad1de 100644 --- a/include/Nazara/Renderer/ShaderEnums.hpp +++ b/include/Nazara/Renderer/ShaderEnums.hpp @@ -65,6 +65,7 @@ namespace Nz::ShaderNodes { None = -1, + AccessMember, AssignOp, BinaryOp, Branch, diff --git a/include/Nazara/Renderer/ShaderExpressionType.hpp b/include/Nazara/Renderer/ShaderExpressionType.hpp index c937d6a10..68671b07c 100644 --- a/include/Nazara/Renderer/ShaderExpressionType.hpp +++ b/include/Nazara/Renderer/ShaderExpressionType.hpp @@ -8,7 +8,7 @@ #define NAZARA_SHADER_EXPRESSIONTYPE_HPP #include -#include +#include #include #include diff --git a/include/Nazara/Renderer/ShaderNodes.hpp b/include/Nazara/Renderer/ShaderNodes.hpp index 57ff5ab60..0c5049df9 100644 --- a/include/Nazara/Renderer/ShaderNodes.hpp +++ b/include/Nazara/Renderer/ShaderNodes.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -59,7 +60,7 @@ namespace Nz inline Expression(NodeType type); virtual ExpressionCategory GetExpressionCategory() const; - virtual BasicType GetExpressionType() const = 0; + virtual ShaderExpressionType GetExpressionType() const = 0; }; class Statement; @@ -125,7 +126,7 @@ namespace Nz inline Identifier(); ExpressionCategory GetExpressionCategory() const override; - BasicType GetExpressionType() const override; + ShaderExpressionType GetExpressionType() const override; void Visit(ShaderVisitor& visitor) override; VariablePtr var; @@ -133,13 +134,28 @@ namespace Nz static inline std::shared_ptr Build(VariablePtr variable); }; + struct NAZARA_RENDERER_API AccessMember : public Expression + { + inline AccessMember(); + + ExpressionCategory GetExpressionCategory() const override; + ShaderExpressionType GetExpressionType() const override; + void Visit(ShaderVisitor& visitor) override; + + std::size_t memberIndex; + ExpressionPtr structExpr; + ShaderExpressionType exprType; //< FIXME: Use ShaderAst to automate + + static inline std::shared_ptr Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType); + }; + ////////////////////////////////////////////////////////////////////////// struct NAZARA_RENDERER_API AssignOp : public Expression { inline AssignOp(); - BasicType GetExpressionType() const override; + ShaderExpressionType GetExpressionType() const override; void Visit(ShaderVisitor& visitor) override; AssignType op; @@ -153,7 +169,7 @@ namespace Nz { inline BinaryOp(); - BasicType GetExpressionType() const override; + ShaderExpressionType GetExpressionType() const override; void Visit(ShaderVisitor& visitor) override; BinaryType op; @@ -187,7 +203,7 @@ namespace Nz { inline Cast(); - BasicType GetExpressionType() const override; + ShaderExpressionType GetExpressionType() const override; void Visit(ShaderVisitor& visitor) override; BasicType exprType; @@ -201,7 +217,7 @@ namespace Nz { inline Constant(); - BasicType GetExpressionType() const override; + ShaderExpressionType GetExpressionType() const override; void Visit(ShaderVisitor& visitor) override; BasicType exprType; @@ -227,7 +243,7 @@ namespace Nz inline SwizzleOp(); ExpressionCategory GetExpressionCategory() const override; - BasicType GetExpressionType() const override; + ShaderExpressionType GetExpressionType() const override; void Visit(ShaderVisitor& visitor) override; std::array components; @@ -243,7 +259,7 @@ namespace Nz { inline Sample2D(); - BasicType GetExpressionType() const override; + ShaderExpressionType GetExpressionType() const override; void Visit(ShaderVisitor& visitor) override; ExpressionPtr sampler; @@ -258,7 +274,7 @@ namespace Nz { inline IntrinsicCall(); - BasicType GetExpressionType() const override; + ShaderExpressionType GetExpressionType() const override; void Visit(ShaderVisitor& visitor) override; IntrinsicType intrinsic; diff --git a/include/Nazara/Renderer/ShaderNodes.inl b/include/Nazara/Renderer/ShaderNodes.inl index 31d3028ed..606487e0c 100644 --- a/include/Nazara/Renderer/ShaderNodes.inl +++ b/include/Nazara/Renderer/ShaderNodes.inl @@ -146,6 +146,22 @@ namespace Nz::ShaderNodes } + inline AccessMember::AccessMember() : + Expression(NodeType::AccessMember) + { + } + + inline std::shared_ptr AccessMember::Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType) + { + auto node = std::make_shared(); + node->exprType = std::move(exprType); + node->memberIndex = memberIndex; + node->structExpr = std::move(structExpr); + + return node; + } + + inline AssignOp::AssignOp() : Expression(NodeType::AssignOp) { diff --git a/include/Nazara/Renderer/ShaderSerializer.hpp b/include/Nazara/Renderer/ShaderSerializer.hpp index e4e6be7b3..c6596b58d 100644 --- a/include/Nazara/Renderer/ShaderSerializer.hpp +++ b/include/Nazara/Renderer/ShaderSerializer.hpp @@ -25,6 +25,7 @@ namespace Nz ShaderSerializerBase(ShaderSerializerBase&&) = delete; ~ShaderSerializerBase() = default; + void Serialize(ShaderNodes::AccessMember& node); void Serialize(ShaderNodes::AssignOp& node); void Serialize(ShaderNodes::BinaryOp& node); void Serialize(ShaderNodes::BuiltinVariable& var); @@ -51,6 +52,8 @@ namespace Nz virtual void Node(ShaderNodes::NodePtr& node) = 0; template void Node(std::shared_ptr& node); + virtual void Type(ShaderExpressionType& type) = 0; + virtual void Value(bool& val) = 0; virtual void Value(float& val) = 0; virtual void Value(std::string& val) = 0; @@ -78,6 +81,7 @@ namespace Nz bool IsWriting() const override; void Node(const ShaderNodes::NodePtr& node); void Node(ShaderNodes::NodePtr& node) override; + void Type(ShaderExpressionType& type) override; void Value(bool& val) override; void Value(float& val) override; void Value(std::string& val) override; @@ -103,7 +107,7 @@ namespace Nz private: bool IsWriting() const override; void Node(ShaderNodes::NodePtr& node) override; - void Type(ShaderExpressionType& type); + void Type(ShaderExpressionType& type) override; void Value(bool& val) override; void Value(float& val) override; void Value(std::string& val) override; diff --git a/include/Nazara/Renderer/ShaderValidator.hpp b/include/Nazara/Renderer/ShaderValidator.hpp index ecfea1290..09015b073 100644 --- a/include/Nazara/Renderer/ShaderValidator.hpp +++ b/include/Nazara/Renderer/ShaderValidator.hpp @@ -33,6 +33,7 @@ namespace Nz void TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right); using ShaderVisitor::Visit; + void Visit(const ShaderNodes::AccessMember& node) override; void Visit(const ShaderNodes::AssignOp& node) override; void Visit(const ShaderNodes::BinaryOp& node) override; void Visit(const ShaderNodes::Branch& node) override; diff --git a/include/Nazara/Renderer/ShaderVariables.hpp b/include/Nazara/Renderer/ShaderVariables.hpp index f80386d0c..9b52d3913 100644 --- a/include/Nazara/Renderer/ShaderVariables.hpp +++ b/include/Nazara/Renderer/ShaderVariables.hpp @@ -12,7 +12,7 @@ #include #include #include -#include +#include #include #include #include @@ -34,7 +34,7 @@ namespace Nz virtual VariableType GetType() const = 0; virtual void Visit(ShaderVarVisitor& visitor) = 0; - BasicType type; + ShaderExpressionType type; }; struct BuiltinVariable; @@ -48,7 +48,7 @@ namespace Nz VariableType GetType() const override; void Visit(ShaderVarVisitor& visitor) override; - static inline std::shared_ptr Build(BuiltinEntry entry, BasicType varType); + static inline std::shared_ptr Build(BuiltinEntry entry, ShaderExpressionType varType); }; struct NamedVariable; @@ -69,7 +69,7 @@ namespace Nz VariableType GetType() const override; void Visit(ShaderVarVisitor& visitor) override; - static inline std::shared_ptr Build(std::string varName, BasicType varType); + static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); }; struct LocalVariable; @@ -81,7 +81,7 @@ namespace Nz VariableType GetType() const override; void Visit(ShaderVarVisitor& visitor) override; - static inline std::shared_ptr Build(std::string varName, BasicType varType); + static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); }; struct OutputVariable; @@ -93,7 +93,7 @@ namespace Nz VariableType GetType() const override; void Visit(ShaderVarVisitor& visitor) override; - static inline std::shared_ptr Build(std::string varName, BasicType varType); + static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); }; struct ParameterVariable; @@ -105,7 +105,7 @@ namespace Nz VariableType GetType() const override; void Visit(ShaderVarVisitor& visitor) override; - static inline std::shared_ptr Build(std::string varName, BasicType varType); + static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); }; struct UniformVariable; @@ -117,7 +117,7 @@ namespace Nz VariableType GetType() const override; void Visit(ShaderVarVisitor& visitor) override; - static inline std::shared_ptr Build(std::string varName, BasicType varType); + static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); }; } } diff --git a/include/Nazara/Renderer/ShaderVariables.inl b/include/Nazara/Renderer/ShaderVariables.inl index e01ecb33e..917459c67 100644 --- a/include/Nazara/Renderer/ShaderVariables.inl +++ b/include/Nazara/Renderer/ShaderVariables.inl @@ -7,7 +7,7 @@ namespace Nz::ShaderNodes { - inline std::shared_ptr BuiltinVariable::Build(BuiltinEntry variable, BasicType varType) + inline std::shared_ptr BuiltinVariable::Build(BuiltinEntry variable, ShaderExpressionType varType) { auto node = std::make_shared(); node->entry = variable; @@ -16,7 +16,7 @@ namespace Nz::ShaderNodes return node; } - inline std::shared_ptr InputVariable::Build(std::string varName, BasicType varType) + inline std::shared_ptr InputVariable::Build(std::string varName, ShaderExpressionType varType) { auto node = std::make_shared(); node->name = std::move(varName); @@ -25,7 +25,7 @@ namespace Nz::ShaderNodes return node; } - inline std::shared_ptr LocalVariable::Build(std::string varName, BasicType varType) + inline std::shared_ptr LocalVariable::Build(std::string varName, ShaderExpressionType varType) { auto node = std::make_shared(); node->name = std::move(varName); @@ -34,7 +34,7 @@ namespace Nz::ShaderNodes return node; } - inline std::shared_ptr OutputVariable::Build(std::string varName, BasicType varType) + inline std::shared_ptr OutputVariable::Build(std::string varName, ShaderExpressionType varType) { auto node = std::make_shared(); node->name = std::move(varName); @@ -43,7 +43,7 @@ namespace Nz::ShaderNodes return node; } - inline std::shared_ptr ParameterVariable::Build(std::string varName, BasicType varType) + inline std::shared_ptr ParameterVariable::Build(std::string varName, ShaderExpressionType varType) { auto node = std::make_shared(); node->name = std::move(varName); @@ -52,7 +52,7 @@ namespace Nz::ShaderNodes return node; } - inline std::shared_ptr UniformVariable::Build(std::string varName, BasicType varType) + inline std::shared_ptr UniformVariable::Build(std::string varName, ShaderExpressionType varType) { auto node = std::make_shared(); node->name = std::move(varName); diff --git a/include/Nazara/Renderer/ShaderVisitor.hpp b/include/Nazara/Renderer/ShaderVisitor.hpp index 9db22882c..1cd98fac3 100644 --- a/include/Nazara/Renderer/ShaderVisitor.hpp +++ b/include/Nazara/Renderer/ShaderVisitor.hpp @@ -27,6 +27,7 @@ namespace Nz bool IsConditionEnabled(const std::string& name) const; + virtual void Visit(const ShaderNodes::AccessMember& node) = 0; virtual void Visit(const ShaderNodes::AssignOp& node) = 0; virtual void Visit(const ShaderNodes::BinaryOp& node) = 0; virtual void Visit(const ShaderNodes::Branch& node) = 0; diff --git a/src/Nazara/Renderer/GlslWriter.cpp b/src/Nazara/Renderer/GlslWriter.cpp index 4dcbcd4cc..33ae8c0db 100644 --- a/src/Nazara/Renderer/GlslWriter.cpp +++ b/src/Nazara/Renderer/GlslWriter.cpp @@ -21,6 +21,8 @@ namespace Nz if (!ValidateShader(shader, &error)) throw std::runtime_error("Invalid shader AST: " + error); + m_context.shader = &shader; + State state; m_currentState = &state; CallOnExit onExit([this]() @@ -294,7 +296,30 @@ namespace Nz AppendLine(); AppendLine("}"); } - + + void GlslWriter::Visit(const ShaderNodes::AccessMember& node) + { + Append("("); + Visit(node.structExpr); + Append(")"); + + const ShaderExpressionType& exprType = node.structExpr->GetExpressionType(); + assert(std::holds_alternative(exprType)); + + const std::string& structName = std::get(exprType); + + const auto& structs = m_context.shader->GetStructs(); + auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; }); + assert(it != structs.end()); + + const ShaderAst::Struct& s = *it; + assert(node.memberIndex < s.members.size()); + + const auto& member = s.members[node.memberIndex]; + Append("."); + Append(member.name); + } + void GlslWriter::Visit(const ShaderNodes::AssignOp& node) { Visit(node.left); @@ -374,9 +399,7 @@ namespace Nz Append(node.exprType); Append("("); - unsigned int i = 0; - unsigned int requiredComponents = ShaderNodes::Node::GetComponentCount(node.exprType); - while (requiredComponents > 0) + for (std::size_t i = 0; node.expressions[i]; ++i) { if (i != 0) m_currentState->stream << ", "; @@ -385,7 +408,6 @@ namespace Nz NazaraAssert(exprPtr, "Invalid expression"); Visit(exprPtr); - requiredComponents -= ShaderNodes::Node::GetComponentCount(exprPtr->GetExpressionType()); } Append(")"); diff --git a/src/Nazara/Renderer/ShaderNodes.cpp b/src/Nazara/Renderer/ShaderNodes.cpp index 57ac58ee8..794e1217a 100644 --- a/src/Nazara/Renderer/ShaderNodes.cpp +++ b/src/Nazara/Renderer/ShaderNodes.cpp @@ -47,7 +47,7 @@ namespace Nz::ShaderNodes return ExpressionCategory::LValue; } - BasicType Identifier::GetExpressionType() const + ShaderExpressionType Identifier::GetExpressionType() const { assert(var); return var->type; @@ -58,8 +58,22 @@ namespace Nz::ShaderNodes visitor.Visit(*this); } + ExpressionCategory ShaderNodes::AccessMember::GetExpressionCategory() const + { + return ExpressionCategory::LValue; + } - BasicType AssignOp::GetExpressionType() const + ShaderExpressionType AccessMember::GetExpressionType() const + { + return exprType; + } + + void AccessMember::Visit(ShaderVisitor& visitor) + { + visitor.Visit(*this); + } + + ShaderExpressionType AssignOp::GetExpressionType() const { return left->GetExpressionType(); } @@ -70,31 +84,39 @@ namespace Nz::ShaderNodes } - BasicType BinaryOp::GetExpressionType() const + ShaderExpressionType BinaryOp::GetExpressionType() const { - ShaderNodes::BasicType exprType = ShaderNodes::BasicType::Void; + std::optional exprType; switch (op) { - case ShaderNodes::BinaryType::Add: - case ShaderNodes::BinaryType::Substract: + case BinaryType::Add: + case BinaryType::Substract: exprType = left->GetExpressionType(); break; - case ShaderNodes::BinaryType::Divide: - case ShaderNodes::BinaryType::Multiply: - //FIXME - exprType = static_cast(std::max(UnderlyingCast(left->GetExpressionType()), UnderlyingCast(right->GetExpressionType()))); - break; + case BinaryType::Divide: + case BinaryType::Multiply: + { + const ShaderExpressionType& leftExprType = left->GetExpressionType(); + assert(std::holds_alternative(leftExprType)); - case ShaderNodes::BinaryType::Equality: + const ShaderExpressionType& rightExprType = right->GetExpressionType(); + assert(std::holds_alternative(rightExprType)); + + //FIXME + exprType = static_cast(std::max(UnderlyingCast(std::get(leftExprType)), UnderlyingCast(std::get(rightExprType)))); + break; + } + + case BinaryType::Equality: exprType = BasicType::Boolean; break; } - NazaraAssert(exprType != ShaderNodes::BasicType::Void, "Unhandled builtin"); + NazaraAssert(exprType.has_value(), "Unhandled builtin"); - return exprType; + return *exprType; } void BinaryOp::Visit(ShaderVisitor& visitor) @@ -109,7 +131,7 @@ namespace Nz::ShaderNodes } - BasicType Constant::GetExpressionType() const + ShaderExpressionType Constant::GetExpressionType() const { return exprType; } @@ -119,7 +141,7 @@ namespace Nz::ShaderNodes visitor.Visit(*this); } - BasicType Cast::GetExpressionType() const + ShaderExpressionType Cast::GetExpressionType() const { return exprType; } @@ -135,9 +157,12 @@ namespace Nz::ShaderNodes return ExpressionCategory::LValue; } - BasicType SwizzleOp::GetExpressionType() const + ShaderExpressionType SwizzleOp::GetExpressionType() const { - return static_cast(UnderlyingCast(GetComponentType(expression->GetExpressionType())) + componentCount - 1); + const ShaderExpressionType& exprType = expression->GetExpressionType(); + assert(std::holds_alternative(exprType)); + + return static_cast(UnderlyingCast(GetComponentType(std::get(exprType))) + componentCount - 1); } void SwizzleOp::Visit(ShaderVisitor& visitor) @@ -146,7 +171,7 @@ namespace Nz::ShaderNodes } - BasicType Sample2D::GetExpressionType() const + ShaderExpressionType Sample2D::GetExpressionType() const { return BasicType::Float4; } @@ -157,7 +182,7 @@ namespace Nz::ShaderNodes } - BasicType IntrinsicCall::GetExpressionType() const + ShaderExpressionType IntrinsicCall::GetExpressionType() const { switch (intrinsic) { diff --git a/src/Nazara/Renderer/ShaderSerializer.cpp b/src/Nazara/Renderer/ShaderSerializer.cpp index 17f369037..71beead60 100644 --- a/src/Nazara/Renderer/ShaderSerializer.cpp +++ b/src/Nazara/Renderer/ShaderSerializer.cpp @@ -22,6 +22,11 @@ namespace Nz { } + void Visit(const ShaderNodes::AccessMember& node) override + { + Serialize(node); + } + void Visit(const ShaderNodes::AssignOp& node) override { Serialize(node); @@ -125,6 +130,13 @@ namespace Nz }; } + void ShaderSerializerBase::Serialize(ShaderNodes::AccessMember& node) + { + Value(node.memberIndex); + Node(node.structExpr); + Type(node.exprType); + } + void ShaderSerializerBase::Serialize(ShaderNodes::AssignOp& node) { Enum(node.op); @@ -153,8 +165,8 @@ namespace Nz void ShaderSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node) { - Enum(node.type); - Enum(node.type); + Enum(node.entry); + Type(node.type); } void ShaderSerializerBase::Serialize(ShaderNodes::Cast& node) @@ -219,7 +231,7 @@ namespace Nz void ShaderSerializerBase::Serialize(ShaderNodes::NamedVariable& node) { Value(node.name); - Enum(node.type); + Type(node.type); } void ShaderSerializerBase::Serialize(ShaderNodes::Sample2D& node) @@ -348,6 +360,26 @@ namespace Nz } } + void ShaderSerializer::Type(ShaderExpressionType& type) + { + std::visit([&](auto&& arg) + { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + m_stream << UInt8(0); + m_stream << UInt32(arg); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(1); + m_stream << arg; + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, type); + } + void ShaderSerializer::Node(const ShaderNodes::NodePtr& node) { Node(const_cast(node)); //< Yes const_cast is ugly but it won't be used for writing diff --git a/src/Nazara/Renderer/ShaderValidator.cpp b/src/Nazara/Renderer/ShaderValidator.cpp index 1445d4cf1..c181ecd62 100644 --- a/src/Nazara/Renderer/ShaderValidator.cpp +++ b/src/Nazara/Renderer/ShaderValidator.cpp @@ -21,7 +21,7 @@ namespace Nz struct Local { std::string name; - ShaderNodes::BasicType type; + ShaderExpressionType type; }; const ShaderAst::Function* currentFunction; @@ -83,6 +83,28 @@ namespace Nz throw AstError{ "Left expression type must match right expression type" }; } + void ShaderValidator::Visit(const ShaderNodes::AccessMember& node) + { + const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType(); + if (!std::holds_alternative(exprType)) + throw AstError{ "expression is not a structure" }; + + const std::string& structName = std::get(exprType); + + const auto& structs = m_shader.GetStructs(); + auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; }); + if (it == structs.end()) + throw AstError{ "invalid structure" }; + + const ShaderAst::Struct& s = *it; + if (node.memberIndex >= s.members.size()) + throw AstError{ "member index out of bounds" }; + + const auto& member = s.members[node.memberIndex]; + if (member.type != node.exprType) + throw AstError{ "member type does not match node type" }; + } + void ShaderValidator::Visit(const ShaderNodes::AssignOp& node) { MandatoryNode(node.left); @@ -101,8 +123,16 @@ namespace Nz MandatoryNode(node.left); MandatoryNode(node.right); - ShaderNodes::BasicType leftType = node.left->GetExpressionType(); - ShaderNodes::BasicType rightType = node.right->GetExpressionType(); + const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType(); + if (!std::holds_alternative(leftExprType)) + throw AstError{ "left expression type does not support binary operation" }; + + const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType(); + if (!std::holds_alternative(rightExprType)) + throw AstError{ "right expression type does not support binary operation" }; + + ShaderNodes::BasicType leftType = std::get(leftExprType); + ShaderNodes::BasicType rightType = std::get(rightExprType); switch (node.op) { @@ -179,7 +209,11 @@ namespace Nz if (!exprPtr) break; - componentCount += node.GetComponentCount(exprPtr->GetExpressionType()); + const ShaderExpressionType& exprType = exprPtr->GetExpressionType(); + if (!std::holds_alternative(exprType)) + throw AstError{ "incompatible type" }; + + componentCount += node.GetComponentCount(std::get(exprType)); Visit(exprPtr); } @@ -318,7 +352,7 @@ namespace Nz for (auto& param : node.parameters) MandatoryNode(param); - ShaderNodes::BasicType type = node.parameters.front()->GetExpressionType(); + ShaderExpressionType type = node.parameters.front()->GetExpressionType(); for (std::size_t i = 1; i < node.parameters.size(); ++i) { if (type != node.parameters[i]->GetExpressionType()) @@ -333,7 +367,7 @@ namespace Nz { case ShaderNodes::IntrinsicType::CrossProduct: { - if (node.parameters[0]->GetExpressionType() != ShaderNodes::BasicType::Float3) + if (node.parameters[0]->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float3 }) throw AstError{ "CrossProduct only works with Float3 expressions" }; break; @@ -349,10 +383,10 @@ namespace Nz void ShaderValidator::Visit(const ShaderNodes::Sample2D& node) { - if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderNodes::BasicType::Sampler2D) + if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D }) throw AstError{ "Sampler must be a Sampler2D" }; - if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderNodes::BasicType::Float2) + if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float2 }) throw AstError{ "Coordinates must be a Float2" }; Visit(node.sampler); @@ -378,7 +412,11 @@ namespace Nz if (node.componentCount > 4) throw AstError{ "Cannot swizzle more than four elements" }; - switch (MandatoryExpr(node.expression)->GetExpressionType()) + const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType(); + if (!std::holds_alternative(exprType)) + throw AstError{ "Cannot swizzle this type" }; + + switch (std::get(exprType)) { case ShaderNodes::BasicType::Float1: case ShaderNodes::BasicType::Float2: