From ea99c6a19ed3ae4225378f9b30d924c5cdf4a44e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Mon, 12 Apr 2021 15:38:20 +0200 Subject: [PATCH] Shader: First working version on both Vulkan & OpenGL (ES) --- include/Nazara/Renderer/Enums.hpp | 1 + include/Nazara/Shader.hpp | 1 - .../ConstantValue.hpp} | 4 +- include/Nazara/Shader/Ast/ExpressionType.hpp | 12 +- include/Nazara/Shader/Ast/ExpressionType.inl | 12 + .../Nazara/Shader/Ast/TransformVisitor.hpp | 90 ++ .../Nazara/Shader/Ast/TransformVisitor.inl | 62 ++ include/Nazara/Shader/GlslWriter.hpp | 3 +- include/Nazara/Shader/ShaderAstCloner.hpp | 17 +- include/Nazara/Shader/ShaderAstNodes.hpp | 4 +- .../Shader/ShaderAstRecursiveVisitor.hpp | 4 +- include/Nazara/Shader/ShaderAstSerializer.hpp | 4 +- include/Nazara/Shader/ShaderAstUtils.hpp | 4 +- include/Nazara/Shader/ShaderAstValidator.hpp | 2 +- include/Nazara/Shader/ShaderBuilder.hpp | 26 +- include/Nazara/Shader/ShaderBuilder.inl | 42 +- include/Nazara/Shader/ShaderNodes.hpp | 32 +- include/Nazara/Shader/ShaderNodes.inl | 5 + include/Nazara/Shader/SpirvAstVisitor.hpp | 95 +- include/Nazara/Shader/SpirvAstVisitor.inl | 37 +- include/Nazara/Shader/SpirvConstantCache.hpp | 47 +- include/Nazara/Shader/SpirvExpressionLoad.hpp | 10 +- include/Nazara/Shader/SpirvExpressionLoad.inl | 5 +- .../Nazara/Shader/SpirvExpressionStore.hpp | 10 +- .../Nazara/Shader/SpirvExpressionStore.inl | 5 +- include/Nazara/Shader/SpirvWriter.hpp | 43 +- src/Nazara/Shader/Ast/TransformVisitor.cpp | 225 +++++ src/Nazara/Shader/GlslWriter.cpp | 12 +- src/Nazara/Shader/ShaderAstCloner.cpp | 144 ++- .../Shader/ShaderAstRecursiveVisitor.cpp | 12 +- src/Nazara/Shader/ShaderAstScopedVisitor.cpp | 2 +- src/Nazara/Shader/ShaderAstSerializer.cpp | 42 +- src/Nazara/Shader/ShaderAstUtils.cpp | 12 +- src/Nazara/Shader/ShaderAstValidator.cpp | 11 +- src/Nazara/Shader/ShaderLangParser.cpp | 6 +- src/Nazara/Shader/SpirvAstVisitor.cpp | 591 ++++++++----- src/Nazara/Shader/SpirvConstantCache.cpp | 302 ++++--- src/Nazara/Shader/SpirvDecoder.cpp | 4 +- src/Nazara/Shader/SpirvExpressionLoad.cpp | 39 +- src/Nazara/Shader/SpirvExpressionStore.cpp | 38 +- src/Nazara/Shader/SpirvPrinter.cpp | 4 +- src/Nazara/Shader/SpirvWriter.cpp | 835 ++++++++---------- 42 files changed, 1803 insertions(+), 1053 deletions(-) rename include/Nazara/Shader/{ShaderConstantValue.hpp => Ast/ConstantValue.hpp} (90%) create mode 100644 include/Nazara/Shader/Ast/TransformVisitor.hpp create mode 100644 include/Nazara/Shader/Ast/TransformVisitor.inl create mode 100644 src/Nazara/Shader/Ast/TransformVisitor.cpp diff --git a/include/Nazara/Renderer/Enums.hpp b/include/Nazara/Renderer/Enums.hpp index 3f9de6791..405e50968 100644 --- a/include/Nazara/Renderer/Enums.hpp +++ b/include/Nazara/Renderer/Enums.hpp @@ -140,6 +140,7 @@ namespace Nz HLSL, MSL, NazaraBinary, + NazaraShader, SpirV }; diff --git a/include/Nazara/Shader.hpp b/include/Nazara/Shader.hpp index 2bd6df584..65cd2025e 100644 --- a/include/Nazara/Shader.hpp +++ b/include/Nazara/Shader.hpp @@ -43,7 +43,6 @@ #include #include #include -#include #include #include #include diff --git a/include/Nazara/Shader/ShaderConstantValue.hpp b/include/Nazara/Shader/Ast/ConstantValue.hpp similarity index 90% rename from include/Nazara/Shader/ShaderConstantValue.hpp rename to include/Nazara/Shader/Ast/ConstantValue.hpp index 27c9e1d7e..0f8ed1fc4 100644 --- a/include/Nazara/Shader/ShaderConstantValue.hpp +++ b/include/Nazara/Shader/Ast/ConstantValue.hpp @@ -13,9 +13,9 @@ #include #include -namespace Nz +namespace Nz::ShaderAst { - using ShaderConstantValue = std::variant< + using ConstantValue = std::variant< bool, float, Int32, diff --git a/include/Nazara/Shader/Ast/ExpressionType.hpp b/include/Nazara/Shader/Ast/ExpressionType.hpp index f42776fc3..a54b03395 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.hpp +++ b/include/Nazara/Shader/Ast/ExpressionType.hpp @@ -50,9 +50,17 @@ namespace Nz::ShaderAst inline bool operator!=(const SamplerType& rhs) const; }; + struct StructType + { + std::size_t structIndex; + + inline bool operator==(const StructType& rhs) const; + inline bool operator!=(const StructType& rhs) const; + }; + struct UniformType { - IdentifierType containedType; + std::variant containedType; inline bool operator==(const UniformType& rhs) const; inline bool operator!=(const UniformType& rhs) const; @@ -67,7 +75,7 @@ namespace Nz::ShaderAst inline bool operator!=(const VectorType& rhs) const; }; - using ExpressionType = std::variant; + using ExpressionType = std::variant; struct StructDescription { diff --git a/include/Nazara/Shader/Ast/ExpressionType.inl b/include/Nazara/Shader/Ast/ExpressionType.inl index b9b7734e7..2fdb36674 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.inl +++ b/include/Nazara/Shader/Ast/ExpressionType.inl @@ -51,6 +51,17 @@ namespace Nz::ShaderAst return !operator==(rhs); } + + inline bool StructType::operator==(const StructType& rhs) const + { + return structIndex == rhs.structIndex; + } + + inline bool StructType::operator!=(const StructType& rhs) const + { + return !operator==(rhs); + } + inline bool UniformType::operator==(const UniformType& rhs) const { return containedType == rhs.containedType; @@ -61,6 +72,7 @@ namespace Nz::ShaderAst return !operator==(rhs); } + inline bool VectorType::operator==(const VectorType& rhs) const { return componentCount == rhs.componentCount && type == rhs.type; diff --git a/include/Nazara/Shader/Ast/TransformVisitor.hpp b/include/Nazara/Shader/Ast/TransformVisitor.hpp new file mode 100644 index 000000000..d2483a095 --- /dev/null +++ b/include/Nazara/Shader/Ast/TransformVisitor.hpp @@ -0,0 +1,90 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERAST_TRANSFORMVISITOR_HPP +#define NAZARA_SHADERAST_TRANSFORMVISITOR_HPP + +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_SHADER_API TransformVisitor : AstCloner + { + public: + inline TransformVisitor(); + TransformVisitor(const TransformVisitor&) = delete; + TransformVisitor(TransformVisitor&&) = delete; + ~TransformVisitor() = default; + + StatementPtr Transform(StatementPtr& statement); + + TransformVisitor& operator=(const TransformVisitor&) = delete; + TransformVisitor& operator=(TransformVisitor&&) = delete; + + private: + struct Identifier; + + ExpressionPtr Clone(AccessMemberIdentifierExpression& node) override; + ExpressionPtr Clone(CastExpression& node) override; + ExpressionPtr Clone(IdentifierExpression& node) override; + ExpressionPtr CloneExpression(ExpressionPtr& expr) override; + + inline const Identifier* FindIdentifier(const std::string_view& identifierName) const; + + void PushScope(); + void PopScope(); + + inline std::size_t RegisterFunction(std::string name); + inline std::size_t RegisterStruct(std::string name, StructDescription description); + inline std::size_t RegisterVariable(std::string name); + + ExpressionType ResolveType(const ExpressionType& exprType); + + using AstCloner::Visit; + void Visit(BranchStatement& node) override; + void Visit(ConditionalStatement& node) override; + void Visit(DeclareExternalStatement& node) override; + void Visit(DeclareFunctionStatement& node) override; + void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; + void Visit(MultiStatement& node) override; + + struct Alias + { + std::variant value; + }; + + struct Struct + { + std::size_t structIndex; + }; + + struct Variable + { + std::size_t varIndex; + }; + + struct Identifier + { + std::string name; + std::variant value; + }; + + private: + std::size_t m_nextFuncIndex; + std::size_t m_nextVarIndex; + std::vector m_identifiersInScope; + std::vector m_structs; + std::vector m_scopeSizes; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/Ast/TransformVisitor.inl b/include/Nazara/Shader/Ast/TransformVisitor.inl new file mode 100644 index 000000000..07601b026 --- /dev/null +++ b/include/Nazara/Shader/Ast/TransformVisitor.inl @@ -0,0 +1,62 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + inline TransformVisitor::TransformVisitor() : + m_nextFuncIndex(0), + m_nextVarIndex(0) + { + } + + inline auto TransformVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* + { + auto it = std::find_if(m_identifiersInScope.rbegin(), m_identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); + if (it == m_identifiersInScope.rend()) + return nullptr; + + return &*it; + } + + inline std::size_t TransformVisitor::RegisterFunction(std::string name) + { + std::size_t funcIndex = m_nextFuncIndex++; + return funcIndex; + } + + + inline std::size_t TransformVisitor::RegisterStruct(std::string name, StructDescription description) + { + std::size_t structIndex = m_structs.size(); + m_structs.emplace_back(std::move(description)); + + m_identifiersInScope.push_back({ + std::move(name), + Struct { + structIndex + } + }); + + return structIndex; + } + + inline std::size_t TransformVisitor::RegisterVariable(std::string name) + { + std::size_t varIndex = m_nextVarIndex++; + + m_identifiersInScope.push_back({ + std::move(name), + Variable { + varIndex + } + }); + + return varIndex; + } +} + +#include diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 8b5f964e8..7b4741cd4 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -52,6 +52,7 @@ namespace Nz void Append(ShaderAst::NoType); void Append(ShaderAst::PrimitiveType type); void Append(const ShaderAst::SamplerType& samplerType); + void Append(const ShaderAst::StructType& structType); void Append(const ShaderAst::UniformType& uniformType); void Append(const ShaderAst::VectorType& vecType); template void Append(const T& param); @@ -67,7 +68,7 @@ namespace Nz void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); - void Visit(ShaderAst::AccessMemberExpression& node) override; + void Visit(ShaderAst::AccessMemberIdentifierExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::CastExpression& node) override; diff --git a/include/Nazara/Shader/ShaderAstCloner.hpp b/include/Nazara/Shader/ShaderAstCloner.hpp index 6b7d6b2b7..33b6acf3a 100644 --- a/include/Nazara/Shader/ShaderAstCloner.hpp +++ b/include/Nazara/Shader/ShaderAstCloner.hpp @@ -30,15 +30,25 @@ namespace Nz::ShaderAst AstCloner& operator=(AstCloner&&) = delete; protected: - ExpressionPtr CloneExpression(ExpressionPtr& expr); - StatementPtr CloneStatement(StatementPtr& statement); + virtual ExpressionPtr CloneExpression(ExpressionPtr& expr); + virtual StatementPtr CloneStatement(StatementPtr& statement); + virtual StatementPtr Clone(DeclareExternalStatement& node); virtual StatementPtr Clone(DeclareFunctionStatement& node); + virtual StatementPtr Clone(DeclareStructStatement& node); + virtual StatementPtr Clone(DeclareVariableStatement& node); + + virtual ExpressionPtr Clone(AccessMemberIdentifierExpression& node); + virtual ExpressionPtr Clone(AccessMemberIndexExpression& node); + virtual ExpressionPtr Clone(CastExpression& node); + virtual ExpressionPtr Clone(IdentifierExpression& node); + virtual ExpressionPtr Clone(VariableExpression& node); using AstExpressionVisitor::Visit; using AstStatementVisitor::Visit; - void Visit(AccessMemberExpression& node) override; + void Visit(AccessMemberIdentifierExpression& node) override; + void Visit(AccessMemberIndexExpression& node) override; void Visit(AssignExpression& node) override; void Visit(BinaryExpression& node) override; void Visit(CastExpression& node) override; @@ -47,6 +57,7 @@ namespace Nz::ShaderAst void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; void Visit(SwizzleExpression& node) override; + void Visit(VariableExpression& node) override; void Visit(BranchStatement& node) override; void Visit(ConditionalStatement& node) override; diff --git a/include/Nazara/Shader/ShaderAstNodes.hpp b/include/Nazara/Shader/ShaderAstNodes.hpp index 86b19521b..0bd48175c 100644 --- a/include/Nazara/Shader/ShaderAstNodes.hpp +++ b/include/Nazara/Shader/ShaderAstNodes.hpp @@ -26,7 +26,8 @@ #define NAZARA_SHADERAST_STATEMENT_LAST(X) NAZARA_SHADERAST_STATEMENT(X) #endif -NAZARA_SHADERAST_EXPRESSION(AccessMemberExpression) +NAZARA_SHADERAST_EXPRESSION(AccessMemberIdentifierExpression) +NAZARA_SHADERAST_EXPRESSION(AccessMemberIndexExpression) NAZARA_SHADERAST_EXPRESSION(AssignExpression) NAZARA_SHADERAST_EXPRESSION(BinaryExpression) NAZARA_SHADERAST_EXPRESSION(CastExpression) @@ -35,6 +36,7 @@ NAZARA_SHADERAST_EXPRESSION(ConstantExpression) NAZARA_SHADERAST_EXPRESSION(IdentifierExpression) NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression) NAZARA_SHADERAST_EXPRESSION(SwizzleExpression) +NAZARA_SHADERAST_EXPRESSION(VariableExpression) NAZARA_SHADERAST_STATEMENT(BranchStatement) NAZARA_SHADERAST_STATEMENT(ConditionalStatement) NAZARA_SHADERAST_STATEMENT(DeclareExternalStatement) diff --git a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp index 1bfbb65da..f05a20d22 100644 --- a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp @@ -20,7 +20,8 @@ namespace Nz::ShaderAst AstRecursiveVisitor() = default; ~AstRecursiveVisitor() = default; - void Visit(AccessMemberExpression& node) override; + void Visit(AccessMemberIdentifierExpression& node) override; + void Visit(AccessMemberIndexExpression& node) override; void Visit(AssignExpression& node) override; void Visit(BinaryExpression& node) override; void Visit(CastExpression& node) override; @@ -29,6 +30,7 @@ namespace Nz::ShaderAst void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; void Visit(SwizzleExpression& node) override; + void Visit(VariableExpression& node) override; void Visit(BranchStatement& node) override; void Visit(ConditionalStatement& node) override; diff --git a/include/Nazara/Shader/ShaderAstSerializer.hpp b/include/Nazara/Shader/ShaderAstSerializer.hpp index 2cb0bd5f1..20ac281bf 100644 --- a/include/Nazara/Shader/ShaderAstSerializer.hpp +++ b/include/Nazara/Shader/ShaderAstSerializer.hpp @@ -23,7 +23,8 @@ namespace Nz::ShaderAst AstSerializerBase(AstSerializerBase&&) = delete; ~AstSerializerBase() = default; - void Serialize(AccessMemberExpression& node); + void Serialize(AccessMemberIdentifierExpression& node); + void Serialize(AccessMemberIndexExpression& node); void Serialize(AssignExpression& node); void Serialize(BinaryExpression& node); void Serialize(CastExpression& node); @@ -32,6 +33,7 @@ namespace Nz::ShaderAst void Serialize(IdentifierExpression& node); void Serialize(IntrinsicExpression& node); void Serialize(SwizzleExpression& node); + void Serialize(VariableExpression& node); void Serialize(BranchStatement& node); void Serialize(ConditionalStatement& node); diff --git a/include/Nazara/Shader/ShaderAstUtils.hpp b/include/Nazara/Shader/ShaderAstUtils.hpp index 3f577ed3e..a560bb5b6 100644 --- a/include/Nazara/Shader/ShaderAstUtils.hpp +++ b/include/Nazara/Shader/ShaderAstUtils.hpp @@ -31,7 +31,8 @@ namespace Nz::ShaderAst private: using AstExpressionVisitor::Visit; - void Visit(AccessMemberExpression& node) override; + void Visit(AccessMemberIdentifierExpression& node) override; + void Visit(AccessMemberIndexExpression& node) override; void Visit(AssignExpression& node) override; void Visit(BinaryExpression& node) override; void Visit(CastExpression& node) override; @@ -40,6 +41,7 @@ namespace Nz::ShaderAst void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; void Visit(SwizzleExpression& node) override; + void Visit(VariableExpression& node) override; ExpressionCategory m_expressionCategory; }; diff --git a/include/Nazara/Shader/ShaderAstValidator.hpp b/include/Nazara/Shader/ShaderAstValidator.hpp index b617e27c7..9556eee0f 100644 --- a/include/Nazara/Shader/ShaderAstValidator.hpp +++ b/include/Nazara/Shader/ShaderAstValidator.hpp @@ -34,7 +34,7 @@ namespace Nz::ShaderAst ExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); const ExpressionType& ResolveAlias(const ExpressionType& expressionType); - void Visit(AccessMemberExpression& node) override; + void Visit(AccessMemberIdentifierExpression& node) override; void Visit(AssignExpression& node) override; void Visit(BinaryExpression& node) override; void Visit(CastExpression& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index c6edb8041..c9f0b50de 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -15,6 +15,11 @@ namespace Nz::ShaderBuilder { namespace Impl { + struct AccessMember + { + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr structExpr, std::vector memberIdentifiers) const; + }; + struct Assign { inline std::unique_ptr operator()(ShaderAst::AssignType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const; @@ -36,9 +41,19 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, std::vector expressions) const; }; + struct ConditionalExpression + { + inline std::unique_ptr operator()(std::string conditionName, ShaderAst::ExpressionPtr truePath, ShaderAst::ExpressionPtr falsePath) const; + }; + + struct ConditionalStatement + { + inline std::unique_ptr operator()(std::string conditionName, ShaderAst::StatementPtr statement) const; + }; + struct Constant { - inline std::unique_ptr operator()(ShaderConstantValue value) const; + inline std::unique_ptr operator()(ShaderAst::ConstantValue value) const; }; struct DeclareFunction @@ -83,12 +98,20 @@ namespace Nz::ShaderBuilder { inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expr = nullptr) const; }; + + struct Swizzle + { + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression, std::vector swizzleComponents) const; + }; } + constexpr Impl::AccessMember AccessMember; constexpr Impl::Assign Assign; constexpr Impl::Binary Binary; constexpr Impl::Branch Branch; constexpr Impl::Cast Cast; + constexpr Impl::ConditionalExpression ConditionalExpression; + constexpr Impl::ConditionalStatement ConditionalStatement; constexpr Impl::Constant Constant; constexpr Impl::DeclareFunction DeclareFunction; constexpr Impl::DeclareStruct DeclareStruct; @@ -99,6 +122,7 @@ namespace Nz::ShaderBuilder constexpr Impl::Intrinsic Intrinsic; constexpr Impl::NoParam NoOp; constexpr Impl::Return Return; + constexpr Impl::Swizzle Swizzle; } #include diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index feff95b8a..ec83815a7 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -7,6 +7,15 @@ namespace Nz::ShaderBuilder { + inline std::unique_ptr Impl::AccessMember::operator()(ShaderAst::ExpressionPtr structExpr, std::vector memberIdentifiers) const + { + auto accessMemberNode = std::make_unique(); + accessMemberNode->structExpr = std::move(structExpr); + accessMemberNode->memberIdentifiers = std::move(memberIdentifiers); + + return accessMemberNode; + } + inline std::unique_ptr Impl::Assign::operator()(ShaderAst::AssignType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const { auto assignNode = std::make_unique(); @@ -61,7 +70,26 @@ namespace Nz::ShaderBuilder return castNode; } - inline std::unique_ptr Impl::Constant::operator()(ShaderConstantValue value) const + inline std::unique_ptr Impl::ConditionalExpression::operator()(std::string conditionName, ShaderAst::ExpressionPtr truePath, ShaderAst::ExpressionPtr falsePath) const + { + auto condExprNode = std::make_unique(); + condExprNode->conditionName = std::move(conditionName); + condExprNode->falsePath = std::move(falsePath); + condExprNode->truePath = std::move(truePath); + + return condExprNode; + } + + inline std::unique_ptr Impl::ConditionalStatement::operator()(std::string conditionName, ShaderAst::StatementPtr statement) const + { + auto condStatementNode = std::make_unique(); + condStatementNode->conditionName = std::move(conditionName); + condStatementNode->statement = std::move(statement); + + return condStatementNode; + } + + inline std::unique_ptr Impl::Constant::operator()(ShaderAst::ConstantValue value) const { auto constantNode = std::make_unique(); constantNode->value = std::move(value); @@ -157,6 +185,18 @@ namespace Nz::ShaderBuilder { return std::make_unique(); } + + inline std::unique_ptr Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector swizzleComponents) const + { + auto swizzleNode = std::make_unique(); + swizzleNode->expression = std::move(expression); + + assert(swizzleComponents.size() <= swizzleNode->components.size()); + for (std::size_t i = 0; i < swizzleComponents.size(); ++i) + swizzleNode->components[i] = swizzleComponents[i]; + + return swizzleNode; + } } #include diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index 7b442c07c..5413457f2 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -12,7 +12,7 @@ #include #include #include -#include +#include #include #include #include @@ -60,7 +60,7 @@ namespace Nz::ShaderAst std::optional cachedExpressionType; }; - struct NAZARA_SHADER_API AccessMemberExpression : public Expression + struct NAZARA_SHADER_API AccessMemberIdentifierExpression : public Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -69,6 +69,15 @@ namespace Nz::ShaderAst std::vector memberIdentifiers; }; + struct NAZARA_SHADER_API AccessMemberIndexExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + ExpressionPtr structExpr; + std::vector memberIndices; + }; + struct NAZARA_SHADER_API AssignExpression : public Expression { NodeType GetType() const override; @@ -113,7 +122,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; - ShaderConstantValue value; + ShaderAst::ConstantValue value; }; struct NAZARA_SHADER_API IdentifierExpression : public Expression @@ -143,6 +152,14 @@ namespace Nz::ShaderAst ExpressionPtr expression; }; + struct NAZARA_SHADER_API VariableExpression : Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::size_t variableId; + }; + // Statements struct Statement; @@ -193,11 +210,12 @@ namespace Nz::ShaderAst struct ExternalVar { - std::vector attributes; std::string name; + std::vector attributes; ExpressionType type; }; + std::optional varIndex; std::vector attributes; std::vector externalVars; }; @@ -213,6 +231,8 @@ namespace Nz::ShaderAst ExpressionType type; }; + std::optional funcIndex; + std::optional varIndex; std::string name; std::vector attributes; std::vector parameters; @@ -225,6 +245,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; + std::optional structIndex; std::vector attributes; StructDescription description; }; @@ -234,6 +255,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; + std::optional varIndex; std::string varName; ExpressionPtr initialExpression; ExpressionType varType; @@ -274,6 +296,8 @@ namespace Nz::ShaderAst ExpressionPtr returnExpr; }; + + inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr); } #include diff --git a/include/Nazara/Shader/ShaderNodes.inl b/include/Nazara/Shader/ShaderNodes.inl index 6c702a06e..d8b320b72 100644 --- a/include/Nazara/Shader/ShaderNodes.inl +++ b/include/Nazara/Shader/ShaderNodes.inl @@ -7,6 +7,11 @@ namespace Nz::ShaderAst { + const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr) + { + assert(expr.cachedExpressionType); + return expr.cachedExpressionType.value(); + } } #include diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index d868dab41..0f2d06495 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -11,7 +11,9 @@ #include #include #include +#include #include +#include #include namespace Nz @@ -21,17 +23,25 @@ namespace Nz class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept { public: - inline SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks); + struct EntryPoint; + struct FuncData; + struct Variable; + + inline SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::vector& funcData); SpirvAstVisitor(const SpirvAstVisitor&) = delete; SpirvAstVisitor(SpirvAstVisitor&&) = delete; ~SpirvAstVisitor() = default; + UInt32 AllocateResultId(); + UInt32 EvaluateExpression(ShaderAst::ExpressionPtr& expr); + const Variable& GetVariable(std::size_t varIndex) const; + using ExpressionVisitorExcept::Visit; using StatementVisitorExcept::Visit; - void Visit(ShaderAst::AccessMemberExpression& node) override; + void Visit(ShaderAst::AccessMemberIndexExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::BranchStatement& node) override; @@ -39,27 +49,102 @@ namespace Nz void Visit(ShaderAst::ConditionalExpression& node) override; void Visit(ShaderAst::ConditionalStatement& node) override; void Visit(ShaderAst::ConstantExpression& node) override; + void Visit(ShaderAst::DeclareExternalStatement& node) override; + void Visit(ShaderAst::DeclareFunctionStatement& node) override; + void Visit(ShaderAst::DeclareStructStatement& node) override; void Visit(ShaderAst::DeclareVariableStatement& node) override; void Visit(ShaderAst::DiscardStatement& node) override; void Visit(ShaderAst::ExpressionStatement& node) override; - void Visit(ShaderAst::IdentifierExpression& node) override; void Visit(ShaderAst::IntrinsicExpression& node) override; void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; + void Visit(ShaderAst::VariableExpression& node) override; SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete; SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete; + struct EntryPoint + { + struct Input + { + UInt32 memberIndexConstantId; + UInt32 memberPointerId; + UInt32 varId; + }; + + struct Output + { + Int32 memberIndex; + UInt32 typeId; + UInt32 varId; + }; + + struct InputStruct + { + UInt32 pointerId; + UInt32 typeId; + }; + + ShaderStageType stageType; + std::optional inputStruct; + std::optional outputStructTypeId; + std::size_t funcIndex; + std::vector inputs; + std::vector outputs; + }; + + struct FuncData + { + std::optional entryPointData; + + struct Parameter + { + UInt32 pointerTypeId; + UInt32 typeId; + }; + + struct Variable + { + UInt32 typeId; + UInt32 varId; + }; + + std::string name; + std::vector parameters; + std::vector variables; + std::unordered_map varIndexToVarId; + UInt32 funcId; + UInt32 funcTypeId; + UInt32 returnTypeId; + }; + + struct Variable + { + SpirvStorageClass storage; + UInt32 pointerId; + UInt32 pointedTypeId; + }; + private: - inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr) const; void PushResultId(UInt32 value); UInt32 PopResultId(); - std::vector& m_blocks; + inline void RegisterExternalVariable(std::size_t varIndex, const ShaderAst::ExpressionType& type); + inline void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription structDesc); + inline void RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass); + + std::size_t m_extVarIndex; + std::size_t m_funcIndex; + std::vector m_scopeSizes; + std::vector& m_funcData; + std::vector m_structs; + std::vector> m_variables; + std::vector m_functionBlocks; std::vector m_resultIds; SpirvBlock* m_currentBlock; + SpirvSection& m_instructions; SpirvWriter& m_writer; }; } diff --git a/include/Nazara/Shader/SpirvAstVisitor.inl b/include/Nazara/Shader/SpirvAstVisitor.inl index bb54eb594..f67692f6f 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.inl +++ b/include/Nazara/Shader/SpirvAstVisitor.inl @@ -7,17 +7,42 @@ namespace Nz { - inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks) : - m_blocks(blocks), + inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::vector& funcData) : + m_extVarIndex(0), + m_funcIndex(0), + m_funcData(funcData), + m_currentBlock(nullptr), + m_instructions(instructions), m_writer(writer) { - m_currentBlock = &m_blocks.back(); } - inline const ShaderAst::ExpressionType& SpirvAstVisitor::GetExpressionType(ShaderAst::Expression& expr) const + void SpirvAstVisitor::RegisterExternalVariable(std::size_t varIndex, const ShaderAst::ExpressionType& type) { - assert(expr.cachedExpressionType); - return expr.cachedExpressionType.value(); + UInt32 pointerId = m_writer.GetExtVarPointerId(varIndex); + SpirvStorageClass storageClass = (IsSamplerType(type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform; + + RegisterVariable(varIndex, m_writer.GetTypeId(type), pointerId, storageClass); + } + + inline void SpirvAstVisitor::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription structDesc) + { + if (structIndex >= m_structs.size()) + m_structs.resize(structIndex + 1); + + m_structs[structIndex] = std::move(structDesc); + } + + inline void SpirvAstVisitor::RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass) + { + if (varIndex >= m_variables.size()) + m_variables.resize(varIndex + 1); + + m_variables[varIndex] = Variable{ + storageClass, + pointerId, + typeId + }; } } diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index c01ab8d68..3025c013d 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -8,8 +8,8 @@ #define NAZARA_SPIRVCONSTANTCACHE_HPP #include -#include #include +#include #include #include #include @@ -25,6 +25,8 @@ namespace Nz class NAZARA_SHADER_API SpirvConstantCache { public: + using StructCallback = std::function; + SpirvConstantCache(UInt32& resultId); SpirvConstantCache(const SpirvConstantCache& cache) = delete; SpirvConstantCache(SpirvConstantCache&& cache) noexcept; @@ -37,8 +39,6 @@ namespace Nz using ConstantPtr = std::shared_ptr; using TypePtr = std::shared_ptr; - using IdentifierCallback = std::function; - struct Bool {}; struct Float @@ -66,11 +66,6 @@ namespace Nz UInt32 columnCount; }; - struct Identifier - { - std::string name; - }; - struct Image { std::optional qualifier; @@ -112,7 +107,7 @@ namespace Nz std::vector members; }; - using AnyType = std::variant; + using AnyType = std::variant; struct ConstantBool { @@ -134,10 +129,11 @@ namespace Nz struct Variable { + std::optional funcId; //< For inputs/outputs + std::optional initializer; std::string debugName; TypePtr type; SpirvStorageClass storageClass; - std::optional initializer; }; using BaseType = std::variant; @@ -166,6 +162,21 @@ namespace Nz AnyType type; }; + ConstantPtr BuildConstant(const ShaderAst::ConstantValue& value) const; + TypePtr BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters) const; + TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const; + TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const; + TypePtr BuildType(const ShaderAst::ExpressionType& type) const; + TypePtr BuildType(const ShaderAst::IdentifierType& type) const; + TypePtr BuildType(const ShaderAst::MatrixType& type) const; + TypePtr BuildType(const ShaderAst::NoType& type) const; + TypePtr BuildType(const ShaderAst::PrimitiveType& type) const; + TypePtr BuildType(const ShaderAst::SamplerType& type) const; + TypePtr BuildType(const ShaderAst::StructType& type) const; + TypePtr BuildType(const ShaderAst::StructDescription& structDesc) const; + TypePtr BuildType(const ShaderAst::VectorType& type) const; + TypePtr BuildType(const ShaderAst::UniformType& type) const; + UInt32 GetId(const Constant& c); UInt32 GetId(const Type& t); UInt32 GetId(const Variable& v); @@ -174,26 +185,13 @@ namespace Nz UInt32 Register(Type t); UInt32 Register(Variable v); - void SetIdentifierCallback(IdentifierCallback callback); + void SetStructCallback(StructCallback callback); void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos); SpirvConstantCache& operator=(const SpirvConstantCache& cache) = delete; SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept; - static ConstantPtr BuildConstant(const ShaderConstantValue& value); - static TypePtr BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters); - static TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass); - static TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass); - static TypePtr BuildType(const ShaderAst::ExpressionType& type); - static TypePtr BuildType(const ShaderAst::IdentifierType& type); - static TypePtr BuildType(const ShaderAst::MatrixType& type); - static TypePtr BuildType(const ShaderAst::NoType& type); - static TypePtr BuildType(const ShaderAst::PrimitiveType& type); - static TypePtr BuildType(const ShaderAst::SamplerType& type); - static TypePtr BuildType(const ShaderAst::StructDescription& structDesc); - static TypePtr BuildType(const ShaderAst::VectorType& type); - private: struct DepRegisterer; struct Eq; @@ -204,7 +202,6 @@ namespace Nz void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos); - IdentifierCallback m_identifierCallback; std::unique_ptr m_internal; }; } diff --git a/include/Nazara/Shader/SpirvExpressionLoad.hpp b/include/Nazara/Shader/SpirvExpressionLoad.hpp index ddc9551fd..bff0047cf 100644 --- a/include/Nazara/Shader/SpirvExpressionLoad.hpp +++ b/include/Nazara/Shader/SpirvExpressionLoad.hpp @@ -15,13 +15,14 @@ namespace Nz { + class SpirvAstVisitor; class SpirvBlock; class SpirvWriter; class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAst::ExpressionVisitorExcept { public: - inline SpirvExpressionLoad(SpirvWriter& writer, SpirvBlock& block); + inline SpirvExpressionLoad(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block); SpirvExpressionLoad(const SpirvExpressionLoad&) = delete; SpirvExpressionLoad(SpirvExpressionLoad&&) = delete; ~SpirvExpressionLoad() = default; @@ -29,8 +30,8 @@ namespace Nz UInt32 Evaluate(ShaderAst::Expression& node); using ExpressionVisitorExcept::Visit; - //void Visit(ShaderAst::AccessMemberExpression& node) override; - void Visit(ShaderAst::IdentifierExpression& node) override; + void Visit(ShaderAst::AccessMemberIndexExpression& node) override; + void Visit(ShaderAst::VariableExpression& node) override; SpirvExpressionLoad& operator=(const SpirvExpressionLoad&) = delete; SpirvExpressionLoad& operator=(SpirvExpressionLoad&&) = delete; @@ -39,7 +40,7 @@ namespace Nz struct Pointer { SpirvStorageClass storage; - UInt32 resultId; + UInt32 pointerId; UInt32 pointedTypeId; }; @@ -48,6 +49,7 @@ namespace Nz UInt32 resultId; }; + SpirvAstVisitor& m_visitor; SpirvBlock& m_block; SpirvWriter& m_writer; std::variant m_value; diff --git a/include/Nazara/Shader/SpirvExpressionLoad.inl b/include/Nazara/Shader/SpirvExpressionLoad.inl index 6d5aff9cb..1522515b9 100644 --- a/include/Nazara/Shader/SpirvExpressionLoad.inl +++ b/include/Nazara/Shader/SpirvExpressionLoad.inl @@ -7,9 +7,10 @@ namespace Nz { - inline SpirvExpressionLoad::SpirvExpressionLoad(SpirvWriter& writer, SpirvBlock& block) : + inline SpirvExpressionLoad::SpirvExpressionLoad(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block) : m_block(block), - m_writer(writer) + m_writer(writer), + m_visitor(visitor) { } } diff --git a/include/Nazara/Shader/SpirvExpressionStore.hpp b/include/Nazara/Shader/SpirvExpressionStore.hpp index ee0d96f6a..e66d545d3 100644 --- a/include/Nazara/Shader/SpirvExpressionStore.hpp +++ b/include/Nazara/Shader/SpirvExpressionStore.hpp @@ -14,13 +14,14 @@ namespace Nz { + class SpirvAstVisitor; class SpirvBlock; class SpirvWriter; class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAst::ExpressionVisitorExcept { public: - inline SpirvExpressionStore(SpirvWriter& writer, SpirvBlock& block); + inline SpirvExpressionStore(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block); SpirvExpressionStore(const SpirvExpressionStore&) = delete; SpirvExpressionStore(SpirvExpressionStore&&) = delete; ~SpirvExpressionStore() = default; @@ -28,9 +29,9 @@ namespace Nz void Store(ShaderAst::ExpressionPtr& node, UInt32 resultId); using ExpressionVisitorExcept::Visit; - //void Visit(ShaderAst::AccessMemberExpression& node) override; - void Visit(ShaderAst::IdentifierExpression& node) override; + void Visit(ShaderAst::AccessMemberIndexExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; + void Visit(ShaderAst::VariableExpression& node) override; SpirvExpressionStore& operator=(const SpirvExpressionStore&) = delete; SpirvExpressionStore& operator=(SpirvExpressionStore&&) = delete; @@ -44,9 +45,10 @@ namespace Nz struct Pointer { SpirvStorageClass storage; - UInt32 resultId; + UInt32 pointerId; }; + SpirvAstVisitor& m_visitor; SpirvBlock& m_block; SpirvWriter& m_writer; std::variant m_value; diff --git a/include/Nazara/Shader/SpirvExpressionStore.inl b/include/Nazara/Shader/SpirvExpressionStore.inl index 771624788..16326a438 100644 --- a/include/Nazara/Shader/SpirvExpressionStore.inl +++ b/include/Nazara/Shader/SpirvExpressionStore.inl @@ -7,9 +7,10 @@ namespace Nz { - inline SpirvExpressionStore::SpirvExpressionStore(SpirvWriter& writer, SpirvBlock& block) : + inline SpirvExpressionStore::SpirvExpressionStore(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block) : m_block(block), - m_writer(writer) + m_writer(writer), + m_visitor(visitor) { } } diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index 3b770ae87..2bab892ca 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include #include #include @@ -27,7 +27,6 @@ namespace Nz friend class SpirvBlock; friend class SpirvExpressionLoad; friend class SpirvExpressionStore; - friend class SpirvVisitor; public: struct Environment; @@ -48,7 +47,6 @@ namespace Nz }; private: - struct ExtVar; struct FunctionParameter; struct OnlyCache {}; @@ -56,36 +54,21 @@ namespace Nz void AppendHeader(); - UInt32 GetConstantId(const ShaderConstantValue& value) const; + SpirvConstantCache::TypePtr BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode); + + UInt32 GetConstantId(const ShaderAst::ConstantValue& value) const; + UInt32 GetExtVarPointerId(std::size_t varIndex) const; UInt32 GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode); - const ExtVar& GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const; - const ExtVar& GetInputVariable(const std::string& name) const; - const ExtVar& GetOutputVariable(const std::string& name) const; - const ExtVar& GetUniformVariable(const std::string& name) const; UInt32 GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const; UInt32 GetTypeId(const ShaderAst::ExpressionType& type) const; inline bool IsConditionEnabled(const std::string& condition) const; - UInt32 ReadInputVariable(const std::string& name); - std::optional ReadInputVariable(const std::string& name, OnlyCache); - UInt32 ReadLocalVariable(const std::string& name); - std::optional ReadLocalVariable(const std::string& name, OnlyCache); - UInt32 ReadParameterVariable(const std::string& name); - std::optional ReadParameterVariable(const std::string& name, OnlyCache); - UInt32 ReadUniformVariable(const std::string& name); - std::optional ReadUniformVariable(const std::string& name, OnlyCache); - UInt32 ReadVariable(ExtVar& var); - std::optional ReadVariable(const ExtVar& var, OnlyCache); - - UInt32 RegisterConstant(const ShaderConstantValue& value); + UInt32 RegisterConstant(const ShaderAst::ConstantValue& value); UInt32 RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode); UInt32 RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass); UInt32 RegisterType(ShaderAst::ExpressionType type); - void WriteLocalVariable(std::string name, UInt32 resultId); - - static SpirvConstantCache::TypePtr BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode); static void MergeSections(std::vector& output, const SpirvSection& from); struct Context @@ -93,20 +76,6 @@ namespace Nz const States* states = nullptr; }; - struct ExtVar - { - UInt32 pointerTypeId; - UInt32 typeId; - UInt32 varId; - std::optional valueId; - }; - - struct FunctionParameter - { - std::string name; - ShaderAst::ExpressionType type; - }; - struct State; Context m_context; diff --git a/src/Nazara/Shader/Ast/TransformVisitor.cpp b/src/Nazara/Shader/Ast/TransformVisitor.cpp new file mode 100644 index 000000000..14f127697 --- /dev/null +++ b/src/Nazara/Shader/Ast/TransformVisitor.cpp @@ -0,0 +1,225 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include + +namespace Nz::ShaderAst +{ + StatementPtr TransformVisitor::Transform(StatementPtr& nodePtr) + { + StatementPtr clone; + + PushScope(); //< Global scope + { + clone = AstCloner::Clone(nodePtr); + } + PopScope(); + + return clone; + } + + void TransformVisitor::Visit(BranchStatement& node) + { + for (auto& cond : node.condStatements) + { + PushScope(); + { + cond.condition->Visit(*this); + cond.statement->Visit(*this); + } + PopScope(); + } + + if (node.elseStatement) + { + PushScope(); + { + node.elseStatement->Visit(*this); + } + PopScope(); + } + } + + void TransformVisitor::Visit(ConditionalStatement& node) + { + PushScope(); + { + AstCloner::Visit(node); + } + PopScope(); + } + + ExpressionType TransformVisitor::ResolveType(const ExpressionType& exprType) + { + return std::visit([&](auto&& arg) -> ExpressionType + { + using T = std::decay_t; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return exprType; + } + else if constexpr (std::is_same_v) + { + const Identifier* identifier = FindIdentifier(arg.name); + assert(identifier); + assert(std::holds_alternative(identifier->value)); + + return StructType{ std::get(identifier->value).structIndex }; + } + else if constexpr (std::is_same_v) + { + return std::visit([&](auto&& containedArg) + { + ExpressionType resolvedType = ResolveType(containedArg); + assert(std::holds_alternative(resolvedType)); + + return UniformType{ std::get(resolvedType) }; + }, arg.containedType); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, exprType); + } + + void TransformVisitor::Visit(DeclareExternalStatement& node) + { + for (auto& extVar : node.externalVars) + { + extVar.type = ResolveType(extVar.type); + + std::size_t varIndex = RegisterVariable(extVar.name); + if (!node.varIndex) + node.varIndex = varIndex; + } + + AstCloner::Visit(node); + } + + void TransformVisitor::Visit(DeclareFunctionStatement& node) + { + node.funcIndex = m_nextFuncIndex++; + node.returnType = ResolveType(node.returnType); + for (auto& parameter : node.parameters) + parameter.type = ResolveType(parameter.type); + + PushScope(); + { + for (auto& parameter : node.parameters) + { + std::size_t varIndex = RegisterVariable(parameter.name); + if (!node.varIndex) + node.varIndex = varIndex; + } + + AstCloner::Visit(node); + } + PopScope(); + } + + void TransformVisitor::Visit(DeclareStructStatement& node) + { + node.structIndex = RegisterStruct(node.description.name, node.description); + + AstCloner::Visit(node); + } + + void TransformVisitor::Visit(DeclareVariableStatement& node) + { + node.varType = ResolveType(node.varType); + node.varIndex = RegisterVariable(node.varName); + + AstCloner::Visit(node); + } + + void TransformVisitor::Visit(MultiStatement& node) + { + PushScope(); + { + AstCloner::Visit(node); + } + PopScope(); + } + + ExpressionPtr TransformVisitor::Clone(AccessMemberIdentifierExpression& node) + { + auto accessMemberIndex = std::make_unique(); + accessMemberIndex->structExpr = CloneExpression(node.structExpr); + accessMemberIndex->cachedExpressionType = node.cachedExpressionType; + accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size()); + + ExpressionType exprType = GetExpressionType(*node.structExpr); + for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i) + { + exprType = ResolveType(exprType); + assert(std::holds_alternative(exprType)); + + std::size_t structIndex = std::get(exprType).structIndex; + assert(structIndex < m_structs.size()); + const StructDescription& structDesc = m_structs[structIndex]; + + auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; }); + assert(it != structDesc.members.end()); + + accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it); + exprType = it->type; + } + + return accessMemberIndex; + } + + ExpressionPtr TransformVisitor::Clone(CastExpression& node) + { + ExpressionPtr expr = AstCloner::Clone(node); + + CastExpression* castExpr = static_cast(expr.get()); + castExpr->targetType = ResolveType(castExpr->targetType); + + return expr; + } + + ExpressionPtr TransformVisitor::Clone(IdentifierExpression& node) + { + const Identifier* identifier = FindIdentifier(node.identifier); + assert(identifier); + assert(std::holds_alternative(identifier->value)); + + auto varExpr = std::make_unique(); + varExpr->cachedExpressionType = node.cachedExpressionType; + varExpr->variableId = std::get(identifier->value).varIndex; + + return varExpr; + } + + ExpressionPtr TransformVisitor::CloneExpression(ExpressionPtr& expr) + { + ExpressionPtr exprPtr = AstCloner::CloneExpression(expr); + if (exprPtr) + { + assert(exprPtr->cachedExpressionType); + *exprPtr->cachedExpressionType = ResolveType(*exprPtr->cachedExpressionType); + } + + return exprPtr; + } + + void TransformVisitor::PushScope() + { + m_scopeSizes.push_back(m_identifiersInScope.size()); + } + + void TransformVisitor::PopScope() + { + assert(!m_scopeSizes.empty()); + m_identifiersInScope.resize(m_scopeSizes.back()); + m_scopeSizes.pop_back(); + } +} diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 0f1aaac4a..46c63e5f7 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -330,6 +330,11 @@ namespace Nz } } + void GlslWriter::Append(const ShaderAst::StructType& structType) + { + throw std::runtime_error("unexpected struct type"); + } + void GlslWriter::Append(const ShaderAst::UniformType& uniformType) { /* TODO */ @@ -371,6 +376,7 @@ namespace Nz m_currentState->stream << param; } + template void GlslWriter::Append(const T1& firstParam, const T2& secondParam, Args&&... params) { @@ -595,7 +601,7 @@ namespace Nz Append(")"); } - void GlslWriter::Visit(ShaderAst::AccessMemberExpression& node) + void GlslWriter::Visit(ShaderAst::AccessMemberIdentifierExpression& node) { Visit(node.structExpr, true); @@ -741,8 +747,6 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node) { - - for (const auto& externalVar : node.externalVars) { std::optional bindingIndex; @@ -774,7 +778,7 @@ namespace Nz EnterScope(); { - const Identifier* identifier = FindIdentifier(std::get(externalVar.type).containedType.name); + const Identifier* identifier = FindIdentifier(std::get(std::get(externalVar.type).containedType).name); assert(identifier); assert(std::holds_alternative(identifier->value)); diff --git a/src/Nazara/Shader/ShaderAstCloner.cpp b/src/Nazara/Shader/ShaderAstCloner.cpp index 4ebfa9548..7e1b3256c 100644 --- a/src/Nazara/Shader/ShaderAstCloner.cpp +++ b/src/Nazara/Shader/ShaderAstCloner.cpp @@ -42,13 +42,25 @@ namespace Nz::ShaderAst return PopStatement(); } + StatementPtr AstCloner::Clone(DeclareExternalStatement& node) + { + auto clone = std::make_unique(); + clone->attributes = node.attributes; + clone->externalVars = node.externalVars; + clone->varIndex = node.varIndex; + + return clone; + } + StatementPtr AstCloner::Clone(DeclareFunctionStatement& node) { auto clone = std::make_unique(); clone->attributes = node.attributes; + clone->funcIndex = node.funcIndex; clone->name = node.name; clone->parameters = node.parameters; clone->returnType = node.returnType; + clone->varIndex = node.varIndex; clone->statements.reserve(node.statements.size()); for (auto& statement : node.statements) @@ -57,15 +69,95 @@ namespace Nz::ShaderAst return clone; } - void AstCloner::Visit(AccessMemberExpression& node) + StatementPtr AstCloner::Clone(DeclareStructStatement& node) { - auto clone = std::make_unique(); + auto clone = std::make_unique(); + clone->structIndex = node.structIndex; + clone->description = node.description; + + return clone; + } + + StatementPtr AstCloner::Clone(DeclareVariableStatement& node) + { + auto clone = std::make_unique(); + clone->varIndex = node.varIndex; + clone->varName = node.varName; + clone->varType = node.varType; + clone->initialExpression = CloneExpression(node.initialExpression); + + return clone; + } + + ExpressionPtr AstCloner::Clone(AccessMemberIdentifierExpression& node) + { + auto clone = std::make_unique(); clone->memberIdentifiers = node.memberIdentifiers; clone->structExpr = CloneExpression(node.structExpr); clone->cachedExpressionType = node.cachedExpressionType; - PushExpression(std::move(clone)); + return clone; + } + + ExpressionPtr AstCloner::Clone(AccessMemberIndexExpression& node) + { + auto clone = std::make_unique(); + clone->memberIndices = node.memberIndices; + clone->structExpr = CloneExpression(node.structExpr); + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + + ExpressionPtr AstCloner::Clone(CastExpression& node) + { + auto clone = std::make_unique(); + clone->targetType = node.targetType; + + std::size_t expressionCount = 0; + for (auto& expr : node.expressions) + { + if (!expr) + break; + + clone->expressions[expressionCount++] = CloneExpression(expr); + } + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + + ExpressionPtr AstCloner::Clone(IdentifierExpression& node) + { + auto clone = std::make_unique(); + clone->identifier = node.identifier; + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + + ExpressionPtr AstCloner::Clone(VariableExpression& node) + { + auto clone = std::make_unique(); + clone->variableId = node.variableId; + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + + void AstCloner::Visit(AccessMemberIdentifierExpression& node) + { + return PushExpression(Clone(node)); + } + + void AstCloner::Visit(AccessMemberIndexExpression& node) + { + return PushExpression(Clone(node)); } void AstCloner::Visit(AssignExpression& node) @@ -94,21 +186,7 @@ namespace Nz::ShaderAst void AstCloner::Visit(CastExpression& node) { - auto clone = std::make_unique(); - clone->targetType = node.targetType; - - std::size_t expressionCount = 0; - for (auto& expr : node.expressions) - { - if (!expr) - break; - - clone->expressions[expressionCount++] = CloneExpression(expr); - } - - clone->cachedExpressionType = node.cachedExpressionType; - - PushExpression(std::move(clone)); + PushExpression(Clone(node)); } void AstCloner::Visit(ConditionalExpression& node) @@ -135,12 +213,7 @@ namespace Nz::ShaderAst void AstCloner::Visit(IdentifierExpression& node) { - auto clone = std::make_unique(); - clone->identifier = node.identifier; - - clone->cachedExpressionType = node.cachedExpressionType; - - PushExpression(std::move(clone)); + PushExpression(Clone(node)); } void AstCloner::Visit(IntrinsicExpression& node) @@ -169,6 +242,11 @@ namespace Nz::ShaderAst PushExpression(std::move(clone)); } + void AstCloner::Visit(VariableExpression& node) + { + PushExpression(Clone(node)); + } + void AstCloner::Visit(BranchStatement& node) { auto clone = std::make_unique(); @@ -197,11 +275,7 @@ namespace Nz::ShaderAst void AstCloner::Visit(DeclareExternalStatement& node) { - auto clone = std::make_unique(); - clone->attributes = node.attributes; - clone->externalVars = node.externalVars; - - PushStatement(std::move(clone)); + PushStatement(Clone(node)); } void AstCloner::Visit(DeclareFunctionStatement& node) @@ -211,20 +285,12 @@ namespace Nz::ShaderAst void AstCloner::Visit(DeclareStructStatement& node) { - auto clone = std::make_unique(); - clone->description = node.description; - - PushStatement(std::move(clone)); + PushStatement(Clone(node)); } void AstCloner::Visit(DeclareVariableStatement& node) { - auto clone = std::make_unique(); - clone->varName = node.varName; - clone->varType = node.varType; - clone->initialExpression = CloneExpression(node.initialExpression); - - PushStatement(std::move(clone)); + PushStatement(Clone(node)); } void AstCloner::Visit(DiscardStatement& /*node*/) diff --git a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp index e0a87ad51..52e6dc3eb 100644 --- a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp @@ -7,7 +7,12 @@ namespace Nz::ShaderAst { - void AstRecursiveVisitor::Visit(AccessMemberExpression& node) + void AstRecursiveVisitor::Visit(AccessMemberIdentifierExpression& node) + { + node.structExpr->Visit(*this); + } + + void AstRecursiveVisitor::Visit(AccessMemberIndexExpression& node) { node.structExpr->Visit(*this); } @@ -62,6 +67,11 @@ namespace Nz::ShaderAst node.expression->Visit(*this); } + void AstRecursiveVisitor::Visit(VariableExpression& node) + { + /* Nothing to do */ + } + void AstRecursiveVisitor::Visit(BranchStatement& node) { for (auto& cond : node.condStatements) diff --git a/src/Nazara/Shader/ShaderAstScopedVisitor.cpp b/src/Nazara/Shader/ShaderAstScopedVisitor.cpp index cc1aec980..39e80470e 100644 --- a/src/Nazara/Shader/ShaderAstScopedVisitor.cpp +++ b/src/Nazara/Shader/ShaderAstScopedVisitor.cpp @@ -53,7 +53,7 @@ namespace Nz::ShaderAst { ExpressionType subType = extVar.type; if (IsUniformType(subType)) - subType = IdentifierType{ std::get(subType).containedType }; + subType = std::get(std::get(subType).containedType); RegisterVariable(extVar.name, std::move(subType)); } diff --git a/src/Nazara/Shader/ShaderAstSerializer.cpp b/src/Nazara/Shader/ShaderAstSerializer.cpp index 20b0b4b63..7102e45fa 100644 --- a/src/Nazara/Shader/ShaderAstSerializer.cpp +++ b/src/Nazara/Shader/ShaderAstSerializer.cpp @@ -33,7 +33,7 @@ namespace Nz::ShaderAst }; } - void AstSerializerBase::Serialize(AccessMemberExpression& node) + void AstSerializerBase::Serialize(AccessMemberIdentifierExpression& node) { Node(node.structExpr); @@ -42,6 +42,15 @@ namespace Nz::ShaderAst Value(identifier); } + void AstSerializerBase::Serialize(AccessMemberIndexExpression& node) + { + Node(node.structExpr); + + Container(node.memberIndices); + for (std::size_t& identifier : node.memberIndices) + SizeT(identifier); + } + void AstSerializerBase::Serialize(AssignExpression& node) { Enum(node.op); @@ -133,6 +142,11 @@ namespace Nz::ShaderAst Enum(node.components[i]); } + void AstSerializerBase::Serialize(VariableExpression& node) + { + SizeT(node.variableId); + } + void AstSerializerBase::Serialize(BranchStatement& node) { @@ -364,14 +378,19 @@ namespace Nz::ShaderAst m_stream << UInt32(arg.dim); m_stream << UInt32(arg.sampledType); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { m_stream << UInt8(5); - m_stream << arg.containedType.name; + m_stream << UInt32(arg.structIndex); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(6); + m_stream << std::get(arg.containedType).name; } else if constexpr (std::is_same_v) { - m_stream << UInt8(6); + m_stream << UInt8(7); m_stream << UInt32(arg.componentCount); m_stream << UInt32(arg.type); } @@ -621,7 +640,18 @@ namespace Nz::ShaderAst break; } - case 5: //< UniformType + case 5: //< StructType + { + UInt32 structIndex; + Value(structIndex); + + type = StructType{ + structIndex + }; + break; + } + + case 6: //< UniformType { std::string containedType; Value(containedType); @@ -634,7 +664,7 @@ namespace Nz::ShaderAst break; } - case 6: //< VectorType + case 7: //< VectorType { UInt32 componentCount; PrimitiveType componentType; diff --git a/src/Nazara/Shader/ShaderAstUtils.cpp b/src/Nazara/Shader/ShaderAstUtils.cpp index c26d38ae5..47f73cdeb 100644 --- a/src/Nazara/Shader/ShaderAstUtils.cpp +++ b/src/Nazara/Shader/ShaderAstUtils.cpp @@ -13,7 +13,12 @@ namespace Nz::ShaderAst return m_expressionCategory; } - void ShaderAstValueCategory::Visit(AccessMemberExpression& node) + void ShaderAstValueCategory::Visit(AccessMemberIdentifierExpression& node) + { + node.structExpr->Visit(*this); + } + + void ShaderAstValueCategory::Visit(AccessMemberIndexExpression& node) { node.structExpr->Visit(*this); } @@ -66,4 +71,9 @@ namespace Nz::ShaderAst { node.expression->Visit(*this); } + + void ShaderAstValueCategory::Visit(VariableExpression& node) + { + m_expressionCategory = ExpressionCategory::LValue; + } } diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index 57610a357..8f998341e 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -131,7 +131,7 @@ namespace Nz::ShaderAst return expressionType; } - void AstValidator::Visit(AccessMemberExpression& node) + void AstValidator::Visit(AccessMemberIdentifierExpression& node) { // Register expressions types AstScopedVisitor::Visit(node); @@ -351,7 +351,7 @@ namespace Nz::ShaderAst if (!exprPtr) break; - ExpressionType exprType = GetExpressionType(*exprPtr); + const ExpressionType& exprType = GetExpressionType(*exprPtr); if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) throw AstError{ "incompatible type" }; @@ -552,14 +552,17 @@ namespace Nz::ShaderAst void AstValidator::Visit(DeclareExternalStatement& node) { - for (const auto& [attributeType, arg] : node.attributes) + if (!node.attributes.empty()) + throw AstError{ "unhandled attribute for external block" }; + + /*for (const auto& [attributeType, arg] : node.attributes) { switch (attributeType) { default: throw AstError{ "unhandled attribute for external block" }; } - } + }*/ for (const auto& extVar : node.externalVars) { diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 3a493ab30..b896055a0 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -602,7 +602,7 @@ namespace Nz::ShaderLang if (Peek().type == TokenType::Dot) { - std::unique_ptr accessMemberNode = std::make_unique(); + std::unique_ptr accessMemberNode = std::make_unique(); accessMemberNode->structExpr = std::move(identifierExpr); do @@ -685,9 +685,9 @@ namespace Nz::ShaderLang if (IsVariableInScope(identifier)) { auto node = ParseIdentifier(); - if (node->GetType() == ShaderAst::NodeType::AccessMemberExpression) + if (node->GetType() == ShaderAst::NodeType::AccessMemberIdentifierExpression) { - ShaderAst::AccessMemberExpression* memberExpr = static_cast(node.get()); + ShaderAst::AccessMemberIdentifierExpression* memberExpr = static_cast(node.get()); if (!memberExpr->memberIdentifiers.empty() && memberExpr->memberIdentifiers.front() == "Sample") { if (Peek().type == TokenType::OpenParenthesis) diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 28aca39ca..9e0795751 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include #include #include @@ -12,6 +13,11 @@ namespace Nz { + UInt32 SpirvAstVisitor::AllocateResultId() + { + return m_writer.AllocateResultId(); + } + UInt32 SpirvAstVisitor::EvaluateExpression(ShaderAst::ExpressionPtr& expr) { expr->Visit(*this); @@ -20,9 +26,16 @@ namespace Nz return PopResultId(); } - void SpirvAstVisitor::Visit(ShaderAst::AccessMemberExpression& node) + auto SpirvAstVisitor::GetVariable(std::size_t varIndex) const -> const Variable& { - SpirvExpressionLoad accessMemberVisitor(m_writer, *m_currentBlock); + assert(varIndex < m_variables.size()); + assert(m_variables[varIndex]); + return *m_variables[varIndex]; + } + + void SpirvAstVisitor::Visit(ShaderAst::AccessMemberIndexExpression& node) + { + SpirvExpressionLoad accessMemberVisitor(m_writer, *this, *m_currentBlock); PushResultId(accessMemberVisitor.Evaluate(node)); } @@ -30,7 +43,7 @@ namespace Nz { UInt32 resultId = EvaluateExpression(node.right); - SpirvExpressionStore storeVisitor(m_writer, *m_currentBlock); + SpirvExpressionStore storeVisitor(m_writer, *this, *m_currentBlock); storeVisitor.Store(node.left, resultId); PushResultId(resultId); @@ -38,18 +51,24 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node) { - ShaderAst::ExpressionType resultExprType = GetExpressionType(node); - assert(IsPrimitiveType(resultExprType)); + auto RetrieveBaseType = [](const ShaderAst::ExpressionType& exprType) + { + if (IsPrimitiveType(exprType)) + return std::get(exprType); + else if (IsVectorType(exprType)) + return std::get(exprType).type; + else if (IsMatrixType(exprType)) + return std::get(exprType).type; + else + throw std::runtime_error("unexpected type"); + }; - ShaderAst::ExpressionType leftExprType = GetExpressionType(*node.left); - assert(IsPrimitiveType(leftExprType)); + const ShaderAst::ExpressionType& resultType = GetExpressionType(node); + const ShaderAst::ExpressionType& leftType = GetExpressionType(*node.left); + const ShaderAst::ExpressionType& rightType = GetExpressionType(*node.right); - ShaderAst::ExpressionType rightExprType = GetExpressionType(*node.right); - assert(IsPrimitiveType(rightExprType)); - - ShaderAst::PrimitiveType resultType = std::get(resultExprType); - ShaderAst::PrimitiveType leftType = std::get(leftExprType); - ShaderAst::PrimitiveType rightType = std::get(rightExprType); + ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType); + ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType); UInt32 leftOperand = EvaluateExpression(node.left); @@ -64,28 +83,16 @@ namespace Nz { case ShaderAst::BinaryType::Add: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Float32: -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFAdd; case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpIAdd; case ShaderAst::PrimitiveType::Boolean: -// case ShaderAst::PrimitiveType::Sampler2D: -// case ShaderAst::PrimitiveType::Void: break; } @@ -94,28 +101,16 @@ namespace Nz case ShaderAst::BinaryType::Subtract: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Float32: -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFSub; case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpISub; case ShaderAst::PrimitiveType::Boolean: -// case ShaderAst::PrimitiveType::Sampler2D: -// case ShaderAst::PrimitiveType::Void: break; } @@ -124,30 +119,18 @@ namespace Nz case ShaderAst::BinaryType::Divide: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Float32: -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFDiv; case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSDiv; case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpUDiv; case ShaderAst::PrimitiveType::Boolean: -// case ShaderAst::PrimitiveType::Sampler2D: -// case ShaderAst::PrimitiveType::Void: break; } @@ -156,31 +139,17 @@ namespace Nz case ShaderAst::BinaryType::CompEq: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Boolean: return SpirvOp::OpLogicalEqual; case ShaderAst::PrimitiveType::Float32: -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdEqual; case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpIEqual; - -// case ShaderAst::PrimitiveType::Sampler2D: -// case ShaderAst::PrimitiveType::Void: -// break; } break; @@ -188,30 +157,18 @@ namespace Nz case ShaderAst::BinaryType::CompGe: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Float32: -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdGreaterThan; case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSGreaterThan; case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpUGreaterThan; case ShaderAst::PrimitiveType::Boolean: -// case ShaderAst::PrimitiveType::Sampler2D: -// case ShaderAst::PrimitiveType::Void: break; } @@ -220,30 +177,18 @@ namespace Nz case ShaderAst::BinaryType::CompGt: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Float32: -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdGreaterThanEqual; case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSGreaterThanEqual; case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpUGreaterThanEqual; case ShaderAst::PrimitiveType::Boolean: -// case ShaderAst::PrimitiveType::Sampler2D: -// case ShaderAst::PrimitiveType::Void: break; } @@ -252,30 +197,18 @@ namespace Nz case ShaderAst::BinaryType::CompLe: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Float32: -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdLessThanEqual; case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSLessThanEqual; case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpULessThanEqual; case ShaderAst::PrimitiveType::Boolean: -// case ShaderAst::PrimitiveType::Sampler2D: -// case ShaderAst::PrimitiveType::Void: break; } @@ -284,30 +217,18 @@ namespace Nz case ShaderAst::BinaryType::CompLt: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Float32: -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdLessThan; case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: return SpirvOp::OpSLessThan; case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpULessThan; case ShaderAst::PrimitiveType::Boolean: -// case ShaderAst::PrimitiveType::Sampler2D: -// case ShaderAst::PrimitiveType::Void: break; } @@ -316,31 +237,17 @@ namespace Nz case ShaderAst::BinaryType::CompNe: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Boolean: return SpirvOp::OpLogicalNotEqual; case ShaderAst::PrimitiveType::Float32: -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpFOrdNotEqual; case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpINotEqual; - -// case ShaderAst::PrimitiveType::Sampler2D: -// case ShaderAst::PrimitiveType::Void: -// break; } break; @@ -348,81 +255,51 @@ namespace Nz case ShaderAst::BinaryType::Multiply: { - switch (leftType) + switch (leftTypeBase) { case ShaderAst::PrimitiveType::Float32: { - switch (rightType) + if (IsPrimitiveType(leftType)) { - case ShaderAst::PrimitiveType::Float32: - return SpirvOp::OpFMul; - -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// swapOperands = true; -// return SpirvOp::OpVectorTimesScalar; -// -// case ShaderAst::PrimitiveType::Mat4x4: -// swapOperands = true; -// return SpirvOp::OpMatrixTimesScalar; - - default: - break; + // Handle float * matrix|vector as matrix|vector * float + if (IsMatrixType(rightType)) + { + swapOperands = true; + return SpirvOp::OpMatrixTimesScalar; + } + else if (IsVectorType(rightType)) + { + swapOperands = true; + return SpirvOp::OpVectorTimesScalar; + } + } + else if (IsPrimitiveType(rightType)) + { + if (IsMatrixType(leftType)) + return SpirvOp::OpMatrixTimesScalar; + else if (IsVectorType(leftType)) + return SpirvOp::OpVectorTimesScalar; + } + else if (IsMatrixType(leftType)) + { + if (IsMatrixType(rightType)) + return SpirvOp::OpMatrixTimesMatrix; + else if (IsVectorType(rightType)) + return SpirvOp::OpMatrixTimesVector; + } + else if (IsMatrixType(rightType)) + { + assert(IsVectorType(leftType)); + return SpirvOp::OpVectorTimesMatrix; } - break; + return SpirvOp::OpFMul; } -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// { -// switch (rightType) -// { -// case ShaderAst::PrimitiveType::Float32: -// return SpirvOp::OpVectorTimesScalar; -// -// case ShaderAst::PrimitiveType::Float2: -// case ShaderAst::PrimitiveType::Float3: -// case ShaderAst::PrimitiveType::Float4: -// return SpirvOp::OpFMul; -// -// case ShaderAst::PrimitiveType::Mat4x4: -// return SpirvOp::OpVectorTimesMatrix; -// -// default: -// break; -// } -// -// break; -// } - case ShaderAst::PrimitiveType::Int32: -// case ShaderAst::PrimitiveType::Int2: -// case ShaderAst::PrimitiveType::Int3: -// case ShaderAst::PrimitiveType::Int4: case ShaderAst::PrimitiveType::UInt32: -// case ShaderAst::PrimitiveType::UInt2: -// case ShaderAst::PrimitiveType::UInt3: -// case ShaderAst::PrimitiveType::UInt4: return SpirvOp::OpIMul; -// case ShaderAst::PrimitiveType::Mat4x4: -// { -// switch (rightType) -// { -// case ShaderAst::PrimitiveType::Float32: return SpirvOp::OpMatrixTimesScalar; -// case ShaderAst::PrimitiveType::Float4: return SpirvOp::OpMatrixTimesVector; -// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix; -// -// default: -// break; -// } -// -// break; -// } - default: break; } @@ -454,7 +331,7 @@ namespace Nz firstCond.statement->Visit(*this); SpirvBlock mergeBlock(m_writer); - m_blocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None); + m_functionBlocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None); std::optional nextBlock; for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex) @@ -463,10 +340,10 @@ namespace Nz SpirvBlock contentBlock(m_writer); - m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId()); + m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId()); previousConditionId = EvaluateExpression(statement.condition); - m_blocks.emplace_back(std::move(previousContentBlock)); + m_functionBlocks.emplace_back(std::move(previousContentBlock)); previousContentBlock = std::move(contentBlock); m_currentBlock = &previousContentBlock; @@ -479,54 +356,148 @@ namespace Nz SpirvBlock elseBlock(m_writer); m_currentBlock = &elseBlock; + node.elseStatement->Visit(*this); elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice - m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId()); - m_blocks.emplace_back(std::move(previousContentBlock)); - m_blocks.emplace_back(std::move(elseBlock)); + m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId()); + m_functionBlocks.emplace_back(std::move(previousContentBlock)); + m_functionBlocks.emplace_back(std::move(elseBlock)); } else { - m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId()); - m_blocks.emplace_back(std::move(previousContentBlock)); + m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId()); + m_functionBlocks.emplace_back(std::move(previousContentBlock)); } - m_blocks.emplace_back(std::move(mergeBlock)); + m_functionBlocks.emplace_back(std::move(mergeBlock)); - m_currentBlock = &m_blocks.back(); + m_currentBlock = &m_functionBlocks.back(); } void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node) { const ShaderAst::ExpressionType& targetExprType = node.targetType; - assert(IsPrimitiveType(targetExprType)); - - ShaderAst::PrimitiveType targetType = std::get(targetExprType); - - StackVector exprResults = NazaraStackVector(UInt32, node.expressions.size()); - - for (auto& exprPtr : node.expressions) + if (IsPrimitiveType(targetExprType)) { - if (!exprPtr) - break; + ShaderAst::PrimitiveType targetType = std::get(targetExprType); - exprResults.push_back(EvaluateExpression(exprPtr)); + assert(node.expressions[0] && !node.expressions[1]); + ShaderAst::ExpressionPtr& expression = node.expressions[0]; + + assert(expression->cachedExpressionType.has_value()); + const ShaderAst::ExpressionType& exprType = expression->cachedExpressionType.value(); + assert(IsPrimitiveType(exprType)); + ShaderAst::PrimitiveType fromType = std::get(exprType); + + UInt32 fromId = EvaluateExpression(expression); + if (targetType == fromType) + return PushResultId(fromId); + + std::optional castOp; + switch (targetType) + { + case ShaderAst::PrimitiveType::Boolean: + throw std::runtime_error("unsupported cast to boolean"); + + case ShaderAst::PrimitiveType::Float32: + { + switch (fromType) + { + case ShaderAst::PrimitiveType::Boolean: + throw std::runtime_error("unsupported cast from boolean"); + + case ShaderAst::PrimitiveType::Float32: + break; //< Already handled + + case ShaderAst::PrimitiveType::Int32: + castOp = SpirvOp::OpConvertSToF; + break; + + case ShaderAst::PrimitiveType::UInt32: + castOp = SpirvOp::OpConvertUToF; + break; + } + break; + } + + case ShaderAst::PrimitiveType::Int32: + { + switch (fromType) + { + case ShaderAst::PrimitiveType::Boolean: + throw std::runtime_error("unsupported cast from boolean"); + + case ShaderAst::PrimitiveType::Float32: + castOp = SpirvOp::OpConvertFToS; + break; + + case ShaderAst::PrimitiveType::Int32: + break; //< Already handled + + case ShaderAst::PrimitiveType::UInt32: + castOp = SpirvOp::OpSConvert; + break; + } + break; + } + + case ShaderAst::PrimitiveType::UInt32: + { + switch (fromType) + { + case ShaderAst::PrimitiveType::Boolean: + throw std::runtime_error("unsupported cast from boolean"); + + case ShaderAst::PrimitiveType::Float32: + castOp = SpirvOp::OpConvertFToU; + break; + + case ShaderAst::PrimitiveType::Int32: + castOp = SpirvOp::OpUConvert; + break; + + case ShaderAst::PrimitiveType::UInt32: + break; //< Already handled + } + break; + } + } + + assert(castOp); + + UInt32 resultId = m_writer.AllocateResultId(); + m_currentBlock->Append(*castOp, m_writer.GetTypeId(targetType), resultId, fromId); + + throw std::runtime_error("toudou"); } - - UInt32 resultId = m_writer.AllocateResultId(); - - m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender) + else { - appender(m_writer.GetTypeId(targetType)); - appender(resultId); + assert(IsVectorType(targetExprType)); + StackVector exprResults = NazaraStackVector(UInt32, node.expressions.size()); - for (UInt32 exprResultId : exprResults) - appender(exprResultId); - }); + for (auto& exprPtr : node.expressions) + { + if (!exprPtr) + break; - PushResultId(resultId); + exprResults.push_back(EvaluateExpression(exprPtr)); + } + + UInt32 resultId = m_writer.AllocateResultId(); + + m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender) + { + appender(m_writer.GetTypeId(targetExprType)); + appender(resultId); + + for (UInt32 exprResultId : exprResults) + appender(exprResultId); + }); + + PushResultId(resultId); + } } void SpirvAstVisitor::Visit(ShaderAst::ConditionalExpression& node) @@ -551,10 +522,108 @@ namespace Nz }, node.value); } + void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node) + { + assert(node.varIndex); + + std::size_t varIndex = *node.varIndex; + for (auto&& extVar : node.externalVars) + RegisterExternalVariable(varIndex++, extVar.type); + } + + void SpirvAstVisitor::Visit(ShaderAst::DeclareFunctionStatement& node) + { + assert(node.funcIndex); + m_funcIndex = *node.funcIndex; + + auto& func = m_funcData[m_funcIndex]; + func.funcId = m_writer.AllocateResultId(); + + m_instructions.Append(SpirvOp::OpFunction, func.returnTypeId, func.funcId, 0, func.funcTypeId); + + if (!func.parameters.empty()) + { + std::size_t varIndex = *node.varIndex; + for (const auto& param : func.parameters) + { + UInt32 paramResultId = m_writer.AllocateResultId(); + m_instructions.Append(SpirvOp::OpFunctionParameter, param.typeId, paramResultId); + + RegisterVariable(varIndex++, param.typeId, paramResultId, SpirvStorageClass::Function); + } + } + + m_functionBlocks.clear(); + + m_currentBlock = &m_functionBlocks.emplace_back(m_writer); + CallOnExit resetCurrentBlock([&] { m_currentBlock = nullptr; }); + + for (auto& var : func.variables) + { + var.varId = m_writer.AllocateResultId(); + m_currentBlock->Append(SpirvOp::OpVariable, var.typeId, var.varId, SpirvStorageClass::Function); + } + + if (func.entryPointData) + { + auto& entryPointData = *func.entryPointData; + if (entryPointData.inputStruct) + { + auto& inputStruct = *entryPointData.inputStruct; + + std::size_t varIndex = *node.varIndex; + + UInt32 paramId = m_writer.AllocateResultId(); + m_currentBlock->Append(SpirvOp::OpVariable, inputStruct.pointerId, paramId, SpirvStorageClass::Function); + + for (const auto& input : entryPointData.inputs) + { + UInt32 resultId = m_writer.AllocateResultId(); + m_currentBlock->Append(SpirvOp::OpAccessChain, input.memberPointerId, resultId, paramId, input.memberIndexConstantId); + m_currentBlock->Append(SpirvOp::OpCopyMemory, resultId, input.varId); + } + + RegisterVariable(varIndex, inputStruct.typeId, paramId, SpirvStorageClass::Function); + } + } + + for (auto& statementPtr : node.statements) + statementPtr->Visit(*this); + + // Add implicit return + if (!m_functionBlocks.back().IsTerminated()) + m_functionBlocks.back().Append(SpirvOp::OpReturn); + + for (SpirvBlock& block : m_functionBlocks) + m_instructions.AppendSection(block); + + m_instructions.Append(SpirvOp::OpFunctionEnd); + } + + void SpirvAstVisitor::Visit(ShaderAst::DeclareStructStatement& node) + { + assert(node.structIndex); + RegisterStruct(*node.structIndex, node.description); + } + void SpirvAstVisitor::Visit(ShaderAst::DeclareVariableStatement& node) { + const auto& func = m_funcData[m_funcIndex]; + + UInt32 pointerTypeId = m_writer.GetPointerTypeId(node.varType, SpirvStorageClass::Function); + UInt32 typeId = m_writer.GetTypeId(node.varType); + + assert(node.varIndex); + auto varIt = func.varIndexToVarId.find(*node.varIndex); + UInt32 varId = func.variables[varIt->second].varId; + + RegisterVariable(*node.varIndex, typeId, varId, SpirvStorageClass::Function); + if (node.initialExpression) - m_writer.WriteLocalVariable(node.varName, EvaluateExpression(node.initialExpression)); + { + UInt32 value = EvaluateExpression(node.initialExpression); + m_currentBlock->Append(SpirvOp::OpStore, varId, value); + } } void SpirvAstVisitor::Visit(ShaderAst::DiscardStatement& /*node*/) @@ -569,19 +638,13 @@ namespace Nz PopResultId(); } - void SpirvAstVisitor::Visit(ShaderAst::IdentifierExpression& node) - { - SpirvExpressionLoad loadVisitor(m_writer, *m_currentBlock); - PushResultId(loadVisitor.Evaluate(node)); - } - void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node) { switch (node.intrinsic) { case ShaderAst::IntrinsicType::DotProduct: { - ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0]); + const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); assert(IsVectorType(vecExprType)); const ShaderAst::VectorType& vecType = std::get(vecExprType); @@ -598,6 +661,19 @@ namespace Nz break; } + case ShaderAst::IntrinsicType::SampleTexture: + { + UInt32 typeId = m_writer.GetTypeId(ShaderAst::VectorType{4, ShaderAst::PrimitiveType::Float32}); + + UInt32 samplerId = EvaluateExpression(node.parameters[0]); + UInt32 coordinatesId = EvaluateExpression(node.parameters[1]); + UInt32 resultId = m_writer.AllocateResultId(); + + m_currentBlock->Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId); + PushResultId(resultId); + break; + } + case ShaderAst::IntrinsicType::CrossProduct: default: throw std::runtime_error("not yet implemented"); @@ -609,23 +685,44 @@ namespace Nz // nothing to do } - void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node) - { - if (node.returnExpr) - m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr)); - else - m_currentBlock->Append(SpirvOp::OpReturn); - } - void SpirvAstVisitor::Visit(ShaderAst::MultiStatement& node) { for (auto& statement : node.statements) statement->Visit(*this); } + void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node) + { + if (node.returnExpr) + { + // Handle entry point return + const auto& func = m_funcData[m_funcIndex]; + if (func.entryPointData) + { + auto& entryPointData = *func.entryPointData; + if (entryPointData.outputStructTypeId) + { + UInt32 paramId = EvaluateExpression(node.returnExpr); + for (const auto& output : entryPointData.outputs) + { + UInt32 resultId = m_writer.AllocateResultId(); + m_currentBlock->Append(SpirvOp::OpCompositeExtract, output.typeId, resultId, paramId, output.memberIndex); + m_currentBlock->Append(SpirvOp::OpStore, output.varId, resultId); + } + } + + m_currentBlock->Append(SpirvOp::OpReturn); + } + else + m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr)); + } + else + m_currentBlock->Append(SpirvOp::OpReturn); + } + void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) { - ShaderAst::ExpressionType targetExprType = GetExpressionType(node); + const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node); assert(IsPrimitiveType(targetExprType)); ShaderAst::PrimitiveType targetType = std::get(targetExprType); @@ -658,6 +755,12 @@ namespace Nz PushResultId(resultId); } + void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node) + { + SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock); + PushResultId(loadVisitor.Evaluate(node)); + } + void SpirvAstVisitor::PushResultId(UInt32 value) { m_resultIds.push_back(value); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index 2aff0d642..ad2fa52d8 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -50,11 +50,6 @@ namespace Nz return Compare(lhs.parameters, rhs.parameters) && Compare(lhs.returnType, rhs.returnType); } - bool Compare(const Identifier& lhs, const Identifier& rhs) const - { - return lhs.name == rhs.name; - } - bool Compare(const Image& lhs, const Image& rhs) const { return lhs.arrayed == rhs.arrayed @@ -114,6 +109,9 @@ namespace Nz if (lhs.debugName != rhs.debugName) return false; + if (lhs.funcId != rhs.funcId) + return false; + if (!Compare(lhs.initializer, rhs.initializer)) return false; @@ -231,11 +229,6 @@ namespace Nz void Register(const Integer&) {} void Register(const Void&) {} - void Register(const Identifier& identifier) - { - Register(identifier); - } - void Register(const Image& image) { Register(image.sampledType); @@ -406,6 +399,7 @@ namespace Nz tsl::ordered_map, UInt32 /*id*/, AnyHasher, Eq> ids; tsl::ordered_map variableIds; tsl::ordered_map structureSizes; + StructCallback structCallback; UInt32& nextResultId; }; @@ -417,132 +411,8 @@ namespace Nz SpirvConstantCache::SpirvConstantCache(SpirvConstantCache&& cache) noexcept = default; SpirvConstantCache::~SpirvConstantCache() = default; - - UInt32 SpirvConstantCache::GetId(const Constant& c) - { - auto it = m_internal->ids.find(c.constant); - if (it == m_internal->ids.end()) - throw std::runtime_error("constant is not registered"); - - return it->second; - } - - UInt32 SpirvConstantCache::GetId(const Type& t) - { - auto it = m_internal->ids.find(t.type); - if (it == m_internal->ids.end()) - throw std::runtime_error("constant is not registered"); - - return it->second; - } - - UInt32 SpirvConstantCache::GetId(const Variable& v) - { - auto it = m_internal->variableIds.find(v); - if (it == m_internal->variableIds.end()) - throw std::runtime_error("variable is not registered"); - - return it->second; - } - - UInt32 SpirvConstantCache::Register(Constant c) - { - AnyConstant& constant = c.constant; - - DepRegisterer registerer(*this); - registerer.Register(constant); - - std::size_t h = m_internal->ids.hash_function()(constant); - auto it = m_internal->ids.find(constant, h); - if (it == m_internal->ids.end()) - { - UInt32 resultId = m_internal->nextResultId++; - it = m_internal->ids.emplace(std::move(constant), resultId).first; - } - - return it.value(); - } - - UInt32 SpirvConstantCache::Register(Type t) - { - AnyType& type = t.type; - if (std::holds_alternative(type)) - { - assert(m_identifierCallback); - return Register(*m_identifierCallback(std::get(type).name)); - } - - DepRegisterer registerer(*this); - registerer.Register(type); - - std::size_t h = m_internal->ids.hash_function()(type); - auto it = m_internal->ids.find(type, h); - if (it == m_internal->ids.end()) - { - UInt32 resultId = m_internal->nextResultId++; - it = m_internal->ids.emplace(std::move(type), resultId).first; - } - - return it.value(); - } - - UInt32 SpirvConstantCache::Register(Variable v) - { - DepRegisterer registerer(*this); - registerer.Register(v); - - std::size_t h = m_internal->variableIds.hash_function()(v); - auto it = m_internal->variableIds.find(v, h); - if (it == m_internal->variableIds.end()) - { - UInt32 resultId = m_internal->nextResultId++; - it = m_internal->variableIds.emplace(std::move(v), resultId).first; - } - - return it.value(); - } - - void SpirvConstantCache::SetIdentifierCallback(IdentifierCallback callback) - { - m_identifierCallback = std::move(callback); - } - - void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) - { - for (auto&& [object, id] : m_internal->ids) - { - UInt32 resultId = id; - - std::visit(overloaded - { - [&](const AnyConstant& constant) { Write(constant, resultId, constants); }, - [&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); }, - }, object); - } - - for (auto&& [variable, id] : m_internal->variableIds) - { - const auto& var = variable; - UInt32 resultId = id; - - if (!variable.debugName.empty()) - debugInfos.Append(SpirvOp::OpName, resultId, variable.debugName); - - constants.AppendVariadic(SpirvOp::OpVariable, [&](const auto& appender) - { - appender(GetId(*var.type)); - appender(resultId); - appender(var.storageClass); - - if (var.initializer) - appender(GetId((*var.initializer)->constant)); - }); - } - } - - SpirvConstantCache& SpirvConstantCache::operator=(SpirvConstantCache&& cache) noexcept = default; - - auto SpirvConstantCache::BuildConstant(const ShaderConstantValue& value) -> ConstantPtr + + auto SpirvConstantCache::BuildConstant(const ShaderAst::ConstantValue& value) const -> ConstantPtr { return std::make_shared(std::visit([&](auto&& arg) -> SpirvConstantCache::AnyConstant { @@ -590,7 +460,7 @@ namespace Nz }, value)); } - auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters) -> TypePtr + auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters) const -> TypePtr { std::vector parameterTypes; parameterTypes.reserve(parameters.size()); @@ -604,7 +474,7 @@ namespace Nz }); } - auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) -> TypePtr + auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const -> TypePtr { return std::make_shared(Pointer{ BuildType(type), @@ -612,7 +482,7 @@ namespace Nz }); } - auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) -> TypePtr + auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr { return std::make_shared(Pointer{ BuildType(type), @@ -620,7 +490,7 @@ namespace Nz }); } - auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr { return std::visit([&](auto&& arg) -> TypePtr { @@ -628,16 +498,13 @@ namespace Nz }, type); } - auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) const -> TypePtr { - return std::make_shared( - Identifier{ - type.name - } - ); + // No IdentifierType is expected (as they should have been resolved by now) + throw std::runtime_error("unexpected identifier"); } - auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) const -> TypePtr { return std::make_shared([&]() -> AnyType { @@ -657,7 +524,7 @@ namespace Nz }()); } - auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) const -> TypePtr { return std::make_shared( Matrix{ @@ -668,12 +535,12 @@ namespace Nz }); } - auto SpirvConstantCache::BuildType(const ShaderAst::NoType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::NoType& type) const -> TypePtr { return std::make_shared(Void{}); } - auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) const -> TypePtr { //TODO auto imageType = Image{ @@ -690,7 +557,13 @@ namespace Nz return std::make_shared(SampledImage{ std::make_shared(imageType) }); } - auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::StructType& type) const -> TypePtr + { + assert(m_internal->structCallback); + return BuildType(m_internal->structCallback(type.structIndex)); + } + + auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) const -> TypePtr { Structure sType; sType.name = structDesc.name; @@ -705,11 +578,136 @@ namespace Nz return std::make_shared(std::move(sType)); } - auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) const -> TypePtr { return std::make_shared(Vector{ BuildType(type.type), UInt32(type.componentCount) }); } + auto SpirvConstantCache::BuildType(const ShaderAst::UniformType& type) const -> TypePtr + { + assert(std::holds_alternative(type.containedType)); + return BuildType(std::get(type.containedType)); + } + + UInt32 SpirvConstantCache::GetId(const Constant& c) + { + auto it = m_internal->ids.find(c.constant); + if (it == m_internal->ids.end()) + throw std::runtime_error("constant is not registered"); + + return it->second; + } + + UInt32 SpirvConstantCache::GetId(const Type& t) + { + auto it = m_internal->ids.find(t.type); + if (it == m_internal->ids.end()) + throw std::runtime_error("type is not registered"); + + return it->second; + } + + UInt32 SpirvConstantCache::GetId(const Variable& v) + { + auto it = m_internal->variableIds.find(v); + if (it == m_internal->variableIds.end()) + throw std::runtime_error("variable is not registered"); + + return it->second; + } + + UInt32 SpirvConstantCache::Register(Constant c) + { + AnyConstant& constant = c.constant; + + DepRegisterer registerer(*this); + registerer.Register(constant); + + std::size_t h = m_internal->ids.hash_function()(constant); + auto it = m_internal->ids.find(constant, h); + if (it == m_internal->ids.end()) + { + UInt32 resultId = m_internal->nextResultId++; + it = m_internal->ids.emplace(std::move(constant), resultId).first; + } + + return it.value(); + } + + UInt32 SpirvConstantCache::Register(Type t) + { + AnyType& type = t.type; + + DepRegisterer registerer(*this); + registerer.Register(type); + + std::size_t h = m_internal->ids.hash_function()(type); + auto it = m_internal->ids.find(type, h); + if (it == m_internal->ids.end()) + { + UInt32 resultId = m_internal->nextResultId++; + it = m_internal->ids.emplace(std::move(type), resultId).first; + } + + return it.value(); + } + + UInt32 SpirvConstantCache::Register(Variable v) + { + DepRegisterer registerer(*this); + registerer.Register(v); + + std::size_t h = m_internal->variableIds.hash_function()(v); + auto it = m_internal->variableIds.find(v, h); + if (it == m_internal->variableIds.end()) + { + UInt32 resultId = m_internal->nextResultId++; + it = m_internal->variableIds.emplace(std::move(v), resultId).first; + } + + return it.value(); + } + + void SpirvConstantCache::SetStructCallback(StructCallback callback) + { + m_internal->structCallback = std::move(callback); + } + + void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) + { + for (auto&& [object, id] : m_internal->ids) + { + UInt32 resultId = id; + + std::visit(overloaded + { + [&](const AnyConstant& constant) { Write(constant, resultId, constants); }, + [&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); }, + }, object); + } + + for (auto&& [variable, id] : m_internal->variableIds) + { + const auto& var = variable; + UInt32 resultId = id; + + if (!variable.debugName.empty()) + debugInfos.Append(SpirvOp::OpName, resultId, variable.debugName); + + constants.AppendVariadic(SpirvOp::OpVariable, [&](const auto& appender) + { + appender(GetId(*var.type)); + appender(resultId); + appender(var.storageClass); + + if (var.initializer) + appender(GetId((*var.initializer)->constant)); + }); + } + } + + SpirvConstantCache& SpirvConstantCache::operator=(SpirvConstantCache&& cache) noexcept = default; + void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants) { std::visit([&](auto&& arg) diff --git a/src/Nazara/Shader/SpirvDecoder.cpp b/src/Nazara/Shader/SpirvDecoder.cpp index 141b009f1..cef949889 100644 --- a/src/Nazara/Shader/SpirvDecoder.cpp +++ b/src/Nazara/Shader/SpirvDecoder.cpp @@ -38,6 +38,8 @@ namespace Nz while (m_currentCodepoint < m_codepointEnd) { + const UInt32* instructionBegin = m_currentCodepoint; + UInt32 firstWord = ReadWord(); UInt16 wordCount = static_cast((firstWord >> 16) & 0xFFFF); @@ -50,7 +52,7 @@ namespace Nz if (!HandleOpcode(*inst, wordCount)) break; - m_currentCodepoint += wordCount - 1; + m_currentCodepoint = instructionBegin + wordCount; } } diff --git a/src/Nazara/Shader/SpirvExpressionLoad.cpp b/src/Nazara/Shader/SpirvExpressionLoad.cpp index 068280bd1..12c0d2838 100644 --- a/src/Nazara/Shader/SpirvExpressionLoad.cpp +++ b/src/Nazara/Shader/SpirvExpressionLoad.cpp @@ -3,7 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include -#include +#include #include #include #include @@ -24,9 +24,8 @@ namespace Nz { [this](const Pointer& pointer) -> UInt32 { - UInt32 resultId = m_writer.AllocateResultId(); - - m_block.Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId); + UInt32 resultId = m_visitor.AllocateResultId(); + m_block.Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.pointerId); return resultId; }, @@ -41,25 +40,26 @@ namespace Nz }, m_value); } - /*void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberExpression& node) + void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberIndexExpression& node) { - Visit(node.structExpr); + node.structExpr->Visit(*this); + + const ShaderAst::ExpressionType& exprType = GetExpressionType(node); + + UInt32 resultId = m_visitor.AllocateResultId(); + UInt32 typeId = m_writer.GetTypeId(exprType); std::visit(overloaded { [&](const Pointer& pointer) { - ShaderAst::ShaderExpressionType exprType = GetExpressionType(node.structExpr); - - UInt32 resultId = m_writer.AllocateResultId(); - UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME - UInt32 typeId = m_writer.GetTypeId(node.exprType); + UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) { appender(pointerType); appender(resultId); - appender(pointer.resultId); + appender(pointer.pointerId); for (std::size_t index : node.memberIndices) appender(m_writer.GetConstantId(Int32(index))); @@ -69,9 +69,6 @@ namespace Nz }, [&](const Value& value) { - UInt32 resultId = m_writer.AllocateResultId(); - UInt32 typeId = m_writer.GetTypeId(node.exprType); - m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender) { appender(typeId); @@ -89,15 +86,11 @@ namespace Nz throw std::runtime_error("an internal error occurred"); } }, m_value); - }*/ + } - void SpirvExpressionLoad::Visit(ShaderAst::IdentifierExpression& node) + void SpirvExpressionLoad::Visit(ShaderAst::VariableExpression& node) { - if (node.identifier == "d") - m_value = Value{ m_writer.ReadLocalVariable(node.identifier) }; - else - m_value = Value{ m_writer.ReadParameterVariable(node.identifier) }; - - //Visit(node.var); + const auto& var = m_visitor.GetVariable(node.variableId); + m_value = Pointer{ var.storage, var.pointerId, var.pointedTypeId }; } } diff --git a/src/Nazara/Shader/SpirvExpressionStore.cpp b/src/Nazara/Shader/SpirvExpressionStore.cpp index 8655b3a94..6015319c0 100644 --- a/src/Nazara/Shader/SpirvExpressionStore.cpp +++ b/src/Nazara/Shader/SpirvExpressionStore.cpp @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include #include #include @@ -23,11 +24,11 @@ namespace Nz { [&](const Pointer& pointer) { - m_block.Append(SpirvOp::OpStore, pointer.resultId, resultId); + m_block.Append(SpirvOp::OpStore, pointer.pointerId, resultId); }, [&](const LocalVar& value) { - m_writer.WriteLocalVariable(value.varName, resultId); + throw std::runtime_error("not yet implemented"); }, [](std::monostate) { @@ -36,49 +37,50 @@ namespace Nz }, m_value); } - /*void SpirvExpressionStore::Visit(ShaderAst::AccessMemberExpression& node) + void SpirvExpressionStore::Visit(ShaderAst::AccessMemberIndexExpression& node) { - Visit(node.structExpr); + node.structExpr->Visit(*this); + + const ShaderAst::ExpressionType& exprType = GetExpressionType(node); std::visit(overloaded { - [&](const Pointer& pointer) -> UInt32 + [&](const Pointer& pointer) { - UInt32 resultId = m_writer.AllocateResultId(); - UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME + UInt32 resultId = m_visitor.AllocateResultId(); + UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) { appender(pointerType); appender(resultId); - appender(pointer.resultId); + appender(pointer.pointerId); for (std::size_t index : node.memberIndices) appender(m_writer.GetConstantId(Int32(index))); }); - m_value = Pointer{ pointer.storage, resultId }; - - return resultId; + m_value = Pointer { pointer.storage, resultId }; }, - [](const LocalVar& value) -> UInt32 + [&](const LocalVar& value) { throw std::runtime_error("not yet implemented"); }, - [](std::monostate) -> UInt32 + [](std::monostate) { throw std::runtime_error("an internal error occurred"); } }, m_value); - }*/ - - void SpirvExpressionStore::Visit(ShaderAst::IdentifierExpression& node) - { - m_value = LocalVar{ node.identifier }; } void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node) { throw std::runtime_error("not yet implemented"); } + + void SpirvExpressionStore::Visit(ShaderAst::VariableExpression& node) + { + const auto& var = m_visitor.GetVariable(node.variableId); + m_value = Pointer{ var.storage, var.pointerId }; + } } diff --git a/src/Nazara/Shader/SpirvPrinter.cpp b/src/Nazara/Shader/SpirvPrinter.cpp index cd618c8bb..373d7e2c9 100644 --- a/src/Nazara/Shader/SpirvPrinter.cpp +++ b/src/Nazara/Shader/SpirvPrinter.cpp @@ -68,7 +68,7 @@ namespace Nz UInt32 resultId = 0; std::size_t currentOperand = 0; - const UInt32* endPtr = startPtr + wordCount; + const UInt32* endPtr = startPtr + wordCount - 1; while (GetCurrentPtr() < endPtr) { const SpirvInstruction::Operand* operand = &instruction.operands[currentOperand]; @@ -209,7 +209,7 @@ namespace Nz m_currentState->stream << "\n"; - assert(GetCurrentPtr() == startPtr + wordCount); + assert(GetCurrentPtr() == startPtr + wordCount - 1); return true; } diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 0627c569e..8ef74aef7 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -12,10 +12,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -25,34 +27,61 @@ namespace Nz { namespace { + //FIXME: Have this only once + std::unordered_map s_entryPoints = { + { "frag", ShaderStageType::Fragment }, + { "vert", ShaderStageType::Vertex }, + }; + + struct Builtin + { + const char* debugName; + ShaderStageTypeFlags compatibleStages; + SpirvBuiltIn decoration; + }; + + std::unordered_map s_builtinMapping = { + { "position", { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } } + }; + class PreVisitor : public ShaderAst::AstScopedVisitor { public: + struct UniformVar + { + std::optional bindingIndex; + UInt32 pointerId; + }; + + using BuiltinDecoration = std::map; + using LocationDecoration = std::map; using ExtInstList = std::unordered_set; + using ExtVarContainer = std::unordered_map; using LocalContainer = std::unordered_set; using FunctionContainer = std::vector>; + using StructContainer = std::vector; - PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : + PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector& funcs) : m_conditions(conditions), - m_constantCache(constantCache) + m_constantCache(constantCache), + m_externalBlockIndex(0), + m_funcs(funcs) { - m_constantCache.SetIdentifierCallback([&](const std::string& identifierName) + m_constantCache.SetStructCallback([this](std::size_t structIndex) -> const ShaderAst::StructDescription& { - const Identifier* identifier = FindIdentifier(identifierName); - if (!identifier) - throw std::runtime_error("invalid identifier " + identifierName); - - assert(std::holds_alternative(identifier->value)); - return SpirvConstantCache::BuildType(std::get(identifier->value)); + assert(structIndex < declaredStructs.size()); + return declaredStructs[structIndex]; }); } - void Visit(ShaderAst::AccessMemberExpression& node) override + void Visit(ShaderAst::AccessMemberIndexExpression& node) override { - /*for (std::size_t index : node.memberIdentifiers) - m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/ - AstRecursiveVisitor::Visit(node); + + for (std::size_t index : node.memberIndices) + m_constantCache.Register(*m_constantCache.BuildConstant(Int32(index))); + + m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value())); } void Visit(ShaderAst::ConditionalExpression& node) override @@ -64,6 +93,8 @@ namespace Nz Visit(node.truePath); else Visit(node.falsePath);*/ + + m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value())); } void Visit(ShaderAst::ConditionalStatement& node) override @@ -79,52 +110,189 @@ namespace Nz { std::visit([&](auto&& arg) { - m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg)); + m_constantCache.Register(*m_constantCache.BuildConstant(arg)); }, node.value); AstScopedVisitor::Visit(node); } + void Visit(ShaderAst::DeclareExternalStatement& node) override + { + assert(node.varIndex); + std::size_t varIndex = *node.varIndex; + for (auto& extVar : node.externalVars) + { + SpirvConstantCache::Variable variable; + variable.debugName = extVar.name; + variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform; + variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass); + + UniformVar& uniformVar = extVars[varIndex++]; + uniformVar.pointerId = m_constantCache.Register(variable); + + for (const auto& [attributeType, attributeParam] : extVar.attributes) + { + if (attributeType == ShaderAst::AttributeType::Binding) + { + uniformVar.bindingIndex = std::get(attributeParam); + break; + } + } + } + } + void Visit(ShaderAst::DeclareFunctionStatement& node) override { - funcs.emplace_back(node); + std::optional entryPointType; + for (auto& attribute : node.attributes) + { + if (attribute.type == ShaderAst::AttributeType::Entry) + { + auto it = s_entryPoints.find(std::get(attribute.args)); + assert(it != s_entryPoints.end()); - std::vector parameterTypes; - for (auto& parameter : node.parameters) - parameterTypes.push_back(parameter.type); + entryPointType = it->second; + break; + } + } - m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes)); + assert(node.funcIndex); + std::size_t funcIndex = *node.funcIndex; + if (funcIndex >= m_funcs.size()) + m_funcs.resize(funcIndex + 1); + + auto& funcData = m_funcs[funcIndex]; + funcData.name = node.name; + + if (!entryPointType) + { + std::vector parameterTypes; + for (auto& parameter : node.parameters) + parameterTypes.push_back(parameter.type); + + funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType)); + funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(node.returnType, parameterTypes)); + + for (auto& parameter : node.parameters) + { + auto& funcParam = funcData.parameters.emplace_back(); + funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)); + funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameter.type)); + } + } + else + { + using EntryPoint = SpirvAstVisitor::EntryPoint; + + funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{})); + funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, {})); + + std::optional inputStruct; + std::vector inputs; + if (!node.parameters.empty()) + { + assert(node.parameters.size() == 1); + auto& parameter = node.parameters.front(); + assert(std::holds_alternative(parameter.type)); + + std::size_t structIndex = std::get(parameter.type).structIndex; + const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex]; + + std::size_t memberIndex = 0; + for (const auto& member : structDesc.members) + { + if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0) + { + inputs.push_back({ + m_constantCache.Register(*m_constantCache.BuildConstant(Int32(memberIndex))), + m_constantCache.Register(*m_constantCache.BuildPointerType(member.type, SpirvStorageClass::Function)), + varId + }); + } + + memberIndex++; + } + + inputStruct = EntryPoint::InputStruct{ + m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)), + m_constantCache.Register(*m_constantCache.BuildType(parameter.type)) + }; + } + + std::optional outputStructId; + std::vector outputs; + if (!IsNoType(node.returnType)) + { + assert(std::holds_alternative(node.returnType)); + + std::size_t structIndex = std::get(node.returnType).structIndex; + const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex]; + + std::size_t memberIndex = 0; + for (const auto& member : structDesc.members) + { + if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0) + { + outputs.push_back({ + Int32(memberIndex), + m_constantCache.Register(*m_constantCache.BuildType(member.type)), + varId + }); + } + + memberIndex++; + } + + outputStructId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType)); + } + + funcData.entryPointData = EntryPoint{ + *entryPointType, + inputStruct, + outputStructId, + funcIndex, + std::move(inputs), + std::move(outputs) + }; + } + + m_funcIndex = funcIndex; AstScopedVisitor::Visit(node); + m_funcIndex.reset(); } void Visit(ShaderAst::DeclareStructStatement& node) override { AstScopedVisitor::Visit(node); - SpirvConstantCache::Structure sType; - sType.name = node.description.name; + assert(node.structIndex); + std::size_t structIndex = *node.structIndex; + if (structIndex >= declaredStructs.size()) + declaredStructs.resize(structIndex + 1); - for (const auto& [name, attribute, type] : node.description.members) - { - auto& sMembers = sType.members.emplace_back(); - sMembers.name = name; - sMembers.type = SpirvConstantCache::BuildType(type); - } + declaredStructs[structIndex] = node.description; - m_constantCache.Register(SpirvConstantCache::Type{ std::move(sType) }); + m_constantCache.Register(*m_constantCache.BuildType(node.description)); } void Visit(ShaderAst::DeclareVariableStatement& node) override { AstScopedVisitor::Visit(node); - m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType)); + assert(m_funcIndex); + auto& func = m_funcs[*m_funcIndex]; + + assert(node.varIndex); + func.varIndexToVarId[*node.varIndex] = func.variables.size(); + + auto& var = func.variables.emplace_back(); + var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType, SpirvStorageClass::Function)); } void Visit(ShaderAst::IdentifierExpression& node) override { - m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value())); + m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value())); AstScopedVisitor::Visit(node); } @@ -144,40 +312,88 @@ namespace Nz case ShaderAst::IntrinsicType::DotProduct: break; } + + m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value())); } + void Visit(ShaderAst::SwizzleExpression& node) override + { + AstScopedVisitor::Visit(node); + + m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value())); + } + + UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass) + { + std::optional> builtinOpt; + std::optional attributeLocation; + for (const auto& [attributeType, attributeParam] : member.attributes) + { + if (attributeType == ShaderAst::AttributeType::Builtin) + { + auto it = s_builtinMapping.find(std::get(attributeParam)); + if (it != s_builtinMapping.end()) + { + builtinOpt = it->second; + break; + } + } + else if (attributeType == ShaderAst::AttributeType::Location) + { + attributeLocation = std::get(attributeParam); + break; + } + } + + if (builtinOpt) + { + Builtin& builtin = *builtinOpt; + if ((builtin.compatibleStages & entryPointType) == 0) + return 0; + + SpirvBuiltIn builtinDecoration = builtin.decoration; + + SpirvConstantCache::Variable variable; + variable.debugName = builtin.debugName; + variable.funcId = funcIndex; + variable.storageClass = storageClass; + variable.type = m_constantCache.BuildPointerType(member.type, storageClass); + + UInt32 varId = m_constantCache.Register(variable); + builtinDecorations[varId] = builtinDecoration; + + return varId; + } + else if (attributeLocation) + { + SpirvConstantCache::Variable variable; + variable.debugName = member.name; + variable.funcId = funcIndex; + variable.storageClass = storageClass; + variable.type = m_constantCache.BuildPointerType(member.type, storageClass); + + UInt32 varId = m_constantCache.Register(variable); + locationDecorations[varId] = *attributeLocation; + + return varId; + } + + return 0; + } + + BuiltinDecoration builtinDecorations; ExtInstList extInsts; - FunctionContainer funcs; + ExtVarContainer extVars; + LocationDecoration locationDecorations; + StructContainer declaredStructs; private: const SpirvWriter::States& m_conditions; SpirvConstantCache& m_constantCache; + std::optional m_funcIndex; + std::size_t m_externalBlockIndex; + std::vector& m_funcs; }; - - template - constexpr ShaderAst::PrimitiveType GetBasicType() - { - if constexpr (std::is_same_v) - return ShaderAst::PrimitiveType::Boolean; - else if constexpr (std::is_same_v) - return(ShaderAst::PrimitiveType::Float32); - else if constexpr (std::is_same_v) - return(ShaderAst::PrimitiveType::Int32); - else if constexpr (std::is_same_v) - return(ShaderAst::PrimitiveType::Float2); - else if constexpr (std::is_same_v) - return(ShaderAst::PrimitiveType::Float3); - else if constexpr (std::is_same_v) - return(ShaderAst::PrimitiveType::Float4); - else if constexpr (std::is_same_v) - return(ShaderAst::PrimitiveType::Int2); - else if constexpr (std::is_same_v) - return(ShaderAst::PrimitiveType::Int3); - else if constexpr (std::is_same_v) - return(ShaderAst::PrimitiveType::Int4); - else - static_assert(AlwaysFalse::value, "unhandled type"); - } } struct SpirvWriter::State @@ -194,18 +410,13 @@ namespace Nz UInt32 id; }; - tsl::ordered_map inputIds; - tsl::ordered_map outputIds; - tsl::ordered_map parameterIds; - tsl::ordered_map uniformIds; std::unordered_map extensionInstructions; - std::unordered_map builtinIds; std::unordered_map varToResult; - std::vector funcs; - std::vector functionBlocks; + std::vector funcs; std::vector resultIds; UInt32 nextVarIndex = 1; SpirvConstantCache constantTypeCache; //< init after nextVarIndex + PreVisitor* preVisitor; // Output SpirvSection header; @@ -226,6 +437,9 @@ namespace Nz if (!ShaderAst::ValidateAst(shader, &error)) throw std::runtime_error("Invalid shader AST: " + error); + ShaderAst::TransformVisitor transformVisitor; + ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader); + m_context.states = &conditions; State state; @@ -235,245 +449,37 @@ namespace Nz m_currentState = nullptr; }); - ShaderAst::AstCloner cloner; - // Register all extended instruction sets - PreVisitor preVisitor(conditions, state.constantTypeCache); - shader->Visit(preVisitor); + PreVisitor preVisitor(conditions, state.constantTypeCache, state.funcs); + transformedShader->Visit(preVisitor); + + m_currentState->preVisitor = &preVisitor; for (const std::string& extInst : preVisitor.extInsts) state.extensionInstructions[extInst] = AllocateResultId(); - // Register all types - /*for (const auto& func : shader.GetFunctions()) - { - RegisterType(func.returnType); - for (const auto& param : func.parameters) - RegisterType(param.type); - } - - for (const auto& input : shader.GetInputs()) - RegisterPointerType(input.type, SpirvStorageClass::Input); - - for (const auto& output : shader.GetOutputs()) - RegisterPointerType(output.type, SpirvStorageClass::Output); - - for (const auto& uniform : shader.GetUniforms()) - RegisterPointerType(uniform.type, (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform); - - for (const auto& func : shader.GetFunctions()) - RegisterFunctionType(func.returnType, func.parameters); - - for (const auto& type : preVisitor.variableTypes) - RegisterType(type); - - for (const auto& builtin : preVisitor.builtinVars) - RegisterType(builtin->type); - - // Register result id and debug infos for global variables/functions - for (const auto& builtin : preVisitor.builtinVars) - { - SpirvConstantCache::Variable variable; - SpirvBuiltIn builtinDecoration; - switch (builtin->entry) - { - case ShaderAst::BuiltinEntry::VertexPosition: - variable.debugName = "builtin_VertexPosition"; - variable.storageClass = SpirvStorageClass::Output; - - builtinDecoration = SpirvBuiltIn::Position; - break; - - default: - throw std::runtime_error("unexpected builtin type"); - } - - const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type; - assert(IsBasicType(builtinExprType)); - - ShaderAst::BasicType builtinType = std::get(builtinExprType); - - variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass); - - UInt32 varId = m_currentState->constantTypeCache.Register(variable); - - ExtVar builtinData; - builtinData.pointerTypeId = GetPointerTypeId(builtinType, variable.storageClass); - builtinData.typeId = GetTypeId(builtinType); - builtinData.varId = varId; - - state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpirvDecoration::BuiltIn, builtinDecoration); - - state.builtinIds.emplace(builtin->entry, builtinData); - } - - for (const auto& input : shader.GetInputs()) - { - SpirvConstantCache::Variable variable; - variable.debugName = input.name; - variable.storageClass = SpirvStorageClass::Input; - variable.type = SpirvConstantCache::BuildPointerType(shader, input.type, variable.storageClass); - - UInt32 varId = m_currentState->constantTypeCache.Register(variable); - - ExtVar inputData; - inputData.pointerTypeId = GetPointerTypeId(input.type, variable.storageClass); - inputData.typeId = GetTypeId(input.type); - inputData.varId = varId; - - state.inputIds.emplace(input.name, std::move(inputData)); - - if (input.locationIndex) - state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *input.locationIndex); - } - - for (const auto& output : shader.GetOutputs()) - { - SpirvConstantCache::Variable variable; - variable.debugName = output.name; - variable.storageClass = SpirvStorageClass::Output; - variable.type = SpirvConstantCache::BuildPointerType(shader, output.type, variable.storageClass); - - UInt32 varId = m_currentState->constantTypeCache.Register(variable); - - ExtVar outputData; - outputData.pointerTypeId = GetPointerTypeId(output.type, variable.storageClass); - outputData.typeId = GetTypeId(output.type); - outputData.varId = varId; - - state.outputIds.emplace(output.name, std::move(outputData)); - - if (output.locationIndex) - state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *output.locationIndex); - } - - for (const auto& uniform : shader.GetUniforms()) - { - SpirvConstantCache::Variable variable; - variable.debugName = uniform.name; - variable.storageClass = (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform; - variable.type = SpirvConstantCache::BuildPointerType(shader, uniform.type, variable.storageClass); - - UInt32 varId = m_currentState->constantTypeCache.Register(variable); - - ExtVar uniformData; - uniformData.pointerTypeId = GetPointerTypeId(uniform.type, variable.storageClass); - uniformData.typeId = GetTypeId(uniform.type); - uniformData.varId = varId; - - state.uniformIds.emplace(uniform.name, std::move(uniformData)); - - if (uniform.bindingIndex) - { - state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex); - state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0); - } - }*/ - - for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs) - { - auto& funcData = state.funcs.emplace_back(); - funcData.statement = &func; - funcData.id = AllocateResultId(); - funcData.typeId = GetFunctionTypeId(func); - - state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name); - } - - std::size_t funcIndex = 0; - - for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs) - { - auto& funcData = state.funcs[funcIndex++]; - - state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId); - - state.functionBlocks.clear(); - state.functionBlocks.emplace_back(*this); - - state.parameterIds.clear(); - - for (const auto& param : func.parameters) - { - UInt32 paramResultId = AllocateResultId(); - state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId); - - ExtVar parameterData; - parameterData.pointerTypeId = GetPointerTypeId(param.type, SpirvStorageClass::Function); - parameterData.typeId = GetTypeId(param.type); - parameterData.varId = paramResultId; - - state.parameterIds.emplace(param.name, std::move(parameterData)); - } - - SpirvAstVisitor visitor(*this, state.functionBlocks); - for (const auto& statement : func.statements) - statement->Visit(visitor); - - if (!state.functionBlocks.back().IsTerminated()) - { - assert(func.returnType == ShaderAst::ExpressionType{ ShaderAst::NoType{} }); - state.functionBlocks.back().Append(SpirvOp::OpReturn); - } - - for (SpirvBlock& block : state.functionBlocks) - state.instructions.AppendSection(block); - - state.instructions.Append(SpirvOp::OpFunctionEnd); - } - - m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo); + SpirvAstVisitor visitor(*this, state.instructions, state.funcs); + transformedShader->Visit(visitor); AppendHeader(); - for (std::size_t i = 0; i < ShaderStageTypeCount; ++i) + for (auto&& [varIndex, extVar] : preVisitor.extVars) { - /*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i]; - if (!statement) - continue; - - auto it = std::find_if(state.funcs.begin(), state.funcs.end(), [&](const auto& funcData) { return funcData.statement == statement; }); - assert(it != state.funcs.end()); - - const auto& entryFunc = *it; - - SpirvExecutionModel execModel; - - ShaderStageType stage = static_cast(i); - switch (stage) + if (extVar.bindingIndex) { - case ShaderStageType::Fragment: - execModel = SpirvExecutionModel::Fragment; - break; - - case ShaderStageType::Vertex: - execModel = SpirvExecutionModel::Vertex; - break; - - default: - throw std::runtime_error("not yet implemented"); + state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, *extVar.bindingIndex); + state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, 0); } - - state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender) - { - appender(execModel); - appender(entryFunc.id); - appender(statement->name); - - for (const auto& [name, varData] : state.builtinIds) - appender(varData.varId); - - for (const auto& [name, varData] : state.inputIds) - appender(varData.varId); - - for (const auto& [name, varData] : state.outputIds) - appender(varData.varId); - }); - - if (stage == ShaderStageType::Fragment) - state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/ } + for (auto&& [varId, builtin] : preVisitor.builtinDecorations) + state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin); + + for (auto&& [varId, location] : preVisitor.locationDecorations) + state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location); + + m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo); + std::vector ret; MergeSections(ret, state.header); MergeSections(ret, state.debugInfo); @@ -511,171 +517,53 @@ namespace Nz m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst); m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450); - } - UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const - { - return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value)); - } - - UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode) - { - return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) }); - } - - auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar& - { - auto it = m_currentState->builtinIds.find(builtin); - assert(it != m_currentState->builtinIds.end()); - - return it->second; - } - - auto SpirvWriter::GetInputVariable(const std::string& name) const -> const ExtVar& - { - auto it = m_currentState->inputIds.find(name); - assert(it != m_currentState->inputIds.end()); - - return it->second; - } - - auto SpirvWriter::GetOutputVariable(const std::string& name) const -> const ExtVar& - { - auto it = m_currentState->outputIds.find(name); - assert(it != m_currentState->outputIds.end()); - - return it->second; - } - - auto SpirvWriter::GetUniformVariable(const std::string& name) const -> const ExtVar& - { - auto it = m_currentState->uniformIds.find(name); - assert(it != m_currentState->uniformIds.end()); - - return it.value(); - } - - UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const - { - return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass)); - } - - UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const - { - return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type)); - } - - UInt32 SpirvWriter::ReadInputVariable(const std::string& name) - { - auto it = m_currentState->inputIds.find(name); - assert(it != m_currentState->inputIds.end()); - - return ReadVariable(it.value()); - } - - std::optional SpirvWriter::ReadInputVariable(const std::string& name, OnlyCache) - { - auto it = m_currentState->inputIds.find(name); - assert(it != m_currentState->inputIds.end()); - - return ReadVariable(it.value(), OnlyCache{}); - } - - UInt32 SpirvWriter::ReadLocalVariable(const std::string& name) - { - auto it = m_currentState->varToResult.find(name); - assert(it != m_currentState->varToResult.end()); - - return it->second; - } - - std::optional SpirvWriter::ReadLocalVariable(const std::string& name, OnlyCache) - { - auto it = m_currentState->varToResult.find(name); - if (it == m_currentState->varToResult.end()) - return {}; - - return it->second; - } - - UInt32 SpirvWriter::ReadParameterVariable(const std::string& name) - { - auto it = m_currentState->parameterIds.find(name); - assert(it != m_currentState->parameterIds.end()); - - return ReadVariable(it.value()); - } - - std::optional SpirvWriter::ReadParameterVariable(const std::string& name, OnlyCache) - { - auto it = m_currentState->parameterIds.find(name); - assert(it != m_currentState->parameterIds.end()); - - return ReadVariable(it.value(), OnlyCache{}); - } - - UInt32 SpirvWriter::ReadUniformVariable(const std::string& name) - { - auto it = m_currentState->uniformIds.find(name); - assert(it != m_currentState->uniformIds.end()); - - return ReadVariable(it.value()); - } - - std::optional SpirvWriter::ReadUniformVariable(const std::string& name, OnlyCache) - { - auto it = m_currentState->uniformIds.find(name); - assert(it != m_currentState->uniformIds.end()); - - return ReadVariable(it.value(), OnlyCache{}); - } - - UInt32 SpirvWriter::ReadVariable(ExtVar& var) - { - if (!var.valueId.has_value()) + std::optional fragmentFuncId; + for (auto& func : m_currentState->funcs) { - UInt32 resultId = AllocateResultId(); - m_currentState->functionBlocks.back().Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId); + m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name); - var.valueId = resultId; + if (func.entryPointData) + { + auto& entryPointData = func.entryPointData.value(); + + SpirvExecutionModel execModel; + + switch (entryPointData.stageType) + { + case ShaderStageType::Fragment: + execModel = SpirvExecutionModel::Fragment; + break; + + case ShaderStageType::Vertex: + execModel = SpirvExecutionModel::Vertex; + break; + + default: + throw std::runtime_error("not yet implemented"); + } + + m_currentState->header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender) + { + appender(execModel); + appender(func.funcId); + appender(func.name); + + for (const auto& input : entryPointData.inputs) + appender(input.varId); + + for (const auto& output : entryPointData.outputs) + appender(output.varId); + }); + + if (entryPointData.stageType == ShaderStageType::Fragment) + fragmentFuncId = func.funcId; + } } - return var.valueId.value(); - } + if (fragmentFuncId) + m_currentState->header.Append(SpirvOp::OpExecutionMode, *fragmentFuncId, SpirvExecutionMode::OriginUpperLeft); - std::optional SpirvWriter::ReadVariable(const ExtVar& var, OnlyCache) - { - if (!var.valueId.has_value()) - return {}; - - return var.valueId.value(); - } - - UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value) - { - return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value)); - } - - UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode) - { - return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) }); - } - - UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass) - { - return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass)); - } - - UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type) - { - assert(m_currentState); - return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type)); - } - - void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId) - { - assert(m_currentState); - m_currentState->varToResult.insert_or_assign(std::move(name), resultId); } SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode) @@ -686,7 +574,56 @@ namespace Nz for (const auto& parameter : functionNode.parameters) parameterTypes.push_back(parameter.type); - return SpirvConstantCache::BuildFunctionType(functionNode.returnType, parameterTypes); + return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType, parameterTypes); + } + + UInt32 SpirvWriter::GetConstantId(const ShaderAst::ConstantValue& value) const + { + return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildConstant(value)); + } + + UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const + { + auto it = m_currentState->preVisitor->extVars.find(extVarIndex); + assert(it != m_currentState->preVisitor->extVars.end()); + + return it->second.pointerId; + } + + UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode) + { + return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) }); + } + + UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const + { + return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass)); + } + + UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const + { + return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildType(type)); + } + + UInt32 SpirvWriter::RegisterConstant(const ShaderAst::ConstantValue& value) + { + return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildConstant(value)); + } + + UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode) + { + return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) }); + } + + UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass) + { + return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass)); + } + + UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type) + { + assert(m_currentState); + return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildType(type)); } void SpirvWriter::MergeSections(std::vector& output, const SpirvSection& from)