diff --git a/examples/VulkanTest/main.cpp b/examples/VulkanTest/main.cpp index 9a5313f9a..706cf047c 100644 --- a/examples/VulkanTest/main.cpp +++ b/examples/VulkanTest/main.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -7,8 +8,8 @@ #define SPIRV 0 int main() -{ - { +{ + /*{ Nz::File file("frag.shader"); if (!file.Open(Nz::OpenMode_ReadOnly)) return __LINE__; @@ -39,7 +40,7 @@ int main() return 0; } - + */ Nz::Initializer loader; if (!loader) { diff --git a/examples/bin/frag.shader b/examples/bin/frag.shader index 1fbcce530..cd71692bd 100644 Binary files a/examples/bin/frag.shader and b/examples/bin/frag.shader differ diff --git a/examples/bin/test.spirv b/examples/bin/test.spirv new file mode 100644 index 000000000..3d47fe4e8 Binary files /dev/null and b/examples/bin/test.spirv differ diff --git a/examples/bin/vert.shader b/examples/bin/vert.shader index b1fa132bb..2d5c2daca 100644 Binary files a/examples/bin/vert.shader and b/examples/bin/vert.shader differ diff --git a/include/Nazara/Renderer/GlslWriter.hpp b/include/Nazara/Renderer/GlslWriter.hpp index 8178c73d4..d7229f1d9 100644 --- a/include/Nazara/Renderer/GlslWriter.hpp +++ b/include/Nazara/Renderer/GlslWriter.hpp @@ -61,26 +61,26 @@ namespace Nz using ShaderVarVisitor::Visit; using ShaderAstVisitor::Visit; - void Visit(const ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired = false); - void Visit(const ShaderNodes::AccessMember& node) override; - void Visit(const ShaderNodes::AssignOp& node) override; - void Visit(const ShaderNodes::Branch& node) override; - void Visit(const ShaderNodes::BinaryOp& node) override; - void Visit(const ShaderNodes::BuiltinVariable& var) override; - void Visit(const ShaderNodes::Cast& node) override; - void Visit(const ShaderNodes::Constant& node) override; - void Visit(const ShaderNodes::DeclareVariable& node) override; - void Visit(const ShaderNodes::ExpressionStatement& node) override; - void Visit(const ShaderNodes::Identifier& node) override; - void Visit(const ShaderNodes::InputVariable& var) override; - void Visit(const ShaderNodes::IntrinsicCall& node) override; - void Visit(const ShaderNodes::LocalVariable& var) override; - void Visit(const ShaderNodes::ParameterVariable& var) override; - void Visit(const ShaderNodes::OutputVariable& var) override; - void Visit(const ShaderNodes::Sample2D& node) override; - void Visit(const ShaderNodes::StatementBlock& node) override; - void Visit(const ShaderNodes::SwizzleOp& node) override; - void Visit(const ShaderNodes::UniformVariable& var) override; + void Visit(ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired = false); + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::AssignOp& node) override; + void Visit(ShaderNodes::Branch& node) override; + void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::BuiltinVariable& var) override; + void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::Constant& node) override; + void Visit(ShaderNodes::DeclareVariable& node) override; + void Visit(ShaderNodes::ExpressionStatement& node) override; + void Visit(ShaderNodes::Identifier& node) override; + void Visit(ShaderNodes::InputVariable& var) override; + void Visit(ShaderNodes::IntrinsicCall& node) override; + void Visit(ShaderNodes::LocalVariable& var) override; + void Visit(ShaderNodes::ParameterVariable& var) override; + void Visit(ShaderNodes::OutputVariable& var) override; + void Visit(ShaderNodes::Sample2D& node) override; + void Visit(ShaderNodes::StatementBlock& node) override; + void Visit(ShaderNodes::SwizzleOp& node) override; + void Visit(ShaderNodes::UniformVariable& var) override; static bool HasExplicitBinding(const ShaderAst& shader); static bool HasExplicitLocation(const ShaderAst& shader); diff --git a/include/Nazara/Renderer/ShaderAstCloner.hpp b/include/Nazara/Renderer/ShaderAstCloner.hpp index 20d1ed52b..ff1f9c5a4 100644 --- a/include/Nazara/Renderer/ShaderAstCloner.hpp +++ b/include/Nazara/Renderer/ShaderAstCloner.hpp @@ -33,26 +33,26 @@ namespace Nz ShaderNodes::StatementPtr CloneStatement(const ShaderNodes::StatementPtr& statement); ShaderNodes::VariablePtr CloneVariable(const ShaderNodes::VariablePtr& statement); - void Visit(const ShaderNodes::AccessMember& node) override; - void Visit(const ShaderNodes::AssignOp& node) override; - void Visit(const ShaderNodes::BinaryOp& node) override; - void Visit(const ShaderNodes::Branch& node) override; - void Visit(const ShaderNodes::Cast& node) override; - void Visit(const ShaderNodes::Constant& node) override; - void Visit(const ShaderNodes::DeclareVariable& node) override; - void Visit(const ShaderNodes::ExpressionStatement& node) override; - void Visit(const ShaderNodes::Identifier& node) override; - void Visit(const ShaderNodes::IntrinsicCall& node) override; - void Visit(const ShaderNodes::Sample2D& node) override; - void Visit(const ShaderNodes::StatementBlock& node) override; - void Visit(const ShaderNodes::SwizzleOp& node) override; + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::AssignOp& node) override; + void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::Branch& node) override; + void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::Constant& node) override; + void Visit(ShaderNodes::DeclareVariable& node) override; + void Visit(ShaderNodes::ExpressionStatement& node) override; + void Visit(ShaderNodes::Identifier& node) override; + void Visit(ShaderNodes::IntrinsicCall& node) override; + void Visit(ShaderNodes::Sample2D& node) override; + void Visit(ShaderNodes::StatementBlock& node) override; + void Visit(ShaderNodes::SwizzleOp& node) override; - void Visit(const ShaderNodes::BuiltinVariable& var) override; - void Visit(const ShaderNodes::InputVariable& var) override; - void Visit(const ShaderNodes::LocalVariable& var) override; - void Visit(const ShaderNodes::OutputVariable& var) override; - void Visit(const ShaderNodes::ParameterVariable& var) override; - void Visit(const ShaderNodes::UniformVariable& var) override; + void Visit(ShaderNodes::BuiltinVariable& var) override; + void Visit(ShaderNodes::InputVariable& var) override; + void Visit(ShaderNodes::LocalVariable& var) override; + void Visit(ShaderNodes::OutputVariable& var) override; + void Visit(ShaderNodes::ParameterVariable& var) override; + void Visit(ShaderNodes::UniformVariable& var) override; void PushExpression(ShaderNodes::ExpressionPtr expression); void PushStatement(ShaderNodes::StatementPtr statement); diff --git a/include/Nazara/Renderer/ShaderAstRecursiveVisitor.hpp b/include/Nazara/Renderer/ShaderAstRecursiveVisitor.hpp index 622c1aef0..0482ea009 100644 --- a/include/Nazara/Renderer/ShaderAstRecursiveVisitor.hpp +++ b/include/Nazara/Renderer/ShaderAstRecursiveVisitor.hpp @@ -21,19 +21,19 @@ namespace Nz using ShaderAstVisitor::Visit; - void Visit(const ShaderNodes::AccessMember& node) override; - void Visit(const ShaderNodes::AssignOp& node) override; - void Visit(const ShaderNodes::BinaryOp& node) override; - void Visit(const ShaderNodes::Branch& node) override; - void Visit(const ShaderNodes::Cast& node) override; - void Visit(const ShaderNodes::Constant& node) override; - void Visit(const ShaderNodes::DeclareVariable& node) override; - void Visit(const ShaderNodes::ExpressionStatement& node) override; - void Visit(const ShaderNodes::Identifier& node) override; - void Visit(const ShaderNodes::IntrinsicCall& node) override; - void Visit(const ShaderNodes::Sample2D& node) override; - void Visit(const ShaderNodes::StatementBlock& node) override; - void Visit(const ShaderNodes::SwizzleOp& node) override; + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::AssignOp& node) override; + void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::Branch& node) override; + void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::Constant& node) override; + void Visit(ShaderNodes::DeclareVariable& node) override; + void Visit(ShaderNodes::ExpressionStatement& node) override; + void Visit(ShaderNodes::Identifier& node) override; + void Visit(ShaderNodes::IntrinsicCall& node) override; + void Visit(ShaderNodes::Sample2D& node) override; + void Visit(ShaderNodes::StatementBlock& node) override; + void Visit(ShaderNodes::SwizzleOp& node) override; }; } diff --git a/include/Nazara/Renderer/ShaderAstSerializer.hpp b/include/Nazara/Renderer/ShaderAstSerializer.hpp index 68380054c..3298d81e6 100644 --- a/include/Nazara/Renderer/ShaderAstSerializer.hpp +++ b/include/Nazara/Renderer/ShaderAstSerializer.hpp @@ -57,9 +57,13 @@ namespace Nz virtual void Value(bool& val) = 0; virtual void Value(float& val) = 0; virtual void Value(std::string& val) = 0; + virtual void Value(Int32& val) = 0; virtual void Value(Vector2f& val) = 0; virtual void Value(Vector3f& val) = 0; virtual void Value(Vector4f& val) = 0; + virtual void Value(Vector2i32& val) = 0; + virtual void Value(Vector3i32& val) = 0; + virtual void Value(Vector4i32& val) = 0; virtual void Value(UInt8& val) = 0; virtual void Value(UInt16& val) = 0; virtual void Value(UInt32& val) = 0; @@ -85,9 +89,13 @@ namespace Nz void Value(bool& val) override; void Value(float& val) override; void Value(std::string& val) override; + void Value(Int32& val) override; void Value(Vector2f& val) override; void Value(Vector3f& val) override; void Value(Vector4f& val) override; + void Value(Vector2i32& val) override; + void Value(Vector3i32& val) override; + void Value(Vector4i32& val) override; void Value(UInt8& val) override; void Value(UInt16& val) override; void Value(UInt32& val) override; @@ -111,9 +119,13 @@ namespace Nz void Value(bool& val) override; void Value(float& val) override; void Value(std::string& val) override; + void Value(Int32& val) override; void Value(Vector2f& val) override; void Value(Vector3f& val) override; void Value(Vector4f& val) override; + void Value(Vector2i32& val) override; + void Value(Vector3i32& val) override; + void Value(Vector4i32& val) override; void Value(UInt8& val) override; void Value(UInt16& val) override; void Value(UInt32& val) override; diff --git a/include/Nazara/Renderer/ShaderAstValidator.hpp b/include/Nazara/Renderer/ShaderAstValidator.hpp index aa586f99b..90aec52aa 100644 --- a/include/Nazara/Renderer/ShaderAstValidator.hpp +++ b/include/Nazara/Renderer/ShaderAstValidator.hpp @@ -34,27 +34,27 @@ namespace Nz void TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right); using ShaderAstRecursiveVisitor::Visit; - void Visit(const ShaderNodes::AccessMember& node) override; - void Visit(const ShaderNodes::AssignOp& node) override; - void Visit(const ShaderNodes::BinaryOp& node) override; - void Visit(const ShaderNodes::Branch& node) override; - void Visit(const ShaderNodes::Cast& node) override; - void Visit(const ShaderNodes::Constant& node) override; - void Visit(const ShaderNodes::DeclareVariable& node) override; - void Visit(const ShaderNodes::ExpressionStatement& node) override; - void Visit(const ShaderNodes::Identifier& node) override; - void Visit(const ShaderNodes::IntrinsicCall& node) override; - void Visit(const ShaderNodes::Sample2D& node) override; - void Visit(const ShaderNodes::StatementBlock& node) override; - void Visit(const ShaderNodes::SwizzleOp& node) override; + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::AssignOp& node) override; + void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::Branch& node) override; + void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::Constant& node) override; + void Visit(ShaderNodes::DeclareVariable& node) override; + void Visit(ShaderNodes::ExpressionStatement& node) override; + void Visit(ShaderNodes::Identifier& node) override; + void Visit(ShaderNodes::IntrinsicCall& node) override; + void Visit(ShaderNodes::Sample2D& node) override; + void Visit(ShaderNodes::StatementBlock& node) override; + void Visit(ShaderNodes::SwizzleOp& node) override; using ShaderVarVisitor::Visit; - void Visit(const ShaderNodes::BuiltinVariable& var) override; - void Visit(const ShaderNodes::InputVariable& var) override; - void Visit(const ShaderNodes::LocalVariable& var) override; - void Visit(const ShaderNodes::OutputVariable& var) override; - void Visit(const ShaderNodes::ParameterVariable& var) override; - void Visit(const ShaderNodes::UniformVariable& var) override; + void Visit(ShaderNodes::BuiltinVariable& var) override; + void Visit(ShaderNodes::InputVariable& var) override; + void Visit(ShaderNodes::LocalVariable& var) override; + void Visit(ShaderNodes::OutputVariable& var) override; + void Visit(ShaderNodes::ParameterVariable& var) override; + void Visit(ShaderNodes::UniformVariable& var) override; struct Context; diff --git a/include/Nazara/Renderer/ShaderAstVisitor.hpp b/include/Nazara/Renderer/ShaderAstVisitor.hpp index d955f17e9..3ce83f406 100644 --- a/include/Nazara/Renderer/ShaderAstVisitor.hpp +++ b/include/Nazara/Renderer/ShaderAstVisitor.hpp @@ -27,20 +27,20 @@ namespace Nz bool IsConditionEnabled(const std::string& name) const; - virtual void Visit(const ShaderNodes::AccessMember& node) = 0; - virtual void Visit(const ShaderNodes::AssignOp& node) = 0; - virtual void Visit(const ShaderNodes::BinaryOp& node) = 0; - virtual void Visit(const ShaderNodes::Branch& node) = 0; - virtual void Visit(const ShaderNodes::Cast& node) = 0; - virtual void Visit(const ShaderNodes::Constant& node) = 0; - virtual void Visit(const ShaderNodes::DeclareVariable& node) = 0; - virtual void Visit(const ShaderNodes::ExpressionStatement& node) = 0; - virtual void Visit(const ShaderNodes::Identifier& node) = 0; - virtual void Visit(const ShaderNodes::IntrinsicCall& node) = 0; void Visit(const ShaderNodes::NodePtr& node); - virtual void Visit(const ShaderNodes::Sample2D& node) = 0; - virtual void Visit(const ShaderNodes::StatementBlock& node) = 0; - virtual void Visit(const ShaderNodes::SwizzleOp& node) = 0; + virtual void Visit(ShaderNodes::AccessMember& node) = 0; + virtual void Visit(ShaderNodes::AssignOp& node) = 0; + virtual void Visit(ShaderNodes::BinaryOp& node) = 0; + virtual void Visit(ShaderNodes::Branch& node) = 0; + virtual void Visit(ShaderNodes::Cast& node) = 0; + virtual void Visit(ShaderNodes::Constant& node) = 0; + virtual void Visit(ShaderNodes::DeclareVariable& node) = 0; + virtual void Visit(ShaderNodes::ExpressionStatement& node) = 0; + virtual void Visit(ShaderNodes::Identifier& node) = 0; + virtual void Visit(ShaderNodes::IntrinsicCall& node) = 0; + virtual void Visit(ShaderNodes::Sample2D& node) = 0; + virtual void Visit(ShaderNodes::StatementBlock& node) = 0; + virtual void Visit(ShaderNodes::SwizzleOp& node) = 0; private: std::unordered_set m_conditions; diff --git a/include/Nazara/Renderer/ShaderEnums.hpp b/include/Nazara/Renderer/ShaderEnums.hpp index ab9ed4953..46e6a3d74 100644 --- a/include/Nazara/Renderer/ShaderEnums.hpp +++ b/include/Nazara/Renderer/ShaderEnums.hpp @@ -18,15 +18,19 @@ namespace Nz::ShaderNodes enum class BasicType { - Boolean, // bool - Float1, // float - Float2, // vec2 - Float3, // vec3 - Float4, // vec4 - Mat4x4, // mat4 - Sampler2D, // sampler2D + Boolean, //< bool + Float1, //< float + Float2, //< vec2 + Float3, //< vec3 + Float4, //< vec4 + Int1, //< int + Int2, //< ivec2 + Int3, //< ivec3 + Int4, //< ivec4 + Mat4x4, //< mat4 + Sampler2D, //< sampler2D - Void // void + Void //< void }; enum class BinaryType diff --git a/include/Nazara/Renderer/ShaderNodes.hpp b/include/Nazara/Renderer/ShaderNodes.hpp index c0b8688e2..985cde039 100644 --- a/include/Nazara/Renderer/ShaderNodes.hpp +++ b/include/Nazara/Renderer/ShaderNodes.hpp @@ -145,7 +145,7 @@ namespace Nz std::size_t memberIndex; ExpressionPtr structExpr; - ShaderExpressionType exprType; //< FIXME: Use ShaderAst to automate + ShaderExpressionType exprType; static inline std::shared_ptr Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType); }; @@ -225,9 +225,13 @@ namespace Nz using Variant = std::variant< bool, float, + Int32, Vector2f, Vector3f, - Vector4f + Vector4f, + Vector2i32, + Vector3i32, + Vector4i32 >; Variant value; diff --git a/include/Nazara/Renderer/ShaderNodes.inl b/include/Nazara/Renderer/ShaderNodes.inl index 0e9307ff4..e6dbec2ba 100644 --- a/include/Nazara/Renderer/ShaderNodes.inl +++ b/include/Nazara/Renderer/ShaderNodes.inl @@ -28,12 +28,15 @@ namespace Nz::ShaderNodes switch (type) { case BasicType::Float2: + case BasicType::Int2: return 2; case BasicType::Float3: + case BasicType::Int3: return 3; case BasicType::Float4: + case BasicType::Int4: return 4; case BasicType::Mat4x4: @@ -53,6 +56,11 @@ namespace Nz::ShaderNodes case BasicType::Float4: return BasicType::Float1; + case BasicType::Int2: + case BasicType::Int3: + case BasicType::Int4: + return BasicType::Int1; + case BasicType::Mat4x4: return BasicType::Float4; diff --git a/include/Nazara/Renderer/ShaderVarVisitor.hpp b/include/Nazara/Renderer/ShaderVarVisitor.hpp index 6df035d74..bda5e0b4d 100644 --- a/include/Nazara/Renderer/ShaderVarVisitor.hpp +++ b/include/Nazara/Renderer/ShaderVarVisitor.hpp @@ -21,13 +21,14 @@ namespace Nz ShaderVarVisitor(ShaderVarVisitor&&) = delete; virtual ~ShaderVarVisitor(); - virtual void Visit(const ShaderNodes::BuiltinVariable& var) = 0; - virtual void Visit(const ShaderNodes::InputVariable& var) = 0; - virtual void Visit(const ShaderNodes::LocalVariable& var) = 0; - virtual void Visit(const ShaderNodes::OutputVariable& var) = 0; - virtual void Visit(const ShaderNodes::ParameterVariable& var) = 0; - virtual void Visit(const ShaderNodes::UniformVariable& var) = 0; void Visit(const ShaderNodes::VariablePtr& node); + + virtual void Visit(ShaderNodes::BuiltinVariable& var) = 0; + virtual void Visit(ShaderNodes::InputVariable& var) = 0; + virtual void Visit(ShaderNodes::LocalVariable& var) = 0; + virtual void Visit(ShaderNodes::OutputVariable& var) = 0; + virtual void Visit(ShaderNodes::ParameterVariable& var) = 0; + virtual void Visit(ShaderNodes::UniformVariable& var) = 0; }; } diff --git a/include/Nazara/Renderer/SpirvWriter.hpp b/include/Nazara/Renderer/SpirvWriter.hpp index 2301d6caf..b9e48dd12 100644 --- a/include/Nazara/Renderer/SpirvWriter.hpp +++ b/include/Nazara/Renderer/SpirvWriter.hpp @@ -10,8 +10,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -20,7 +20,7 @@ namespace Nz { - class NAZARA_RENDERER_API SpirvWriter : public ShaderAstVisitor + class NAZARA_RENDERER_API SpirvWriter : public ShaderAstVisitor, public ShaderVarVisitor { public: struct Environment; @@ -41,6 +41,7 @@ namespace Nz }; private: + struct ExtVar; struct Opcode; struct Raw; struct WordCount; @@ -76,28 +77,39 @@ namespace Nz void AppendStructType(std::size_t structIndex, UInt32 resultId); void AppendTypes(); + UInt32 EvaluateExpression(const ShaderNodes::ExpressionPtr& expr); + UInt32 GetConstantId(const ShaderNodes::Constant::Variant& value) const; UInt32 GetTypeId(const ShaderExpressionType& type) const; void PushResultId(UInt32 value); UInt32 PopResultId(); + UInt32 ReadVariable(ExtVar& var); UInt32 RegisterType(ShaderExpressionType type); using ShaderAstVisitor::Visit; - void Visit(const ShaderNodes::AccessMember& node) override; - void Visit(const ShaderNodes::AssignOp& node) override; - void Visit(const ShaderNodes::Branch& node) override; - void Visit(const ShaderNodes::BinaryOp& node) override; - void Visit(const ShaderNodes::Cast& node) override; - void Visit(const ShaderNodes::Constant& node) override; - void Visit(const ShaderNodes::DeclareVariable& node) override; - void Visit(const ShaderNodes::ExpressionStatement& node) override; - void Visit(const ShaderNodes::Identifier& node) override; - void Visit(const ShaderNodes::IntrinsicCall& node) override; - void Visit(const ShaderNodes::Sample2D& node) override; - void Visit(const ShaderNodes::StatementBlock& node) override; - void Visit(const ShaderNodes::SwizzleOp& node) override; + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::AssignOp& node) override; + void Visit(ShaderNodes::Branch& node) override; + void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::Constant& node) override; + void Visit(ShaderNodes::DeclareVariable& node) override; + void Visit(ShaderNodes::ExpressionStatement& node) override; + void Visit(ShaderNodes::Identifier& node) override; + void Visit(ShaderNodes::IntrinsicCall& node) override; + void Visit(ShaderNodes::Sample2D& node) override; + void Visit(ShaderNodes::StatementBlock& node) override; + void Visit(ShaderNodes::SwizzleOp& node) override; + + using ShaderVarVisitor::Visit; + void Visit(ShaderNodes::BuiltinVariable& var) override; + void Visit(ShaderNodes::InputVariable& var) override; + void Visit(ShaderNodes::LocalVariable& var) override; + void Visit(ShaderNodes::OutputVariable& var) override; + void Visit(ShaderNodes::ParameterVariable& var) override; + void Visit(ShaderNodes::UniformVariable& var) override; static void MergeBlocks(std::vector& output, const Section& from); diff --git a/src/Nazara/Renderer/GlslWriter.cpp b/src/Nazara/Renderer/GlslWriter.cpp index 8fa23755a..3d651a84f 100644 --- a/src/Nazara/Renderer/GlslWriter.cpp +++ b/src/Nazara/Renderer/GlslWriter.cpp @@ -188,30 +188,18 @@ namespace Nz { switch (type) { - case ShaderNodes::BasicType::Boolean: - Append("bool"); - break; - case ShaderNodes::BasicType::Float1: - Append("float"); - break; - case ShaderNodes::BasicType::Float2: - Append("vec2"); - break; - case ShaderNodes::BasicType::Float3: - Append("vec3"); - break; - case ShaderNodes::BasicType::Float4: - Append("vec4"); - break; - case ShaderNodes::BasicType::Mat4x4: - Append("mat4"); - break; - case ShaderNodes::BasicType::Sampler2D: - Append("sampler2D"); - break; - case ShaderNodes::BasicType::Void: - Append("void"); - break; + case ShaderNodes::BasicType::Boolean: return Append("bool"); + case ShaderNodes::BasicType::Float1: return Append("float"); + case ShaderNodes::BasicType::Float2: return Append("vec2"); + case ShaderNodes::BasicType::Float3: return Append("vec3"); + case ShaderNodes::BasicType::Float4: return Append("vec4"); + case ShaderNodes::BasicType::Int1: return Append("int"); + case ShaderNodes::BasicType::Int2: return Append("ivec2"); + case ShaderNodes::BasicType::Int3: return Append("ivec3"); + case ShaderNodes::BasicType::Int4: return Append("ivec4"); + case ShaderNodes::BasicType::Mat4x4: return Append("mat4"); + case ShaderNodes::BasicType::Sampler2D: return Append("sampler2D"); + case ShaderNodes::BasicType::Void: return Append("void"); } } @@ -298,7 +286,7 @@ namespace Nz AppendLine("}"); } - void GlslWriter::Visit(const ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired) + void GlslWriter::Visit(ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired) { bool enclose = encloseIfRequired && (expr->GetExpressionCategory() != ShaderNodes::ExpressionCategory::LValue); @@ -311,7 +299,7 @@ namespace Nz Append(")"); } - void GlslWriter::Visit(const ShaderNodes::AccessMember& node) + void GlslWriter::Visit(ShaderNodes::AccessMember& node) { Visit(node.structExpr, true); @@ -332,7 +320,7 @@ namespace Nz Append(member.name); } - void GlslWriter::Visit(const ShaderNodes::AssignOp& node) + void GlslWriter::Visit(ShaderNodes::AssignOp& node) { Visit(node.left); @@ -346,7 +334,7 @@ namespace Nz Visit(node.right); } - void GlslWriter::Visit(const ShaderNodes::Branch& node) + void GlslWriter::Visit(ShaderNodes::Branch& node) { bool first = true; for (const auto& statement : node.condStatements) @@ -375,7 +363,7 @@ namespace Nz } } - void GlslWriter::Visit(const ShaderNodes::BinaryOp& node) + void GlslWriter::Visit(ShaderNodes::BinaryOp& node) { Visit(node.left, true); @@ -401,12 +389,12 @@ namespace Nz Visit(node.right, true); } - void GlslWriter::Visit(const ShaderNodes::BuiltinVariable& var) + void GlslWriter::Visit(ShaderNodes::BuiltinVariable& var) { Append(var.entry); } - void GlslWriter::Visit(const ShaderNodes::Cast& node) + void GlslWriter::Visit(ShaderNodes::Cast& node) { Append(node.exprType); Append("("); @@ -425,28 +413,31 @@ namespace Nz Append(")"); } - void GlslWriter::Visit(const ShaderNodes::Constant& node) + void GlslWriter::Visit(ShaderNodes::Constant& node) { std::visit([&](auto&& arg) { using T = std::decay_t; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) + Append("i"); //< for ivec + if constexpr (std::is_same_v) Append((arg) ? "true" : "false"); - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) Append(std::to_string(arg)); - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) Append("vec2(" + std::to_string(arg.x) + ", " + std::to_string(arg.y) + ")"); - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) Append("vec3(" + std::to_string(arg.x) + ", " + std::to_string(arg.y) + ", " + std::to_string(arg.z) + ")"); - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) Append("vec4(" + std::to_string(arg.x) + ", " + std::to_string(arg.y) + ", " + std::to_string(arg.z) + ", " + std::to_string(arg.w) + ")"); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, node.value); } - void GlslWriter::Visit(const ShaderNodes::DeclareVariable& node) + void GlslWriter::Visit(ShaderNodes::DeclareVariable& node) { assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable); @@ -464,23 +455,23 @@ namespace Nz AppendLine(";"); } - void GlslWriter::Visit(const ShaderNodes::ExpressionStatement& node) + void GlslWriter::Visit(ShaderNodes::ExpressionStatement& node) { Visit(node.expression); Append(";"); } - void GlslWriter::Visit(const ShaderNodes::Identifier& node) + void GlslWriter::Visit(ShaderNodes::Identifier& node) { Visit(node.var); } - void GlslWriter::Visit(const ShaderNodes::InputVariable& var) + void GlslWriter::Visit(ShaderNodes::InputVariable& var) { Append(var.name); } - void GlslWriter::Visit(const ShaderNodes::IntrinsicCall& node) + void GlslWriter::Visit(ShaderNodes::IntrinsicCall& node) { switch (node.intrinsic) { @@ -504,22 +495,22 @@ namespace Nz Append(")"); } - void GlslWriter::Visit(const ShaderNodes::LocalVariable& var) + void GlslWriter::Visit(ShaderNodes::LocalVariable& var) { Append(var.name); } - void GlslWriter::Visit(const ShaderNodes::ParameterVariable& var) + void GlslWriter::Visit(ShaderNodes::ParameterVariable& var) { Append(var.name); } - void GlslWriter::Visit(const ShaderNodes::OutputVariable& var) + void GlslWriter::Visit(ShaderNodes::OutputVariable& var) { Append(var.name); } - void GlslWriter::Visit(const ShaderNodes::Sample2D& node) + void GlslWriter::Visit(ShaderNodes::Sample2D& node) { Append("texture("); Visit(node.sampler); @@ -528,7 +519,7 @@ namespace Nz Append(")"); } - void GlslWriter::Visit(const ShaderNodes::StatementBlock& node) + void GlslWriter::Visit(ShaderNodes::StatementBlock& node) { bool first = true; for (const ShaderNodes::StatementPtr& statement : node.statements) @@ -542,7 +533,7 @@ namespace Nz } } - void GlslWriter::Visit(const ShaderNodes::SwizzleOp& node) + void GlslWriter::Visit(ShaderNodes::SwizzleOp& node) { Visit(node.expression); Append("."); @@ -570,7 +561,7 @@ namespace Nz } } - void GlslWriter::Visit(const ShaderNodes::UniformVariable& var) + void GlslWriter::Visit(ShaderNodes::UniformVariable& var) { Append(var.name); } diff --git a/src/Nazara/Renderer/Renderer.cpp b/src/Nazara/Renderer/Renderer.cpp index 9df137326..6954f5cf0 100644 --- a/src/Nazara/Renderer/Renderer.cpp +++ b/src/Nazara/Renderer/Renderer.cpp @@ -69,7 +69,7 @@ namespace Nz }; RegisterImpl("NazaraOpenGLRenderer" NazaraRendererDebugSuffix, [] { return 50; }); - //RegisterImpl("NazaraVulkanRenderer" NazaraRendererDebugSuffix, [] { return 100; }); + RegisterImpl("NazaraVulkanRenderer" NazaraRendererDebugSuffix, [] { return 100; }); std::sort(implementations.begin(), implementations.end(), [](const auto& lhs, const auto& rhs) { return lhs.score > rhs.score; }); diff --git a/src/Nazara/Renderer/ShaderAstCloner.cpp b/src/Nazara/Renderer/ShaderAstCloner.cpp index 75ac0634c..2b712075c 100644 --- a/src/Nazara/Renderer/ShaderAstCloner.cpp +++ b/src/Nazara/Renderer/ShaderAstCloner.cpp @@ -45,22 +45,22 @@ namespace Nz return PopVariable(); } - void ShaderAstCloner::Visit(const ShaderNodes::AccessMember& node) + void ShaderAstCloner::Visit(ShaderNodes::AccessMember& node) { PushExpression(ShaderNodes::AccessMember::Build(CloneExpression(node.structExpr), node.memberIndex, node.exprType)); } - void ShaderAstCloner::Visit(const ShaderNodes::AssignOp& node) + void ShaderAstCloner::Visit(ShaderNodes::AssignOp& node) { PushExpression(ShaderNodes::AssignOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right))); } - void ShaderAstCloner::Visit(const ShaderNodes::BinaryOp& node) + void ShaderAstCloner::Visit(ShaderNodes::BinaryOp& node) { PushExpression(ShaderNodes::BinaryOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right))); } - void ShaderAstCloner::Visit(const ShaderNodes::Branch& node) + void ShaderAstCloner::Visit(ShaderNodes::Branch& node) { std::vector condStatements; condStatements.reserve(node.condStatements.size()); @@ -75,7 +75,7 @@ namespace Nz PushStatement(ShaderNodes::Branch::Build(std::move(condStatements), CloneStatement(node.elseStatement))); } - void ShaderAstCloner::Visit(const ShaderNodes::Cast& node) + void ShaderAstCloner::Visit(ShaderNodes::Cast& node) { std::size_t expressionCount = 0; std::array expressions; @@ -91,27 +91,27 @@ namespace Nz PushExpression(ShaderNodes::Cast::Build(node.exprType, expressions.data(), expressionCount)); } - void ShaderAstCloner::Visit(const ShaderNodes::Constant& node) + void ShaderAstCloner::Visit(ShaderNodes::Constant& node) { PushExpression(ShaderNodes::Constant::Build(node.value)); } - void ShaderAstCloner::Visit(const ShaderNodes::DeclareVariable& node) + void ShaderAstCloner::Visit(ShaderNodes::DeclareVariable& node) { PushStatement(ShaderNodes::DeclareVariable::Build(CloneVariable(node.variable), CloneExpression(node.expression))); } - void ShaderAstCloner::Visit(const ShaderNodes::ExpressionStatement& node) + void ShaderAstCloner::Visit(ShaderNodes::ExpressionStatement& node) { PushStatement(ShaderNodes::ExpressionStatement::Build(CloneExpression(node.expression))); } - void ShaderAstCloner::Visit(const ShaderNodes::Identifier& node) + void ShaderAstCloner::Visit(ShaderNodes::Identifier& node) { PushExpression(ShaderNodes::Identifier::Build(CloneVariable(node.var))); } - void ShaderAstCloner::Visit(const ShaderNodes::IntrinsicCall& node) + void ShaderAstCloner::Visit(ShaderNodes::IntrinsicCall& node) { std::vector parameters; parameters.reserve(node.parameters.size()); @@ -122,12 +122,12 @@ namespace Nz PushExpression(ShaderNodes::IntrinsicCall::Build(node.intrinsic, std::move(parameters))); } - void ShaderAstCloner::Visit(const ShaderNodes::Sample2D& node) + void ShaderAstCloner::Visit(ShaderNodes::Sample2D& node) { PushExpression(ShaderNodes::Sample2D::Build(CloneExpression(node.sampler), CloneExpression(node.coordinates))); } - void ShaderAstCloner::Visit(const ShaderNodes::StatementBlock& node) + void ShaderAstCloner::Visit(ShaderNodes::StatementBlock& node) { std::vector statements; statements.reserve(node.statements.size()); @@ -138,37 +138,37 @@ namespace Nz PushStatement(ShaderNodes::StatementBlock::Build(std::move(statements))); } - void ShaderAstCloner::Visit(const ShaderNodes::SwizzleOp& node) + void ShaderAstCloner::Visit(ShaderNodes::SwizzleOp& node) { PushExpression(ShaderNodes::SwizzleOp::Build(PopExpression(), node.components.data(), node.componentCount)); } - void ShaderAstCloner::Visit(const ShaderNodes::BuiltinVariable& var) + void ShaderAstCloner::Visit(ShaderNodes::BuiltinVariable& var) { PushVariable(ShaderNodes::BuiltinVariable::Build(var.entry, var.type)); } - void ShaderAstCloner::Visit(const ShaderNodes::InputVariable& var) + void ShaderAstCloner::Visit(ShaderNodes::InputVariable& var) { PushVariable(ShaderNodes::InputVariable::Build(var.name, var.type)); } - void ShaderAstCloner::Visit(const ShaderNodes::LocalVariable& var) + void ShaderAstCloner::Visit(ShaderNodes::LocalVariable& var) { PushVariable(ShaderNodes::LocalVariable::Build(var.name, var.type)); } - void ShaderAstCloner::Visit(const ShaderNodes::OutputVariable& var) + void ShaderAstCloner::Visit(ShaderNodes::OutputVariable& var) { PushVariable(ShaderNodes::OutputVariable::Build(var.name, var.type)); } - void ShaderAstCloner::Visit(const ShaderNodes::ParameterVariable& var) + void ShaderAstCloner::Visit(ShaderNodes::ParameterVariable& var) { PushVariable(ShaderNodes::ParameterVariable::Build(var.name, var.type)); } - void ShaderAstCloner::Visit(const ShaderNodes::UniformVariable& var) + void ShaderAstCloner::Visit(ShaderNodes::UniformVariable& var) { PushVariable(ShaderNodes::UniformVariable::Build(var.name, var.type)); } diff --git a/src/Nazara/Renderer/ShaderAstRecursiveVisitor.cpp b/src/Nazara/Renderer/ShaderAstRecursiveVisitor.cpp index 81344bcf0..5a39e68c5 100644 --- a/src/Nazara/Renderer/ShaderAstRecursiveVisitor.cpp +++ b/src/Nazara/Renderer/ShaderAstRecursiveVisitor.cpp @@ -7,24 +7,24 @@ namespace Nz { - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::AccessMember& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AccessMember& node) { Visit(node.structExpr); } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::AssignOp& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AssignOp& node) { Visit(node.left); Visit(node.right); } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::BinaryOp& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::BinaryOp& node) { Visit(node.left); Visit(node.right); } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::Branch& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Branch& node) { for (auto& cond : node.condStatements) { @@ -36,7 +36,7 @@ namespace Nz Visit(node.elseStatement); } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::Cast& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Cast& node) { for (auto& expr : node.expressions) { @@ -47,46 +47,46 @@ namespace Nz } } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::Constant& /*node*/) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Constant& /*node*/) { /* Nothing to do */ } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::DeclareVariable& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::DeclareVariable& node) { if (node.expression) Visit(node.expression); } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::ExpressionStatement& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ExpressionStatement& node) { Visit(node.expression); } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::Identifier& /*node*/) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Identifier& /*node*/) { /* Nothing to do */ } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::IntrinsicCall& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::IntrinsicCall& node) { for (auto& param : node.parameters) Visit(param); } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::Sample2D& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Sample2D& node) { Visit(node.sampler); Visit(node.coordinates); } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::StatementBlock& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::StatementBlock& node) { for (auto& statement : node.statements) Visit(statement); } - void ShaderAstRecursiveVisitor::Visit(const ShaderNodes::SwizzleOp& node) + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::SwizzleOp& node) { Visit(node.expression); } diff --git a/src/Nazara/Renderer/ShaderAstSerializer.cpp b/src/Nazara/Renderer/ShaderAstSerializer.cpp index b4de31cc7..f8afa178c 100644 --- a/src/Nazara/Renderer/ShaderAstSerializer.cpp +++ b/src/Nazara/Renderer/ShaderAstSerializer.cpp @@ -22,98 +22,98 @@ namespace Nz { } - void Visit(const ShaderNodes::AccessMember& node) override + void Visit(ShaderNodes::AccessMember& node) override { Serialize(node); } - void Visit(const ShaderNodes::AssignOp& node) override + void Visit(ShaderNodes::AssignOp& node) override { Serialize(node); } - void Visit(const ShaderNodes::BinaryOp& node) override + void Visit(ShaderNodes::BinaryOp& node) override { Serialize(node); } - void Visit(const ShaderNodes::Branch& node) override + void Visit(ShaderNodes::Branch& node) override { Serialize(node); } - void Visit(const ShaderNodes::Cast& node) override + void Visit(ShaderNodes::Cast& node) override { Serialize(node); } - void Visit(const ShaderNodes::Constant& node) override + void Visit(ShaderNodes::Constant& node) override { Serialize(node); } - void Visit(const ShaderNodes::DeclareVariable& node) override + void Visit(ShaderNodes::DeclareVariable& node) override { Serialize(node); } - void Visit(const ShaderNodes::ExpressionStatement& node) override + void Visit(ShaderNodes::ExpressionStatement& node) override { Serialize(node); } - void Visit(const ShaderNodes::Identifier& node) override + void Visit(ShaderNodes::Identifier& node) override { Serialize(node); } - void Visit(const ShaderNodes::IntrinsicCall& node) override + void Visit(ShaderNodes::IntrinsicCall& node) override { Serialize(node); } - void Visit(const ShaderNodes::Sample2D& node) override + void Visit(ShaderNodes::Sample2D& node) override { Serialize(node); } - void Visit(const ShaderNodes::StatementBlock& node) override + void Visit(ShaderNodes::StatementBlock& node) override { Serialize(node); } - void Visit(const ShaderNodes::SwizzleOp& node) override + void Visit(ShaderNodes::SwizzleOp& node) override { Serialize(node); } - void Visit(const ShaderNodes::BuiltinVariable& var) override + void Visit(ShaderNodes::BuiltinVariable& var) override { Serialize(var); } - void Visit(const ShaderNodes::InputVariable& var) override + void Visit(ShaderNodes::InputVariable& var) override { Serialize(var); } - void Visit(const ShaderNodes::LocalVariable& var) override + void Visit(ShaderNodes::LocalVariable& var) override { Serialize(var); } - void Visit(const ShaderNodes::OutputVariable& var) override + void Visit(ShaderNodes::OutputVariable& var) override { Serialize(var); } - void Visit(const ShaderNodes::ParameterVariable& var) override + void Visit(ShaderNodes::ParameterVariable& var) override { Serialize(var); } - void Visit(const ShaderNodes::UniformVariable& var) override + void Visit(ShaderNodes::UniformVariable& var) override { Serialize(var); } @@ -193,14 +193,18 @@ namespace Nz Value(value); }; - static_assert(std::variant_size_v == 5); + static_assert(std::variant_size_v == 9); switch (typeIndex) { case 0: SerializeValue(bool()); break; case 1: SerializeValue(float()); break; - case 2: SerializeValue(Vector2f()); break; - case 3: SerializeValue(Vector3f()); break; - case 4: SerializeValue(Vector4f()); break; + case 2: SerializeValue(Int32()); break; + case 3: SerializeValue(Vector2f()); break; + case 4: SerializeValue(Vector3f()); break; + case 5: SerializeValue(Vector4f()); break; + case 6: SerializeValue(Vector2i32()); break; + case 7: SerializeValue(Vector3i32()); break; + case 8: SerializeValue(Vector4i32()); break; default: throw std::runtime_error("unexpected data type"); } } @@ -403,6 +407,11 @@ namespace Nz m_stream << val; } + void ShaderAstSerializer::Value(Int32& val) + { + m_stream << val; + } + void ShaderAstSerializer::Value(Vector2f& val) { m_stream << val; @@ -418,6 +427,21 @@ namespace Nz m_stream << val; } + void ShaderAstSerializer::Value(Vector2i32& val) + { + m_stream << val; + } + + void ShaderAstSerializer::Value(Vector3i32& val) + { + m_stream << val; + } + + void ShaderAstSerializer::Value(Vector4i32& val) + { + m_stream << val; + } + void ShaderAstSerializer::Value(UInt8& val) { m_stream << val; @@ -644,6 +668,11 @@ namespace Nz m_stream >> val; } + void ShaderAstUnserializer::Value(Int32& val) + { + m_stream >> val; + } + void ShaderAstUnserializer::Value(Vector2f& val) { m_stream >> val; @@ -659,6 +688,21 @@ namespace Nz m_stream >> val; } + void ShaderAstUnserializer::Value(Vector2i32& val) + { + m_stream >> val; + } + + void ShaderAstUnserializer::Value(Vector3i32& val) + { + m_stream >> val; + } + + void ShaderAstUnserializer::Value(Vector4i32& val) + { + m_stream >> val; + } + void ShaderAstUnserializer::Value(UInt8& val) { m_stream >> val; @@ -689,6 +733,7 @@ namespace Nz HandleType(BuiltinVariable); HandleType(InputVariable); HandleType(LocalVariable); + HandleType(ParameterVariable); HandleType(OutputVariable); HandleType(UniformVariable); } diff --git a/src/Nazara/Renderer/ShaderAstValidator.cpp b/src/Nazara/Renderer/ShaderAstValidator.cpp index 8e161d532..81970e766 100644 --- a/src/Nazara/Renderer/ShaderAstValidator.cpp +++ b/src/Nazara/Renderer/ShaderAstValidator.cpp @@ -83,7 +83,7 @@ namespace Nz throw AstError{ "Left expression type must match right expression type" }; } - void ShaderAstValidator::Visit(const ShaderNodes::AccessMember& node) + void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node) { const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType(); if (!std::holds_alternative(exprType)) @@ -105,7 +105,7 @@ namespace Nz throw AstError{ "member type does not match node type" }; } - void ShaderAstValidator::Visit(const ShaderNodes::AssignOp& node) + void ShaderAstValidator::Visit(ShaderNodes::AssignOp& node) { MandatoryNode(node.left); MandatoryNode(node.right); @@ -117,7 +117,7 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::BinaryOp& node) + void ShaderAstValidator::Visit(ShaderNodes::BinaryOp& node) { MandatoryNode(node.left); MandatoryNode(node.right); @@ -147,8 +147,9 @@ namespace Nz switch (leftType) { case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Int1: { - if (ShaderNodes::Node::GetComponentType(rightType) != ShaderNodes::BasicType::Float1) + if (ShaderNodes::Node::GetComponentType(rightType) != leftType) throw AstError{ "Left expression type is not compatible with right expression type" }; break; @@ -157,8 +158,11 @@ namespace Nz case ShaderNodes::BasicType::Float2: case ShaderNodes::BasicType::Float3: case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: { - if (leftType != rightType && rightType != ShaderNodes::BasicType::Float1) + if (leftType != rightType && rightType != ShaderNodes::Node::GetComponentType(leftType)) throw AstError{ "Left expression type is not compatible with right expression type" }; break; @@ -189,7 +193,7 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::Branch& node) + void ShaderAstValidator::Visit(ShaderNodes::Branch& node) { for (const auto& condStatement : node.condStatements) { @@ -200,7 +204,7 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::Cast& node) + void ShaderAstValidator::Visit(ShaderNodes::Cast& node) { unsigned int componentCount = 0; unsigned int requiredComponents = node.GetComponentCount(node.exprType); @@ -222,11 +226,11 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::Constant& /*node*/) + void ShaderAstValidator::Visit(ShaderNodes::Constant& /*node*/) { } - void ShaderAstValidator::Visit(const ShaderNodes::DeclareVariable& node) + void ShaderAstValidator::Visit(ShaderNodes::DeclareVariable& node) { assert(m_context); @@ -242,14 +246,14 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::ExpressionStatement& node) + void ShaderAstValidator::Visit(ShaderNodes::ExpressionStatement& node) { MandatoryNode(node.expression); ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::Identifier& node) + void ShaderAstValidator::Visit(ShaderNodes::Identifier& node) { assert(m_context); @@ -259,7 +263,7 @@ namespace Nz Visit(node.var); } - void ShaderAstValidator::Visit(const ShaderNodes::IntrinsicCall& node) + void ShaderAstValidator::Visit(ShaderNodes::IntrinsicCall& node) { switch (node.intrinsic) { @@ -300,7 +304,7 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::Sample2D& node) + void ShaderAstValidator::Visit(ShaderNodes::Sample2D& node) { if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D }) throw AstError{ "Sampler must be a Sampler2D" }; @@ -311,7 +315,7 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::StatementBlock& node) + void ShaderAstValidator::Visit(ShaderNodes::StatementBlock& node) { assert(m_context); @@ -327,7 +331,7 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::SwizzleOp& node) + void ShaderAstValidator::Visit(ShaderNodes::SwizzleOp& node) { if (node.componentCount > 4) throw AstError{ "Cannot swizzle more than four elements" }; @@ -342,6 +346,10 @@ namespace Nz case ShaderNodes::BasicType::Float2: case ShaderNodes::BasicType::Float3: case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: break; default: @@ -351,12 +359,23 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(const ShaderNodes::BuiltinVariable& /*var*/) + void ShaderAstValidator::Visit(ShaderNodes::BuiltinVariable& var) { - /* Nothing to do */ + switch (var.entry) + { + case ShaderNodes::BuiltinEntry::VertexPosition: + if (!std::holds_alternative(var.type) || + std::get(var.type) != ShaderNodes::BasicType::Float4) + throw AstError{ "Builtin is not of the expected type" }; + + break; + + default: + break; + } } - void ShaderAstValidator::Visit(const ShaderNodes::InputVariable& var) + void ShaderAstValidator::Visit(ShaderNodes::InputVariable& var) { for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i) { @@ -371,7 +390,7 @@ namespace Nz throw AstError{ "Input not found" }; } - void ShaderAstValidator::Visit(const ShaderNodes::LocalVariable& var) + void ShaderAstValidator::Visit(ShaderNodes::LocalVariable& var) { const auto& vars = m_context->declaredLocals; @@ -382,7 +401,7 @@ namespace Nz TypeMustMatch(it->type, var.type); } - void ShaderAstValidator::Visit(const ShaderNodes::OutputVariable& var) + void ShaderAstValidator::Visit(ShaderNodes::OutputVariable& var) { for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i) { @@ -397,7 +416,7 @@ namespace Nz throw AstError{ "Output not found" }; } - void ShaderAstValidator::Visit(const ShaderNodes::ParameterVariable& var) + void ShaderAstValidator::Visit(ShaderNodes::ParameterVariable& var) { assert(m_context->currentFunction); @@ -410,7 +429,7 @@ namespace Nz TypeMustMatch(it->type, var.type); } - void ShaderAstValidator::Visit(const ShaderNodes::UniformVariable& var) + void ShaderAstValidator::Visit(ShaderNodes::UniformVariable& var) { for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i) { diff --git a/src/Nazara/Renderer/ShaderNodes.cpp b/src/Nazara/Renderer/ShaderNodes.cpp index 7ea148cb6..29ffb8508 100644 --- a/src/Nazara/Renderer/ShaderNodes.cpp +++ b/src/Nazara/Renderer/ShaderNodes.cpp @@ -59,7 +59,7 @@ namespace Nz::ShaderNodes visitor.Visit(*this); } - ExpressionCategory ShaderNodes::AccessMember::GetExpressionCategory() const + ExpressionCategory AccessMember::GetExpressionCategory() const { return ExpressionCategory::LValue; } @@ -111,10 +111,14 @@ namespace Nz::ShaderNodes case BasicType::Float2: case BasicType::Float3: case BasicType::Float4: + case BasicType::Int2: + case BasicType::Int3: + case BasicType::Int4: exprType = leftExprType; break; case BasicType::Float1: + case BasicType::Int1: case BasicType::Mat4x4: exprType = rightExprType; break; @@ -159,12 +163,20 @@ namespace Nz::ShaderNodes return ShaderNodes::BasicType::Boolean; else if constexpr (std::is_same_v) return ShaderNodes::BasicType::Float1; + else if constexpr (std::is_same_v) + return ShaderNodes::BasicType::Int1; else if constexpr (std::is_same_v) return ShaderNodes::BasicType::Float2; else if constexpr (std::is_same_v) return ShaderNodes::BasicType::Float3; else if constexpr (std::is_same_v) return ShaderNodes::BasicType::Float4; + else if constexpr (std::is_same_v) + return ShaderNodes::BasicType::Int2; + else if constexpr (std::is_same_v) + return ShaderNodes::BasicType::Int3; + else if constexpr (std::is_same_v) + return ShaderNodes::BasicType::Int4; else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, value); diff --git a/src/Nazara/Renderer/SpirvWriter.cpp b/src/Nazara/Renderer/SpirvWriter.cpp index a94cd6ec9..1e9218e34 100644 --- a/src/Nazara/Renderer/SpirvWriter.cpp +++ b/src/Nazara/Renderer/SpirvWriter.cpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include #include #include @@ -34,28 +36,35 @@ namespace Nz using ShaderAstRecursiveVisitor::Visit; using ShaderVarVisitor::Visit; - void Visit(const ShaderNodes::Constant& node) override + void Visit(ShaderNodes::AccessMember& node) override + { + constants.emplace(Int32(node.memberIndex)); + + ShaderAstRecursiveVisitor::Visit(node); + } + + void Visit(ShaderNodes::Constant& node) override { std::visit([&](auto&& arg) { using T = std::decay_t; - if constexpr (std::is_same_v || std::is_same_v) + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) constants.emplace(arg); - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) { constants.emplace(arg.x); constants.emplace(arg.y); constants.emplace(arg); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) { constants.emplace(arg.x); constants.emplace(arg.y); constants.emplace(arg.z); constants.emplace(arg); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) { constants.emplace(arg.x); constants.emplace(arg.y); @@ -71,21 +80,21 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } - void Visit(const ShaderNodes::DeclareVariable& node) override + void Visit(ShaderNodes::DeclareVariable& node) override { Visit(node.variable); ShaderAstRecursiveVisitor::Visit(node); } - void Visit(const ShaderNodes::Identifier& node) override + void Visit(ShaderNodes::Identifier& node) override { Visit(node.var); ShaderAstRecursiveVisitor::Visit(node); } - void Visit(const ShaderNodes::IntrinsicCall& node) override + void Visit(ShaderNodes::IntrinsicCall& node) override { ShaderAstRecursiveVisitor::Visit(node); @@ -102,32 +111,32 @@ namespace Nz } } - void Visit(const ShaderNodes::BuiltinVariable& var) override + void Visit(ShaderNodes::BuiltinVariable& var) override { builtinVars.insert(std::static_pointer_cast(var.shared_from_this())); } - void Visit(const ShaderNodes::InputVariable& var) override + void Visit(ShaderNodes::InputVariable& var) override { /* Handled by ShaderAst */ } - void Visit(const ShaderNodes::LocalVariable& var) override + void Visit(ShaderNodes::LocalVariable& var) override { localVars.insert(std::static_pointer_cast(var.shared_from_this())); } - void Visit(const ShaderNodes::OutputVariable& var) override + void Visit(ShaderNodes::OutputVariable& var) override { /* Handled by ShaderAst */ } - void Visit(const ShaderNodes::ParameterVariable& var) override + void Visit(ShaderNodes::ParameterVariable& var) override { paramVars.insert(std::static_pointer_cast(var.shared_from_this())); } - void Visit(const ShaderNodes::UniformVariable& var) override + void Visit(ShaderNodes::UniformVariable& var) override { /* Handled by ShaderAst */ } @@ -138,8 +147,57 @@ namespace Nz LocalContainer localVars; ParameterContainer paramVars; }; + + class AssignVisitor : public ShaderAstRecursiveVisitor + { + public: + void Visit(ShaderNodes::AccessMember& node) override + { + } + + void Visit(ShaderNodes::Identifier& node) override + { + } + + void Visit(ShaderNodes::SwizzleOp& node) override + { + } + }; + + template + constexpr ShaderNodes::BasicType GetBasicType() + { + if constexpr (std::is_same_v) + return ShaderNodes::BasicType::Boolean; + else if constexpr (std::is_same_v) + return(ShaderNodes::BasicType::Float1); + else if constexpr (std::is_same_v) + return(ShaderNodes::BasicType::Int1); + else if constexpr (std::is_same_v) + return(ShaderNodes::BasicType::Float2); + else if constexpr (std::is_same_v) + return(ShaderNodes::BasicType::Float3); + else if constexpr (std::is_same_v) + return(ShaderNodes::BasicType::Float4); + else if constexpr (std::is_same_v) + return(ShaderNodes::BasicType::Int2); + else if constexpr (std::is_same_v) + return(ShaderNodes::BasicType::Int3); + else if constexpr (std::is_same_v) + return(ShaderNodes::BasicType::Int4); + else + static_assert(AlwaysFalse::value, "unhandled type"); + } } + struct SpirvWriter::ExtVar + { + UInt32 pointerTypeId; + UInt32 typeId; + UInt32 varId; + std::optional valueId; + }; + struct SpirvWriter::Opcode { SpvOp op; @@ -165,20 +223,15 @@ namespace Nz std::vector paramsId; }; - struct ExtVar - { - UInt32 pointerTypeId; - UInt32 varId; - }; - std::unordered_map extensionInstructions; - std::unordered_map builtinIds; + std::unordered_map builtinIds; + std::unordered_map varToResult; tsl::ordered_map constantIds; tsl::ordered_map typeIds; std::vector funcs; - std::vector inputIds; - std::vector outputIds; - std::vector uniformIds; + tsl::ordered_map inputIds; + tsl::ordered_map outputIds; + tsl::ordered_map uniformIds; std::vector> structFields; std::vector resultIds; UInt32 nextVarIndex = 1; @@ -213,14 +266,17 @@ namespace Nz }); state.structFields.resize(shader.GetStructCount()); - state.annotations.Append(Opcode{ SpvOpNop }); - state.constants.Append(Opcode{ SpvOpNop }); - state.debugInfo.Append(Opcode{ SpvOpNop }); - state.types.Append(Opcode{ SpvOpNop }); + + std::vector functionStatements; + + ShaderAstCloner cloner; PreVisitor preVisitor; for (const auto& func : shader.GetFunctions()) + { + functionStatements.emplace_back(cloner.Clone(func.statement)); preVisitor.Visit(func.statement); + } // Register all extended instruction sets for (const std::string& extInst : preVisitor.extInsts) @@ -246,39 +302,67 @@ namespace Nz for (const auto& local : preVisitor.localVars) RegisterType(local->type); + for (const auto& builtin : preVisitor.builtinVars) + RegisterType(builtin->type); + // Register constant types for (const auto& constant : preVisitor.constants) { std::visit([&](auto&& arg) { using T = std::decay_t; - - if constexpr (std::is_same_v) - RegisterType(ShaderNodes::BasicType::Boolean); - else if constexpr (std::is_same_v) - RegisterType(ShaderNodes::BasicType::Float1); - else if constexpr (std::is_same_v) - RegisterType(ShaderNodes::BasicType::Float2); - else if constexpr (std::is_same_v) - RegisterType(ShaderNodes::BasicType::Float3); - else if constexpr (std::is_same_v) - RegisterType(ShaderNodes::BasicType::Float4); - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + RegisterType(GetBasicType()); }, constant); } AppendTypes(); // Register result id and debug infos for global variables/functions + for (const auto& builtin : preVisitor.builtinVars) + { + const ShaderExpressionType& builtinExprType = builtin->type; + assert(std::holds_alternative(builtinExprType)); + + ShaderNodes::BasicType builtinType = std::get(builtinExprType); + + ExtVar builtinData; + builtinData.pointerTypeId = AllocateResultId(); + builtinData.typeId = GetTypeId(builtinType); + builtinData.varId = AllocateResultId(); + + SpvBuiltIn spvBuiltin; + std::string debugName; + switch (builtin->entry) + { + case ShaderNodes::BuiltinEntry::VertexPosition: + debugName = "builtin_VertexPosition"; + spvBuiltin = SpvBuiltInPosition; + break; + + default: + throw std::runtime_error("unexpected builtin type"); + } + + state.debugInfo.Append(Opcode{ SpvOpName }, builtinData.varId, debugName); + state.types.Append(Opcode{ SpvOpTypePointer }, builtinData.pointerTypeId, SpvStorageClassOutput, builtinData.typeId); + state.types.Append(Opcode{ SpvOpVariable }, builtinData.pointerTypeId, builtinData.varId, SpvStorageClassOutput); + + state.annotations.Append(Opcode{ SpvOpDecorate }, builtinData.varId, SpvDecorationBuiltIn, spvBuiltin); + + state.builtinIds.emplace(builtin->entry, builtinData); + } + for (const auto& input : shader.GetInputs()) { - auto& inputData = state.inputIds.emplace_back(); + ExtVar inputData; inputData.pointerTypeId = AllocateResultId(); + inputData.typeId = GetTypeId(input.type); inputData.varId = AllocateResultId(); + state.inputIds.emplace(input.name, inputData); + state.debugInfo.Append(Opcode{ SpvOpName }, inputData.varId, input.name); - state.types.Append(Opcode{ SpvOpTypePointer }, inputData.pointerTypeId, SpvStorageClassInput, GetTypeId(input.type)); + state.types.Append(Opcode{ SpvOpTypePointer }, inputData.pointerTypeId, SpvStorageClassInput, inputData.typeId); state.types.Append(Opcode{ SpvOpVariable }, inputData.pointerTypeId, inputData.varId, SpvStorageClassInput); if (input.locationIndex) @@ -287,12 +371,15 @@ namespace Nz for (const auto& output : shader.GetOutputs()) { - auto& outputData = state.outputIds.emplace_back(); + ExtVar outputData; outputData.pointerTypeId = AllocateResultId(); + outputData.typeId = GetTypeId(output.type); outputData.varId = AllocateResultId(); + state.outputIds.emplace(output.name, outputData); + state.debugInfo.Append(Opcode{ SpvOpName }, outputData.varId, output.name); - state.types.Append(Opcode{ SpvOpTypePointer }, outputData.pointerTypeId, SpvStorageClassOutput, GetTypeId(output.type)); + state.types.Append(Opcode{ SpvOpTypePointer }, outputData.pointerTypeId, SpvStorageClassOutput, outputData.typeId); state.types.Append(Opcode{ SpvOpVariable }, outputData.pointerTypeId, outputData.varId, SpvStorageClassOutput); if (output.locationIndex) @@ -301,12 +388,15 @@ namespace Nz for (const auto& uniform : shader.GetUniforms()) { - auto& uniformData = state.uniformIds.emplace_back(); + ExtVar uniformData; uniformData.pointerTypeId = AllocateResultId(); + uniformData.typeId = GetTypeId(uniform.type); uniformData.varId = AllocateResultId(); + state.uniformIds.emplace(uniform.name, uniformData); + state.debugInfo.Append(Opcode{ SpvOpName }, uniformData.varId, uniform.name); - state.types.Append(Opcode{ SpvOpTypePointer }, uniformData.pointerTypeId, SpvStorageClassUniform, GetTypeId(uniform.type)); + state.types.Append(Opcode{ SpvOpTypePointer }, uniformData.pointerTypeId, SpvStorageClassUniform, uniformData.typeId); state.types.Append(Opcode{ SpvOpVariable }, uniformData.pointerTypeId, uniformData.varId, SpvStorageClassUniform); if (uniform.bindingIndex) @@ -338,16 +428,20 @@ namespace Nz AppendConstants(); + std::size_t entryPointIndex = std::numeric_limits::max(); + for (std::size_t funcIndex = 0; funcIndex < shader.GetFunctionCount(); ++funcIndex) { const auto& func = shader.GetFunction(funcIndex); + if (func.name == "main") + entryPointIndex = funcIndex; auto& funcData = state.funcs[funcIndex]; - state.instructions.Append(Opcode{ SpvOpNop }); - state.instructions.Append(Opcode{ SpvOpFunction }, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId); + state.instructions.Append(Opcode{ SpvOpLabel }, AllocateResultId()); + for (const auto& param : func.parameters) { UInt32 paramResultId = AllocateResultId(); @@ -356,24 +450,56 @@ namespace Nz state.instructions.Append(Opcode{ SpvOpFunctionParameter }, GetTypeId(param.type), paramResultId); } - Visit(func.statement); + Visit(functionStatements[funcIndex]); + + if (func.returnType == ShaderNodes::BasicType::Void) + state.instructions.Append(Opcode{ SpvOpReturn }); state.instructions.Append(Opcode{ SpvOpFunctionEnd }); } + assert(entryPointIndex != std::numeric_limits::max()); + AppendHeader(); - /*assert(m_context.shader); + SpvExecutionModel execModel; + const auto& entryFuncData = shader.GetFunction(entryPointIndex); + const auto& entryFunc = m_currentState->funcs[entryPointIndex]; + + assert(m_context.shader); switch (m_context.shader->GetStage()) { case ShaderStageType::Fragment: + execModel = SpvExecutionModelFragment; break; + case ShaderStageType::Vertex: + execModel = SpvExecutionModelVertex; break; default: - break; - }*/ + throw std::runtime_error("not yet implemented"); + } + + // OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos + + std::size_t nameSize = state.header.CountWord(entryFuncData.name); + + state.header.Append(Opcode{ SpvOpEntryPoint }, WordCount{ static_cast(3 + nameSize + m_currentState->builtinIds.size() + m_currentState->inputIds.size() + m_currentState->outputIds.size()) }); + state.header.Append(execModel); + state.header.Append(entryFunc.id); + state.header.Append(entryFuncData.name); + for (const auto& [name, varData] : m_currentState->builtinIds) + state.header.Append(varData.varId); + + for (const auto& [name, varData] : m_currentState->inputIds) + state.header.Append(varData.varId); + + for (const auto& [name, varData] : m_currentState->outputIds) + state.header.Append(varData.varId); + + if (m_context.shader->GetStage() == ShaderStageType::Fragment) + state.header.Append(Opcode{ SpvOpExecutionMode }, entryFunc.id, SpvExecutionModeOriginUpperLeft); std::vector ret; MergeBlocks(ret, state.header); @@ -407,14 +533,14 @@ namespace Nz if constexpr (std::is_same_v) m_currentState->constants.Append(Opcode{ (arg) ? SpvOpConstantTrue : SpvOpConstantFalse }, constantId); - else if constexpr (std::is_same_v) - m_currentState->constants.Append(Opcode{ SpvOpConstant }, GetTypeId(ShaderNodes::BasicType::Float1), constantId, Raw{ &arg, sizeof(arg) }); - else if constexpr (std::is_same_v) - m_currentState->constants.Append(Opcode{ SpvOpConstantComposite }, GetTypeId(ShaderNodes::BasicType::Float2), constantId, GetConstantId(arg.x), GetConstantId(arg.y)); - else if constexpr (std::is_same_v) - m_currentState->constants.Append(Opcode{ SpvOpConstantComposite }, GetTypeId(ShaderNodes::BasicType::Float3), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z)); - else if constexpr (std::is_same_v) - m_currentState->constants.Append(Opcode{ SpvOpConstantComposite }, GetTypeId(ShaderNodes::BasicType::Float3), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z), GetConstantId(arg.w)); + else if constexpr (std::is_same_v || std::is_same_v) + m_currentState->constants.Append(Opcode{ SpvOpConstant }, GetTypeId(GetBasicType()), constantId, Raw{ &arg, sizeof(arg) }); + else if constexpr (std::is_same_v || std::is_same_v) + m_currentState->constants.Append(Opcode{ SpvOpConstantComposite }, GetTypeId(GetBasicType()), constantId, GetConstantId(arg.x), GetConstantId(arg.y)); + else if constexpr (std::is_same_v || std::is_same_v) + m_currentState->constants.Append(Opcode{ SpvOpConstantComposite }, GetTypeId(GetBasicType()), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z)); + else if constexpr (std::is_same_v || std::is_same_v) + m_currentState->constants.Append(Opcode{ SpvOpConstantComposite }, GetTypeId(GetBasicType()), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z), GetConstantId(arg.w)); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, value); @@ -467,14 +593,18 @@ namespace Nz std::size_t offset = [&] { switch (arg) { - case ShaderNodes::BasicType::Boolean: return structOffsets.AddField(StructFieldType_Bool1); - case ShaderNodes::BasicType::Float1: return structOffsets.AddField(StructFieldType_Float1); - case ShaderNodes::BasicType::Float2: return structOffsets.AddField(StructFieldType_Float2); - case ShaderNodes::BasicType::Float3: return structOffsets.AddField(StructFieldType_Float3); - case ShaderNodes::BasicType::Float4: return structOffsets.AddField(StructFieldType_Float4); - case ShaderNodes::BasicType::Mat4x4: return structOffsets.AddMatrix(StructFieldType_Float1, 4, 4, true); + case ShaderNodes::BasicType::Boolean: return structOffsets.AddField(StructFieldType_Bool1); + case ShaderNodes::BasicType::Float1: return structOffsets.AddField(StructFieldType_Float1); + case ShaderNodes::BasicType::Float2: return structOffsets.AddField(StructFieldType_Float2); + case ShaderNodes::BasicType::Float3: return structOffsets.AddField(StructFieldType_Float3); + case ShaderNodes::BasicType::Float4: return structOffsets.AddField(StructFieldType_Float4); + case ShaderNodes::BasicType::Int1: return structOffsets.AddField(StructFieldType_Int1); + case ShaderNodes::BasicType::Int2: return structOffsets.AddField(StructFieldType_Int2); + case ShaderNodes::BasicType::Int3: return structOffsets.AddField(StructFieldType_Int3); + case ShaderNodes::BasicType::Int4: return structOffsets.AddField(StructFieldType_Int4); + case ShaderNodes::BasicType::Mat4x4: return structOffsets.AddMatrix(StructFieldType_Float1, 4, 4, true); case ShaderNodes::BasicType::Sampler2D: throw std::runtime_error("unexpected sampler2D as struct member"); - case ShaderNodes::BasicType::Void: throw std::runtime_error("unexpected void as struct member"); + case ShaderNodes::BasicType::Void: throw std::runtime_error("unexpected void as struct member"); } assert(false); @@ -537,13 +667,22 @@ namespace Nz case ShaderNodes::BasicType::Float2: case ShaderNodes::BasicType::Float3: case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: { - UInt32 vecSize = UInt32(arg) - UInt32(ShaderNodes::BasicType::Float2) + 1; + ShaderNodes::BasicType baseType = ShaderNodes::Node::GetComponentType(arg); - m_currentState->types.Append(Opcode{ SpvOpTypeVector }, resultId, GetTypeId(ShaderNodes::BasicType::Float1), vecSize); + UInt32 vecSize = UInt32(arg) - UInt32(baseType) + 1; + + m_currentState->types.Append(Opcode{ SpvOpTypeVector }, resultId, GetTypeId(baseType), vecSize); break; } + case ShaderNodes::BasicType::Int1: + m_currentState->types.Append(Opcode{ SpvOpTypeInt }, resultId, 32, 1); + break; + case ShaderNodes::BasicType::Mat4x4: { m_currentState->types.Append(Opcode{ SpvOpTypeMatrix }, resultId, GetTypeId(ShaderNodes::BasicType::Float4), 4); @@ -581,6 +720,12 @@ namespace Nz } } + UInt32 SpirvWriter::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr) + { + Visit(expr); + return PopResultId(); + } + UInt32 SpirvWriter::GetConstantId(const ShaderNodes::Constant::Variant& value) const { auto typeIt = m_currentState->constantIds.find(value); @@ -613,6 +758,19 @@ namespace Nz return resultId; } + UInt32 SpirvWriter::ReadVariable(ExtVar& var) + { + if (!var.valueId.has_value()) + { + UInt32 resultId = AllocateResultId(); + m_currentState->instructions.Append(Opcode{ SpvOpLoad }, var.typeId, resultId, var.varId); + + var.valueId = resultId; + } + + return var.valueId.value(); + } + UInt32 SpirvWriter::RegisterType(ShaderExpressionType type) { auto it = m_currentState->typeIds.find(type); @@ -628,6 +786,7 @@ namespace Nz { case ShaderNodes::BasicType::Boolean: case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Int1: case ShaderNodes::BasicType::Void: break; //< Nothing to do @@ -635,11 +794,11 @@ namespace Nz case ShaderNodes::BasicType::Float2: case ShaderNodes::BasicType::Float3: case ShaderNodes::BasicType::Float4: - RegisterType(ShaderNodes::BasicType::Float1); - break; - + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: case ShaderNodes::BasicType::Mat4x4: - RegisterType(ShaderNodes::BasicType::Float4); + RegisterType(ShaderNodes::Node::GetComponentType(arg)); break; case ShaderNodes::BasicType::Sampler2D: @@ -670,60 +829,392 @@ namespace Nz return it->second; } - void SpirvWriter::Visit(const ShaderNodes::AccessMember& node) + void SpirvWriter::Visit(ShaderNodes::AccessMember& node) { - Visit(node.structExpr); + UInt32 pointerId; + SpvStorageClass storage; + + switch (node.structExpr->GetType()) + { + case ShaderNodes::NodeType::Identifier: + { + auto& identifier = static_cast(*node.structExpr); + switch (identifier.var->GetType()) + { + case ShaderNodes::VariableType::BuiltinVariable: + { + auto& builtinvar = static_cast(*identifier.var); + auto it = m_currentState->builtinIds.find(builtinvar.entry); + assert(it != m_currentState->builtinIds.end()); + + pointerId = it->second.varId; + break; + } + + case ShaderNodes::VariableType::InputVariable: + { + auto& inputVar = static_cast(*identifier.var); + auto it = m_currentState->inputIds.find(inputVar.name); + assert(it != m_currentState->inputIds.end()); + + storage = SpvStorageClassInput; + + pointerId = it->second.varId; + break; + } + + case ShaderNodes::VariableType::OutputVariable: + { + auto& outputVar = static_cast(*identifier.var); + auto it = m_currentState->outputIds.find(outputVar.name); + assert(it != m_currentState->outputIds.end()); + + storage = SpvStorageClassOutput; + + pointerId = it->second.varId; + break; + } + + case ShaderNodes::VariableType::UniformVariable: + { + auto& uniformVar = static_cast(*identifier.var); + auto it = m_currentState->uniformIds.find(uniformVar.name); + assert(it != m_currentState->uniformIds.end()); + + storage = SpvStorageClassUniform; + + pointerId = it->second.varId; + break; + } + + case ShaderNodes::VariableType::LocalVariable: + case ShaderNodes::VariableType::ParameterVariable: + default: + throw std::runtime_error("not yet implemented"); + } + break; + } + + case ShaderNodes::NodeType::SwizzleOp: //< TODO + default: + throw std::runtime_error("not yet implemented"); + } + + UInt32 memberPointerId = AllocateResultId(); + UInt32 pointerType = AllocateResultId(); + UInt32 typeId = GetTypeId(node.exprType); + UInt32 indexId = GetConstantId(Int32(node.memberIndex)); + + m_currentState->types.Append(Opcode{ SpvOpTypePointer }, pointerType, storage, typeId); + + m_currentState->instructions.Append(Opcode{ SpvOpAccessChain }, pointerType, memberPointerId, pointerId, indexId); + + UInt32 resultId = AllocateResultId(); + + m_currentState->instructions.Append(Opcode{ SpvOpLoad }, typeId, resultId, memberPointerId); + + PushResultId(resultId); } - void SpirvWriter::Visit(const ShaderNodes::AssignOp& node) + void SpirvWriter::Visit(ShaderNodes::AssignOp& node) { - Visit(node.left); - Visit(node.right); + UInt32 result = EvaluateExpression(node.right); + + switch (node.left->GetType()) + { + case ShaderNodes::NodeType::Identifier: + { + auto& identifier = static_cast(*node.left); + switch (identifier.var->GetType()) + { + case ShaderNodes::VariableType::BuiltinVariable: + { + auto& builtinvar = static_cast(*identifier.var); + auto it = m_currentState->builtinIds.find(builtinvar.entry); + assert(it != m_currentState->builtinIds.end()); + + m_currentState->instructions.Append(Opcode{ SpvOpStore }, it->second.varId, result); + PushResultId(result); + break; + } + + case ShaderNodes::VariableType::OutputVariable: + { + auto& outputVar = static_cast(*identifier.var); + auto it = m_currentState->outputIds.find(outputVar.name); + assert(it != m_currentState->outputIds.end()); + + m_currentState->instructions.Append(Opcode{ SpvOpStore }, it->second.varId, result); + PushResultId(result); + break; + } + + case ShaderNodes::VariableType::InputVariable: + case ShaderNodes::VariableType::LocalVariable: + case ShaderNodes::VariableType::ParameterVariable: + case ShaderNodes::VariableType::UniformVariable: + default: + throw std::runtime_error("not yet implemented"); + } + break; + } + + case ShaderNodes::NodeType::SwizzleOp: //< TODO + default: + throw std::runtime_error("not yet implemented"); + } } - void SpirvWriter::Visit(const ShaderNodes::Branch& node) + void SpirvWriter::Visit(ShaderNodes::Branch& node) { throw std::runtime_error("not yet implemented"); } - void SpirvWriter::Visit(const ShaderNodes::BinaryOp& node) + void SpirvWriter::Visit(ShaderNodes::BinaryOp& node) { - Visit(node.left); - Visit(node.right); + ShaderExpressionType resultExprType = node.GetExpressionType(); + assert(std::holds_alternative(resultExprType)); + const ShaderExpressionType& leftExprType = node.left->GetExpressionType(); + assert(std::holds_alternative(leftExprType)); + + const ShaderExpressionType& rightExprType = node.right->GetExpressionType(); + assert(std::holds_alternative(rightExprType)); + + ShaderNodes::BasicType resultType = std::get(resultExprType); + ShaderNodes::BasicType leftType = std::get(leftExprType); + ShaderNodes::BasicType rightType = std::get(rightExprType); + + + UInt32 leftOperand = EvaluateExpression(node.left); + UInt32 rightOperand = EvaluateExpression(node.right); UInt32 resultId = AllocateResultId(); - UInt32 leftOperand = PopResultId(); - UInt32 rightOperand = PopResultId(); - SpvOp op = [&] { + bool swapOperands = false; + + SpvOp op = [&] + { switch (node.op) { - case ShaderNodes::BinaryType::Add: return SpvOpFAdd; - case ShaderNodes::BinaryType::Substract: return SpvOpFSub; - case ShaderNodes::BinaryType::Multiply: return SpvOpFMul; - case ShaderNodes::BinaryType::Divide: return SpvOpFDiv; - case ShaderNodes::BinaryType::Equality: return SpvOpFOrdEqual; + case ShaderNodes::BinaryType::Add: + { + switch (leftType) + { + case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Mat4x4: + return SpvOpFAdd; + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + return SpvOpIAdd; + + case ShaderNodes::BasicType::Boolean: + case ShaderNodes::BasicType::Sampler2D: + case ShaderNodes::BasicType::Void: + break; + } + } + + case ShaderNodes::BinaryType::Substract: + { + switch (leftType) + { + case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Mat4x4: + return SpvOpFSub; + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + return SpvOpISub; + + case ShaderNodes::BasicType::Boolean: + case ShaderNodes::BasicType::Sampler2D: + case ShaderNodes::BasicType::Void: + break; + } + } + + case ShaderNodes::BinaryType::Divide: + { + switch (leftType) + { + case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Mat4x4: + return SpvOpFDiv; + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + return SpvOpSDiv; + + case ShaderNodes::BasicType::Boolean: + case ShaderNodes::BasicType::Sampler2D: + case ShaderNodes::BasicType::Void: + break; + } + } + + case ShaderNodes::BinaryType::Equality: + { + switch (leftType) + { + case ShaderNodes::BasicType::Boolean: + return SpvOpLogicalEqual; + + case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Mat4x4: + return SpvOpFOrdEqual; + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + return SpvOpIEqual; + + case ShaderNodes::BasicType::Sampler2D: + case ShaderNodes::BasicType::Void: + break; + } + } + + case ShaderNodes::BinaryType::Multiply: + { + switch (leftType) + { + case ShaderNodes::BasicType::Float1: + { + switch (rightType) + { + case ShaderNodes::BasicType::Float1: + return SpvOpFMul; + + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + swapOperands = true; + return SpvOpVectorTimesScalar; + + case ShaderNodes::BasicType::Mat4x4: + swapOperands = true; + return SpvOpMatrixTimesScalar; + + default: + break; + } + + break; + } + + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + { + switch (rightType) + { + case ShaderNodes::BasicType::Float1: + return SpvOpVectorTimesScalar; + + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + return SpvOpFMul; + + case ShaderNodes::BasicType::Mat4x4: + return SpvOpVectorTimesMatrix; + + default: + break; + } + + break; + } + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + return SpvOpIMul; + + case ShaderNodes::BasicType::Mat4x4: + { + switch (rightType) + { + case ShaderNodes::BasicType::Float1: return SpvOpMatrixTimesScalar; + case ShaderNodes::BasicType::Float4: return SpvOpMatrixTimesVector; + case ShaderNodes::BasicType::Mat4x4: return SpvOpMatrixTimesMatrix; + + default: + break; + } + + break; + } + + default: + break; + } + break; + } } assert(false); throw std::runtime_error("unexpected binary operation"); }(); - m_currentState->instructions.Append(Opcode{ op }, GetTypeId(ShaderNodes::BasicType::Float3), resultId, leftOperand, rightOperand); + if (swapOperands) + std::swap(leftOperand, rightOperand); + + m_currentState->instructions.Append(Opcode{ op }, GetTypeId(resultType), resultId, leftOperand, rightOperand); + PushResultId(resultId); } - void SpirvWriter::Visit(const ShaderNodes::Cast& node) + void SpirvWriter::Visit(ShaderNodes::Cast& node) { - for (auto& expr : node.expressions) + const ShaderExpressionType& targetExprType = node.exprType; + assert(std::holds_alternative(targetExprType)); + + ShaderNodes::BasicType targetType = std::get(targetExprType); + + StackVector exprResults = NazaraStackVector(UInt32, node.expressions.size()); + + for (const auto& exprPtr : node.expressions) { - if (!expr) + if (!exprPtr) break; - Visit(expr); + exprResults.push_back(EvaluateExpression(exprPtr)); } + + UInt32 resultId = AllocateResultId(); + + m_currentState->instructions.Append(Opcode{ SpvOpCompositeConstruct }, WordCount { static_cast(3 + exprResults.size()) }); + m_currentState->instructions.Append(GetTypeId(targetType)); + m_currentState->instructions.Append(resultId); + + for (UInt32 resultId : exprResults) + m_currentState->instructions.Append(resultId); + + PushResultId(resultId); } - void SpirvWriter::Visit(const ShaderNodes::Constant& node) + void SpirvWriter::Visit(ShaderNodes::Constant& node) { std::visit([&] (const auto& value) { @@ -731,43 +1222,150 @@ namespace Nz }, node.value); } - void SpirvWriter::Visit(const ShaderNodes::DeclareVariable& node) + void SpirvWriter::Visit(ShaderNodes::DeclareVariable& node) { if (node.expression) - Visit(node.expression); + { + assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable); + + const auto& localVar = static_cast(*node.variable); + m_currentState->varToResult[localVar.name] = EvaluateExpression(node.expression); + } } - void SpirvWriter::Visit(const ShaderNodes::ExpressionStatement& node) + void SpirvWriter::Visit(ShaderNodes::ExpressionStatement& node) { Visit(node.expression); + PopResultId(); } - void SpirvWriter::Visit(const ShaderNodes::Identifier& node) + void SpirvWriter::Visit(ShaderNodes::Identifier& node) { - PushResultId(42); + Visit(node.var); } - void SpirvWriter::Visit(const ShaderNodes::IntrinsicCall& node) + void SpirvWriter::Visit(ShaderNodes::IntrinsicCall& node) { - for (auto& param : node.parameters) - Visit(param); + switch (node.intrinsic) + { + case ShaderNodes::IntrinsicType::DotProduct: + { + const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType(); + assert(std::holds_alternative(vecExprType)); + + ShaderNodes::BasicType vecType = std::get(vecExprType); + + UInt32 typeId = GetTypeId(node.GetComponentType(vecType)); + + UInt32 vec1 = EvaluateExpression(node.parameters[0]); + UInt32 vec2 = EvaluateExpression(node.parameters[1]); + + UInt32 resultId = AllocateResultId(); + + m_currentState->instructions.Append(Opcode{ SpvOpDot }, typeId, resultId, vec1, vec2); + PushResultId(resultId); + break; + } + + case ShaderNodes::IntrinsicType::CrossProduct: + default: + throw std::runtime_error("not yet implemented"); + } } - void SpirvWriter::Visit(const ShaderNodes::Sample2D& node) + void SpirvWriter::Visit(ShaderNodes::Sample2D& node) { - Visit(node.sampler); - Visit(node.coordinates); + // OpImageSampleImplicitLod %v4float %31 %35 + + UInt32 typeId = GetTypeId(ShaderNodes::BasicType::Float4); + + UInt32 samplerId = EvaluateExpression(node.sampler); + UInt32 coordinatesId = EvaluateExpression(node.coordinates); + UInt32 resultId = AllocateResultId(); + + m_currentState->instructions.Append(Opcode{ SpvOpImageSampleImplicitLod }, typeId, resultId, samplerId, coordinatesId); + PushResultId(resultId); } - void SpirvWriter::Visit(const ShaderNodes::StatementBlock& node) + void SpirvWriter::Visit(ShaderNodes::StatementBlock& node) { for (auto& statement : node.statements) Visit(statement); } - void SpirvWriter::Visit(const ShaderNodes::SwizzleOp& node) + void SpirvWriter::Visit(ShaderNodes::SwizzleOp& node) { - Visit(node.expression); + const ShaderExpressionType& targetExprType = node.GetExpressionType(); + assert(std::holds_alternative(targetExprType)); + + ShaderNodes::BasicType targetType = std::get(targetExprType); + + UInt32 exprResultId = EvaluateExpression(node.expression); + UInt32 resultId = AllocateResultId(); + + if (node.componentCount > 1) + { + // Swizzling is implemented via SpvOpVectorShuffle using the same vector twice as operands + m_currentState->instructions.Append(Opcode{ SpvOpVectorShuffle }, WordCount{ static_cast(5 + node.componentCount) }); + m_currentState->instructions.Append(GetTypeId(targetType)); + m_currentState->instructions.Append(resultId); + m_currentState->instructions.Append(exprResultId); + m_currentState->instructions.Append(exprResultId); + + for (std::size_t i = 0; i < node.componentCount; ++i) + m_currentState->instructions.Append(UInt32(node.components[0]) - UInt32(node.components[i])); + } + else + { + // Extract a single component from the vector + assert(node.componentCount == 1); + + m_currentState->instructions.Append(Opcode{ SpvOpCompositeExtract }, GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) ); + } + + PushResultId(resultId); + } + + void SpirvWriter::Visit(ShaderNodes::BuiltinVariable& var) + { + throw std::runtime_error("not implemented yet"); + } + + void SpirvWriter::Visit(ShaderNodes::InputVariable& var) + { + auto it = m_currentState->inputIds.find(var.name); + assert(it != m_currentState->inputIds.end()); + + PushResultId(ReadVariable(it.value())); + } + + void SpirvWriter::Visit(ShaderNodes::LocalVariable& var) + { + auto it = m_currentState->varToResult.find(var.name); + assert(it != m_currentState->varToResult.end()); + + PushResultId(it->second); + } + + void SpirvWriter::Visit(ShaderNodes::OutputVariable& var) + { + auto it = m_currentState->outputIds.find(var.name); + assert(it != m_currentState->outputIds.end()); + + PushResultId(ReadVariable(it.value())); + } + + void SpirvWriter::Visit(ShaderNodes::ParameterVariable& var) + { + throw std::runtime_error("not implemented yet"); + } + + void SpirvWriter::Visit(ShaderNodes::UniformVariable& var) + { + auto it = m_currentState->uniformIds.find(var.name); + assert(it != m_currentState->uniformIds.end()); + + PushResultId(ReadVariable(it.value())); } void SpirvWriter::MergeBlocks(std::vector& output, const Section& from) @@ -794,15 +1392,16 @@ namespace Nz UInt32 codepoint = 0; for (std::size_t j = 0; j < 4; ++j) { +#ifdef NAZARA_BIG_ENDIAN + std::size_t pos = i * 4 + (3 - j); +#else std::size_t pos = i * 4 + j; +#endif + if (pos < raw.size) codepoint |= UInt32(ptr[pos]) << (j * 8); } -#ifdef NAZARA_BIG_ENDIAN - SwapBytes(codepoint); -#endif - Append(codepoint); } diff --git a/src/Nazara/VulkanRenderer/Vulkan.cpp b/src/Nazara/VulkanRenderer/Vulkan.cpp index e6e2d4cfd..4f55223e1 100644 --- a/src/Nazara/VulkanRenderer/Vulkan.cpp +++ b/src/Nazara/VulkanRenderer/Vulkan.cpp @@ -57,7 +57,7 @@ namespace Nz String appName = "Another application made with Nazara Engine"; String engineName = "Nazara Engine - Vulkan Renderer"; - constexpr UInt32 appVersion = VK_MAKE_VERSION(1, 0, 0); + UInt32 appVersion = VK_MAKE_VERSION(1, 0, 0); UInt32 engineVersion = VK_MAKE_VERSION(1, 0, 0); parameters.GetStringParameter("VkAppInfo_OverrideApplicationName", &appName); diff --git a/src/Nazara/VulkanRenderer/VulkanShaderStage.cpp b/src/Nazara/VulkanRenderer/VulkanShaderStage.cpp index 238bb6468..c0b510dd8 100644 --- a/src/Nazara/VulkanRenderer/VulkanShaderStage.cpp +++ b/src/Nazara/VulkanRenderer/VulkanShaderStage.cpp @@ -3,25 +3,61 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include +#include +#include #include namespace Nz { bool VulkanShaderStage::Create(Vk::Device& device, ShaderStageType type, ShaderLanguage lang, const void* source, std::size_t sourceSize) { - if (lang != ShaderLanguage::SpirV) - { - NazaraError("Only Spir-V is supported for now"); - return false; - } - - if (!m_shaderModule.Create(device, reinterpret_cast(source), sourceSize)) - { - NazaraError("Failed to create shader module"); - return false; - } - m_stage = type; + + switch (lang) + { + case ShaderLanguage::NazaraBinary: + { + ByteStream byteStream(source, sourceSize); + auto shader = Nz::UnserializeShader(byteStream); + + if (shader.GetStage() != type) + throw std::runtime_error("incompatible shader stage"); + + SpirvWriter::Environment env; + + SpirvWriter writer; + writer.SetEnv(env); + + std::vector code = writer.Generate(shader); + + if (!m_shaderModule.Create(device, code.data(), code.size() * sizeof(UInt32))) + { + NazaraError("Failed to create shader module"); + return false; + } + + break; + } + + case ShaderLanguage::SpirV: + { + if (!m_shaderModule.Create(device, reinterpret_cast(source), sourceSize)) + { + NazaraError("Failed to create shader module"); + return false; + } + + break; + } + + default: + { + NazaraError("this language is not supported"); + return false; + } + } + return true; } }