From f93a5bbdc143e5885f180823f0dee88ad76a1552 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Sun, 4 Apr 2021 20:31:09 +0200 Subject: [PATCH] Shader: Rework scope handling --- include/Nazara/Shader.hpp | 3 - include/Nazara/Shader/GlslWriter.hpp | 8 +- include/Nazara/Shader/ShaderAstCache.hpp | 56 --- include/Nazara/Shader/ShaderAstCache.inl | 45 --- include/Nazara/Shader/ShaderAstCloner.hpp | 2 +- .../Nazara/Shader/ShaderAstExpressionType.hpp | 58 --- .../Nazara/Shader/ShaderAstExpressionType.inl | 17 - .../Nazara/Shader/ShaderAstScopedVisitor.hpp | 68 ++++ .../Nazara/Shader/ShaderAstScopedVisitor.inl | 38 ++ include/Nazara/Shader/ShaderAstValidator.hpp | 21 +- include/Nazara/Shader/ShaderNodes.hpp | 2 + include/Nazara/Shader/SpirvAstVisitor.hpp | 4 +- include/Nazara/Shader/SpirvAstVisitor.inl | 9 +- include/Nazara/Shader/SpirvConstantCache.hpp | 14 +- src/Nazara/Shader/GlslWriter.cpp | 177 ++++++--- src/Nazara/Shader/ShaderAstCloner.cpp | 20 +- src/Nazara/Shader/ShaderAstExpressionType.cpp | 258 ------------ src/Nazara/Shader/ShaderAstOptimizer.cpp | 8 +- src/Nazara/Shader/ShaderAstScopedVisitor.cpp | 110 ++++++ src/Nazara/Shader/ShaderAstValidator.cpp | 374 +++++++++--------- src/Nazara/Shader/SpirvAstVisitor.cpp | 11 +- src/Nazara/Shader/SpirvConstantCache.cpp | 72 ++-- src/Nazara/Shader/SpirvWriter.cpp | 41 +- 23 files changed, 661 insertions(+), 755 deletions(-) delete mode 100644 include/Nazara/Shader/ShaderAstCache.hpp delete mode 100644 include/Nazara/Shader/ShaderAstCache.inl delete mode 100644 include/Nazara/Shader/ShaderAstExpressionType.hpp delete mode 100644 include/Nazara/Shader/ShaderAstExpressionType.inl create mode 100644 include/Nazara/Shader/ShaderAstScopedVisitor.hpp create mode 100644 include/Nazara/Shader/ShaderAstScopedVisitor.inl delete mode 100644 src/Nazara/Shader/ShaderAstExpressionType.cpp create mode 100644 src/Nazara/Shader/ShaderAstScopedVisitor.cpp diff --git a/include/Nazara/Shader.hpp b/include/Nazara/Shader.hpp index cef4eada9..2bd6df584 100644 --- a/include/Nazara/Shader.hpp +++ b/include/Nazara/Shader.hpp @@ -32,12 +32,9 @@ #include #include #include -#include #include -#include #include #include -#include #include #include #include diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index aab8f06c5..8b5f964e8 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include #include #include @@ -17,7 +17,7 @@ namespace Nz { - class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstRecursiveVisitor + class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstScopedVisitor { public: struct Environment; @@ -57,8 +57,8 @@ namespace Nz template void Append(const T& param); template void Append(const T1& firstParam, const T2& secondParam, Args&&... params); void AppendCommentSection(const std::string& section); - void AppendEntryPoint(ShaderStageType shaderStage); - void AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); + void AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader); + void AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); void AppendLine(const std::string& txt = {}); template void AppendLine(Args&&... params); diff --git a/include/Nazara/Shader/ShaderAstCache.hpp b/include/Nazara/Shader/ShaderAstCache.hpp deleted file mode 100644 index 2a18b173d..000000000 --- a/include/Nazara/Shader/ShaderAstCache.hpp +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADERASTCACHE_HPP -#define NAZARA_SHADERASTCACHE_HPP - -#include -#include -#include -#include - -namespace Nz::ShaderAst -{ - struct AstCache - { - struct Identifier; - - struct Alias - { - std::variant value; - }; - - struct Variable - { - ExpressionType type; - }; - - struct Identifier - { - std::string name; - std::variant value; - }; - - struct Scope - { - std::optional parentScopeIndex; - std::vector identifiers; - }; - - inline void Clear(); - inline const Identifier* FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const; - inline std::size_t GetScopeId(const Node* node) const; - - std::array entryFunctions = {}; - std::unordered_map nodeExpressionType; - std::unordered_map scopeIdByNode; - std::vector scopes; - }; -} - -#include - -#endif diff --git a/include/Nazara/Shader/ShaderAstCache.inl b/include/Nazara/Shader/ShaderAstCache.inl deleted file mode 100644 index f254eac0c..000000000 --- a/include/Nazara/Shader/ShaderAstCache.inl +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include - -namespace Nz::ShaderAst -{ - inline void AstCache::Clear() - { - entryFunctions.fill(nullptr); - nodeExpressionType.clear(); - scopeIdByNode.clear(); - scopes.clear(); - } - - inline auto AstCache::FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const -> const Identifier* - { - assert(startingScopeId < scopes.size()); - - std::optional scopeId = startingScopeId; - do - { - const auto& scope = scopes[*scopeId]; - auto it = std::find_if(scope.identifiers.rbegin(), scope.identifiers.rend(), [&](const auto& identifier) { return identifier.name == identifierName; }); - if (it != scope.identifiers.rend()) - return &*it; - - scopeId = scope.parentScopeIndex; - } while (scopeId); - - return nullptr; - } - - inline std::size_t AstCache::GetScopeId(const Node* node) const - { - auto it = scopeIdByNode.find(node); - assert(it != scopeIdByNode.end()); - - return it->second; - } -} - -#include diff --git a/include/Nazara/Shader/ShaderAstCloner.hpp b/include/Nazara/Shader/ShaderAstCloner.hpp index a2ee45f44..6b7d6b2b7 100644 --- a/include/Nazara/Shader/ShaderAstCloner.hpp +++ b/include/Nazara/Shader/ShaderAstCloner.hpp @@ -33,7 +33,7 @@ namespace Nz::ShaderAst ExpressionPtr CloneExpression(ExpressionPtr& expr); StatementPtr CloneStatement(StatementPtr& statement); - virtual std::unique_ptr Clone(DeclareFunctionStatement& node); + virtual StatementPtr Clone(DeclareFunctionStatement& node); using AstExpressionVisitor::Visit; using AstStatementVisitor::Visit; diff --git a/include/Nazara/Shader/ShaderAstExpressionType.hpp b/include/Nazara/Shader/ShaderAstExpressionType.hpp deleted file mode 100644 index 095b5e5ea..000000000 --- a/include/Nazara/Shader/ShaderAstExpressionType.hpp +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADERASTEXPRESSIONTYPE_HPP -#define NAZARA_SHADERASTEXPRESSIONTYPE_HPP - -#include -#include -#include -#include -#include - -namespace Nz::ShaderAst -{ - struct AstCache; - - class NAZARA_SHADER_API ExpressionTypeVisitor : public AstExpressionVisitor - { - public: - ExpressionTypeVisitor() = default; - ExpressionTypeVisitor(const ExpressionTypeVisitor&) = delete; - ExpressionTypeVisitor(ExpressionTypeVisitor&&) = delete; - ~ExpressionTypeVisitor() = default; - - ExpressionType GetExpressionType(Expression& expression, AstCache* cache); - - ExpressionTypeVisitor& operator=(const ExpressionTypeVisitor&) = delete; - ExpressionTypeVisitor& operator=(ExpressionTypeVisitor&&) = delete; - - private: - ExpressionType GetExpressionTypeInternal(Expression& expression); - ExpressionType ResolveAlias(Expression& expression, ExpressionType expressionType); - - void Visit(Expression& expression); - - void Visit(AccessMemberExpression& node) override; - void Visit(AssignExpression& node) override; - void Visit(BinaryExpression& node) override; - void Visit(CastExpression& node) override; - void Visit(ConditionalExpression& node) override; - void Visit(ConstantExpression& node) override; - void Visit(IdentifierExpression& node) override; - void Visit(IntrinsicExpression& node) override; - void Visit(SwizzleExpression& node) override; - - AstCache* m_cache; - std::optional m_lastExpressionType; - }; - - inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache = nullptr); -} - -#include - -#endif diff --git a/include/Nazara/Shader/ShaderAstExpressionType.inl b/include/Nazara/Shader/ShaderAstExpressionType.inl deleted file mode 100644 index 279a3909e..000000000 --- a/include/Nazara/Shader/ShaderAstExpressionType.inl +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include - -namespace Nz::ShaderAst -{ - inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache) - { - ExpressionTypeVisitor visitor; - return visitor.GetExpressionType(expression, cache); - } -} - -#include diff --git a/include/Nazara/Shader/ShaderAstScopedVisitor.hpp b/include/Nazara/Shader/ShaderAstScopedVisitor.hpp new file mode 100644 index 000000000..e212a30b5 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstScopedVisitor.hpp @@ -0,0 +1,68 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADER_SCOPED_VISITOR_HPP +#define NAZARA_SHADER_SCOPED_VISITOR_HPP + +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_SHADER_API AstScopedVisitor : public AstRecursiveVisitor + { + public: + struct Identifier; + + AstScopedVisitor() = default; + ~AstScopedVisitor() = default; + + inline const Identifier* FindIdentifier(const std::string_view& identifierName) const; + + void ScopedVisit(StatementPtr& nodePtr); + + using AstRecursiveVisitor::Visit; + void Visit(BranchStatement& node) override; + void Visit(ConditionalStatement& node) override; + void Visit(DeclareExternalStatement& node) override; + void Visit(DeclareFunctionStatement& node) override; + void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; + void Visit(MultiStatement& node) override; + + struct Alias + { + std::variant value; + }; + + struct Variable + { + ExpressionType type; + }; + + struct Identifier + { + std::string name; + std::variant value; + }; + + protected: + void PushScope(); + void PopScope(); + + inline void RegisterStruct(StructDescription structDesc); + inline void RegisterVariable(std::string name, ExpressionType type); + + private: + std::vector m_identifiersInScope; + std::vector m_scopeSizes; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/ShaderAstScopedVisitor.inl b/include/Nazara/Shader/ShaderAstScopedVisitor.inl new file mode 100644 index 000000000..67042069a --- /dev/null +++ b/include/Nazara/Shader/ShaderAstScopedVisitor.inl @@ -0,0 +1,38 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + inline auto AstScopedVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* + { + auto it = std::find_if(m_identifiersInScope.rbegin(), m_identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); + if (it == m_identifiersInScope.rend()) + return nullptr; + + return &*it; + } + + inline void AstScopedVisitor::RegisterStruct(StructDescription structDesc) + { + std::string name = structDesc.name; + + m_identifiersInScope.push_back({ + std::move(name), + std::move(structDesc) + }); + } + + inline void AstScopedVisitor::RegisterVariable(std::string name, ExpressionType type) + { + m_identifiersInScope.push_back({ + std::move(name), + Variable { std::move(type) } + }); + } +} + +#include diff --git a/include/Nazara/Shader/ShaderAstValidator.hpp b/include/Nazara/Shader/ShaderAstValidator.hpp index e3dd4d94e..b617e27c7 100644 --- a/include/Nazara/Shader/ShaderAstValidator.hpp +++ b/include/Nazara/Shader/ShaderAstValidator.hpp @@ -9,13 +9,12 @@ #include #include -#include -#include +#include #include namespace Nz::ShaderAst { - class NAZARA_SHADER_API AstValidator : public AstRecursiveVisitor + class NAZARA_SHADER_API AstValidator final : public AstScopedVisitor { public: inline AstValidator(); @@ -23,28 +22,24 @@ namespace Nz::ShaderAst AstValidator(AstValidator&&) = delete; ~AstValidator() = default; - bool Validate(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr); + bool Validate(StatementPtr& node, std::string* error = nullptr); private: + const ExpressionType& GetExpressionType(Expression& expression); Expression& MandatoryExpr(ExpressionPtr& node); Statement& MandatoryStatement(StatementPtr& node); void TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right); void TypeMustMatch(const ExpressionType& left, const ExpressionType& right); ExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); - - AstCache::Scope& EnterScope(); - void ExitScope(); - - void RegisterExpressionType(Expression& node, ExpressionType expressionType); - void RegisterScope(Node& node); + const ExpressionType& ResolveAlias(const ExpressionType& expressionType); void Visit(AccessMemberExpression& node) override; void Visit(AssignExpression& node) override; void Visit(BinaryExpression& node) override; void Visit(CastExpression& node) override; - void Visit(ConditionalExpression& node) override; void Visit(ConstantExpression& node) override; + void Visit(ConditionalExpression& node) override; void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; void Visit(SwizzleExpression& node) override; @@ -54,17 +49,15 @@ namespace Nz::ShaderAst void Visit(DeclareExternalStatement& node) override; void Visit(DeclareFunctionStatement& node) override; void Visit(DeclareStructStatement& node) override; - void Visit(DeclareVariableStatement& node) override; void Visit(ExpressionStatement& node) override; void Visit(MultiStatement& node) override; - void Visit(ReturnStatement& node) override; struct Context; Context* m_context; }; - NAZARA_SHADER_API bool ValidateAst(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr); + NAZARA_SHADER_API bool ValidateAst(StatementPtr& node, std::string* error = nullptr); } #include diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index bfd0ebb47..7b442c07c 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -56,6 +56,8 @@ namespace Nz::ShaderAst Expression& operator=(const Expression&) = delete; Expression& operator=(Expression&&) noexcept = default; + + std::optional cachedExpressionType; }; struct NAZARA_SHADER_API AccessMemberExpression : public Expression diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index b6536ea2e..d868dab41 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -21,7 +21,7 @@ namespace Nz class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept { public: - inline SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks, ShaderAst::AstCache* cache); + inline SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks); SpirvAstVisitor(const SpirvAstVisitor&) = delete; SpirvAstVisitor(SpirvAstVisitor&&) = delete; ~SpirvAstVisitor() = default; @@ -53,10 +53,10 @@ namespace Nz SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete; private: + inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr) const; void PushResultId(UInt32 value); UInt32 PopResultId(); - ShaderAst::AstCache* m_cache; std::vector& m_blocks; std::vector m_resultIds; SpirvBlock* m_currentBlock; diff --git a/include/Nazara/Shader/SpirvAstVisitor.inl b/include/Nazara/Shader/SpirvAstVisitor.inl index 8694244be..bb54eb594 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.inl +++ b/include/Nazara/Shader/SpirvAstVisitor.inl @@ -7,13 +7,18 @@ namespace Nz { - inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks, ShaderAst::AstCache* cache) : - m_cache(cache), + inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks) : m_blocks(blocks), m_writer(writer) { m_currentBlock = &m_blocks.back(); } + + inline const ShaderAst::ExpressionType& SpirvAstVisitor::GetExpressionType(ShaderAst::Expression& expr) const + { + assert(expr.cachedExpressionType); + return expr.cachedExpressionType.value(); + } } #include diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index 5453e9f46..c01ab8d68 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -31,11 +31,14 @@ namespace Nz ~SpirvConstantCache(); struct Constant; + struct Identifier; struct Type; using ConstantPtr = std::shared_ptr; using TypePtr = std::shared_ptr; + using IdentifierCallback = std::function; + struct Bool {}; struct Float @@ -63,6 +66,11 @@ namespace Nz UInt32 columnCount; }; + struct Identifier + { + std::string name; + }; + struct Image { std::optional qualifier; @@ -104,7 +112,7 @@ namespace Nz std::vector members; }; - using AnyType = std::variant; + using AnyType = std::variant; struct ConstantBool { @@ -166,6 +174,8 @@ namespace Nz UInt32 Register(Type t); UInt32 Register(Variable v); + void SetIdentifierCallback(IdentifierCallback callback); + void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos); SpirvConstantCache& operator=(const SpirvConstantCache& cache) = delete; @@ -181,6 +191,7 @@ namespace Nz static TypePtr BuildType(const ShaderAst::NoType& type); static TypePtr BuildType(const ShaderAst::PrimitiveType& type); static TypePtr BuildType(const ShaderAst::SamplerType& type); + static TypePtr BuildType(const ShaderAst::StructDescription& structDesc); static TypePtr BuildType(const ShaderAst::VectorType& type); private: @@ -193,6 +204,7 @@ namespace Nz void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos); + IdentifierCallback m_identifierCallback; std::unique_ptr m_internal; }; } diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 8580fd5b0..0f1aaac4a 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -22,31 +21,87 @@ namespace Nz static const char* flipYUniformName = "_NzFlipValue"; static const char* overridenMain = "_NzMain"; - struct AstAdapter : ShaderAst::AstCloner + //FIXME: Have this only once + std::unordered_map s_entryPoints = { + { "frag", ShaderStageType::Fragment }, + { "vert", ShaderStageType::Vertex }, + }; + + struct PreVisitor : ShaderAst::AstCloner { using AstCloner::Clone; - std::unique_ptr Clone(ShaderAst::DeclareFunctionStatement& node) override + ShaderAst::StatementPtr Clone(ShaderAst::DeclareFunctionStatement& node) override { auto clone = AstCloner::Clone(node); - if (clone->name == "main") - clone->name = "_NzMain"; + assert(clone->GetType() == ShaderAst::NodeType::DeclareFunctionStatement); + + ShaderAst::DeclareFunctionStatement* func = static_cast(clone.get()); + + bool hasEntryPoint = false; + + for (auto& attribute : func->attributes) + { + if (attribute.type == ShaderAst::AttributeType::Entry) + { + auto it = s_entryPoints.find(std::get(attribute.args)); + assert(it != s_entryPoints.end()); + + if (it->second == selectedEntryPoint) + { + hasEntryPoint = true; + break; + } + } + } + + if (!hasEntryPoint) + return ShaderBuilder::NoOp(); + + entryPoint = func; + + if (func->name == "main") + func->name = "_NzMain"; return clone; } - void Visit(ShaderAst::DeclareFunctionStatement& node) - { - if (removedEntryPoints.find(&node) != removedEntryPoints.end()) - { - PushStatement(ShaderBuilder::NoOp()); - return; - } + ShaderStageType selectedEntryPoint; + ShaderAst::DeclareFunctionStatement* entryPoint = nullptr; + }; - AstCloner::Visit(node); + struct EntryFuncResolver : ShaderAst::AstScopedVisitor + { + void Visit(ShaderAst::DeclareFunctionStatement& node) override + { + + + if (&node != entryPoint) + return; + + assert(node.parameters.size() == 1); + + const ShaderAst::ExpressionType& inputType = node.parameters.front().type; + const ShaderAst::ExpressionType& outputType = node.returnType; + + const Identifier* identifier; + + assert(IsIdentifierType(node.parameters.front().type)); + identifier = FindIdentifier(std::get(inputType).name); + assert(identifier); + + inputIdentifier = *identifier; + + assert(IsIdentifierType(outputType)); + identifier = FindIdentifier(std::get(outputType).name); + assert(identifier); + + outputIdentifier = *identifier; } - std::unordered_set removedEntryPoints; + Identifier inputIdentifier; + Identifier outputIdentifier; + ShaderAst::DeclareFunctionStatement* entryPoint; }; struct Builtin @@ -64,7 +119,6 @@ namespace Nz struct GlslWriter::State { const States* states = nullptr; - ShaderAst::AstCache cache; ShaderAst::DeclareFunctionStatement* entryFunc = nullptr; std::stringstream stream; unsigned int indentLevel = 0; @@ -86,29 +140,18 @@ namespace Nz }); std::string error; - if (!ShaderAst::ValidateAst(shader, &error, &state.cache)) + if (!ShaderAst::ValidateAst(shader, &error)) throw std::runtime_error("Invalid shader AST: " + error); - state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)]; - if (!state.entryFunc) + PreVisitor previsitor; + previsitor.selectedEntryPoint = shaderStage; + + ShaderAst::StatementPtr adaptedShader = previsitor.Clone(shader); + + if (!previsitor.entryPoint) throw std::runtime_error("missing entry point"); - AstAdapter adapter; - - for (ShaderAst::DeclareFunctionStatement* entryFunc : state.cache.entryFunctions) - { - if (entryFunc != state.entryFunc) - adapter.removedEntryPoints.insert(entryFunc); - } - - ShaderAst::StatementPtr adaptedShader = adapter.Clone(shader); - - state.cache.Clear(); - if (!ShaderAst::ValidateAst(adaptedShader, &error, &state.cache)) - throw std::runtime_error("Internal error:" + error); - - state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)]; - assert(state.entryFunc); + state.entryFunc = previsitor.entryPoint; unsigned int glslVersion; if (m_environment.glES) @@ -190,10 +233,14 @@ namespace Nz AppendLine(); } - adaptedShader->Visit(*this); + PushScope(); + { + adaptedShader->Visit(*this); - // Append true GLSL entry point - AppendEntryPoint(shaderStage); + // Append true GLSL entry point + AppendEntryPoint(shaderStage, adaptedShader); + } + PopScope(); return state.stream.str(); } @@ -340,8 +387,12 @@ namespace Nz AppendLine(); } - void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage) + void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader) { + EntryFuncResolver entryResolver; + entryResolver.entryPoint = m_currentState->entryFunc; + entryResolver.ScopedVisit(shader); + AppendLine(); AppendLine("// Entry point handling"); @@ -354,15 +405,10 @@ namespace Nz std::vector inputFields; const ShaderAst::StructDescription* inputStruct = nullptr; - auto HandleInOutStructs = [this, shaderStage](const ShaderAst::ExpressionType& expressionType, std::vector& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription* + auto HandleInOutStructs = [this, shaderStage](const Identifier& identifier, std::vector& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription* { - assert(IsIdentifierType(expressionType)); - - const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get(expressionType).name); - assert(identifier); - - assert(std::holds_alternative(identifier->value)); - const auto& s = std::get(identifier->value); + assert(std::holds_alternative(identifier.value)); + const auto& s = std::get(identifier.value); for (const auto& member : s.members) { @@ -426,17 +472,12 @@ namespace Nz }; if (!m_currentState->entryFunc->parameters.empty()) - { - assert(m_currentState->entryFunc->parameters.size() == 1); - const auto& parameter = m_currentState->entryFunc->parameters.front(); - - inputStruct = HandleInOutStructs(parameter.type, inputFields, "in", "_nzInput.", "_NzIn_"); - } + inputStruct = HandleInOutStructs(entryResolver.inputIdentifier, inputFields, "in", "_nzInput.", "_NzIn_"); std::vector outputFields; const ShaderAst::StructDescription* outputStruct = nullptr; if (!IsNoType(m_currentState->entryFunc->returnType)) - outputStruct = HandleInOutStructs(m_currentState->entryFunc->returnType, outputFields, "out", "_nzOutput.", "_NzOut_"); + outputStruct = HandleInOutStructs(entryResolver.outputIdentifier, outputFields, "out", "_nzOutput.", "_NzOut_"); if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition) AppendLine("uniform float ", flipYUniformName, ";"); @@ -486,12 +527,12 @@ namespace Nz LeaveScope(); } - void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) + void GlslWriter::AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) { Append("."); Append(memberIdentifier[0]); - const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(scopeId, structName); + const Identifier* identifier = FindIdentifier(structName); assert(identifier); assert(std::holds_alternative(identifier->value)); @@ -503,7 +544,7 @@ namespace Nz const auto& member = *memberIt; if (remainingMembers > 1) - AppendField(scopeId, std::get(member.type).name, memberIdentifier + 1, remainingMembers - 1); + AppendField(std::get(member.type).name, memberIdentifier + 1, remainingMembers - 1); } void GlslWriter::AppendLine(const std::string& txt) @@ -558,12 +599,10 @@ namespace Nz { Visit(node.structExpr, true); - const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache); + const ShaderAst::ExpressionType& exprType = node.structExpr->cachedExpressionType.value(); assert(IsIdentifierType(exprType)); - std::size_t scopeId = m_currentState->cache.GetScopeId(&node); - - AppendField(scopeId, std::get(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size()); + AppendField(std::get(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size()); } void GlslWriter::Visit(ShaderAst::AssignExpression& node) @@ -593,7 +632,9 @@ namespace Nz AppendLine(")"); EnterScope(); + PushScope(); statement.statement->Visit(*this); + PopScope(); LeaveScope(); first = false; @@ -604,7 +645,9 @@ namespace Nz AppendLine("else"); EnterScope(); + PushScope(); node.elseStatement->Visit(*this); + PopScope(); LeaveScope(); } } @@ -698,6 +741,8 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node) { + + for (const auto& externalVar : node.externalVars) { std::optional bindingIndex; @@ -729,7 +774,7 @@ namespace Nz EnterScope(); { - const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get(externalVar.type).containedType.name); + const Identifier* identifier = FindIdentifier(std::get(externalVar.type).containedType.name); assert(identifier); assert(std::holds_alternative(identifier->value)); @@ -780,15 +825,19 @@ namespace Nz Append(")\n"); EnterScope(); + PushScope(); { for (auto& statement : node.statements) statement->Visit(*this); } + PopScope(); LeaveScope(); } void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node) { + RegisterStruct(node.description); + Append("struct "); AppendLine(node.description.name); EnterScope(); @@ -813,6 +862,8 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node) { + RegisterVariable(node.varName, node.varType); + Append(node.varType); Append(" "); Append(node.varName); @@ -871,6 +922,8 @@ namespace Nz void GlslWriter::Visit(ShaderAst::MultiStatement& node) { + PushScope(); + bool first = true; for (const ShaderAst::StatementPtr& statement : node.statements) { @@ -881,6 +934,8 @@ namespace Nz first = false; } + + PopScope(); } void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/) diff --git a/src/Nazara/Shader/ShaderAstCloner.cpp b/src/Nazara/Shader/ShaderAstCloner.cpp index 92eae1621..4ebfa9548 100644 --- a/src/Nazara/Shader/ShaderAstCloner.cpp +++ b/src/Nazara/Shader/ShaderAstCloner.cpp @@ -42,7 +42,7 @@ namespace Nz::ShaderAst return PopStatement(); } - std::unique_ptr AstCloner::Clone(DeclareFunctionStatement& node) + StatementPtr AstCloner::Clone(DeclareFunctionStatement& node) { auto clone = std::make_unique(); clone->attributes = node.attributes; @@ -63,6 +63,8 @@ namespace Nz::ShaderAst clone->memberIdentifiers = node.memberIdentifiers; clone->structExpr = CloneExpression(node.structExpr); + clone->cachedExpressionType = node.cachedExpressionType; + PushExpression(std::move(clone)); } @@ -73,6 +75,8 @@ namespace Nz::ShaderAst clone->left = CloneExpression(node.left); clone->right = CloneExpression(node.right); + clone->cachedExpressionType = node.cachedExpressionType; + PushExpression(std::move(clone)); } @@ -83,6 +87,8 @@ namespace Nz::ShaderAst clone->left = CloneExpression(node.left); clone->right = CloneExpression(node.right); + clone->cachedExpressionType = node.cachedExpressionType; + PushExpression(std::move(clone)); } @@ -100,6 +106,8 @@ namespace Nz::ShaderAst clone->expressions[expressionCount++] = CloneExpression(expr); } + clone->cachedExpressionType = node.cachedExpressionType; + PushExpression(std::move(clone)); } @@ -110,6 +118,8 @@ namespace Nz::ShaderAst clone->falsePath = CloneExpression(node.falsePath); clone->truePath = CloneExpression(node.truePath); + clone->cachedExpressionType = node.cachedExpressionType; + PushExpression(std::move(clone)); } @@ -118,6 +128,8 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->value = node.value; + clone->cachedExpressionType = node.cachedExpressionType; + PushExpression(std::move(clone)); } @@ -126,6 +138,8 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->identifier = node.identifier; + clone->cachedExpressionType = node.cachedExpressionType; + PushExpression(std::move(clone)); } @@ -138,6 +152,8 @@ namespace Nz::ShaderAst for (auto& parameter : node.parameters) clone->parameters.push_back(CloneExpression(parameter)); + clone->cachedExpressionType = node.cachedExpressionType; + PushExpression(std::move(clone)); } @@ -148,6 +164,8 @@ namespace Nz::ShaderAst clone->components = node.components; clone->expression = CloneExpression(node.expression); + clone->cachedExpressionType = node.cachedExpressionType; + PushExpression(std::move(clone)); } diff --git a/src/Nazara/Shader/ShaderAstExpressionType.cpp b/src/Nazara/Shader/ShaderAstExpressionType.cpp deleted file mode 100644 index bd87822f7..000000000 --- a/src/Nazara/Shader/ShaderAstExpressionType.cpp +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include -#include -#include - -namespace Nz::ShaderAst -{ - ExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache) - { - m_cache = cache; - ExpressionType type = GetExpressionTypeInternal(expression); - m_cache = nullptr; - - return type; - } - - ExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression) - { - m_lastExpressionType.reset(); - - Visit(expression); - - assert(m_lastExpressionType.has_value()); - return std::move(*m_lastExpressionType); - } - - ExpressionType ExpressionTypeVisitor::ResolveAlias(Expression& expression, ExpressionType expressionType) - { - if (IsIdentifierType(expressionType)) - { - auto scopeIt = m_cache->scopeIdByNode.find(&expression); - if (scopeIt == m_cache->scopeIdByNode.end()) - throw std::runtime_error("internal error"); - - const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get(expressionType).name); - if (identifier && std::holds_alternative(identifier->value)) - { - const AstCache::Alias& alias = std::get(identifier->value); - return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType - { - using T = std::decay_t; - - if constexpr (std::is_same_v) - return arg; - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, alias.value); - } - } - - return expressionType; - } - - void ExpressionTypeVisitor::Visit(Expression& expression) - { - if (m_cache) - { - auto it = m_cache->nodeExpressionType.find(&expression); - if (it != m_cache->nodeExpressionType.end()) - { - m_lastExpressionType = it->second; - return; - } - } - - expression.Visit(*this); - - if (m_cache) - { - assert(m_lastExpressionType.has_value()); - m_cache->nodeExpressionType.emplace(&expression, *m_lastExpressionType); - } - } - - void ExpressionTypeVisitor::Visit(AccessMemberExpression& node) - { - auto scopeIt = m_cache->scopeIdByNode.find(&node); - if (scopeIt == m_cache->scopeIdByNode.end()) - throw std::runtime_error("internal error"); - - ExpressionType expressionType = ResolveAlias(node, GetExpressionTypeInternal(*node.structExpr)); - if (!IsIdentifierType(expressionType)) - throw std::runtime_error("internal error"); - - const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get(expressionType).name); - - throw std::runtime_error("unhandled accessmember expression"); - } - - void ExpressionTypeVisitor::Visit(AssignExpression& node) - { - Visit(*node.left); - } - - void ExpressionTypeVisitor::Visit(BinaryExpression& node) - { - switch (node.op) - { - case BinaryType::Add: - case BinaryType::Subtract: - return Visit(*node.left); - - case BinaryType::Divide: - case BinaryType::Multiply: - { - ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.left)); - ExpressionType rightExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.right)); - - if (IsPrimitiveType(leftExprType)) - { - switch (std::get(leftExprType)) - { - case PrimitiveType::Boolean: - m_lastExpressionType = std::move(leftExprType); - break; - - case PrimitiveType::Float32: - case PrimitiveType::Int32: - case PrimitiveType::UInt32: - m_lastExpressionType = std::move(rightExprType); - break; - } - } - else if (IsMatrixType(leftExprType)) - { - if (IsVectorType(rightExprType)) - m_lastExpressionType = std::move(rightExprType); - else - m_lastExpressionType = std::move(leftExprType); - } - else if (IsVectorType(leftExprType)) - m_lastExpressionType = std::move(leftExprType); - else - throw std::runtime_error("validation failure"); - - break; - } - - case BinaryType::CompEq: - case BinaryType::CompGe: - case BinaryType::CompGt: - case BinaryType::CompLe: - case BinaryType::CompLt: - case BinaryType::CompNe: - m_lastExpressionType = PrimitiveType::Boolean; - break; - } - } - - void ExpressionTypeVisitor::Visit(CastExpression& node) - { - m_lastExpressionType = node.targetType; - } - - void ExpressionTypeVisitor::Visit(ConditionalExpression& node) - { - ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.truePath)); - assert(leftExprType == ResolveAlias(node, GetExpressionTypeInternal(*node.falsePath))); - - m_lastExpressionType = std::move(leftExprType); - } - - void ExpressionTypeVisitor::Visit(ConstantExpression& node) - { - m_lastExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType - { - using T = std::decay_t; - - if constexpr (std::is_same_v) - return PrimitiveType::Boolean; - else if constexpr (std::is_same_v) - return PrimitiveType::Float32; - else if constexpr (std::is_same_v) - return PrimitiveType::Int32; - else if constexpr (std::is_same_v) - return PrimitiveType::UInt32; - else if constexpr (std::is_same_v) - return VectorType{ 2, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 3, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 4, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 2, PrimitiveType::Int32 }; - else if constexpr (std::is_same_v) - return VectorType{ 3, PrimitiveType::Int32 }; - else if constexpr (std::is_same_v) - return VectorType{ 4, PrimitiveType::Int32 }; - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, node.value); - } - - void ExpressionTypeVisitor::Visit(IdentifierExpression& node) - { - assert(m_cache); - - auto scopeIt = m_cache->scopeIdByNode.find(&node); - if (scopeIt == m_cache->scopeIdByNode.end()) - throw std::runtime_error("internal error"); - - const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, node.identifier); - if (!identifier || !std::holds_alternative(identifier->value)) - throw std::runtime_error("internal error"); - - m_lastExpressionType = ResolveAlias(node, std::get(identifier->value).type); - } - - void ExpressionTypeVisitor::Visit(IntrinsicExpression& node) - { - switch (node.intrinsic) - { - case IntrinsicType::CrossProduct: - Visit(*node.parameters.front()); - break; - - case IntrinsicType::DotProduct: - m_lastExpressionType = PrimitiveType::Float32; - break; - - case IntrinsicType::SampleTexture: - { - if (node.parameters.empty()) - throw std::runtime_error("validation failure"); - - ExpressionType firstParamType = ResolveAlias(node, GetExpressionTypeInternal(*node.parameters.front())); - - if (!IsSamplerType(firstParamType)) - throw std::runtime_error("validation failure"); - - const auto& sampler = std::get(firstParamType); - - m_lastExpressionType = VectorType{ - 4, - sampler.sampledType - }; - - break; - } - } - } - - void ExpressionTypeVisitor::Visit(SwizzleExpression& node) - { - ExpressionType exprType = GetExpressionTypeInternal(*node.expression); - - if (IsMatrixType(exprType)) - m_lastExpressionType = std::get(exprType).type; - else if (IsVectorType(exprType)) - m_lastExpressionType = std::get(exprType).type; - else - throw std::runtime_error("validation failure"); - } -} diff --git a/src/Nazara/Shader/ShaderAstOptimizer.cpp b/src/Nazara/Shader/ShaderAstOptimizer.cpp index c83ae1412..3c67240a7 100644 --- a/src/Nazara/Shader/ShaderAstOptimizer.cpp +++ b/src/Nazara/Shader/ShaderAstOptimizer.cpp @@ -4,7 +4,6 @@ #include #include -#include #include #include #include @@ -453,8 +452,11 @@ namespace Nz::ShaderAst { auto& constant = static_cast(*cond); - assert(IsPrimitiveType(GetExpressionType(constant))); - assert(std::get(GetExpressionType(constant)) == PrimitiveType::Boolean); + assert(constant.cachedExpressionType); + const ExpressionType& constantType = constant.cachedExpressionType.value(); + + assert(IsPrimitiveType(constantType)); + assert(std::get(constantType) == PrimitiveType::Boolean); bool cValue = std::get(constant.value); if (!cValue) diff --git a/src/Nazara/Shader/ShaderAstScopedVisitor.cpp b/src/Nazara/Shader/ShaderAstScopedVisitor.cpp new file mode 100644 index 000000000..cc1aec980 --- /dev/null +++ b/src/Nazara/Shader/ShaderAstScopedVisitor.cpp @@ -0,0 +1,110 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + void AstScopedVisitor::ScopedVisit(StatementPtr& nodePtr) + { + PushScope(); //< Global scope + { + nodePtr->Visit(*this); + } + PopScope(); + } + + void AstScopedVisitor::Visit(BranchStatement& node) + { + for (auto& cond : node.condStatements) + { + PushScope(); + { + cond.condition->Visit(*this); + cond.statement->Visit(*this); + } + PopScope(); + } + + if (node.elseStatement) + { + PushScope(); + { + node.elseStatement->Visit(*this); + } + PopScope(); + } + } + + void AstScopedVisitor::Visit(ConditionalStatement& node) + { + PushScope(); + { + AstRecursiveVisitor::Visit(node); + } + PopScope(); + } + + void AstScopedVisitor::Visit(DeclareExternalStatement& node) + { + for (auto& extVar : node.externalVars) + { + ExpressionType subType = extVar.type; + if (IsUniformType(subType)) + subType = IdentifierType{ std::get(subType).containedType }; + + RegisterVariable(extVar.name, std::move(subType)); + } + + AstRecursiveVisitor::Visit(node); + } + + void AstScopedVisitor::Visit(DeclareFunctionStatement& node) + { + PushScope(); + { + for (auto& parameter : node.parameters) + RegisterVariable(parameter.name, parameter.type); + + AstRecursiveVisitor::Visit(node); + } + PopScope(); + } + + void AstScopedVisitor::Visit(DeclareStructStatement& node) + { + RegisterStruct(node.description); + + AstRecursiveVisitor::Visit(node); + } + + void AstScopedVisitor::Visit(DeclareVariableStatement& node) + { + RegisterVariable(node.varName, node.varType); + + AstRecursiveVisitor::Visit(node); + } + + void AstScopedVisitor::Visit(MultiStatement& node) + { + PushScope(); + { + AstRecursiveVisitor::Visit(node); + } + PopScope(); + } + + void AstScopedVisitor::PushScope() + { + m_scopeSizes.push_back(m_identifiersInScope.size()); + } + + void AstScopedVisitor::PopScope() + { + assert(!m_scopeSizes.empty()); + m_identifiersInScope.resize(m_scopeSizes.back()); + m_scopeSizes.pop_back(); + } +} diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index 8293870e9..57610a357 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -27,29 +26,21 @@ namespace Nz::ShaderAst struct AstValidator::Context { - //const ShaderAst::Function* currentFunction; - std::optional activeScopeId; + std::array entryFunctions = {}; std::unordered_set declaredExternalVar; - std::unordered_set usedBindingIndexes;; - AstCache* cache; + std::unordered_set usedBindingIndexes; }; - bool AstValidator::Validate(StatementPtr& node, std::string* error, AstCache* cache) + bool AstValidator::Validate(StatementPtr& node, std::string* error) { try { - AstCache dummy; - Context currentContext; - currentContext.cache = (cache) ? cache : &dummy; m_context = ¤tContext; CallOnExit resetContext([&] { m_context = nullptr; }); - EnterScope(); - node->Visit(*this); - ExitScope(); - + ScopedVisit(node); return true; } catch (const AstError& e) @@ -61,6 +52,12 @@ namespace Nz::ShaderAst } } + const ExpressionType& AstValidator::GetExpressionType(Expression& expression) + { + assert(expression.cachedExpressionType); + return ResolveAlias(expression.cachedExpressionType.value()); + } + Expression& AstValidator::MandatoryExpr(ExpressionPtr& node) { if (!node) @@ -79,7 +76,7 @@ namespace Nz::ShaderAst void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right) { - return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache)); + return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right)); } void AstValidator::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) @@ -90,7 +87,7 @@ namespace Nz::ShaderAst ExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) { - const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName); + const Identifier* identifier = FindIdentifier(structName); if (!identifier) throw AstError{ "unknown identifier " + structName }; @@ -111,81 +108,69 @@ namespace Nz::ShaderAst return member.type; } - AstCache::Scope& AstValidator::EnterScope() + const ExpressionType& AstValidator::ResolveAlias(const ExpressionType& expressionType) { - std::size_t newScopeId = m_context->cache->scopes.size(); + if (!IsIdentifierType(expressionType)) + return expressionType; - std::optional previousScope = m_context->activeScopeId; + const Identifier* identifier = FindIdentifier(std::get(expressionType).name); + if (identifier && std::holds_alternative(identifier->value)) + { + const Alias& alias = std::get(identifier->value); + return std::visit([&](auto&& arg) -> const ShaderAst::ExpressionType& + { + using T = std::decay_t; - auto& newScope = m_context->cache->scopes.emplace_back(); - newScope.parentScopeIndex = previousScope; + if constexpr (std::is_same_v) + return arg; + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, alias.value); + } - m_context->activeScopeId = newScopeId; - return m_context->cache->scopes[newScopeId]; - } - - void AstValidator::ExitScope() - { - assert(m_context->activeScopeId); - auto& previousScope = m_context->cache->scopes[*m_context->activeScopeId]; - m_context->activeScopeId = previousScope.parentScopeIndex; - } - - void AstValidator::RegisterExpressionType(Expression& node, ExpressionType expressionType) - { - m_context->cache->nodeExpressionType[&node] = std::move(expressionType); - } - - void AstValidator::RegisterScope(Node& node) - { - if (m_context->activeScopeId) - m_context->cache->scopeIdByNode[&node] = *m_context->activeScopeId; + return expressionType; } void AstValidator::Visit(AccessMemberExpression& node) { - RegisterScope(node); - // Register expressions types - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); - ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache); + ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr)); if (!IsIdentifierType(exprType)) throw AstError{ "expression is not a structure" }; const std::string& structName = std::get(exprType).name; - RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size())); + node.cachedExpressionType = CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()); } void AstValidator::Visit(AssignExpression& node) { - RegisterScope(node); - MandatoryExpr(node.left); MandatoryExpr(node.right); // Register expressions types - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); TypeMustMatch(node.left, node.right); if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) throw AstError { "Assignation is only possible with a l-value" }; + + node.cachedExpressionType = GetExpressionType(*node.right); } void AstValidator::Visit(BinaryExpression& node) { - RegisterScope(node); - // Register expression type - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); - ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache); + ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left)); if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType)) throw AstError{ "left expression type does not support binary operation" }; - ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache); + ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right)); if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType)) throw AstError{ "right expression type does not support binary operation" }; @@ -201,12 +186,18 @@ namespace Nz::ShaderAst if (leftType == PrimitiveType::Boolean) throw AstError{ "this operation is not supported for booleans" }; - [[fallthrough]]; + TypeMustMatch(node.left, node.right); + + node.cachedExpressionType = PrimitiveType::Boolean; + break; + case BinaryType::Add: case BinaryType::CompEq: case BinaryType::CompNe: case BinaryType::Subtract: TypeMustMatch(node.left, node.right); + + node.cachedExpressionType = leftExprType; break; case BinaryType::Multiply: @@ -219,9 +210,20 @@ namespace Nz::ShaderAst case PrimitiveType::UInt32: { if (IsMatrixType(rightExprType)) + { TypeMustMatch(leftType, std::get(rightExprType).type); + node.cachedExpressionType = rightExprType; + } + else if (IsPrimitiveType(rightExprType)) + { + TypeMustMatch(leftType, rightExprType); + node.cachedExpressionType = leftExprType; + } else if (IsVectorType(rightExprType)) + { TypeMustMatch(leftType, std::get(rightExprType).type); + node.cachedExpressionType = rightExprType; + } else throw AstError{ "incompatible types" }; @@ -248,18 +250,29 @@ namespace Nz::ShaderAst case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: + TypeMustMatch(node.left, node.right); + node.cachedExpressionType = PrimitiveType::Boolean; + break; + case BinaryType::Add: case BinaryType::Subtract: TypeMustMatch(node.left, node.right); + node.cachedExpressionType = leftExprType; break; case BinaryType::Multiply: case BinaryType::Divide: { if (IsMatrixType(rightExprType)) + { TypeMustMatch(leftExprType, rightExprType); + node.cachedExpressionType = leftExprType; //< FIXME + } else if (IsPrimitiveType(rightExprType)) + { TypeMustMatch(leftType.type, rightExprType); + node.cachedExpressionType = leftExprType; + } else if (IsVectorType(rightExprType)) { const VectorType& rightType = std::get(rightExprType); @@ -267,6 +280,8 @@ namespace Nz::ShaderAst if (leftType.columnCount != rightType.componentCount) throw AstError{ "incompatible types" }; + + node.cachedExpressionType = rightExprType; } else throw AstError{ "incompatible types" }; @@ -275,7 +290,7 @@ namespace Nz::ShaderAst } else if (IsVectorType(leftExprType)) { - const MatrixType& leftType = std::get(leftExprType); + const VectorType& leftType = std::get(leftExprType); switch (node.op) { case BinaryType::CompGe: @@ -284,16 +299,29 @@ namespace Nz::ShaderAst case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: + TypeMustMatch(node.left, node.right); + node.cachedExpressionType = PrimitiveType::Boolean; + break; + case BinaryType::Add: case BinaryType::Subtract: TypeMustMatch(node.left, node.right); + node.cachedExpressionType = leftExprType; break; case BinaryType::Multiply: case BinaryType::Divide: { if (IsPrimitiveType(rightExprType)) + { TypeMustMatch(leftType.type, rightExprType); + node.cachedExpressionType = rightExprType; + } + else if (IsVectorType(rightExprType)) + { + TypeMustMatch(leftType, rightExprType); + node.cachedExpressionType = rightExprType; + } else throw AstError{ "incompatible types" }; } @@ -303,11 +331,9 @@ namespace Nz::ShaderAst void AstValidator::Visit(CastExpression& node) { - RegisterScope(node); + AstScopedVisitor::Visit(node); - AstRecursiveVisitor::Visit(node); - - auto GetComponentCount = [](const ExpressionType& exprType) -> unsigned int + auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t { if (IsPrimitiveType(exprType)) return 1; @@ -317,15 +343,15 @@ namespace Nz::ShaderAst throw AstError{ "wut" }; }; - unsigned int componentCount = 0; - unsigned int requiredComponents = GetComponentCount(node.targetType); + std::size_t componentCount = 0; + std::size_t requiredComponents = GetComponentCount(node.targetType); for (auto& exprPtr : node.expressions) { if (!exprPtr) break; - ExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache); + ExpressionType exprType = GetExpressionType(*exprPtr); if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) throw AstError{ "incompatible type" }; @@ -334,6 +360,40 @@ namespace Nz::ShaderAst if (componentCount != requiredComponents) throw AstError{ "component count doesn't match required component count" }; + + node.cachedExpressionType = node.targetType; + } + + void AstValidator::Visit(ConstantExpression& node) + { + node.cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + return PrimitiveType::Boolean; + else if constexpr (std::is_same_v) + return PrimitiveType::Float32; + else if constexpr (std::is_same_v) + return PrimitiveType::Int32; + else if constexpr (std::is_same_v) + return PrimitiveType::UInt32; + else if constexpr (std::is_same_v) + return VectorType{ 2, PrimitiveType::Float32 }; + else if constexpr (std::is_same_v) + return VectorType{ 3, PrimitiveType::Float32 }; + else if constexpr (std::is_same_v) + return VectorType{ 4, PrimitiveType::Float32 }; + else if constexpr (std::is_same_v) + return VectorType{ 2, PrimitiveType::Int32 }; + else if constexpr (std::is_same_v) + return VectorType{ 3, PrimitiveType::Int32 }; + else if constexpr (std::is_same_v) + return VectorType{ 4, PrimitiveType::Int32 }; + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, node.value); + } void AstValidator::Visit(ConditionalExpression& node) @@ -341,37 +401,31 @@ namespace Nz::ShaderAst MandatoryExpr(node.truePath); MandatoryExpr(node.falsePath); - RegisterScope(node); + AstScopedVisitor::Visit(node); - AstRecursiveVisitor::Visit(node); + ExpressionType leftExprType = GetExpressionType(*node.truePath); + if (leftExprType != GetExpressionType(*node.falsePath)) + throw AstError{ "true path type must match false path type" }; + + node.cachedExpressionType = leftExprType; //if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) // throw AstError{ "condition not found" }; } - void AstValidator::Visit(ConstantExpression& node) - { - RegisterScope(node); - } - void AstValidator::Visit(IdentifierExpression& node) { assert(m_context); - if (!m_context->activeScopeId) - throw AstError{ "no scope" }; - - RegisterScope(node); - - const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, node.identifier); + const Identifier* identifier = FindIdentifier(node.identifier); if (!identifier) - throw AstError{ "Unknown variable " + node.identifier }; + throw AstError{ "Unknown identifier " + node.identifier }; + + node.cachedExpressionType = ResolveAlias(std::get(identifier->value).type); } void AstValidator::Visit(IntrinsicExpression& node) { - RegisterScope(node); - - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); switch (node.intrinsic) { @@ -384,10 +438,11 @@ namespace Nz::ShaderAst for (auto& param : node.parameters) MandatoryExpr(param); - ExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache); + ExpressionType type = GetExpressionType(*node.parameters.front()); + for (std::size_t i = 1; i < node.parameters.size(); ++i) { - if (type != GetExpressionType(MandatoryExpr(node.parameters[i]), m_context->cache)) + if (type != GetExpressionType(MandatoryExpr(node.parameters[i]))) throw AstError{ "All type must match" }; } @@ -402,11 +457,11 @@ namespace Nz::ShaderAst for (auto& param : node.parameters) MandatoryExpr(param); - if (!IsSamplerType(GetExpressionType(*node.parameters[0], m_context->cache))) + if (!IsSamplerType(GetExpressionType(*node.parameters[0]))) throw AstError{ "First parameter must be a sampler" }; - if (!IsVectorType(GetExpressionType(*node.parameters[1], m_context->cache))) - throw AstError{ "First parameter must be a vector" }; + if (!IsVectorType(GetExpressionType(*node.parameters[1]))) + throw AstError{ "Second parameter must be a vector" }; } } @@ -414,63 +469,89 @@ namespace Nz::ShaderAst { case IntrinsicType::CrossProduct: { - if (GetExpressionType(*node.parameters[0]) != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) + ExpressionType type = GetExpressionType(*node.parameters.front()); + if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) throw AstError{ "CrossProduct only works with vec3 expressions" }; + node.cachedExpressionType = std::move(type); break; } case IntrinsicType::DotProduct: + { + ExpressionType type = GetExpressionType(*node.parameters.front()); + if (!IsVectorType(type)) + throw AstError{ "DotProduct expects vector types" }; + + node.cachedExpressionType = std::get(type).type; break; + } + + case IntrinsicType::SampleTexture: + { + node.cachedExpressionType = VectorType{ 4, std::get(GetExpressionType(*node.parameters.front())).sampledType }; + break; + } } } void AstValidator::Visit(SwizzleExpression& node) { - RegisterScope(node); - if (node.componentCount > 4) throw AstError{ "Cannot swizzle more than four elements" }; - ExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache); - if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) - throw AstError{ "Cannot swizzle this type" }; + MandatoryExpr(node.expression); - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); + + ExpressionType exprType = GetExpressionType(*node.expression); + if (IsPrimitiveType(exprType) || IsVectorType(exprType)) + { + PrimitiveType baseType; + if (IsPrimitiveType(exprType)) + baseType = std::get(exprType); + else + baseType = std::get(exprType).type; + + if (node.componentCount > 1) + { + node.cachedExpressionType = VectorType{ + node.componentCount, + baseType + }; + } + else + node.cachedExpressionType = baseType; + } + else + throw AstError{ "Cannot swizzle this type" }; } void AstValidator::Visit(BranchStatement& node) { - RegisterScope(node); - for (auto& condStatement : node.condStatements) { - ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache); + ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition)); if (!IsPrimitiveType(condType) || std::get(condType) != PrimitiveType::Boolean) throw AstError{ "if expression must resolve to boolean type" }; MandatoryStatement(condStatement.statement); } - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); } void AstValidator::Visit(ConditionalStatement& node) { MandatoryStatement(node.statement); - RegisterScope(node); - - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); //if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) // throw AstError{ "condition not found" }; } void AstValidator::Visit(DeclareExternalStatement& node) { - RegisterScope(node); - auto& scope = m_context->cache->scopes[*m_context->activeScopeId]; - for (const auto& [attributeType, arg] : node.attributes) { switch (attributeType) @@ -513,7 +594,7 @@ namespace Nz::ShaderAst throw AstError{ "attribute layout requires a string parameter" }; if (std::get(arg) != "std140") - throw AstError{ "unknow layout type" }; + throw AstError{ "unknown layout type" }; hasLayout = true; break; @@ -528,14 +609,9 @@ namespace Nz::ShaderAst throw AstError{ "External variable " + extVar.name + " is already declared" }; m_context->declaredExternalVar.insert(extVar.name); - - ExpressionType subType = extVar.type; - if (IsUniformType(subType)) - subType = IdentifierType{ std::get(subType).containedType }; - - auto& identifier = scope.identifiers.emplace_back(); - identifier = AstCache::Identifier{ extVar.name, AstCache::Variable { std::move(subType) } }; } + + AstScopedVisitor::Visit(node); } void AstValidator::Visit(DeclareFunctionStatement& node) @@ -561,10 +637,10 @@ namespace Nz::ShaderAst ShaderStageType stageType = it->second; - if (m_context->cache->entryFunctions[UnderlyingCast(stageType)]) + if (m_context->entryFunctions[UnderlyingCast(stageType)]) throw AstError{ "the same entry type has been defined multiple times" }; - m_context->cache->entryFunctions[UnderlyingCast(it->second)] = &node; + m_context->entryFunctions[UnderlyingCast(it->second)] = &node; if (node.parameters.size() > 1) throw AstError{ "entry functions can either take one struct parameter or no parameter" }; @@ -578,103 +654,41 @@ namespace Nz::ShaderAst } } - auto& scope = EnterScope(); - RegisterScope(node); - - for (auto& parameter : node.parameters) - { - auto& identifier = scope.identifiers.emplace_back(); - identifier = AstCache::Identifier{ parameter.name, AstCache::Variable { parameter.type } }; - } - for (auto& statement : node.statements) - MandatoryStatement(statement).Visit(*this); + MandatoryStatement(statement); - ExitScope(); + AstScopedVisitor::Visit(node); } void AstValidator::Visit(DeclareStructStatement& node) { assert(m_context); - if (!m_context->activeScopeId) - throw AstError{ "cannot declare variable without scope" }; - - RegisterScope(node); - //TODO: check members attributes - auto& scope = m_context->cache->scopes[*m_context->activeScopeId]; - - auto& identifier = scope.identifiers.emplace_back(); - identifier = AstCache::Identifier{ node.description.name, node.description }; - - AstRecursiveVisitor::Visit(node); - } - - void AstValidator::Visit(DeclareVariableStatement& node) - { - assert(m_context); - - if (!m_context->activeScopeId) - throw AstError{ "cannot declare variable without scope" }; - - RegisterScope(node); - - auto& scope = m_context->cache->scopes[*m_context->activeScopeId]; - - auto& identifier = scope.identifiers.emplace_back(); - identifier = AstCache::Identifier{ node.varName, AstCache::Variable { node.varType } }; - - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); } void AstValidator::Visit(ExpressionStatement& node) { - RegisterScope(node); - MandatoryExpr(node.expression); - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); } void AstValidator::Visit(MultiStatement& node) { assert(m_context); - EnterScope(); - - RegisterScope(node); - for (auto& statement : node.statements) MandatoryStatement(statement); - ExitScope(); - - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); } - void AstValidator::Visit(ReturnStatement& node) - { - RegisterScope(node); - - /*if (m_context->currentFunction->returnType != ShaderExpressionType(BasicType::Void)) - { - if (GetExpressionType(MandatoryExpr(node.returnExpr)) != m_context->currentFunction->returnType) - throw AstError{ "Return type doesn't match function return type" }; - } - else - { - if (node.returnExpr) - throw AstError{ "Unexpected expression for return (function doesn't return)" }; - }*/ - - AstRecursiveVisitor::Visit(node); - } - - bool ValidateAst(StatementPtr& node, std::string* error, AstCache* cache) + bool ValidateAst(StatementPtr& node, std::string* error) { AstValidator validator; - return validator.Validate(node, error, cache); + return validator.Validate(node, error); } } diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index d559e7965..28aca39ca 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -4,7 +4,6 @@ #include #include -#include #include #include #include @@ -39,13 +38,13 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node) { - ShaderAst::ExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache); + ShaderAst::ExpressionType resultExprType = GetExpressionType(node); assert(IsPrimitiveType(resultExprType)); - ShaderAst::ExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache); + ShaderAst::ExpressionType leftExprType = GetExpressionType(*node.left); assert(IsPrimitiveType(leftExprType)); - ShaderAst::ExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache); + ShaderAst::ExpressionType rightExprType = GetExpressionType(*node.right); assert(IsPrimitiveType(rightExprType)); ShaderAst::PrimitiveType resultType = std::get(resultExprType); @@ -582,7 +581,7 @@ namespace Nz { case ShaderAst::IntrinsicType::DotProduct: { - ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache); + ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0]); assert(IsVectorType(vecExprType)); const ShaderAst::VectorType& vecType = std::get(vecExprType); @@ -626,7 +625,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) { - ShaderAst::ExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache); + ShaderAst::ExpressionType targetExprType = GetExpressionType(node); assert(IsPrimitiveType(targetExprType)); ShaderAst::PrimitiveType targetType = std::get(targetExprType); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index cc3421ac2..2aff0d642 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -50,6 +50,11 @@ namespace Nz return Compare(lhs.parameters, rhs.parameters) && Compare(lhs.returnType, rhs.returnType); } + bool Compare(const Identifier& lhs, const Identifier& rhs) const + { + return lhs.name == rhs.name; + } + bool Compare(const Image& lhs, const Image& rhs) const { return lhs.arrayed == rhs.arrayed @@ -226,6 +231,11 @@ namespace Nz void Register(const Integer&) {} void Register(const Void&) {} + void Register(const Identifier& identifier) + { + Register(identifier); + } + void Register(const Image& image) { Register(image.sampledType); @@ -456,6 +466,11 @@ namespace Nz UInt32 SpirvConstantCache::Register(Type t) { AnyType& type = t.type; + if (std::holds_alternative(type)) + { + assert(m_identifierCallback); + return Register(*m_identifierCallback(std::get(type).name)); + } DepRegisterer registerer(*this); registerer.Register(type); @@ -487,6 +502,11 @@ namespace Nz return it.value(); } + void SpirvConstantCache::SetIdentifierCallback(IdentifierCallback callback) + { + m_identifierCallback = std::move(callback); + } + void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) { for (auto&& [object, id] : m_internal->ids) @@ -597,7 +617,7 @@ namespace Nz return std::make_shared(Pointer{ BuildType(type), storageClass - }); + }); } auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) -> TypePtr @@ -605,37 +625,16 @@ namespace Nz return std::visit([&](auto&& arg) -> TypePtr { return BuildType(arg); - /*else if constexpr (std::is_same_v) - { - // Register struct members type - const auto& structs = shader.GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); - if (it == structs.end()) - throw std::runtime_error("struct " + arg + " has not been defined"); - - const ShaderAst::Struct& s = *it; - - Structure sType; - sType.name = s.name; - - for (const auto& member : s.members) - { - auto& sMembers = sType.members.emplace_back(); - sMembers.name = member.name; - sMembers.type = BuildType(shader, member.type); - } - - return std::make_shared(std::move(sType)); - return nullptr; - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor");*/ }, type); } auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) -> TypePtr { - throw std::runtime_error("unexpected type"); + return std::make_shared( + Identifier{ + type.name + } + ); } auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) -> TypePtr @@ -691,6 +690,21 @@ namespace Nz return std::make_shared(SampledImage{ std::make_shared(imageType) }); } + auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) -> TypePtr + { + Structure sType; + sType.name = structDesc.name; + + for (const auto& member : structDesc.members) + { + auto& sMembers = sType.members.emplace_back(); + sMembers.name = member.name; + sMembers.type = BuildType(member.type); + } + + return std::make_shared(std::move(sType)); + } + auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) -> TypePtr { return std::make_shared(Vector{ BuildType(type.type), UInt32(type.componentCount) }); @@ -767,6 +781,8 @@ namespace Nz appender(GetId(*param)); }); } + else if constexpr (std::is_same_v) + throw std::runtime_error("unexpected identifier"); else if constexpr (std::is_same_v) { UInt32 depth; @@ -915,6 +931,8 @@ namespace Nz } else if constexpr (std::is_same_v) throw std::runtime_error("unexpected function as struct member"); + else if constexpr (std::is_same_v) + throw std::runtime_error("unexpected identifier"); else if constexpr (std::is_same_v || std::is_same_v) throw std::runtime_error("unexpected opaque type as struct member"); else if constexpr (std::is_same_v) diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index fb1ad3725..0627c569e 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -25,18 +25,26 @@ namespace Nz { namespace { - class PreVisitor : public ShaderAst::AstRecursiveVisitor + class PreVisitor : public ShaderAst::AstScopedVisitor { public: using ExtInstList = std::unordered_set; using LocalContainer = std::unordered_set; using FunctionContainer = std::vector>; - PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : - m_cache(cache), + PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : m_conditions(conditions), m_constantCache(constantCache) { + m_constantCache.SetIdentifierCallback([&](const std::string& identifierName) + { + const Identifier* identifier = FindIdentifier(identifierName); + if (!identifier) + throw std::runtime_error("invalid identifier " + identifierName); + + assert(std::holds_alternative(identifier->value)); + return SpirvConstantCache::BuildType(std::get(identifier->value)); + }); } void Visit(ShaderAst::AccessMemberExpression& node) override @@ -74,7 +82,7 @@ namespace Nz m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg)); }, node.value); - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); } void Visit(ShaderAst::DeclareFunctionStatement& node) override @@ -87,11 +95,13 @@ namespace Nz m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes)); - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); } void Visit(ShaderAst::DeclareStructStatement& node) override { + AstScopedVisitor::Visit(node); + SpirvConstantCache::Structure sType; sType.name = node.description.name; @@ -107,21 +117,21 @@ namespace Nz void Visit(ShaderAst::DeclareVariableStatement& node) override { - m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType)); + AstScopedVisitor::Visit(node); - AstRecursiveVisitor::Visit(node); + m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType)); } void Visit(ShaderAst::IdentifierExpression& node) override { - m_constantCache.Register(*SpirvConstantCache::BuildType(GetExpressionType(node, m_cache))); + m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value())); - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); } void Visit(ShaderAst::IntrinsicExpression& node) override { - AstRecursiveVisitor::Visit(node); + AstScopedVisitor::Visit(node); switch (node.intrinsic) { @@ -140,7 +150,6 @@ namespace Nz FunctionContainer funcs; private: - ShaderAst::AstCache* m_cache; const SpirvWriter::States& m_conditions; SpirvConstantCache& m_constantCache; }; @@ -214,7 +223,7 @@ namespace Nz std::vector SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions) { std::string error; - if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache)) + if (!ShaderAst::ValidateAst(shader, &error)) throw std::runtime_error("Invalid shader AST: " + error); m_context.states = &conditions; @@ -229,7 +238,7 @@ namespace Nz ShaderAst::AstCloner cloner; // Register all extended instruction sets - PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache); + PreVisitor preVisitor(conditions, state.constantTypeCache); shader->Visit(preVisitor); for (const std::string& extInst : preVisitor.extInsts) @@ -397,7 +406,7 @@ namespace Nz state.parameterIds.emplace(param.name, std::move(parameterData)); } - SpirvAstVisitor visitor(*this, state.functionBlocks, &m_context.cache); + SpirvAstVisitor visitor(*this, state.functionBlocks); for (const auto& statement : func.statements) statement->Visit(visitor); @@ -419,7 +428,7 @@ namespace Nz for (std::size_t i = 0; i < ShaderStageTypeCount; ++i) { - const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i]; + /*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i]; if (!statement) continue; @@ -462,7 +471,7 @@ namespace Nz }); if (stage == ShaderStageType::Fragment) - state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft); + state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/ } std::vector ret;