diff --git a/include/Nazara/Shader.hpp b/include/Nazara/Shader.hpp index 9c59a649d..cef4eada9 100644 --- a/include/Nazara/Shader.hpp +++ b/include/Nazara/Shader.hpp @@ -37,12 +37,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include #include diff --git a/include/Nazara/Shader/Ast/Attribute.hpp b/include/Nazara/Shader/Ast/Attribute.hpp new file mode 100644 index 000000000..9c1e0dca5 --- /dev/null +++ b/include/Nazara/Shader/Ast/Attribute.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERAST_ATTRIBUTES_HPP +#define NAZARA_SHADERAST_ATTRIBUTES_HPP + +#include +#include + +namespace Nz::ShaderAst +{ + struct Attribute + { + using Param = std::variant; + + AttributeType type; + Param args; + }; +} + +#endif diff --git a/include/Nazara/Shader/Ast/ExpressionType.hpp b/include/Nazara/Shader/Ast/ExpressionType.hpp new file mode 100644 index 000000000..f42776fc3 --- /dev/null +++ b/include/Nazara/Shader/Ast/ExpressionType.hpp @@ -0,0 +1,96 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADER_AST_EXPRESSIONTYPE_HPP +#define NAZARA_SHADER_AST_EXPRESSIONTYPE_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + struct IdentifierType //< Alias or struct + { + std::string name; + + inline bool operator==(const IdentifierType& rhs) const; + inline bool operator!=(const IdentifierType& rhs) const; + }; + + struct MatrixType + { + std::size_t columnCount; + std::size_t rowCount; + PrimitiveType type; + + inline bool operator==(const MatrixType& rhs) const; + inline bool operator!=(const MatrixType& rhs) const; + }; + + struct NoType + { + inline bool operator==(const NoType& rhs) const; + inline bool operator!=(const NoType& rhs) const; + }; + + struct SamplerType + { + ImageType dim; + PrimitiveType sampledType; + + inline bool operator==(const SamplerType& rhs) const; + inline bool operator!=(const SamplerType& rhs) const; + }; + + struct UniformType + { + IdentifierType containedType; + + inline bool operator==(const UniformType& rhs) const; + inline bool operator!=(const UniformType& rhs) const; + }; + + struct VectorType + { + std::size_t componentCount; + PrimitiveType type; + + inline bool operator==(const VectorType& rhs) const; + inline bool operator!=(const VectorType& rhs) const; + }; + + using ExpressionType = std::variant; + + struct StructDescription + { + struct StructMember + { + std::string name; + std::vector attributes; + ExpressionType type; + }; + + std::string name; + std::vector members; + }; + + inline bool IsIdentifierType(const ExpressionType& type); + inline bool IsMatrixType(const ExpressionType& type); + inline bool IsNoType(const ExpressionType& type); + inline bool IsPrimitiveType(const ExpressionType& type); + inline bool IsSamplerType(const ExpressionType& type); + inline bool IsUniformType(const ExpressionType& type); + inline bool IsVectorType(const ExpressionType& type); +} + +#include + +#endif diff --git a/include/Nazara/Shader/Ast/ExpressionType.inl b/include/Nazara/Shader/Ast/ExpressionType.inl new file mode 100644 index 000000000..b9b7734e7 --- /dev/null +++ b/include/Nazara/Shader/Ast/ExpressionType.inl @@ -0,0 +1,111 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include + +namespace Nz::ShaderAst +{ + inline bool IdentifierType::operator==(const IdentifierType& rhs) const + { + return name == rhs.name; + } + + inline bool IdentifierType::operator!=(const IdentifierType& rhs) const + { + return !operator==(rhs); + } + + + inline bool MatrixType::operator==(const MatrixType& rhs) const + { + return columnCount == rhs.columnCount && rowCount == rhs.rowCount && type == rhs.type; + } + + inline bool MatrixType::operator!=(const MatrixType& rhs) const + { + return !operator==(rhs); + } + + + inline bool NoType::operator==(const NoType& /*rhs*/) const + { + return true; + } + + inline bool NoType::operator!=(const NoType& /*rhs*/) const + { + return false; + } + + + inline bool SamplerType::operator==(const SamplerType& rhs) const + { + return dim == rhs.dim && sampledType == rhs.sampledType; + } + + inline bool SamplerType::operator!=(const SamplerType& rhs) const + { + return !operator==(rhs); + } + + inline bool UniformType::operator==(const UniformType& rhs) const + { + return containedType == rhs.containedType; + } + + inline bool UniformType::operator!=(const UniformType& rhs) const + { + return !operator==(rhs); + } + + inline bool VectorType::operator==(const VectorType& rhs) const + { + return componentCount == rhs.componentCount && type == rhs.type; + } + + inline bool VectorType::operator!=(const VectorType& rhs) const + { + return !operator==(rhs); + } + + + inline bool IsIdentifierType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + + inline bool IsMatrixType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + + inline bool IsNoType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + + inline bool IsPrimitiveType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + + inline bool IsSamplerType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + + bool IsUniformType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + + bool IsVectorType(const ExpressionType& type) + { + return std::holds_alternative(type); + } +} + +#include diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index fd3c299ce..aab8f06c5 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -28,7 +28,7 @@ namespace Nz GlslWriter(GlslWriter&&) = delete; ~GlslWriter() = default; - std::string Generate(ShaderAst::StatementPtr& shader, const States& conditions = {}); + std::string Generate(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader, const States& conditions = {}); void SetEnv(Environment environment); @@ -44,17 +44,26 @@ namespace Nz static const char* GetFlipYUniformName(); private: - void Append(ShaderAst::ShaderExpressionType type); + void Append(const ShaderAst::ExpressionType& type); void Append(ShaderAst::BuiltinEntry builtin); - void Append(ShaderAst::BasicType type); + void Append(const ShaderAst::IdentifierType& identifierType); + void Append(const ShaderAst::MatrixType& matrixType); void Append(ShaderAst::MemoryLayout layout); + void Append(ShaderAst::NoType); + void Append(ShaderAst::PrimitiveType type); + void Append(const ShaderAst::SamplerType& samplerType); + void Append(const ShaderAst::UniformType& uniformType); + void Append(const ShaderAst::VectorType& vecType); template void Append(const T& param); + template void Append(const T1& firstParam, const T2& secondParam, Args&&... params); void AppendCommentSection(const std::string& section); + void AppendEntryPoint(ShaderStageType shaderStage); void AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); void AppendLine(const std::string& txt = {}); + template void AppendLine(Args&&... params); void EnterScope(); - void LeaveScope(); + void LeaveScope(bool skipLine = true); void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); @@ -70,7 +79,9 @@ namespace Nz void Visit(ShaderAst::BranchStatement& node) override; void Visit(ShaderAst::ConditionalStatement& node) override; + void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override; + void Visit(ShaderAst::DeclareStructStatement& node) override; void Visit(ShaderAst::DeclareVariableStatement& node) override; void Visit(ShaderAst::DiscardStatement& node) override; void Visit(ShaderAst::ExpressionStatement& node) override; diff --git a/include/Nazara/Shader/ShaderAstCache.hpp b/include/Nazara/Shader/ShaderAstCache.hpp index 6e1b0a144..2a18b173d 100644 --- a/include/Nazara/Shader/ShaderAstCache.hpp +++ b/include/Nazara/Shader/ShaderAstCache.hpp @@ -16,15 +16,22 @@ namespace Nz::ShaderAst { struct AstCache { + struct Identifier; + + struct Alias + { + std::variant value; + }; + struct Variable { - ShaderExpressionType type; + ExpressionType type; }; struct Identifier { std::string name; - std::variant value; + std::variant value; }; struct Scope @@ -33,12 +40,12 @@ namespace Nz::ShaderAst std::vector identifiers; }; + inline void Clear(); inline const Identifier* FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const; inline std::size_t GetScopeId(const Node* node) const; - ShaderStageType stageType = ShaderStageType::Undefined; std::array entryFunctions = {}; - std::unordered_map nodeExpressionType; + std::unordered_map nodeExpressionType; std::unordered_map scopeIdByNode; std::vector scopes; }; diff --git a/include/Nazara/Shader/ShaderAstCache.inl b/include/Nazara/Shader/ShaderAstCache.inl index 48cd769f9..f254eac0c 100644 --- a/include/Nazara/Shader/ShaderAstCache.inl +++ b/include/Nazara/Shader/ShaderAstCache.inl @@ -7,6 +7,14 @@ namespace Nz::ShaderAst { + inline void AstCache::Clear() + { + entryFunctions.fill(nullptr); + nodeExpressionType.clear(); + scopeIdByNode.clear(); + scopes.clear(); + } + inline auto AstCache::FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const -> const Identifier* { assert(startingScopeId < scopes.size()); @@ -28,7 +36,7 @@ namespace Nz::ShaderAst inline std::size_t AstCache::GetScopeId(const Node* node) const { auto it = scopeIdByNode.find(node); - assert(it == scopeIdByNode.end()); + assert(it != scopeIdByNode.end()); return it->second; } diff --git a/include/Nazara/Shader/ShaderAstCloner.hpp b/include/Nazara/Shader/ShaderAstCloner.hpp index 23e69234a..a2ee45f44 100644 --- a/include/Nazara/Shader/ShaderAstCloner.hpp +++ b/include/Nazara/Shader/ShaderAstCloner.hpp @@ -33,6 +33,8 @@ namespace Nz::ShaderAst ExpressionPtr CloneExpression(ExpressionPtr& expr); StatementPtr CloneStatement(StatementPtr& statement); + virtual std::unique_ptr Clone(DeclareFunctionStatement& node); + using AstExpressionVisitor::Visit; using AstStatementVisitor::Visit; @@ -45,8 +47,10 @@ namespace Nz::ShaderAst void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; void Visit(SwizzleExpression& node) override; + void Visit(BranchStatement& node) override; void Visit(ConditionalStatement& node) override; + void Visit(DeclareExternalStatement& node) override; void Visit(DeclareFunctionStatement& node) override; void Visit(DeclareStructStatement& node) override; void Visit(DeclareVariableStatement& node) override; diff --git a/include/Nazara/Shader/ShaderAstExpressionType.hpp b/include/Nazara/Shader/ShaderAstExpressionType.hpp index 4ad7731e8..095b5e5ea 100644 --- a/include/Nazara/Shader/ShaderAstExpressionType.hpp +++ b/include/Nazara/Shader/ShaderAstExpressionType.hpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include namespace Nz::ShaderAst @@ -25,13 +25,14 @@ namespace Nz::ShaderAst ExpressionTypeVisitor(ExpressionTypeVisitor&&) = delete; ~ExpressionTypeVisitor() = default; - ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache); + ExpressionType GetExpressionType(Expression& expression, AstCache* cache); ExpressionTypeVisitor& operator=(const ExpressionTypeVisitor&) = delete; ExpressionTypeVisitor& operator=(ExpressionTypeVisitor&&) = delete; private: - ShaderExpressionType GetExpressionTypeInternal(Expression& expression); + ExpressionType GetExpressionTypeInternal(Expression& expression); + ExpressionType ResolveAlias(Expression& expression, ExpressionType expressionType); void Visit(Expression& expression); @@ -46,10 +47,10 @@ namespace Nz::ShaderAst void Visit(SwizzleExpression& node) override; AstCache* m_cache; - std::optional m_lastExpressionType; + std::optional m_lastExpressionType; }; - inline ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache = nullptr); + inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache = nullptr); } #include diff --git a/include/Nazara/Shader/ShaderAstExpressionType.inl b/include/Nazara/Shader/ShaderAstExpressionType.inl index f71146200..279a3909e 100644 --- a/include/Nazara/Shader/ShaderAstExpressionType.inl +++ b/include/Nazara/Shader/ShaderAstExpressionType.inl @@ -7,7 +7,7 @@ namespace Nz::ShaderAst { - inline ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache) + inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache) { ExpressionTypeVisitor visitor; return visitor.GetExpressionType(expression, cache); diff --git a/include/Nazara/Shader/ShaderAstNodes.hpp b/include/Nazara/Shader/ShaderAstNodes.hpp index 7bbce68a9..86b19521b 100644 --- a/include/Nazara/Shader/ShaderAstNodes.hpp +++ b/include/Nazara/Shader/ShaderAstNodes.hpp @@ -37,6 +37,7 @@ NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression) NAZARA_SHADERAST_EXPRESSION(SwizzleExpression) NAZARA_SHADERAST_STATEMENT(BranchStatement) NAZARA_SHADERAST_STATEMENT(ConditionalStatement) +NAZARA_SHADERAST_STATEMENT(DeclareExternalStatement) NAZARA_SHADERAST_STATEMENT(DeclareFunctionStatement) NAZARA_SHADERAST_STATEMENT(DeclareStructStatement) NAZARA_SHADERAST_STATEMENT(DeclareVariableStatement) diff --git a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp index 68251ddc0..1bfbb65da 100644 --- a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp @@ -32,6 +32,7 @@ namespace Nz::ShaderAst void Visit(BranchStatement& node) override; void Visit(ConditionalStatement& node) override; + void Visit(DeclareExternalStatement& node) override; void Visit(DeclareFunctionStatement& node) override; void Visit(DeclareStructStatement& node) override; void Visit(DeclareVariableStatement& node) override; diff --git a/include/Nazara/Shader/ShaderAstSerializer.hpp b/include/Nazara/Shader/ShaderAstSerializer.hpp index fe09ea604..2cb0bd5f1 100644 --- a/include/Nazara/Shader/ShaderAstSerializer.hpp +++ b/include/Nazara/Shader/ShaderAstSerializer.hpp @@ -35,6 +35,7 @@ namespace Nz::ShaderAst void Serialize(BranchStatement& node); void Serialize(ConditionalStatement& node); + void Serialize(DeclareExternalStatement& node); void Serialize(DeclareFunctionStatement& node); void Serialize(DeclareStructStatement& node); void Serialize(DeclareVariableStatement& node); @@ -45,6 +46,7 @@ namespace Nz::ShaderAst void Serialize(ReturnStatement& node); protected: + void Attributes(std::vector& attributes); template void Container(T& container); template void Enum(T& enumVal); template void OptEnum(std::optional& optVal); @@ -55,7 +57,7 @@ namespace Nz::ShaderAst virtual void Node(ExpressionPtr& node) = 0; virtual void Node(StatementPtr& node) = 0; - virtual void Type(ShaderExpressionType& type) = 0; + virtual void Type(ExpressionType& type) = 0; virtual void Value(bool& val) = 0; virtual void Value(float& val) = 0; @@ -86,7 +88,7 @@ namespace Nz::ShaderAst bool IsWriting() const override; void Node(ExpressionPtr& node) override; void Node(StatementPtr& node) override; - void Type(ShaderExpressionType& type) override; + void Type(ExpressionType& type) override; void Value(bool& val) override; void Value(float& val) override; void Value(std::string& val) override; @@ -117,7 +119,7 @@ namespace Nz::ShaderAst bool IsWriting() const override; void Node(ExpressionPtr& node) override; void Node(StatementPtr& node) override; - void Type(ShaderExpressionType& type) override; + void Type(ExpressionType& type) override; void Value(bool& val) override; void Value(float& val) override; void Value(std::string& val) override; diff --git a/include/Nazara/Shader/ShaderAstTypes.hpp b/include/Nazara/Shader/ShaderAstTypes.hpp deleted file mode 100644 index e1f50d4e7..000000000 --- a/include/Nazara/Shader/ShaderAstTypes.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADER_ASTTYPES_HPP -#define NAZARA_SHADER_ASTTYPES_HPP - -#include -#include -#include -#include -#include - -namespace Nz::ShaderAst -{ - using ShaderExpressionType = std::variant; - - struct StructDescription - { - struct StructMember - { - std::string name; - ShaderExpressionType type; - }; - - std::string name; - std::vector members; - }; - - inline bool IsBasicType(const ShaderExpressionType& type); - inline bool IsMatrixType(const ShaderExpressionType& type); - inline bool IsSamplerType(const ShaderExpressionType& type); - inline bool IsStructType(const ShaderExpressionType& type); -} - -#include - -#endif // NAZARA_SHADER_ASTTYPES_HPP diff --git a/include/Nazara/Shader/ShaderAstTypes.inl b/include/Nazara/Shader/ShaderAstTypes.inl deleted file mode 100644 index 6eed4a945..000000000 --- a/include/Nazara/Shader/ShaderAstTypes.inl +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include -#include - -namespace Nz::ShaderAst -{ - inline bool IsBasicType(const ShaderExpressionType& type) - { - return std::visit([&](auto&& arg) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - return true; - else if constexpr (std::is_same_v) - return false; - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - - }, type); - } - - inline bool IsMatrixType(const ShaderExpressionType& type) - { - if (!IsBasicType(type)) - return false; - - switch (std::get(type)) - { - case BasicType::Mat4x4: - return true; - - case BasicType::Boolean: - case BasicType::Float1: - case BasicType::Float2: - case BasicType::Float3: - case BasicType::Float4: - case BasicType::Int1: - case BasicType::Int2: - case BasicType::Int3: - case BasicType::Int4: - case BasicType::Sampler2D: - case BasicType::Void: - case BasicType::UInt1: - case BasicType::UInt2: - case BasicType::UInt3: - case BasicType::UInt4: - return false; - } - - return false; - } - - inline bool IsSamplerType(const ShaderExpressionType& type) - { - if (!IsBasicType(type)) - return false; - - switch (std::get(type)) - { - case BasicType::Sampler2D: - return true; - - case BasicType::Boolean: - case BasicType::Float1: - case BasicType::Float2: - case BasicType::Float3: - case BasicType::Float4: - case BasicType::Int1: - case BasicType::Int2: - case BasicType::Int3: - case BasicType::Int4: - case BasicType::Mat4x4: - case BasicType::Void: - case BasicType::UInt1: - case BasicType::UInt2: - case BasicType::UInt3: - case BasicType::UInt4: - return false; - } - - return false; - } - - inline bool IsStructType(const ShaderExpressionType& type) - { - return std::visit([&](auto&& arg) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - return false; - else if constexpr (std::is_same_v) - return true; - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - - }, type); - } -} - -#include diff --git a/include/Nazara/Shader/ShaderAstValidator.hpp b/include/Nazara/Shader/ShaderAstValidator.hpp index 00d708d96..e3dd4d94e 100644 --- a/include/Nazara/Shader/ShaderAstValidator.hpp +++ b/include/Nazara/Shader/ShaderAstValidator.hpp @@ -29,14 +29,14 @@ namespace Nz::ShaderAst Expression& MandatoryExpr(ExpressionPtr& node); Statement& MandatoryStatement(StatementPtr& node); void TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right); - void TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right); + void TypeMustMatch(const ExpressionType& left, const ExpressionType& right); - ShaderExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); + ExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); AstCache::Scope& EnterScope(); void ExitScope(); - void RegisterExpressionType(Expression& node, ShaderExpressionType expressionType); + void RegisterExpressionType(Expression& node, ExpressionType expressionType); void RegisterScope(Node& node); void Visit(AccessMemberExpression& node) override; @@ -51,6 +51,7 @@ namespace Nz::ShaderAst void Visit(BranchStatement& node) override; void Visit(ConditionalStatement& node) override; + void Visit(DeclareExternalStatement& node) override; void Visit(DeclareFunctionStatement& node) override; void Visit(DeclareStructStatement& node) override; void Visit(DeclareVariableStatement& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 625875470..c6edb8041 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -15,6 +15,11 @@ namespace Nz::ShaderBuilder { namespace Impl { + struct Assign + { + inline std::unique_ptr operator()(ShaderAst::AssignType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const; + }; + struct Binary { inline std::unique_ptr operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const; @@ -26,6 +31,11 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(std::vector condStatements, ShaderAst::StatementPtr elseStatement = nullptr) const; }; + struct Cast + { + inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, std::vector expressions) const; + }; + struct Constant { inline std::unique_ptr operator()(ShaderConstantValue value) const; @@ -33,13 +43,24 @@ namespace Nz::ShaderBuilder struct DeclareFunction { - inline std::unique_ptr operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void) const; - inline std::unique_ptr operator()(std::vector attributes, std::string name, std::vector parameters, std::vector statements, ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void) const; + inline std::unique_ptr operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const; + inline std::unique_ptr operator()(std::vector attributes, std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const; + }; + + struct DeclareStruct + { + inline std::unique_ptr operator()(ShaderAst::StructDescription description) const; + inline std::unique_ptr operator()(std::vector attributes, ShaderAst::StructDescription description) const; }; struct DeclareVariable { - inline std::unique_ptr operator()(std::string name, ShaderAst::ShaderExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const; + inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const; + }; + + struct ExpressionStatement + { + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression) const; }; struct Identifier @@ -47,9 +68,9 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(std::string name) const; }; - struct Return + struct Intrinsic { - inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expr = nullptr) const; + inline std::unique_ptr operator()(ShaderAst::IntrinsicType intrinsicType, std::vector parameters) const; }; template @@ -57,15 +78,25 @@ namespace Nz::ShaderBuilder { std::unique_ptr operator()() const; }; + + struct Return + { + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expr = nullptr) const; + }; } + constexpr Impl::Assign Assign; constexpr Impl::Binary Binary; constexpr Impl::Branch Branch; + constexpr Impl::Cast Cast; constexpr Impl::Constant Constant; constexpr Impl::DeclareFunction DeclareFunction; + constexpr Impl::DeclareStruct DeclareStruct; constexpr Impl::DeclareVariable DeclareVariable; + constexpr Impl::ExpressionStatement ExpressionStatement; constexpr Impl::NoParam Discard; constexpr Impl::Identifier Identifier; + constexpr Impl::Intrinsic Intrinsic; constexpr Impl::NoParam NoOp; constexpr Impl::Return Return; } diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 31fb4dcf9..feff95b8a 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -7,14 +7,24 @@ namespace Nz::ShaderBuilder { + inline std::unique_ptr Impl::Assign::operator()(ShaderAst::AssignType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const + { + auto assignNode = std::make_unique(); + assignNode->op = op; + assignNode->left = std::move(left); + assignNode->right = std::move(right); + + return assignNode; + } + inline std::unique_ptr Impl::Binary::operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const { - auto constantNode = std::make_unique(); - constantNode->op = op; - constantNode->left = std::move(left); - constantNode->right = std::move(right); + auto binaryNode = std::make_unique(); + binaryNode->op = op; + binaryNode->left = std::move(left); + binaryNode->right = std::move(right); - return constantNode; + return binaryNode; } inline std::unique_ptr Impl::Branch::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath) const @@ -39,6 +49,18 @@ namespace Nz::ShaderBuilder return branchNode; } + inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::vector expressions) const + { + auto castNode = std::make_unique(); + castNode->targetType = std::move(targetType); + + assert(expressions.size() <= castNode->expressions.size()); + for (std::size_t i = 0; i < expressions.size(); ++i) + castNode->expressions[i] = std::move(expressions[i]); + + return castNode; + } + inline std::unique_ptr Impl::Constant::operator()(ShaderConstantValue value) const { auto constantNode = std::make_unique(); @@ -47,7 +69,7 @@ namespace Nz::ShaderBuilder return constantNode; } - inline std::unique_ptr Impl::DeclareFunction::operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ShaderExpressionType returnType) const + inline std::unique_ptr Impl::DeclareFunction::operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionType returnType) const { auto declareFunctionNode = std::make_unique(); declareFunctionNode->name = std::move(name); @@ -58,7 +80,7 @@ namespace Nz::ShaderBuilder return declareFunctionNode; } - inline std::unique_ptr Impl::DeclareFunction::operator()(std::vector attributes, std::string name, std::vector parameters, std::vector statements, ShaderAst::ShaderExpressionType returnType) const + inline std::unique_ptr Impl::DeclareFunction::operator()(std::vector attributes, std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionType returnType) const { auto declareFunctionNode = std::make_unique(); declareFunctionNode->attributes = std::move(attributes); @@ -70,7 +92,24 @@ namespace Nz::ShaderBuilder return declareFunctionNode; } - inline std::unique_ptr Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ShaderExpressionType type, ShaderAst::ExpressionPtr initialValue) const + inline std::unique_ptr Impl::DeclareStruct::operator()(ShaderAst::StructDescription description) const + { + auto declareStructNode = std::make_unique(); + declareStructNode->description = std::move(description); + + return declareStructNode; + } + + inline std::unique_ptr Impl::DeclareStruct::operator()(std::vector attributes, ShaderAst::StructDescription description) const + { + auto declareStructNode = std::make_unique(); + declareStructNode->attributes = std::move(attributes); + declareStructNode->description = std::move(description); + + return declareStructNode; + } + + inline std::unique_ptr Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const { auto declareVariableNode = std::make_unique(); declareVariableNode->varName = std::move(name); @@ -80,6 +119,14 @@ namespace Nz::ShaderBuilder return declareVariableNode; } + inline std::unique_ptr Impl::ExpressionStatement::operator()(ShaderAst::ExpressionPtr expression) const + { + auto expressionStatementNode = std::make_unique(); + expressionStatementNode->expression = std::move(expression); + + return expressionStatementNode; + } + inline std::unique_ptr Impl::Identifier::operator()(std::string name) const { auto identifierNode = std::make_unique(); @@ -88,6 +135,15 @@ namespace Nz::ShaderBuilder return identifierNode; } + inline std::unique_ptr Impl::Intrinsic::operator()(ShaderAst::IntrinsicType intrinsicType, std::vector parameters) const + { + auto intrinsicExpression = std::make_unique(); + intrinsicExpression->intrinsic = intrinsicType; + intrinsicExpression->parameters = std::move(parameters); + + return intrinsicExpression; + } + inline std::unique_ptr Impl::Return::operator()(ShaderAst::ExpressionPtr expr) const { auto returnNode = std::make_unique(); diff --git a/include/Nazara/Shader/ShaderEnums.hpp b/include/Nazara/Shader/ShaderEnums.hpp index 5c2d3136d..e56e71093 100644 --- a/include/Nazara/Shader/ShaderEnums.hpp +++ b/include/Nazara/Shader/ShaderEnums.hpp @@ -18,28 +18,19 @@ namespace Nz::ShaderAst enum class AttributeType { - Entry, //< Entry point (function only) - has argument type - Layout //< Struct layout (struct only) - has argument style + Binding, //< Binding (external var only) - has argument index + Builtin, //< Builtin (struct member only) - has argument type + Entry, //< Entry point (function only) - has argument type + Layout, //< Struct layout (struct only) - has argument style + Location //< Location (struct member only) - has argument index }; - enum class BasicType + enum class PrimitiveType { - Boolean, //< bool - Float1, //< float - Float2, //< vec2 - Float3, //< vec3 - Float4, //< vec4 - Int1, //< int - Int2, //< ivec2 - Int3, //< ivec3 - Int4, //< ivec4 - Mat4x4, //< mat4 - Sampler2D, //< sampler2D - Void, //< void - UInt1, //< uint - UInt2, //< uvec2 - UInt3, //< uvec3 - UInt4 //< uvec4 + Boolean, //< bool + Float32, //< f32 + Int32, //< i32 + UInt32, //< ui32 }; enum class BinaryType @@ -71,7 +62,8 @@ namespace Nz::ShaderAst enum class IntrinsicType { CrossProduct, - DotProduct + DotProduct, + SampleTexture }; enum class MemoryLayout @@ -107,9 +99,6 @@ namespace Nz::ShaderAst ParameterVariable, UniformVariable }; - - inline std::size_t GetComponentCount(BasicType type); - inline BasicType GetComponentType(BasicType type); } #include diff --git a/include/Nazara/Shader/ShaderEnums.inl b/include/Nazara/Shader/ShaderEnums.inl index fbd01ad49..a24138a6e 100644 --- a/include/Nazara/Shader/ShaderEnums.inl +++ b/include/Nazara/Shader/ShaderEnums.inl @@ -7,51 +7,6 @@ namespace Nz::ShaderAst { - inline std::size_t GetComponentCount(BasicType type) - { - switch (type) - { - case BasicType::Float2: - case BasicType::Int2: - return 2; - - case BasicType::Float3: - case BasicType::Int3: - return 3; - - case BasicType::Float4: - case BasicType::Int4: - return 4; - - case BasicType::Mat4x4: - return 4; - - default: - return 1; - } - } - - inline BasicType GetComponentType(BasicType type) - { - switch (type) - { - case BasicType::Float2: - case BasicType::Float3: - case BasicType::Float4: - return BasicType::Float1; - - case BasicType::Int2: - case BasicType::Int3: - case BasicType::Int4: - return BasicType::Int1; - - case BasicType::Mat4x4: - return BasicType::Float4; - - default: - return type; - } - } } #include diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index ba727b73a..3c61ec888 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -19,6 +19,12 @@ namespace Nz::ShaderLang public: using exception::exception; }; + + class DuplicateIdentifier : public std::exception + { + public: + using exception::exception; + }; class ReservedKeyword : public std::exception { @@ -56,17 +62,24 @@ namespace Nz::ShaderLang // Flow control const Token& Advance(); void Consume(std::size_t count = 1); + ShaderAst::ExpressionType DecodeType(const std::string& identifier); + void EnterScope(); const Token& Expect(const Token& token, TokenType type); const Token& ExpectNot(const Token& token, TokenType type); const Token& Expect(TokenType type); + void LeaveScope(); + bool IsVariableInScope(const std::string_view& identifier) const; + void RegisterVariable(std::string identifier); const Token& Peek(std::size_t advance = 0); - void HandleAttributes(); + std::vector ParseAttributes(); // Statements + ShaderAst::StatementPtr ParseExternalBlock(std::vector attributes = {}); std::vector ParseFunctionBody(); ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector attributes = {}); ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter(); + ShaderAst::StatementPtr ParseStructDeclaration(std::vector attributes = {}); ShaderAst::StatementPtr ParseReturnStatement(); ShaderAst::StatementPtr ParseStatement(); std::vector ParseStatementList(); @@ -75,22 +88,28 @@ namespace Nz::ShaderLang // Expressions ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs); ShaderAst::ExpressionPtr ParseExpression(); + ShaderAst::ExpressionPtr ParseFloatingPointExpression(bool minus = false); ShaderAst::ExpressionPtr ParseIdentifier(); - ShaderAst::ExpressionPtr ParseIntegerExpression(); + ShaderAst::ExpressionPtr ParseIntegerExpression(bool minus = false); + std::vector ParseParameters(); ShaderAst::ExpressionPtr ParseParenthesisExpression(); ShaderAst::ExpressionPtr ParsePrimaryExpression(); + ShaderAst::ExpressionPtr ParseVariableAssignation(); ShaderAst::AttributeType ParseIdentifierAsAttributeType(); const std::string& ParseIdentifierAsName(); - ShaderAst::ShaderExpressionType ParseIdentifierAsType(); + ShaderAst::PrimitiveType ParsePrimitiveType(); + ShaderAst::ExpressionType ParseType(); static int GetTokenPrecedence(TokenType token); struct Context { - std::unique_ptr root; std::size_t tokenCount; std::size_t tokenIndex = 0; + std::vector scopeSizes; + std::vector identifiersInScope; + std::unique_ptr root; const Token* tokens; }; diff --git a/include/Nazara/Shader/ShaderLangTokenList.hpp b/include/Nazara/Shader/ShaderLangTokenList.hpp index 959c5df31..9b818cd0b 100644 --- a/include/Nazara/Shader/ShaderLangTokenList.hpp +++ b/include/Nazara/Shader/ShaderLangTokenList.hpp @@ -21,15 +21,22 @@ NAZARA_SHADERLANG_TOKEN(Colon) NAZARA_SHADERLANG_TOKEN(Comma) NAZARA_SHADERLANG_TOKEN(Divide) NAZARA_SHADERLANG_TOKEN(Dot) +NAZARA_SHADERLANG_TOKEN(Equal) +NAZARA_SHADERLANG_TOKEN(External) NAZARA_SHADERLANG_TOKEN(FloatingPointValue) NAZARA_SHADERLANG_TOKEN(EndOfStream) NAZARA_SHADERLANG_TOKEN(FunctionDeclaration) NAZARA_SHADERLANG_TOKEN(FunctionReturn) +NAZARA_SHADERLANG_TOKEN(GreatherThan) +NAZARA_SHADERLANG_TOKEN(GreatherThanEqual) NAZARA_SHADERLANG_TOKEN(IntegerValue) NAZARA_SHADERLANG_TOKEN(Identifier) +NAZARA_SHADERLANG_TOKEN(LessThan) +NAZARA_SHADERLANG_TOKEN(LessThanEqual) NAZARA_SHADERLANG_TOKEN(Let) NAZARA_SHADERLANG_TOKEN(Multiply) NAZARA_SHADERLANG_TOKEN(Minus) +NAZARA_SHADERLANG_TOKEN(NotEqual) NAZARA_SHADERLANG_TOKEN(Plus) NAZARA_SHADERLANG_TOKEN(OpenAttribute) NAZARA_SHADERLANG_TOKEN(OpenCurlyBracket) @@ -37,6 +44,7 @@ NAZARA_SHADERLANG_TOKEN(OpenSquareBracket) NAZARA_SHADERLANG_TOKEN(OpenParenthesis) NAZARA_SHADERLANG_TOKEN(Semicolon) NAZARA_SHADERLANG_TOKEN(Return) +NAZARA_SHADERLANG_TOKEN(Struct) #undef NAZARA_SHADERLANG_TOKEN #undef NAZARA_SHADERLANG_TOKEN_LAST diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index 386f53168..bfd0ebb47 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -12,9 +12,10 @@ #include #include #include -#include #include #include +#include +#include #include #include #include @@ -25,12 +26,6 @@ namespace Nz::ShaderAst class AstExpressionVisitor; class AstStatementVisitor; - struct Attribute - { - AttributeType type; - std::string args; - }; - struct NAZARA_SHADER_API Node { Node() = default; @@ -97,7 +92,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; - BasicType targetType; + ExpressionType targetType; std::array expressions; }; @@ -189,6 +184,22 @@ namespace Nz::ShaderAst StatementPtr statement; }; + struct NAZARA_SHADER_API DeclareExternalStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + struct ExternalVar + { + std::vector attributes; + std::string name; + ExpressionType type; + }; + + std::vector attributes; + std::vector externalVars; + }; + struct NAZARA_SHADER_API DeclareFunctionStatement : Statement { NodeType GetType() const override; @@ -197,14 +208,14 @@ namespace Nz::ShaderAst struct Parameter { std::string name; - ShaderExpressionType type; + ExpressionType type; }; std::string name; std::vector attributes; std::vector parameters; std::vector statements; - ShaderExpressionType returnType = BasicType::Void; + ExpressionType returnType; }; struct NAZARA_SHADER_API DeclareStructStatement : Statement @@ -212,6 +223,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; + std::vector attributes; StructDescription description; }; @@ -222,7 +234,7 @@ namespace Nz::ShaderAst std::string varName; ExpressionPtr initialExpression; - ShaderExpressionType varType; + ExpressionType varType; }; struct NAZARA_SHADER_API DiscardStatement : Statement diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index 7cc829518..5453e9f46 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include @@ -172,11 +172,16 @@ namespace Nz SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept; static ConstantPtr BuildConstant(const ShaderConstantValue& value); - static TypePtr BuildFunctionType(const ShaderAst::ShaderExpressionType& retType, const std::vector& parameters); - static TypePtr BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass); - static TypePtr BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass); - static TypePtr BuildType(const ShaderAst::BasicType& type); - static TypePtr BuildType(const ShaderAst::ShaderExpressionType& type); + static TypePtr BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters); + static TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass); + static TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass); + static TypePtr BuildType(const ShaderAst::ExpressionType& type); + static TypePtr BuildType(const ShaderAst::IdentifierType& type); + static TypePtr BuildType(const ShaderAst::MatrixType& type); + static TypePtr BuildType(const ShaderAst::NoType& type); + static TypePtr BuildType(const ShaderAst::PrimitiveType& type); + static TypePtr BuildType(const ShaderAst::SamplerType& type); + static TypePtr BuildType(const ShaderAst::VectorType& type); private: struct DepRegisterer; diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index d0c5f561d..17ac7d023 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -62,8 +62,8 @@ namespace Nz const ExtVar& GetInputVariable(const std::string& name) const; const ExtVar& GetOutputVariable(const std::string& name) const; const ExtVar& GetUniformVariable(const std::string& name) const; - UInt32 GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const; - UInt32 GetTypeId(const ShaderAst::ShaderExpressionType& type) const; + UInt32 GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const; + UInt32 GetTypeId(const ShaderAst::ExpressionType& type) const; inline bool IsConditionEnabled(const std::string& condition) const; @@ -80,8 +80,8 @@ namespace Nz UInt32 RegisterConstant(const ShaderConstantValue& value); UInt32 RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode); - UInt32 RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass); - UInt32 RegisterType(ShaderAst::ShaderExpressionType type); + UInt32 RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass); + UInt32 RegisterType(ShaderAst::ExpressionType type); void WriteLocalVariable(std::string name, UInt32 resultId); @@ -106,7 +106,7 @@ namespace Nz struct FunctionParameter { std::string name; - ShaderAst::ShaderExpressionType type; + ShaderAst::ExpressionType type; }; struct State; diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index bb9ce8433..8580fd5b0 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -20,37 +20,43 @@ namespace Nz namespace { static const char* flipYUniformName = "_NzFlipValue"; + static const char* overridenMain = "_NzMain"; struct AstAdapter : ShaderAst::AstCloner { - void Visit(ShaderAst::AssignExpression& node) override + using AstCloner::Clone; + + std::unique_ptr Clone(ShaderAst::DeclareFunctionStatement& node) override { - if (!flipYPosition) - return AstCloner::Visit(node); + auto clone = AstCloner::Clone(node); + if (clone->name == "main") + clone->name = "_NzMain"; - if (node.left->GetType() != ShaderAst::NodeType::IdentifierExpression) - return AstCloner::Visit(node); - - /* - FIXME: - const auto& identifier = static_cast(*node.left); - if (identifier.var->GetType() != ShaderAst::VariableType::BuiltinVariable) - return ShaderAstCloner::Visit(node); - - const auto& builtinVar = static_cast(*identifier.var); - if (builtinVar.entry != ShaderAst::BuiltinEntry::VertexPosition) - return ShaderAstCloner::Visit(node); - - auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderAst::BasicType::Float1); - - auto oneConstant = ShaderBuilder::Constant(1.f); - auto fixYValue = ShaderBuilder::Cast(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant); - auto mulFix = ShaderBuilder::Multiply(CloneExpression(node.right), fixYValue); - - PushExpression(ShaderAst::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));*/ + return clone; } - bool flipYPosition = false; + void Visit(ShaderAst::DeclareFunctionStatement& node) + { + if (removedEntryPoints.find(&node) != removedEntryPoints.end()) + { + PushStatement(ShaderBuilder::NoOp()); + return; + } + + AstCloner::Visit(node); + } + + std::unordered_set removedEntryPoints; + }; + + struct Builtin + { + std::string identifier; + ShaderStageTypeFlags stageFlags; + }; + + std::unordered_map builtinMapping = { + { "position", { "gl_Position", ShaderStageType::Vertex } } }; } @@ -59,6 +65,7 @@ namespace Nz { const States* states = nullptr; ShaderAst::AstCache cache; + ShaderAst::DeclareFunctionStatement* entryFunc = nullptr; std::stringstream stream; unsigned int indentLevel = 0; }; @@ -69,19 +76,8 @@ namespace Nz { } - std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions) + std::string GlslWriter::Generate(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader, const States& conditions) { - /*const ShaderAst* selectedShader = &inputShader; - std::optional modifiedShader; - if (inputShader.GetStage() == ShaderStageType::Vertex && m_environment.flipYPosition) - { - modifiedShader.emplace(inputShader); - - modifiedShader->AddUniform(flipYUniformName, ShaderAst::BasicType::Float1); - - selectedShader = &modifiedShader.value(); - }*/ - State state; m_currentState = &state; CallOnExit onExit([this]() @@ -93,6 +89,27 @@ namespace Nz if (!ShaderAst::ValidateAst(shader, &error, &state.cache)) throw std::runtime_error("Invalid shader AST: " + error); + state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)]; + if (!state.entryFunc) + throw std::runtime_error("missing entry point"); + + AstAdapter adapter; + + for (ShaderAst::DeclareFunctionStatement* entryFunc : state.cache.entryFunctions) + { + if (entryFunc != state.entryFunc) + adapter.removedEntryPoints.insert(entryFunc); + } + + ShaderAst::StatementPtr adaptedShader = adapter.Clone(shader); + + state.cache.Clear(); + if (!ShaderAst::ValidateAst(adaptedShader, &error, &state.cache)) + throw std::runtime_error("Internal error:" + error); + + state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)]; + assert(state.entryFunc); + unsigned int glslVersion; if (m_environment.glES) { @@ -141,14 +158,14 @@ namespace Nz if (!m_environment.glES && m_environment.extCallback) { // GL_ARB_shading_language_420pack (required for layout(binding = X)) - if (glslVersion < 420 && HasExplicitBinding(shader)) + if (glslVersion < 420 && HasExplicitBinding(adaptedShader)) { if (m_environment.extCallback("GL_ARB_shading_language_420pack")) requiredExtensions.emplace_back("GL_ARB_shading_language_420pack"); } // GL_ARB_separate_shader_objects (required for layout(location = X)) - if (glslVersion < 410 && HasExplicitLocation(shader)) + if (glslVersion < 410 && HasExplicitLocation(adaptedShader)) { if (m_environment.extCallback("GL_ARB_separate_shader_objects")) requiredExtensions.emplace_back("GL_ARB_separate_shader_objects"); @@ -173,7 +190,10 @@ namespace Nz AppendLine(); } - shader->Visit(*this); + adaptedShader->Visit(*this); + + // Append true GLSL entry point + AppendEntryPoint(shaderStage); return state.stream.str(); } @@ -188,7 +208,7 @@ namespace Nz return flipYUniformName; } - void GlslWriter::Append(ShaderAst::ShaderExpressionType type) + void GlslWriter::Append(const ShaderAst::ExpressionType& type) { std::visit([&](auto&& arg) { @@ -206,29 +226,82 @@ namespace Nz } } - void GlslWriter::Append(ShaderAst::BasicType type) + void GlslWriter::Append(const ShaderAst::IdentifierType& identifierType) + { + Append(identifierType.name); + } + + void GlslWriter::Append(const ShaderAst::MatrixType& matrixType) + { + if (matrixType.columnCount == matrixType.rowCount) + { + Append("mat"); + Append(matrixType.columnCount); + } + else + { + Append("mat"); + Append(matrixType.columnCount); + Append("x"); + Append(matrixType.rowCount); + } + } + + void GlslWriter::Append(ShaderAst::PrimitiveType type) { switch (type) { - case ShaderAst::BasicType::Boolean: return Append("bool"); - case ShaderAst::BasicType::Float1: return Append("float"); - case ShaderAst::BasicType::Float2: return Append("vec2"); - case ShaderAst::BasicType::Float3: return Append("vec3"); - case ShaderAst::BasicType::Float4: return Append("vec4"); - case ShaderAst::BasicType::Int1: return Append("int"); - case ShaderAst::BasicType::Int2: return Append("ivec2"); - case ShaderAst::BasicType::Int3: return Append("ivec3"); - case ShaderAst::BasicType::Int4: return Append("ivec4"); - case ShaderAst::BasicType::Mat4x4: return Append("mat4"); - case ShaderAst::BasicType::Sampler2D: return Append("sampler2D"); - case ShaderAst::BasicType::UInt1: return Append("uint"); - case ShaderAst::BasicType::UInt2: return Append("uvec2"); - case ShaderAst::BasicType::UInt3: return Append("uvec3"); - case ShaderAst::BasicType::UInt4: return Append("uvec4"); - case ShaderAst::BasicType::Void: return Append("void"); + case ShaderAst::PrimitiveType::Boolean: return Append("bool"); + case ShaderAst::PrimitiveType::Float32: return Append("float"); + case ShaderAst::PrimitiveType::Int32: return Append("ivec2"); + case ShaderAst::PrimitiveType::UInt32: return Append("uint"); } } + void GlslWriter::Append(const ShaderAst::SamplerType& samplerType) + { + switch (samplerType.sampledType) + { + case ShaderAst::PrimitiveType::Boolean: + case ShaderAst::PrimitiveType::Float32: + break; + + case ShaderAst::PrimitiveType::Int32: Append("i"); break; + case ShaderAst::PrimitiveType::UInt32: Append("u"); break; + } + + Append("sampler"); + + switch (samplerType.dim) + { + case ImageType_1D: Append("1D"); break; + case ImageType_1D_Array: Append("1DArray"); break; + case ImageType_2D: Append("2D"); break; + case ImageType_2D_Array: Append("2DArray"); break; + case ImageType_3D: Append("3D"); break; + case ImageType_Cubemap: Append("Cube"); break; + } + } + + void GlslWriter::Append(const ShaderAst::UniformType& uniformType) + { + /* TODO */ + } + + void GlslWriter::Append(const ShaderAst::VectorType& vecType) + { + switch (vecType.type) + { + case ShaderAst::PrimitiveType::Boolean: Append("b"); break; + case ShaderAst::PrimitiveType::Float32: break; + case ShaderAst::PrimitiveType::Int32: Append("i"); break; + case ShaderAst::PrimitiveType::UInt32: Append("u"); break; + } + + Append("vec"); + Append(vecType.componentCount); + } + void GlslWriter::Append(ShaderAst::MemoryLayout layout) { switch (layout) @@ -239,6 +312,11 @@ namespace Nz } } + void GlslWriter::Append(ShaderAst::NoType) + { + return Append("void"); + } + template void GlslWriter::Append(const T& param) { @@ -246,6 +324,12 @@ namespace Nz m_currentState->stream << param; } + template + void GlslWriter::Append(const T1& firstParam, const T2& secondParam, Args&&... params) + { + Append(firstParam); + Append(secondParam, std::forward(params)...); + } void GlslWriter::AppendCommentSection(const std::string& section) { @@ -256,6 +340,152 @@ namespace Nz AppendLine(); } + void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage) + { + AppendLine(); + AppendLine("// Entry point handling"); + + struct InOutField + { + std::string name; + std::string targetName; + }; + + std::vector inputFields; + const ShaderAst::StructDescription* inputStruct = nullptr; + + auto HandleInOutStructs = [this, shaderStage](const ShaderAst::ExpressionType& expressionType, std::vector& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription* + { + assert(IsIdentifierType(expressionType)); + + const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get(expressionType).name); + assert(identifier); + + assert(std::holds_alternative(identifier->value)); + const auto& s = std::get(identifier->value); + + for (const auto& member : s.members) + { + bool skip = false; + std::optional builtinName; + std::optional attributeLocation; + for (const auto& [attributeType, attributeParam] : member.attributes) + { + if (attributeType == ShaderAst::AttributeType::Builtin) + { + auto it = builtinMapping.find(std::get(attributeParam)); + if (it != builtinMapping.end()) + { + const Builtin& builtin = it->second; + if (!builtin.stageFlags.Test(shaderStage)) + { + skip = true; + break; + } + + builtinName = builtin.identifier; + break; + } + } + else if (attributeType == ShaderAst::AttributeType::Location) + { + attributeLocation = std::get(attributeParam); + break; + } + } + + if (!skip && attributeLocation) + { + Append("layout(location = "); + Append(*attributeLocation); + Append(") "); + Append(keyword); + Append(" "); + Append(member.type); + Append(" "); + Append(targetPrefix); + Append(member.name); + AppendLine(";"); + + fields.push_back({ + fromPrefix + member.name, + targetPrefix + member.name + }); + } + else if (builtinName) + { + fields.push_back({ + fromPrefix + member.name, + *builtinName + }); + } + } + AppendLine(); + + return &s; + }; + + if (!m_currentState->entryFunc->parameters.empty()) + { + assert(m_currentState->entryFunc->parameters.size() == 1); + const auto& parameter = m_currentState->entryFunc->parameters.front(); + + inputStruct = HandleInOutStructs(parameter.type, inputFields, "in", "_nzInput.", "_NzIn_"); + } + + std::vector outputFields; + const ShaderAst::StructDescription* outputStruct = nullptr; + if (!IsNoType(m_currentState->entryFunc->returnType)) + outputStruct = HandleInOutStructs(m_currentState->entryFunc->returnType, outputFields, "out", "_nzOutput.", "_NzOut_"); + + if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition) + AppendLine("uniform float ", flipYUniformName, ";"); + + AppendLine("void main()"); + EnterScope(); + { + if (inputStruct) + { + Append(inputStruct->name); + AppendLine(" _nzInput;"); + for (const auto& [name, targetName] : inputFields) + { + AppendLine(name, " = ", targetName, ";"); + } + AppendLine(); + } + + if (outputStruct) + Append(outputStruct->name, " _nzOutput = "); + + Append(m_currentState->entryFunc->name); + + Append("("); + if (m_currentState->entryFunc) + Append("_nzInput"); + Append(");"); + + if (outputStruct) + { + AppendLine(); + + for (const auto& [name, targetName] : outputFields) + { + bool isOutputPosition = (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition && targetName == "gl_Position"); + + AppendLine(); + + Append(targetName, " = ", name); + if (isOutputPosition) + Append(" * vec4(1.0, ", flipYUniformName, ", 1.0, 1.0)"); + + Append(";"); + } + } + } + LeaveScope(); + } + void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) { Append("."); @@ -273,7 +503,7 @@ namespace Nz const auto& member = *memberIt; if (remainingMembers > 1) - AppendField(scopeId, std::get(member.type), memberIdentifier + 1, remainingMembers - 1); + AppendField(scopeId, std::get(member.type).name, memberIdentifier + 1, remainingMembers - 1); } void GlslWriter::AppendLine(const std::string& txt) @@ -283,6 +513,13 @@ namespace Nz m_currentState->stream << txt << '\n' << std::string(m_currentState->indentLevel, '\t'); } + template + void GlslWriter::AppendLine(Args&&... params) + { + (Append(std::forward(params)), ...); + AppendLine(); + } + void GlslWriter::EnterScope() { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); @@ -291,13 +528,17 @@ namespace Nz AppendLine("{"); } - void GlslWriter::LeaveScope() + void GlslWriter::LeaveScope(bool skipLine) { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); m_currentState->indentLevel--; AppendLine(); - AppendLine("}"); + + if (skipLine) + AppendLine("}"); + else + Append("}"); } void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired) @@ -317,12 +558,12 @@ namespace Nz { Visit(node.structExpr, true); - const ShaderAst::ShaderExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache); - assert(IsStructType(exprType)); + const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache); + assert(IsIdentifierType(exprType)); std::size_t scopeId = m_currentState->cache.GetScopeId(&node); - AppendField(scopeId, std::get(exprType), node.memberIdentifiers.data(), node.memberIdentifiers.size()); + AppendField(scopeId, std::get(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size()); } void GlslWriter::Visit(ShaderAst::AssignExpression& node) @@ -336,7 +577,7 @@ namespace Nz break; } - node.left->Visit(*this); + node.right->Visit(*this); } void GlslWriter::Visit(ShaderAst::BranchStatement& node) @@ -455,6 +696,71 @@ namespace Nz }, node.value); } + void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node) + { + for (const auto& externalVar : node.externalVars) + { + std::optional bindingIndex; + bool isStd140 = false; + for (const auto& [attributeType, attributeParam] : externalVar.attributes) + { + if (attributeType == ShaderAst::AttributeType::Binding) + bindingIndex = std::get(attributeParam); + else if (attributeType == ShaderAst::AttributeType::Layout) + { + if (std::get(attributeParam) == "std140") + isStd140 = true; + } + } + + if (bindingIndex) + { + Append("layout(binding = "); + Append(*bindingIndex); + if (isStd140) + Append(", std140"); + + Append(") uniform "); + + if (IsUniformType(externalVar.type)) + { + Append("_NzBinding_"); + AppendLine(externalVar.name); + + EnterScope(); + { + const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get(externalVar.type).containedType.name); + assert(identifier); + + assert(std::holds_alternative(identifier->value)); + const auto& s = std::get(identifier->value); + + bool first = true; + for (const auto& [name, attribute, type] : s.members) + { + if (!first) + AppendLine(); + + first = false; + + Append(type); + Append(" "); + Append(name); + Append(";"); + } + } + LeaveScope(false); + } + else + Append(externalVar.type); + + Append(" "); + Append(externalVar.name); + AppendLine(";"); + } + } + } + void GlslWriter::Visit(ShaderAst::DeclareFunctionStatement& node) { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); @@ -475,15 +781,36 @@ namespace Nz EnterScope(); { - AstAdapter adapter; - adapter.flipYPosition = m_environment.flipYPosition; - for (auto& statement : node.statements) - adapter.Clone(statement)->Visit(*this); + statement->Visit(*this); } LeaveScope(); } + void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node) + { + Append("struct "); + AppendLine(node.description.name); + EnterScope(); + { + bool first = true; + for (const auto& [name, attribute, type] : node.description.members) + { + if (!first) + AppendLine(); + + first = false; + + Append(type); + Append(" "); + Append(name); + Append(";"); + } + } + LeaveScope(false); + AppendLine(";"); + } + void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node) { Append(node.varType); @@ -506,7 +833,7 @@ namespace Nz void GlslWriter::Visit(ShaderAst::ExpressionStatement& node) { node.expression->Visit(*this); - Append(";"); + AppendLine(";"); } void GlslWriter::Visit(ShaderAst::IdentifierExpression& node) @@ -525,6 +852,10 @@ namespace Nz case ShaderAst::IntrinsicType::DotProduct: Append("dot"); break; + + case ShaderAst::IntrinsicType::SampleTexture: + Append("texture"); + break; } Append("("); @@ -624,4 +955,5 @@ namespace Nz return false; } + } diff --git a/src/Nazara/Shader/ShaderAstCloner.cpp b/src/Nazara/Shader/ShaderAstCloner.cpp index ea7ed0276..92eae1621 100644 --- a/src/Nazara/Shader/ShaderAstCloner.cpp +++ b/src/Nazara/Shader/ShaderAstCloner.cpp @@ -42,6 +42,21 @@ namespace Nz::ShaderAst return PopStatement(); } + std::unique_ptr AstCloner::Clone(DeclareFunctionStatement& node) + { + auto clone = std::make_unique(); + clone->attributes = node.attributes; + clone->name = node.name; + clone->parameters = node.parameters; + clone->returnType = node.returnType; + + clone->statements.reserve(node.statements.size()); + for (auto& statement : node.statements) + clone->statements.push_back(CloneStatement(statement)); + + return clone; + } + void AstCloner::Visit(AccessMemberExpression& node) { auto clone = std::make_unique(); @@ -162,21 +177,20 @@ namespace Nz::ShaderAst PushStatement(std::move(clone)); } - void AstCloner::Visit(DeclareFunctionStatement& node) + void AstCloner::Visit(DeclareExternalStatement& node) { - auto clone = std::make_unique(); + auto clone = std::make_unique(); clone->attributes = node.attributes; - clone->name = node.name; - clone->parameters = node.parameters; - clone->returnType = node.returnType; - - clone->statements.reserve(node.statements.size()); - for (auto& statement : node.statements) - clone->statements.push_back(CloneStatement(statement)); + clone->externalVars = node.externalVars; PushStatement(std::move(clone)); } + void AstCloner::Visit(DeclareFunctionStatement& node) + { + PushStatement(Clone(node)); + } + void AstCloner::Visit(DeclareStructStatement& node) { auto clone = std::make_unique(); diff --git a/src/Nazara/Shader/ShaderAstExpressionType.cpp b/src/Nazara/Shader/ShaderAstExpressionType.cpp index 71f831f69..bd87822f7 100644 --- a/src/Nazara/Shader/ShaderAstExpressionType.cpp +++ b/src/Nazara/Shader/ShaderAstExpressionType.cpp @@ -9,16 +9,16 @@ namespace Nz::ShaderAst { - ShaderExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache = nullptr) + ExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache) { m_cache = cache; - ShaderExpressionType type = GetExpressionTypeInternal(expression); + ExpressionType type = GetExpressionTypeInternal(expression); m_cache = nullptr; return type; } - ShaderExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression) + ExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression) { m_lastExpressionType.reset(); @@ -28,6 +28,33 @@ namespace Nz::ShaderAst return std::move(*m_lastExpressionType); } + ExpressionType ExpressionTypeVisitor::ResolveAlias(Expression& expression, ExpressionType expressionType) + { + if (IsIdentifierType(expressionType)) + { + auto scopeIt = m_cache->scopeIdByNode.find(&expression); + if (scopeIt == m_cache->scopeIdByNode.end()) + throw std::runtime_error("internal error"); + + const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get(expressionType).name); + if (identifier && std::holds_alternative(identifier->value)) + { + const AstCache::Alias& alias = std::get(identifier->value); + return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + return arg; + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, alias.value); + } + } + + return expressionType; + } + void ExpressionTypeVisitor::Visit(Expression& expression) { if (m_cache) @@ -51,6 +78,16 @@ namespace Nz::ShaderAst void ExpressionTypeVisitor::Visit(AccessMemberExpression& node) { + auto scopeIt = m_cache->scopeIdByNode.find(&node); + if (scopeIt == m_cache->scopeIdByNode.end()) + throw std::runtime_error("internal error"); + + ExpressionType expressionType = ResolveAlias(node, GetExpressionTypeInternal(*node.structExpr)); + if (!IsIdentifierType(expressionType)) + throw std::runtime_error("internal error"); + + const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get(expressionType).name); + throw std::runtime_error("unhandled accessmember expression"); } @@ -70,38 +107,35 @@ namespace Nz::ShaderAst case BinaryType::Divide: case BinaryType::Multiply: { - ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.left); - assert(IsBasicType(leftExprType)); + ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.left)); + ExpressionType rightExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.right)); - ShaderExpressionType rightExprType = GetExpressionTypeInternal(*node.right); - assert(IsBasicType(rightExprType)); - - switch (std::get(leftExprType)) + if (IsPrimitiveType(leftExprType)) { - case BasicType::Boolean: - case BasicType::Float2: - case BasicType::Float3: - case BasicType::Float4: - case BasicType::Int2: - case BasicType::Int3: - case BasicType::Int4: - case BasicType::UInt2: - case BasicType::UInt3: - case BasicType::UInt4: - m_lastExpressionType = std::move(leftExprType); - break; + switch (std::get(leftExprType)) + { + case PrimitiveType::Boolean: + m_lastExpressionType = std::move(leftExprType); + break; - case BasicType::Float1: - case BasicType::Int1: - case BasicType::Mat4x4: - case BasicType::UInt1: - m_lastExpressionType = std::move(rightExprType); - break; - - case BasicType::Sampler2D: - case BasicType::Void: - break; + case PrimitiveType::Float32: + case PrimitiveType::Int32: + case PrimitiveType::UInt32: + m_lastExpressionType = std::move(rightExprType); + break; + } } + else if (IsMatrixType(leftExprType)) + { + if (IsVectorType(rightExprType)) + m_lastExpressionType = std::move(rightExprType); + else + m_lastExpressionType = std::move(leftExprType); + } + else if (IsVectorType(leftExprType)) + m_lastExpressionType = std::move(leftExprType); + else + throw std::runtime_error("validation failure"); break; } @@ -112,7 +146,7 @@ namespace Nz::ShaderAst case BinaryType::CompLe: case BinaryType::CompLt: case BinaryType::CompNe: - m_lastExpressionType = BasicType::Boolean; + m_lastExpressionType = PrimitiveType::Boolean; break; } } @@ -124,38 +158,38 @@ namespace Nz::ShaderAst void ExpressionTypeVisitor::Visit(ConditionalExpression& node) { - ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.truePath); - assert(leftExprType == GetExpressionTypeInternal(*node.falsePath)); + ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.truePath)); + assert(leftExprType == ResolveAlias(node, GetExpressionTypeInternal(*node.falsePath))); m_lastExpressionType = std::move(leftExprType); } void ExpressionTypeVisitor::Visit(ConstantExpression& node) { - m_lastExpressionType = std::visit([&](auto&& arg) + m_lastExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType { using T = std::decay_t; if constexpr (std::is_same_v) - return BasicType::Boolean; + return PrimitiveType::Boolean; else if constexpr (std::is_same_v) - return BasicType::Float1; + return PrimitiveType::Float32; else if constexpr (std::is_same_v) - return BasicType::Int1; + return PrimitiveType::Int32; else if constexpr (std::is_same_v) - return BasicType::Int1; + return PrimitiveType::UInt32; else if constexpr (std::is_same_v) - return BasicType::Float2; + return VectorType{ 2, PrimitiveType::Float32 }; else if constexpr (std::is_same_v) - return BasicType::Float3; + return VectorType{ 3, PrimitiveType::Float32 }; else if constexpr (std::is_same_v) - return BasicType::Float4; + return VectorType{ 4, PrimitiveType::Float32 }; else if constexpr (std::is_same_v) - return BasicType::Int2; + return VectorType{ 2, PrimitiveType::Int32 }; else if constexpr (std::is_same_v) - return BasicType::Int3; + return VectorType{ 3, PrimitiveType::Int32 }; else if constexpr (std::is_same_v) - return BasicType::Int4; + return VectorType{ 4, PrimitiveType::Int32 }; else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, node.value); @@ -173,7 +207,7 @@ namespace Nz::ShaderAst if (!identifier || !std::holds_alternative(identifier->value)) throw std::runtime_error("internal error"); - m_lastExpressionType = std::get(identifier->value).type; + m_lastExpressionType = ResolveAlias(node, std::get(identifier->value).type); } void ExpressionTypeVisitor::Visit(IntrinsicExpression& node) @@ -185,16 +219,40 @@ namespace Nz::ShaderAst break; case IntrinsicType::DotProduct: - m_lastExpressionType = BasicType::Float1; + m_lastExpressionType = PrimitiveType::Float32; break; + + case IntrinsicType::SampleTexture: + { + if (node.parameters.empty()) + throw std::runtime_error("validation failure"); + + ExpressionType firstParamType = ResolveAlias(node, GetExpressionTypeInternal(*node.parameters.front())); + + if (!IsSamplerType(firstParamType)) + throw std::runtime_error("validation failure"); + + const auto& sampler = std::get(firstParamType); + + m_lastExpressionType = VectorType{ + 4, + sampler.sampledType + }; + + break; + } } } void ExpressionTypeVisitor::Visit(SwizzleExpression& node) { - ShaderExpressionType exprType = GetExpressionTypeInternal(*node.expression); - assert(IsBasicType(exprType)); + ExpressionType exprType = GetExpressionTypeInternal(*node.expression); - m_lastExpressionType = static_cast(UnderlyingCast(GetComponentType(std::get(exprType))) + node.componentCount - 1); + if (IsMatrixType(exprType)) + m_lastExpressionType = std::get(exprType).type; + else if (IsVectorType(exprType)) + m_lastExpressionType = std::get(exprType).type; + else + throw std::runtime_error("validation failure"); } } diff --git a/src/Nazara/Shader/ShaderAstOptimizer.cpp b/src/Nazara/Shader/ShaderAstOptimizer.cpp index 33c8d3f76..c83ae1412 100644 --- a/src/Nazara/Shader/ShaderAstOptimizer.cpp +++ b/src/Nazara/Shader/ShaderAstOptimizer.cpp @@ -453,8 +453,8 @@ namespace Nz::ShaderAst { auto& constant = static_cast(*cond); - assert(IsBasicType(GetExpressionType(constant))); - assert(std::get(GetExpressionType(constant)) == BasicType::Boolean); + assert(IsPrimitiveType(GetExpressionType(constant))); + assert(std::get(GetExpressionType(constant)) == PrimitiveType::Boolean); bool cValue = std::get(constant.value); if (!cValue) diff --git a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp index c0ef0d920..e0a87ad51 100644 --- a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp @@ -79,6 +79,11 @@ namespace Nz::ShaderAst node.statement->Visit(*this); } + void AstRecursiveVisitor::Visit(DeclareExternalStatement& node) + { + /* Nothing to do */ + } + void AstRecursiveVisitor::Visit(DeclareFunctionStatement& node) { for (auto& statement : node.statements) diff --git a/src/Nazara/Shader/ShaderAstSerializer.cpp b/src/Nazara/Shader/ShaderAstSerializer.cpp index b078a53f8..20b0b4b63 100644 --- a/src/Nazara/Shader/ShaderAstSerializer.cpp +++ b/src/Nazara/Shader/ShaderAstSerializer.cpp @@ -58,7 +58,7 @@ namespace Nz::ShaderAst void AstSerializerBase::Serialize(CastExpression& node) { - Enum(node.targetType); + Type(node.targetType); for (auto& expr : node.expressions) Node(expr); } @@ -152,17 +152,25 @@ namespace Nz::ShaderAst Node(node.statement); } + void AstSerializerBase::Serialize(DeclareExternalStatement& node) + { + Attributes(node.attributes); + + Container(node.externalVars); + for (auto& extVar : node.externalVars) + { + Attributes(extVar.attributes); + Value(extVar.name); + Type(extVar.type); + } + } + void AstSerializerBase::Serialize(DeclareFunctionStatement& node) { Value(node.name); Type(node.returnType); - Container(node.attributes); - for (auto& attribute : node.attributes) - { - Enum(attribute.type); - Value(attribute.args); - } + Attributes(node.attributes); Container(node.parameters); for (auto& parameter : node.parameters) @@ -223,6 +231,78 @@ namespace Nz::ShaderAst m_stream.FlushBits(); } + + void AstSerializerBase::Attributes(std::vector& attributes) + { + Container(attributes); + for (auto& attribute : attributes) + { + Enum(attribute.type); + + if (IsWriting()) + { + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + { + UInt8 typeId = 0; + Value(typeId); + } + else if constexpr (std::is_same_v) + { + UInt8 typeId = 1; + UInt64 v = UInt64(arg); + Value(typeId); + Value(v); + } + else if constexpr (std::is_same_v) + { + UInt8 typeId = 2; + Value(typeId); + Value(arg); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + + }, attribute.args); + } + else + { + UInt8 typeId; + Value(typeId); + + switch (typeId) + { + case 0: + attribute.args.emplace(); + break; + + case 1: + { + UInt64 arg; + Value(arg); + + attribute.args = static_cast(arg); + break; + } + + case 2: + { + std::string arg; + Value(arg); + + attribute.args = std::move(arg); + break; + } + + default: + throw std::runtime_error("invalid attribute type id"); + } + } + } + } bool ShaderAstSerializer::IsWriting() const { @@ -253,20 +333,47 @@ namespace Nz::ShaderAst } } - void ShaderAstSerializer::Type(ShaderExpressionType& type) + void ShaderAstSerializer::Type(ExpressionType& type) { std::visit([&](auto&& arg) { using T = std::decay_t; - if constexpr (std::is_same_v) - { + + if constexpr (std::is_same_v) m_stream << UInt8(0); - m_stream << UInt32(arg); - } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { m_stream << UInt8(1); - m_stream << arg; + m_stream << UInt32(arg); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(2); + m_stream << arg.name; + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(3); + m_stream << UInt32(arg.columnCount); + m_stream << UInt32(arg.rowCount); + m_stream << UInt32(arg.type); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(4); + m_stream << UInt32(arg.dim); + m_stream << UInt32(arg.sampledType); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(5); + m_stream << arg.containedType.name; + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(6); + m_stream << UInt32(arg.componentCount); + m_stream << UInt32(arg.type); } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); @@ -421,28 +528,123 @@ namespace Nz::ShaderAst } } - void ShaderAstUnserializer::Type(ShaderExpressionType& type) + void ShaderAstUnserializer::Type(ExpressionType& type) { UInt8 typeIndex; Value(typeIndex); switch (typeIndex) { - case 0: //< Primitive + /* + if constexpr (std::is_same_v) + m_stream << UInt8(0); + else if constexpr (std::is_same_v) { - BasicType exprType; - Enum(exprType); + m_stream << UInt8(1); + m_stream << UInt32(arg); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(2); + m_stream << arg.name; + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(3); + m_stream << UInt32(arg.columnCount); + m_stream << UInt32(arg.rowCount); + m_stream << UInt32(arg.type); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(4); + m_stream << UInt32(arg.dim); + m_stream << UInt32(arg.sampledType); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(5); + m_stream << UInt32(arg.componentCount); + m_stream << UInt32(arg.type); + } + */ - type = exprType; + case 0: //< NoType + type = NoType{}; + break; + + case 1: //< PrimitiveType + { + PrimitiveType primitiveType; + Enum(primitiveType); + + type = primitiveType; break; } - case 1: //< Struct (name) + case 2: //< Identifier { - std::string structName; - Value(structName); + std::string identifier; + Value(identifier); - type = std::move(structName); + type = IdentifierType{ std::move(identifier) }; + break; + } + + case 3: //< MatrixType + { + UInt32 columnCount, rowCount; + PrimitiveType primitiveType; + Value(columnCount); + Value(rowCount); + Enum(primitiveType); + + type = MatrixType { + columnCount, + rowCount, + primitiveType + }; + break; + } + + case 4: //< SamplerType + { + ImageType dim; + PrimitiveType sampledType; + Enum(dim); + Enum(sampledType); + + type = SamplerType { + dim, + sampledType + }; + break; + } + + case 5: //< UniformType + { + std::string containedType; + Value(containedType); + + type = UniformType { + IdentifierType { + containedType + } + }; + break; + } + + case 6: //< VectorType + { + UInt32 componentCount; + PrimitiveType componentType; + Value(componentCount); + Enum(componentType); + + type = VectorType{ + componentCount, + componentType + }; break; } diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index 2c41cf999..8293870e9 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -18,7 +18,6 @@ namespace Nz::ShaderAst { "frag", ShaderStageType::Fragment }, { "vert", ShaderStageType::Vertex }, }; - } struct AstError @@ -30,6 +29,8 @@ namespace Nz::ShaderAst { //const ShaderAst::Function* currentFunction; std::optional activeScopeId; + std::unordered_set declaredExternalVar; + std::unordered_set usedBindingIndexes;; AstCache* cache; }; @@ -81,31 +82,31 @@ namespace Nz::ShaderAst return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache)); } - void AstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right) + void AstValidator::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) { if (left != right) throw AstError{ "Left expression type must match right expression type" }; } - ShaderExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) + ExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) { const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName); if (!identifier) throw AstError{ "unknown identifier " + structName }; - if (std::holds_alternative(identifier->value)) + if (!std::holds_alternative(identifier->value)) throw AstError{ "identifier is not a struct" }; const StructDescription& s = std::get(identifier->value); - auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; }); + auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; }); if (memberIt == s.members.end()) throw AstError{ "unknown field " + memberIdentifier[0]}; const auto& member = *memberIt; if (remainingMembers > 1) - return CheckField(std::get(member.type), memberIdentifier + 1, remainingMembers - 1); + return CheckField(std::get(member.type).name, memberIdentifier + 1, remainingMembers - 1); else return member.type; } @@ -130,7 +131,7 @@ namespace Nz::ShaderAst m_context->activeScopeId = previousScope.parentScopeIndex; } - void AstValidator::RegisterExpressionType(Expression& node, ShaderExpressionType expressionType) + void AstValidator::RegisterExpressionType(Expression& node, ExpressionType expressionType) { m_context->cache->nodeExpressionType[&node] = std::move(expressionType); } @@ -145,11 +146,14 @@ namespace Nz::ShaderAst { RegisterScope(node); - ShaderExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache); - if (!IsStructType(exprType)) + // Register expressions types + AstRecursiveVisitor::Visit(node); + + ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache); + if (!IsIdentifierType(exprType)) throw AstError{ "expression is not a structure" }; - const std::string& structName = std::get(exprType); + const std::string& structName = std::get(exprType).name; RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size())); } @@ -160,12 +164,14 @@ namespace Nz::ShaderAst MandatoryExpr(node.left); MandatoryExpr(node.right); + + // Register expressions types + AstRecursiveVisitor::Visit(node); + TypeMustMatch(node.left, node.right); if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) throw AstError { "Assignation is only possible with a l-value" }; - - AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(BinaryExpression& node) @@ -175,80 +181,121 @@ namespace Nz::ShaderAst // Register expression type AstRecursiveVisitor::Visit(node); - ShaderExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache); - if (!IsBasicType(leftExprType)) + ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache); + if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType)) throw AstError{ "left expression type does not support binary operation" }; - ShaderExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache); - if (!IsBasicType(rightExprType)) + ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache); + if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType)) throw AstError{ "right expression type does not support binary operation" }; - BasicType leftType = std::get(leftExprType); - BasicType rightType = std::get(rightExprType); - - switch (node.op) + if (IsPrimitiveType(leftExprType)) { - case BinaryType::CompGe: - case BinaryType::CompGt: - case BinaryType::CompLe: - case BinaryType::CompLt: - if (leftType == BasicType::Boolean) - throw AstError{ "this operation is not supported for booleans" }; - - [[fallthrough]]; - case BinaryType::Add: - case BinaryType::CompEq: - case BinaryType::CompNe: - case BinaryType::Subtract: - TypeMustMatch(node.left, node.right); - break; - - case BinaryType::Multiply: - case BinaryType::Divide: + PrimitiveType leftType = std::get(leftExprType); + switch (node.op) { - switch (leftType) + case BinaryType::CompGe: + case BinaryType::CompGt: + case BinaryType::CompLe: + case BinaryType::CompLt: + if (leftType == PrimitiveType::Boolean) + throw AstError{ "this operation is not supported for booleans" }; + + [[fallthrough]]; + case BinaryType::Add: + case BinaryType::CompEq: + case BinaryType::CompNe: + case BinaryType::Subtract: + TypeMustMatch(node.left, node.right); + break; + + case BinaryType::Multiply: + case BinaryType::Divide: { - case BasicType::Float1: - case BasicType::Int1: + switch (leftType) { - if (GetComponentType(rightType) != leftType) - throw AstError{ "Left expression type is not compatible with right expression type" }; - - break; - } - - case BasicType::Float2: - case BasicType::Float3: - case BasicType::Float4: - case BasicType::Int2: - case BasicType::Int3: - case BasicType::Int4: - { - if (leftType != rightType && rightType != GetComponentType(leftType)) - throw AstError{ "Left expression type is not compatible with right expression type" }; - - break; - } - - case BasicType::Mat4x4: - { - switch (rightType) + case PrimitiveType::Float32: + case PrimitiveType::Int32: + case PrimitiveType::UInt32: { - case BasicType::Float1: - case BasicType::Float4: - case BasicType::Mat4x4: - break; + if (IsMatrixType(rightExprType)) + TypeMustMatch(leftType, std::get(rightExprType).type); + else if (IsVectorType(rightExprType)) + TypeMustMatch(leftType, std::get(rightExprType).type); + else + throw AstError{ "incompatible types" }; - default: - TypeMustMatch(node.left, node.right); + break; } - break; - } + case PrimitiveType::Boolean: + throw AstError{ "this operation is not supported for booleans" }; - default: - TypeMustMatch(node.left, node.right); - break; + default: + throw AstError{ "incompatible types" }; + } + } + } + } + else if (IsMatrixType(leftExprType)) + { + const MatrixType& leftType = std::get(leftExprType); + switch (node.op) + { + case BinaryType::CompGe: + case BinaryType::CompGt: + case BinaryType::CompLe: + case BinaryType::CompLt: + case BinaryType::CompEq: + case BinaryType::CompNe: + case BinaryType::Add: + case BinaryType::Subtract: + TypeMustMatch(node.left, node.right); + break; + + case BinaryType::Multiply: + case BinaryType::Divide: + { + if (IsMatrixType(rightExprType)) + TypeMustMatch(leftExprType, rightExprType); + else if (IsPrimitiveType(rightExprType)) + TypeMustMatch(leftType.type, rightExprType); + else if (IsVectorType(rightExprType)) + { + const VectorType& rightType = std::get(rightExprType); + TypeMustMatch(leftType.type, rightType.type); + + if (leftType.columnCount != rightType.componentCount) + throw AstError{ "incompatible types" }; + } + else + throw AstError{ "incompatible types" }; + } + } + } + else if (IsVectorType(leftExprType)) + { + const MatrixType& leftType = std::get(leftExprType); + switch (node.op) + { + case BinaryType::CompGe: + case BinaryType::CompGt: + case BinaryType::CompLe: + case BinaryType::CompLt: + case BinaryType::CompEq: + case BinaryType::CompNe: + case BinaryType::Add: + case BinaryType::Subtract: + TypeMustMatch(node.left, node.right); + break; + + case BinaryType::Multiply: + case BinaryType::Divide: + { + if (IsPrimitiveType(rightExprType)) + TypeMustMatch(leftType.type, rightExprType); + else + throw AstError{ "incompatible types" }; } } } @@ -258,24 +305,35 @@ namespace Nz::ShaderAst { RegisterScope(node); + AstRecursiveVisitor::Visit(node); + + auto GetComponentCount = [](const ExpressionType& exprType) -> unsigned int + { + if (IsPrimitiveType(exprType)) + return 1; + else if (IsVectorType(exprType)) + return std::get(exprType).componentCount; + else + throw AstError{ "wut" }; + }; + unsigned int componentCount = 0; unsigned int requiredComponents = GetComponentCount(node.targetType); + for (auto& exprPtr : node.expressions) { if (!exprPtr) break; - ShaderExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache); - if (!IsBasicType(exprType)) + ExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache); + if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) throw AstError{ "incompatible type" }; - componentCount += GetComponentCount(std::get(exprType)); + componentCount += GetComponentCount(exprType); } if (componentCount != requiredComponents) throw AstError{ "component count doesn't match required component count" }; - - AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(ConditionalExpression& node) @@ -313,34 +371,51 @@ namespace Nz::ShaderAst { RegisterScope(node); + AstRecursiveVisitor::Visit(node); + switch (node.intrinsic) { case IntrinsicType::CrossProduct: case IntrinsicType::DotProduct: { if (node.parameters.size() != 2) - throw AstError { "Expected 2 parameters" }; + throw AstError { "Expected two parameters" }; for (auto& param : node.parameters) MandatoryExpr(param); - ShaderExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache); + ExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache); for (std::size_t i = 1; i < node.parameters.size(); ++i) { - if (type != GetExpressionType(MandatoryExpr(node.parameters[i])), m_context->cache) + if (type != GetExpressionType(MandatoryExpr(node.parameters[i]), m_context->cache)) throw AstError{ "All type must match" }; } break; } + + case IntrinsicType::SampleTexture: + { + if (node.parameters.size() != 2) + throw AstError{ "Expected two parameters" }; + + for (auto& param : node.parameters) + MandatoryExpr(param); + + if (!IsSamplerType(GetExpressionType(*node.parameters[0], m_context->cache))) + throw AstError{ "First parameter must be a sampler" }; + + if (!IsVectorType(GetExpressionType(*node.parameters[1], m_context->cache))) + throw AstError{ "First parameter must be a vector" }; + } } switch (node.intrinsic) { case IntrinsicType::CrossProduct: { - if (GetExpressionType(*node.parameters[0]) != ShaderExpressionType{ BasicType::Float3 }, m_context->cache) - throw AstError{ "CrossProduct only works with Float3 expressions" }; + if (GetExpressionType(*node.parameters[0]) != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) + throw AstError{ "CrossProduct only works with vec3 expressions" }; break; } @@ -348,8 +423,6 @@ namespace Nz::ShaderAst case IntrinsicType::DotProduct: break; } - - AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(SwizzleExpression& node) @@ -359,26 +432,10 @@ namespace Nz::ShaderAst if (node.componentCount > 4) throw AstError{ "Cannot swizzle more than four elements" }; - ShaderExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache); - if (!IsBasicType(exprType)) + ExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache); + if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) throw AstError{ "Cannot swizzle this type" }; - switch (std::get(exprType)) - { - case BasicType::Float1: - case BasicType::Float2: - case BasicType::Float3: - case BasicType::Float4: - case BasicType::Int1: - case BasicType::Int2: - case BasicType::Int3: - case BasicType::Int4: - break; - - default: - throw AstError{ "Cannot swizzle this type" }; - } - AstRecursiveVisitor::Visit(node); } @@ -388,8 +445,8 @@ namespace Nz::ShaderAst for (auto& condStatement : node.condStatements) { - ShaderExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache); - if (!IsBasicType(condType) || std::get(condType) != BasicType::Boolean) + ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache); + if (!IsPrimitiveType(condType) || std::get(condType) != PrimitiveType::Boolean) throw AstError{ "if expression must resolve to boolean type" }; MandatoryStatement(condStatement.statement); @@ -409,6 +466,78 @@ namespace Nz::ShaderAst // throw AstError{ "condition not found" }; } + void AstValidator::Visit(DeclareExternalStatement& node) + { + RegisterScope(node); + auto& scope = m_context->cache->scopes[*m_context->activeScopeId]; + + for (const auto& [attributeType, arg] : node.attributes) + { + switch (attributeType) + { + default: + throw AstError{ "unhandled attribute for external block" }; + } + } + + for (const auto& extVar : node.externalVars) + { + bool hasBinding = false; + bool hasLayout = false; + for (const auto& [attributeType, arg] : extVar.attributes) + { + switch (attributeType) + { + case AttributeType::Binding: + { + if (hasBinding) + throw AstError{ "attribute binding must be present once" }; + + if (!std::holds_alternative(arg)) + throw AstError{ "attribute binding requires a string parameter" }; + + long long bindingIndex = std::get(arg); + if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end()) + throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" }; + + m_context->usedBindingIndexes.insert(bindingIndex); + break; + } + + case AttributeType::Layout: + { + if (hasLayout) + throw AstError{ "attribute layout must be present once" }; + + if (!std::holds_alternative(arg)) + throw AstError{ "attribute layout requires a string parameter" }; + + if (std::get(arg) != "std140") + throw AstError{ "unknow layout type" }; + + hasLayout = true; + break; + } + + default: + throw AstError{ "unhandled attribute for external variable" }; + } + } + + if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end()) + throw AstError{ "External variable " + extVar.name + " is already declared" }; + + m_context->declaredExternalVar.insert(extVar.name); + + ExpressionType subType = extVar.type; + if (IsUniformType(subType)) + subType = IdentifierType{ std::get(subType).containedType }; + + auto& identifier = scope.identifiers.emplace_back(); + identifier = AstCache::Identifier{ extVar.name, AstCache::Variable { std::move(subType) } }; + } + } + void AstValidator::Visit(DeclareFunctionStatement& node) { bool hasEntry = false; @@ -421,12 +550,14 @@ namespace Nz::ShaderAst if (hasEntry) throw AstError{ "attribute entry must be present once" }; - if (arg.empty()) - throw AstError{ "attribute entry requires a parameter" }; + if (!std::holds_alternative(arg)) + throw AstError{ "attribute entry requires a string parameter" }; - auto it = entryPoints.find(arg); + const std::string& argStr = std::get(arg); + + auto it = entryPoints.find(argStr); if (it == entryPoints.end()) - throw AstError{ "invalid parameter " + arg + " for entry attribute" }; + throw AstError{ "invalid parameter " + argStr + " for entry attribute" }; ShaderStageType stageType = it->second; @@ -435,6 +566,9 @@ namespace Nz::ShaderAst m_context->cache->entryFunctions[UnderlyingCast(it->second)] = &node; + if (node.parameters.size() > 1) + throw AstError{ "entry functions can either take one struct parameter or no parameter" }; + hasEntry = true; break; } @@ -468,6 +602,8 @@ namespace Nz::ShaderAst RegisterScope(node); + //TODO: check members attributes + auto& scope = m_context->cache->scopes[*m_context->activeScopeId]; auto& identifier = scope.identifiers.emplace_back(); diff --git a/src/Nazara/Shader/ShaderLangLexer.cpp b/src/Nazara/Shader/ShaderLangLexer.cpp index 93b6bda1c..fbede99d0 100644 --- a/src/Nazara/Shader/ShaderLangLexer.cpp +++ b/src/Nazara/Shader/ShaderLangLexer.cpp @@ -36,22 +36,24 @@ namespace Nz::ShaderLang std::vector Tokenize(const std::string_view& str) { - // Can't use std::from_chars for double thanks to libc++ and libstdc++ developers for being lazy + // Can't use std::from_chars for double, thanks to libc++ and libstdc++ developers for being lazy ForceCLocale forceCLocale; std::unordered_map reservedKeywords = { - { "false", TokenType::BoolFalse }, - { "fn", TokenType::FunctionDeclaration }, - { "let", TokenType::Let }, - { "return", TokenType::Return }, - { "true", TokenType::BoolTrue } + { "external", TokenType::External }, + { "false", TokenType::BoolFalse }, + { "fn", TokenType::FunctionDeclaration }, + { "let", TokenType::Let }, + { "return", TokenType::Return }, + { "struct", TokenType::Struct }, + { "true", TokenType::BoolTrue } }; std::size_t currentPos = 0; auto Peek = [&](std::size_t advance = 1) -> char { - if (currentPos + advance < str.size()) + if (currentPos + advance < str.size() && str[currentPos + advance] != '\0') return str[currentPos + advance]; else return char(-1); @@ -134,7 +136,10 @@ namespace Nz::ShaderLang { currentPos++; if (Peek() == '/') + { + currentPos++; break; + } } else if (next == '\n') { @@ -250,7 +255,48 @@ namespace Nz::ShaderLang break; } - case '=': tokenType = TokenType::Assign; break; + case '=': + { + char next = Peek(); + if (next == '=') + { + currentPos++; + tokenType = TokenType::Equal; + } + else + tokenType = TokenType::Assign; + + break; + } + + case '<': + { + char next = Peek(); + if (next == '=') + { + currentPos++; + tokenType = TokenType::LessThanEqual; + } + else + tokenType = TokenType::LessThan; + + break; + } + + case '>': + { + char next = Peek(); + if (next == '=') + { + currentPos++; + tokenType = TokenType::GreatherThanEqual; + } + else + tokenType = TokenType::GreatherThan; + + break; + } + case '+': tokenType = TokenType::Plus; break; case '*': tokenType = TokenType::Multiply; break; case ':': tokenType = TokenType::Colon; break; diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index c2a05aad3..3a493ab30 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -11,32 +11,24 @@ namespace Nz::ShaderLang { namespace { - std::unordered_map identifierToBasicType = { - { "bool", ShaderAst::BasicType::Boolean }, + std::unordered_map identifierToBasicType = { + { "bool", ShaderAst::PrimitiveType::Boolean }, + { "i32", ShaderAst::PrimitiveType::Int32 }, + { "f32", ShaderAst::PrimitiveType::Float32 }, + { "u32", ShaderAst::PrimitiveType::UInt32 } + }; - { "i32", ShaderAst::BasicType::Int1 }, - { "vec2i32", ShaderAst::BasicType::Int2 }, - { "vec3i32", ShaderAst::BasicType::Int3 }, - { "vec4i32", ShaderAst::BasicType::Int4 }, - - { "f32", ShaderAst::BasicType::Float1 }, - { "vec2f32", ShaderAst::BasicType::Float2 }, - { "vec3f32", ShaderAst::BasicType::Float3 }, - { "vec4f32", ShaderAst::BasicType::Float4 }, - - { "mat4x4f32", ShaderAst::BasicType::Mat4x4 }, - { "sampler2D", ShaderAst::BasicType::Sampler2D }, - { "void", ShaderAst::BasicType::Void }, - - { "u32", ShaderAst::BasicType::UInt1 }, - { "vec2u32", ShaderAst::BasicType::UInt3 }, - { "vec3u32", ShaderAst::BasicType::UInt3 }, - { "vec4u32", ShaderAst::BasicType::UInt4 }, + std::unordered_map identifierToIntrinsic = { + { "cross", ShaderAst::IntrinsicType::CrossProduct }, + { "dot", ShaderAst::IntrinsicType::DotProduct }, }; std::unordered_map identifierToAttributeType = { - { "entry", ShaderAst::AttributeType::Entry }, - { "layout", ShaderAst::AttributeType::Layout }, + { "binding", ShaderAst::AttributeType::Binding }, + { "builtin", ShaderAst::AttributeType::Builtin }, + { "entry", ShaderAst::AttributeType::Entry }, + { "layout", ShaderAst::AttributeType::Layout }, + { "location", ShaderAst::AttributeType::Location }, }; } @@ -50,22 +42,41 @@ namespace Nz::ShaderLang m_context = &context; + std::vector attributes; + + EnterScope(); + bool reachedEndOfStream = false; while (!reachedEndOfStream) { const Token& nextToken = Peek(); switch (nextToken.type) { + case TokenType::EndOfStream: + if (!attributes.empty()) + throw UnexpectedToken{}; + + reachedEndOfStream = true; + break; + + case TokenType::External: + context.root->statements.push_back(ParseExternalBlock(std::move(attributes))); + attributes.clear(); + break; + case TokenType::OpenAttribute: - HandleAttributes(); + assert(attributes.empty()); + attributes = ParseAttributes(); break; case TokenType::FunctionDeclaration: - context.root->statements.push_back(ParseFunctionDeclaration()); + context.root->statements.push_back(ParseFunctionDeclaration(std::move(attributes))); + attributes.clear(); break; - case TokenType::EndOfStream: - reachedEndOfStream = true; + case TokenType::Struct: + context.root->statements.push_back(ParseStructDeclaration(std::move(attributes))); + attributes.clear(); break; default: @@ -73,6 +84,8 @@ namespace Nz::ShaderLang } } + LeaveScope(); + return std::move(context.root); } @@ -90,6 +103,92 @@ namespace Nz::ShaderLang m_context->tokenIndex += count; } + ShaderAst::ExpressionType Parser::DecodeType(const std::string& identifier) + { + if (auto it = identifierToBasicType.find(identifier); it != identifierToBasicType.end()) + return it->second; + + //FIXME: Handle this better + if (identifier == "mat4") + { + ShaderAst::MatrixType matrixType; + matrixType.columnCount = 4; + matrixType.rowCount = 4; + + Expect(Advance(), TokenType::LessThan); //< '<' + matrixType.type = ParsePrimitiveType(); + Expect(Advance(), TokenType::GreatherThan); //< '>' + + return matrixType; + } + else if (identifier == "sampler2D") + { + ShaderAst::SamplerType samplerType; + samplerType.dim = ImageType_2D; + + Expect(Advance(), TokenType::LessThan); //< '<' + samplerType.sampledType = ParsePrimitiveType(); + Expect(Advance(), TokenType::GreatherThan); //< '>' + + return samplerType; + } + else if (identifier == "uniform") + { + ShaderAst::UniformType uniformType; + + Expect(Advance(), TokenType::LessThan); //< '<' + uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() }; + Expect(Advance(), TokenType::GreatherThan); //< '>' + + return uniformType; + } + else if (identifier == "vec2") + { + ShaderAst::VectorType vectorType; + vectorType.componentCount = 2; + + Expect(Advance(), TokenType::LessThan); //< '<' + vectorType.type = ParsePrimitiveType(); + Expect(Advance(), TokenType::GreatherThan); //< '>' + + return vectorType; + } + else if (identifier == "vec3") + { + ShaderAst::VectorType vectorType; + vectorType.componentCount = 3; + + Expect(Advance(), TokenType::LessThan); //< '<' + vectorType.type = ParsePrimitiveType(); + Expect(Advance(), TokenType::GreatherThan); //< '>' + + return vectorType; + } + else if (identifier == "vec4") + { + ShaderAst::VectorType vectorType; + vectorType.componentCount = 4; + + Expect(Advance(), TokenType::LessThan); //< '<' + vectorType.type = ParsePrimitiveType(); + Expect(Advance(), TokenType::GreatherThan); //< '>' + + return vectorType; + } + else + { + ShaderAst::IdentifierType identifierType; + identifierType.name = identifier; + + return identifierType; + } + } + + void Parser::EnterScope() + { + m_context->scopeSizes.push_back(m_context->identifiersInScope.size()); + } + const Token& Parser::Expect(const Token& token, TokenType type) { if (token.type != type) @@ -114,13 +213,34 @@ namespace Nz::ShaderLang return token; } + void Parser::LeaveScope() + { + assert(!m_context->scopeSizes.empty()); + m_context->identifiersInScope.resize(m_context->scopeSizes.back()); + m_context->scopeSizes.pop_back(); + } + + bool Parser::IsVariableInScope(const std::string_view& identifier) const + { + return std::find(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), identifier) != m_context->identifiersInScope.rend(); + } + + void Parser::RegisterVariable(std::string identifier) + { + if (IsVariableInScope(identifier)) + throw DuplicateIdentifier{ ("identifier name " + identifier + " is already taken").c_str() }; + + assert(!m_context->scopeSizes.empty()); + m_context->identifiersInScope.push_back(std::move(identifier)); + } + const Token& Parser::Peek(std::size_t advance) { assert(m_context->tokenIndex + advance < m_context->tokenCount); return m_context->tokens[m_context->tokenIndex + advance]; } - void Parser::HandleAttributes() + std::vector Parser::ParseAttributes() { std::vector attributes; @@ -150,13 +270,22 @@ namespace Nz::ShaderLang ShaderAst::AttributeType attributeType = ParseIdentifierAsAttributeType(); - std::string arg; + ShaderAst::Attribute::Param arg; if (Peek().type == TokenType::OpenParenthesis) { Consume(); - if (Peek().type == TokenType::Identifier) - arg = std::get(Advance().data); + const Token& n = Peek(); + if (n.type == TokenType::Identifier) + { + arg = std::get(n.data); + Consume(); + } + else if (n.type == TokenType::IntegerValue) + { + arg = std::get(n.data); + Consume(); + } Expect(Advance(), TokenType::ClosingParenthesis); } @@ -171,16 +300,54 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::ClosingAttribute); - const Token& nextToken = Peek(); - switch (nextToken.type) + return attributes; + } + + ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector attributes) + { + Expect(Advance(), TokenType::External); + Expect(Advance(), TokenType::OpenCurlyBracket); + + std::unique_ptr externalStatement = std::make_unique(); + externalStatement->attributes = std::move(attributes); + + bool first = true; + + for (;;) { - case TokenType::FunctionDeclaration: - m_context->root->statements.push_back(ParseFunctionDeclaration(std::move(attributes))); + if (!first) + { + const Token& nextToken = Peek(); + if (nextToken.type == TokenType::Comma) + Consume(); + else + { + Expect(nextToken, TokenType::ClosingCurlyBracket); + break; + } + } + + first = false; + + const Token& token = Peek(); + if (token.type == TokenType::ClosingCurlyBracket) break; - default: - throw UnexpectedToken{}; + auto& extVar = externalStatement->externalVars.emplace_back(); + + if (token.type == TokenType::OpenAttribute) + extVar.attributes = ParseAttributes(); + + extVar.name = ParseIdentifierAsName(); + Expect(Advance(), TokenType::Colon); + extVar.type = ParseType(); + + RegisterVariable(extVar.name); } + + Expect(Advance(), TokenType::ClosingCurlyBracket); + + return externalStatement; } std::vector Parser::ParseFunctionBody() @@ -216,17 +383,23 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::ClosingParenthesis); - ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void; + ShaderAst::ExpressionType returnType; if (Peek().type == TokenType::FunctionReturn) { Consume(); - returnType = ParseIdentifierAsType(); + returnType = ParseType(); } Expect(Advance(), TokenType::OpenCurlyBracket); + EnterScope(); + for (const auto& parameter : parameters) + RegisterVariable(parameter.name); + std::vector functionBody = ParseFunctionBody(); + LeaveScope(); + Expect(Advance(), TokenType::ClosingCurlyBracket); return ShaderBuilder::DeclareFunction(std::move(attributes), std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType)); @@ -238,11 +411,59 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::Colon); - ShaderAst::ShaderExpressionType parameterType = ParseIdentifierAsType(); + ShaderAst::ExpressionType parameterType = ParseType(); return { parameterName, parameterType }; } + ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector attributes) + { + Expect(Advance(), TokenType::Struct); + + ShaderAst::StructDescription description; + description.name = ParseIdentifierAsName(); + + Expect(Advance(), TokenType::OpenCurlyBracket); + + bool first = true; + + for (;;) + { + if (!first) + { + const Token& nextToken = Peek(); + if (nextToken.type == TokenType::Comma) + Consume(); + else + { + Expect(nextToken, TokenType::ClosingCurlyBracket); + break; + } + } + + first = false; + + const Token& token = Peek(); + if (token.type == TokenType::ClosingCurlyBracket) + break; + + auto& structField = description.members.emplace_back(); + + if (token.type == TokenType::OpenAttribute) + structField.attributes = ParseAttributes(); + + structField.name = ParseIdentifierAsName(); + + Expect(Advance(), TokenType::Colon); + + structField.type = ParseType(); + } + + Expect(Advance(), TokenType::ClosingCurlyBracket); + + return ShaderBuilder::DeclareStruct(std::move(attributes), std::move(description)); + } + ShaderAst::StatementPtr Parser::ParseReturnStatement() { Expect(Advance(), TokenType::Return); @@ -265,6 +486,10 @@ namespace Nz::ShaderLang statement = ParseVariableDeclaration(); break; + case TokenType::Identifier: + statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation()); + break; + case TokenType::Return: statement = ParseReturnStatement(); break; @@ -290,15 +515,26 @@ namespace Nz::ShaderLang return statements; } + ShaderAst::ExpressionPtr Parser::ParseVariableAssignation() + { + ShaderAst::ExpressionPtr left = ParseIdentifier(); + Expect(Advance(), TokenType::Assign); + + ShaderAst::ExpressionPtr right = ParseExpression(); + + return ShaderBuilder::Assign(ShaderAst::AssignType::Simple, std::move(left), std::move(right)); + } + ShaderAst::StatementPtr Parser::ParseVariableDeclaration() { Expect(Advance(), TokenType::Let); std::string variableName = ParseIdentifierAsName(); + RegisterVariable(variableName); Expect(Advance(), TokenType::Colon); - ShaderAst::ShaderExpressionType variableType = ParseIdentifierAsType(); + ShaderAst::ExpressionType variableType = ParseType(); ShaderAst::ExpressionPtr expression; if (Peek().type == TokenType::Assign) @@ -351,18 +587,61 @@ namespace Nz::ShaderLang return ParseBinOpRhs(0, ParsePrimaryExpression()); } + ShaderAst::ExpressionPtr Parser::ParseFloatingPointExpression(bool minus) + { + const Token& floatingPointToken = Expect(Advance(), TokenType::FloatingPointValue); + return ShaderBuilder::Constant(((minus) ? -1.f : 1.f) * float(std::get(floatingPointToken.data))); //< FIXME + } + ShaderAst::ExpressionPtr Parser::ParseIdentifier() { const Token& identifierToken = Expect(Advance(), TokenType::Identifier); const std::string& identifier = std::get(identifierToken.data); - return ShaderBuilder::Identifier(identifier); + ShaderAst::ExpressionPtr identifierExpr = ShaderBuilder::Identifier(identifier); + + if (Peek().type == TokenType::Dot) + { + std::unique_ptr accessMemberNode = std::make_unique(); + accessMemberNode->structExpr = std::move(identifierExpr); + + do + { + Consume(); + + accessMemberNode->memberIdentifiers.push_back(ParseIdentifierAsName()); + } while (Peek().type == TokenType::Dot); + + identifierExpr = std::move(accessMemberNode); + } + + return identifierExpr; } - ShaderAst::ExpressionPtr Parser::ParseIntegerExpression() + ShaderAst::ExpressionPtr Parser::ParseIntegerExpression(bool minus) { - const Token& integerToken = Expect(Advance(), TokenType::Identifier); - return ShaderBuilder::Constant(static_cast(std::get(integerToken.data))); + const Token& integerToken = Expect(Advance(), TokenType::IntegerValue); + return ShaderBuilder::Constant(((minus) ? -1 : 1) * static_cast(std::get(integerToken.data))); + } + + std::vector Parser::ParseParameters() + { + Expect(Advance(), TokenType::OpenParenthesis); + + std::vector parameters; + bool first = true; + while (Peek().type != TokenType::ClosingParenthesis) + { + if (!first) + Expect(Advance(), TokenType::Comma); + + first = false; + parameters.push_back(ParseExpression()); + } + + Expect(Advance(), TokenType::ClosingParenthesis); + + return parameters; } ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression() @@ -388,15 +667,69 @@ namespace Nz::ShaderLang return ShaderBuilder::Constant(true); case TokenType::FloatingPointValue: - Consume(); - return ShaderBuilder::Constant(float(std::get(token.data))); //< FIXME + return ParseFloatingPointExpression(); case TokenType::Identifier: - return ParseIdentifier(); + { + const std::string& identifier = std::get(token.data); + + if (auto it = identifierToIntrinsic.find(identifier); it != identifierToIntrinsic.end()) + { + if (Peek(1).type == TokenType::OpenParenthesis) + { + Consume(); + return ShaderBuilder::Intrinsic(it->second, ParseParameters()); + } + } + + if (IsVariableInScope(identifier)) + { + auto node = ParseIdentifier(); + if (node->GetType() == ShaderAst::NodeType::AccessMemberExpression) + { + ShaderAst::AccessMemberExpression* memberExpr = static_cast(node.get()); + if (!memberExpr->memberIdentifiers.empty() && memberExpr->memberIdentifiers.front() == "Sample") + { + if (Peek().type == TokenType::OpenParenthesis) + { + auto parameters = ParseParameters(); + parameters.insert(parameters.begin(), std::move(memberExpr->structExpr)); + + return ShaderBuilder::Intrinsic(ShaderAst::IntrinsicType::SampleTexture, std::move(parameters)); + } + } + } + + return node; + } + + Consume(); + + ShaderAst::ExpressionType exprType = DecodeType(identifier); + + return ShaderBuilder::Cast(std::move(exprType), ParseParameters()); + } case TokenType::IntegerValue: return ParseIntegerExpression(); + case TokenType::Minus: + //< FIXME: Handle this with an unary node + if (Peek(1).type == TokenType::FloatingPointValue) + { + Consume(); + return ParseFloatingPointExpression(true); + } + else if (Peek(1).type == TokenType::IntegerValue) + { + Consume(); + return ParseIntegerExpression(true); + } + else + throw UnexpectedToken{}; + + break; + case TokenType::OpenParenthesis: return ParseParenthesisExpression(); @@ -429,7 +762,7 @@ namespace Nz::ShaderLang return identifier; } - ShaderAst::ShaderExpressionType Parser::ParseIdentifierAsType() + ShaderAst::PrimitiveType Parser::ParsePrimitiveType() { const Token& identifierToken = Expect(Advance(), TokenType::Identifier); const std::string& identifier = std::get(identifierToken.data); @@ -441,6 +774,23 @@ namespace Nz::ShaderLang return it->second; } + ShaderAst::ExpressionType Parser::ParseType() + { + // Handle () as no type + if (Peek().type == TokenType::OpenParenthesis) + { + Consume(); + Expect(Advance(), TokenType::ClosingParenthesis); + + return ShaderAst::NoType{}; + } + + const Token& identifierToken = Expect(Advance(), TokenType::Identifier); + const std::string& identifier = std::get(identifierToken.data); + + return DecodeType(identifier); + } + int Parser::GetTokenPrecedence(TokenType token) { switch (token) @@ -452,4 +802,5 @@ namespace Nz::ShaderLang default: return -1; } } + } diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 0dd4737a3..d559e7965 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -39,18 +39,18 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node) { - ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache); - assert(IsBasicType(resultExprType)); + ShaderAst::ExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache); + assert(IsPrimitiveType(resultExprType)); - ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache); - assert(IsBasicType(leftExprType)); + ShaderAst::ExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache); + assert(IsPrimitiveType(leftExprType)); - ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache); - assert(IsBasicType(rightExprType)); + ShaderAst::ExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache); + assert(IsPrimitiveType(rightExprType)); - ShaderAst::BasicType resultType = std::get(resultExprType); - ShaderAst::BasicType leftType = std::get(leftExprType); - ShaderAst::BasicType rightType = std::get(rightExprType); + ShaderAst::PrimitiveType resultType = std::get(resultExprType); + ShaderAst::PrimitiveType leftType = std::get(leftExprType); + ShaderAst::PrimitiveType rightType = std::get(rightExprType); UInt32 leftOperand = EvaluateExpression(node.left); @@ -67,26 +67,26 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Float1: - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Mat4x4: + case ShaderAst::PrimitiveType::Float32: +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFAdd; - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpIAdd; - case ShaderAst::BasicType::Boolean: - case ShaderAst::BasicType::Sampler2D: - case ShaderAst::BasicType::Void: + case ShaderAst::PrimitiveType::Boolean: +// case ShaderAst::PrimitiveType::Sampler2D: +// case ShaderAst::PrimitiveType::Void: break; } @@ -97,26 +97,26 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Float1: - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Mat4x4: + case ShaderAst::PrimitiveType::Float32: +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFSub; - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpISub; - case ShaderAst::BasicType::Boolean: - case ShaderAst::BasicType::Sampler2D: - case ShaderAst::BasicType::Void: + case ShaderAst::PrimitiveType::Boolean: +// case ShaderAst::PrimitiveType::Sampler2D: +// case ShaderAst::PrimitiveType::Void: break; } @@ -127,28 +127,28 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Float1: - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Mat4x4: + case ShaderAst::PrimitiveType::Float32: +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFDiv; - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSDiv; - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpUDiv; - case ShaderAst::BasicType::Boolean: - case ShaderAst::BasicType::Sampler2D: - case ShaderAst::BasicType::Void: + case ShaderAst::PrimitiveType::Boolean: +// case ShaderAst::PrimitiveType::Sampler2D: +// case ShaderAst::PrimitiveType::Void: break; } @@ -159,29 +159,29 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Boolean: + case ShaderAst::PrimitiveType::Boolean: return SpirvOp::OpLogicalEqual; - case ShaderAst::BasicType::Float1: - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Mat4x4: + case ShaderAst::PrimitiveType::Float32: +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdEqual; - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpIEqual; - case ShaderAst::BasicType::Sampler2D: - case ShaderAst::BasicType::Void: - break; +// case ShaderAst::PrimitiveType::Sampler2D: +// case ShaderAst::PrimitiveType::Void: +// break; } break; @@ -191,28 +191,28 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Float1: - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Mat4x4: + case ShaderAst::PrimitiveType::Float32: +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdGreaterThan; - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSGreaterThan; - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpUGreaterThan; - case ShaderAst::BasicType::Boolean: - case ShaderAst::BasicType::Sampler2D: - case ShaderAst::BasicType::Void: + case ShaderAst::PrimitiveType::Boolean: +// case ShaderAst::PrimitiveType::Sampler2D: +// case ShaderAst::PrimitiveType::Void: break; } @@ -223,28 +223,28 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Float1: - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Mat4x4: + case ShaderAst::PrimitiveType::Float32: +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdGreaterThanEqual; - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSGreaterThanEqual; - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpUGreaterThanEqual; - case ShaderAst::BasicType::Boolean: - case ShaderAst::BasicType::Sampler2D: - case ShaderAst::BasicType::Void: + case ShaderAst::PrimitiveType::Boolean: +// case ShaderAst::PrimitiveType::Sampler2D: +// case ShaderAst::PrimitiveType::Void: break; } @@ -255,28 +255,28 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Float1: - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Mat4x4: + case ShaderAst::PrimitiveType::Float32: +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdLessThanEqual; - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSLessThanEqual; - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpULessThanEqual; - case ShaderAst::BasicType::Boolean: - case ShaderAst::BasicType::Sampler2D: - case ShaderAst::BasicType::Void: + case ShaderAst::PrimitiveType::Boolean: +// case ShaderAst::PrimitiveType::Sampler2D: +// case ShaderAst::PrimitiveType::Void: break; } @@ -287,28 +287,28 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Float1: - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Mat4x4: + case ShaderAst::PrimitiveType::Float32: +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdLessThan; - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSLessThan; - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpULessThan; - case ShaderAst::BasicType::Boolean: - case ShaderAst::BasicType::Sampler2D: - case ShaderAst::BasicType::Void: + case ShaderAst::PrimitiveType::Boolean: +// case ShaderAst::PrimitiveType::Sampler2D: +// case ShaderAst::PrimitiveType::Void: break; } @@ -319,29 +319,29 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Boolean: + case ShaderAst::PrimitiveType::Boolean: return SpirvOp::OpLogicalNotEqual; - case ShaderAst::BasicType::Float1: - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Mat4x4: + case ShaderAst::PrimitiveType::Float32: +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdNotEqual; - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpINotEqual; - case ShaderAst::BasicType::Sampler2D: - case ShaderAst::BasicType::Void: - break; +// case ShaderAst::PrimitiveType::Sampler2D: +// case ShaderAst::PrimitiveType::Void: +// break; } break; @@ -351,22 +351,22 @@ namespace Nz { switch (leftType) { - case ShaderAst::BasicType::Float1: + case ShaderAst::PrimitiveType::Float32: { switch (rightType) { - case ShaderAst::BasicType::Float1: + case ShaderAst::PrimitiveType::Float32: return SpirvOp::OpFMul; - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - swapOperands = true; - return SpirvOp::OpVectorTimesScalar; - - case ShaderAst::BasicType::Mat4x4: - swapOperands = true; - return SpirvOp::OpMatrixTimesScalar; +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// swapOperands = true; +// return SpirvOp::OpVectorTimesScalar; +// +// case ShaderAst::PrimitiveType::Mat4x4: +// swapOperands = true; +// return SpirvOp::OpMatrixTimesScalar; default: break; @@ -375,54 +375,54 @@ namespace Nz break; } - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - { - switch (rightType) - { - case ShaderAst::BasicType::Float1: - return SpirvOp::OpVectorTimesScalar; +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// { +// switch (rightType) +// { +// case ShaderAst::PrimitiveType::Float32: +// return SpirvOp::OpVectorTimesScalar; +// +// case ShaderAst::PrimitiveType::Float2: +// case ShaderAst::PrimitiveType::Float3: +// case ShaderAst::PrimitiveType::Float4: +// return SpirvOp::OpFMul; +// +// case ShaderAst::PrimitiveType::Mat4x4: +// return SpirvOp::OpVectorTimesMatrix; +// +// default: +// break; +// } +// +// break; +// } - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - return SpirvOp::OpFMul; - - case ShaderAst::BasicType::Mat4x4: - return SpirvOp::OpVectorTimesMatrix; - - default: - break; - } - - break; - } - - case ShaderAst::BasicType::Int1: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: - case ShaderAst::BasicType::UInt1: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: + case ShaderAst::PrimitiveType::Int32: +// case ShaderAst::PrimitiveType::Int2: +// case ShaderAst::PrimitiveType::Int3: +// case ShaderAst::PrimitiveType::Int4: + case ShaderAst::PrimitiveType::UInt32: +// case ShaderAst::PrimitiveType::UInt2: +// case ShaderAst::PrimitiveType::UInt3: +// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpIMul; - case ShaderAst::BasicType::Mat4x4: - { - switch (rightType) - { - case ShaderAst::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar; - case ShaderAst::BasicType::Float4: return SpirvOp::OpMatrixTimesVector; - case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix; - - default: - break; - } - - break; - } +// case ShaderAst::PrimitiveType::Mat4x4: +// { +// switch (rightType) +// { +// case ShaderAst::PrimitiveType::Float32: return SpirvOp::OpMatrixTimesScalar; +// case ShaderAst::PrimitiveType::Float4: return SpirvOp::OpMatrixTimesVector; +// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix; +// +// default: +// break; +// } +// +// break; +// } default: break; @@ -501,10 +501,10 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node) { - const ShaderAst::ShaderExpressionType& targetExprType = node.targetType; - assert(IsBasicType(targetExprType)); + const ShaderAst::ExpressionType& targetExprType = node.targetType; + assert(IsPrimitiveType(targetExprType)); - ShaderAst::BasicType targetType = std::get(targetExprType); + ShaderAst::PrimitiveType targetType = std::get(targetExprType); StackVector exprResults = NazaraStackVector(UInt32, node.expressions.size()); @@ -582,12 +582,12 @@ namespace Nz { case ShaderAst::IntrinsicType::DotProduct: { - ShaderAst::ShaderExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache); - assert(IsBasicType(vecExprType)); + ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache); + assert(IsVectorType(vecExprType)); - ShaderAst::BasicType vecType = std::get(vecExprType); + const ShaderAst::VectorType& vecType = std::get(vecExprType); - UInt32 typeId = m_writer.GetTypeId(ShaderAst::GetComponentType(vecType)); + UInt32 typeId = m_writer.GetTypeId(vecType.type); UInt32 vec1 = EvaluateExpression(node.parameters[0]); UInt32 vec2 = EvaluateExpression(node.parameters[1]); @@ -626,10 +626,10 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) { - ShaderAst::ShaderExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache); - assert(IsBasicType(targetExprType)); + ShaderAst::ExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache); + assert(IsPrimitiveType(targetExprType)); - ShaderAst::BasicType targetType = std::get(targetExprType); + ShaderAst::PrimitiveType targetType = std::get(targetExprType); UInt32 exprResultId = EvaluateExpression(node.expression); UInt32 resultId = m_writer.AllocateResultId(); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index 6146c29e5..cc3421ac2 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -535,7 +535,7 @@ namespace Nz else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ - BuildType((std::is_same_v) ? ShaderAst::BasicType::Float2 : ShaderAst::BasicType::Int2), + BuildType(ShaderAst::VectorType{ 2, (std::is_same_v) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }), { BuildConstant(arg.x), BuildConstant(arg.y) @@ -545,7 +545,7 @@ namespace Nz else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ - BuildType((std::is_same_v) ? ShaderAst::BasicType::Float3 : ShaderAst::BasicType::Int3), + BuildType(ShaderAst::VectorType{ 3, (std::is_same_v) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }), { BuildConstant(arg.x), BuildConstant(arg.y), @@ -556,7 +556,7 @@ namespace Nz else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ - BuildType((std::is_same_v) ? ShaderAst::BasicType::Float4 : ShaderAst::BasicType::Int4), + BuildType(ShaderAst::VectorType{ 4, (std::is_same_v) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }), { BuildConstant(arg.x), BuildConstant(arg.y), @@ -570,7 +570,7 @@ namespace Nz }, value)); } - auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ShaderExpressionType& retType, const std::vector& parameters) -> TypePtr + auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters) -> TypePtr { std::vector parameterTypes; parameterTypes.reserve(parameters.size()); @@ -584,7 +584,7 @@ namespace Nz }); } - auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr + auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) -> TypePtr { return std::make_shared(Pointer{ BuildType(type), @@ -592,85 +592,22 @@ namespace Nz }); } - auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr + auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) -> TypePtr { return std::make_shared(Pointer{ BuildType(type), storageClass - }); + }); } - auto SpirvConstantCache::BuildType(const ShaderAst::BasicType& type) -> TypePtr - { - return std::make_shared([&]() -> AnyType - { - switch (type) - { - case ShaderAst::BasicType::Boolean: - return Bool{}; - - case ShaderAst::BasicType::Float1: - return Float{ 32 }; - - case ShaderAst::BasicType::Int1: - return Integer{ 32, true }; - - case ShaderAst::BasicType::Float2: - case ShaderAst::BasicType::Float3: - case ShaderAst::BasicType::Float4: - case ShaderAst::BasicType::Int2: - case ShaderAst::BasicType::Int3: - case ShaderAst::BasicType::Int4: - case ShaderAst::BasicType::UInt2: - case ShaderAst::BasicType::UInt3: - case ShaderAst::BasicType::UInt4: - { - auto vecType = BuildType(ShaderAst::GetComponentType(type)); - UInt32 componentCount = ShaderAst::GetComponentCount(type); - - return Vector{ vecType, componentCount }; - } - - case ShaderAst::BasicType::Mat4x4: - return Matrix{ BuildType(ShaderAst::BasicType::Float4), 4u }; - - case ShaderAst::BasicType::UInt1: - return Integer{ 32, false }; - - case ShaderAst::BasicType::Void: - return Void{}; - - case ShaderAst::BasicType::Sampler2D: - { - auto imageType = Image{ - {}, //< qualifier - {}, //< depth - {}, //< sampled - SpirvDim::Dim2D, //< dim - SpirvImageFormat::Unknown, //< format - BuildType(ShaderAst::BasicType::Float1), //< sampledType - false, //< arrayed, - false //< multisampled - }; - - return SampledImage{ std::make_shared(imageType) }; - } - } - - throw std::runtime_error("unexpected type"); - }()); - } - - auto SpirvConstantCache::BuildType(const ShaderAst::ShaderExpressionType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) -> TypePtr { return std::visit([&](auto&& arg) -> TypePtr { - using T = std::decay_t; - if constexpr (std::is_same_v) - return BuildType(arg); - else if constexpr (std::is_same_v) + return BuildType(arg); + /*else if constexpr (std::is_same_v) { - /*// Register struct members type + // Register struct members type const auto& structs = shader.GetStructs(); auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); if (it == structs.end()) @@ -688,14 +625,77 @@ namespace Nz sMembers.type = BuildType(shader, member.type); } - return std::make_shared(std::move(sType));*/ + return std::make_shared(std::move(sType)); return nullptr; } else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + static_assert(AlwaysFalse::value, "non-exhaustive visitor");*/ }, type); } + auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) -> TypePtr + { + throw std::runtime_error("unexpected type"); + } + + auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) -> TypePtr + { + return std::make_shared([&]() -> AnyType + { + switch (type) + { + case ShaderAst::PrimitiveType::Boolean: + return Bool{}; + + case ShaderAst::PrimitiveType::Float32: + return Float{ 32 }; + + case ShaderAst::PrimitiveType::Int32: + return Integer{ 32, true }; + } + + throw std::runtime_error("unexpected type"); + }()); + } + + auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) -> TypePtr + { + return std::make_shared( + Matrix{ + BuildType(ShaderAst::VectorType { + UInt32(type.rowCount), type.type + }), + UInt32(type.columnCount) + }); + } + + auto SpirvConstantCache::BuildType(const ShaderAst::NoType& type) -> TypePtr + { + return std::make_shared(Void{}); + } + + auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) -> TypePtr + { + //TODO + auto imageType = Image{ + {}, //< qualifier + {}, //< depth + {}, //< sampled + SpirvDim::Dim2D, //< dim + SpirvImageFormat::Unknown, //< format + BuildType(ShaderAst::PrimitiveType::Float32), //< sampledType + false, //< arrayed, + false //< multisampled + }; + + return std::make_shared(SampledImage{ std::make_shared(imageType) }); + } + + auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) -> TypePtr + { + return std::make_shared(Vector{ BuildType(type.type), UInt32(type.componentCount) }); + } + void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants) { std::visit([&](auto&& arg) diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index eb52edeca..fb1ad3725 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -29,7 +29,7 @@ namespace Nz { public: using ExtInstList = std::unordered_set; - using LocalContainer = std::unordered_set; + using LocalContainer = std::unordered_set; using FunctionContainer = std::vector>; PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : @@ -81,7 +81,7 @@ namespace Nz { funcs.emplace_back(node); - std::vector parameterTypes; + std::vector parameterTypes; for (auto& parameter : node.parameters) parameterTypes.push_back(parameter.type); @@ -92,8 +92,17 @@ namespace Nz void Visit(ShaderAst::DeclareStructStatement& node) override { - for (auto& field : node.description.members) - m_constantCache.Register(*SpirvConstantCache::BuildType(field.type)); + SpirvConstantCache::Structure sType; + sType.name = node.description.name; + + for (const auto& [name, attribute, type] : node.description.members) + { + auto& sMembers = sType.members.emplace_back(); + sMembers.name = name; + sMembers.type = SpirvConstantCache::BuildType(type); + } + + m_constantCache.Register(SpirvConstantCache::Type{ std::move(sType) }); } void Visit(ShaderAst::DeclareVariableStatement& node) override @@ -137,26 +146,26 @@ namespace Nz }; template - constexpr ShaderAst::BasicType GetBasicType() + constexpr ShaderAst::PrimitiveType GetBasicType() { if constexpr (std::is_same_v) - return ShaderAst::BasicType::Boolean; + return ShaderAst::PrimitiveType::Boolean; else if constexpr (std::is_same_v) - return(ShaderAst::BasicType::Float1); + return(ShaderAst::PrimitiveType::Float32); else if constexpr (std::is_same_v) - return(ShaderAst::BasicType::Int1); + return(ShaderAst::PrimitiveType::Int32); else if constexpr (std::is_same_v) - return(ShaderAst::BasicType::Float2); + return(ShaderAst::PrimitiveType::Float2); else if constexpr (std::is_same_v) - return(ShaderAst::BasicType::Float3); + return(ShaderAst::PrimitiveType::Float3); else if constexpr (std::is_same_v) - return(ShaderAst::BasicType::Float4); + return(ShaderAst::PrimitiveType::Float4); else if constexpr (std::is_same_v) - return(ShaderAst::BasicType::Int2); + return(ShaderAst::PrimitiveType::Int2); else if constexpr (std::is_same_v) - return(ShaderAst::BasicType::Int3); + return(ShaderAst::PrimitiveType::Int3); else if constexpr (std::is_same_v) - return(ShaderAst::BasicType::Int4); + return(ShaderAst::PrimitiveType::Int4); else static_assert(AlwaysFalse::value, "unhandled type"); } @@ -394,7 +403,7 @@ namespace Nz if (!state.functionBlocks.back().IsTerminated()) { - assert(func.returnType == ShaderAst::ShaderExpressionType(ShaderAst::BasicType::Void)); + assert(func.returnType == ShaderAst::ExpressionType{ ShaderAst::NoType{} }); state.functionBlocks.back().Append(SpirvOp::OpReturn); } @@ -537,12 +546,12 @@ namespace Nz return it.value(); } - UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const + UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const { return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass)); } - UInt32 SpirvWriter::GetTypeId(const ShaderAst::ShaderExpressionType& type) const + UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const { return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type)); } @@ -643,12 +652,12 @@ namespace Nz return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) }); } - UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass) + UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass) { return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass)); } - UInt32 SpirvWriter::RegisterType(ShaderAst::ShaderExpressionType type) + UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type) { assert(m_currentState); return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type)); @@ -662,7 +671,7 @@ namespace Nz SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode) { - std::vector parameterTypes; + std::vector parameterTypes; parameterTypes.reserve(functionNode.parameters.size()); for (const auto& parameter : functionNode.parameters)