From afe3a0ea93c1274b7e1af552d7ddd7972d11e7ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Thu, 15 Apr 2021 11:20:56 +0200 Subject: [PATCH] Shader: Merge AstScopedVisitor, AstValidator and TransformVisitor to SanitizeVisitor --- include/Nazara/Shader.hpp | 1 - include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 112 +++ ...ansformVisitor.inl => SanitizeVisitor.inl} | 24 +- .../Nazara/Shader/Ast/TransformVisitor.hpp | 90 -- include/Nazara/Shader/GlslWriter.hpp | 4 +- include/Nazara/Shader/GlslWriter.inl | 9 + include/Nazara/Shader/ShaderAstCloner.hpp | 6 +- include/Nazara/Shader/ShaderAstCloner.inl | 16 + .../Nazara/Shader/ShaderAstScopedVisitor.hpp | 68 -- .../Nazara/Shader/ShaderAstScopedVisitor.inl | 38 - include/Nazara/Shader/ShaderAstValidator.hpp | 66 -- include/Nazara/Shader/ShaderAstValidator.inl | 16 - src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 832 ++++++++++++++++++ src/Nazara/Shader/Ast/TransformVisitor.cpp | 228 ----- src/Nazara/Shader/GlslWriter.cpp | 58 +- src/Nazara/Shader/ShaderAstCloner.cpp | 14 +- src/Nazara/Shader/ShaderAstScopedVisitor.cpp | 110 --- src/Nazara/Shader/ShaderAstValidator.cpp | 642 -------------- src/Nazara/Shader/SpirvWriter.cpp | 32 +- 19 files changed, 1027 insertions(+), 1339 deletions(-) create mode 100644 include/Nazara/Shader/Ast/SanitizeVisitor.hpp rename include/Nazara/Shader/Ast/{TransformVisitor.inl => SanitizeVisitor.inl} (57%) delete mode 100644 include/Nazara/Shader/Ast/TransformVisitor.hpp delete mode 100644 include/Nazara/Shader/ShaderAstScopedVisitor.hpp delete mode 100644 include/Nazara/Shader/ShaderAstScopedVisitor.inl delete mode 100644 include/Nazara/Shader/ShaderAstValidator.hpp delete mode 100644 include/Nazara/Shader/ShaderAstValidator.inl create mode 100644 src/Nazara/Shader/Ast/SanitizeVisitor.cpp delete mode 100644 src/Nazara/Shader/Ast/TransformVisitor.cpp delete mode 100644 src/Nazara/Shader/ShaderAstScopedVisitor.cpp delete mode 100644 src/Nazara/Shader/ShaderAstValidator.cpp diff --git a/include/Nazara/Shader.hpp b/include/Nazara/Shader.hpp index 65cd2025e..122296819 100644 --- a/include/Nazara/Shader.hpp +++ b/include/Nazara/Shader.hpp @@ -41,7 +41,6 @@ #include #include #include -#include #include #include #include diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp new file mode 100644 index 000000000..43167a7f8 --- /dev/null +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -0,0 +1,112 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERAST_TRANSFORMVISITOR_HPP +#define NAZARA_SHADERAST_TRANSFORMVISITOR_HPP + +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_SHADER_API SanitizeVisitor final : AstCloner + { + public: + inline SanitizeVisitor(); + SanitizeVisitor(const SanitizeVisitor&) = delete; + SanitizeVisitor(SanitizeVisitor&&) = delete; + ~SanitizeVisitor() = default; + + StatementPtr Sanitize(StatementPtr& statement, std::string* error = nullptr); + + SanitizeVisitor& operator=(const SanitizeVisitor&) = delete; + SanitizeVisitor& operator=(SanitizeVisitor&&) = delete; + + private: + struct Identifier; + + const ExpressionType& CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices); + + using AstCloner::CloneExpression; + + ExpressionPtr Clone(AccessMemberIdentifierExpression& node) override; + ExpressionPtr Clone(AssignExpression& node) override; + ExpressionPtr Clone(BinaryExpression& node) override; + ExpressionPtr Clone(CastExpression& node) override; + ExpressionPtr Clone(ConditionalExpression& node) override; + ExpressionPtr Clone(ConstantExpression& node) override; + ExpressionPtr Clone(IdentifierExpression& node) override; + ExpressionPtr Clone(IntrinsicExpression& node) override; + ExpressionPtr Clone(SwizzleExpression& node) override; + + StatementPtr Clone(BranchStatement& node) override; + StatementPtr Clone(ConditionalStatement& node) override; + StatementPtr Clone(DeclareExternalStatement& node) override; + StatementPtr Clone(DeclareFunctionStatement& node) override; + StatementPtr Clone(DeclareStructStatement& node) override; + StatementPtr Clone(DeclareVariableStatement& node) override; + StatementPtr Clone(ExpressionStatement& node) override; + StatementPtr Clone(MultiStatement& node) override; + + inline const Identifier* FindIdentifier(const std::string_view& identifierName) const; + + Expression& MandatoryExpr(ExpressionPtr& node); + Statement& MandatoryStatement(StatementPtr& node); + void TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right); + void TypeMustMatch(const ExpressionType& left, const ExpressionType& right); + + void PushScope(); + void PopScope(); + + inline std::size_t RegisterFunction(std::string name); + inline std::size_t RegisterStruct(std::string name, StructDescription description); + inline std::size_t RegisterVariable(std::string name, ExpressionType type); + + std::size_t ResolveStruct(const ExpressionType& exprType); + std::size_t ResolveStruct(const IdentifierType& identifierType); + std::size_t ResolveStruct(const StructType& structType); + std::size_t ResolveStruct(const UniformType& uniformType); + ExpressionType ResolveType(const ExpressionType& exprType); + + struct Alias + { + std::variant value; + }; + + struct Struct + { + std::size_t structIndex; + }; + + struct Variable + { + std::size_t varIndex; + }; + + struct Identifier + { + std::string name; + std::variant value; + }; + + std::size_t m_nextFuncIndex; + std::vector m_identifiersInScope; + std::vector m_structs; + std::vector m_variables; + std::vector m_scopeSizes; + + struct Context; + Context* m_context; + }; + + inline StatementPtr Sanitize(StatementPtr& ast, std::string* error = nullptr); +} + +#include + +#endif diff --git a/include/Nazara/Shader/Ast/TransformVisitor.inl b/include/Nazara/Shader/Ast/SanitizeVisitor.inl similarity index 57% rename from include/Nazara/Shader/Ast/TransformVisitor.inl rename to include/Nazara/Shader/Ast/SanitizeVisitor.inl index 07601b026..29a902b8a 100644 --- a/include/Nazara/Shader/Ast/TransformVisitor.inl +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.inl @@ -2,18 +2,17 @@ // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp -#include +#include #include namespace Nz::ShaderAst { - inline TransformVisitor::TransformVisitor() : - m_nextFuncIndex(0), - m_nextVarIndex(0) + inline SanitizeVisitor::SanitizeVisitor() : + m_nextFuncIndex(0) { } - inline auto TransformVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* + inline auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* { auto it = std::find_if(m_identifiersInScope.rbegin(), m_identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); if (it == m_identifiersInScope.rend()) @@ -22,14 +21,14 @@ namespace Nz::ShaderAst return &*it; } - inline std::size_t TransformVisitor::RegisterFunction(std::string name) + inline std::size_t SanitizeVisitor::RegisterFunction(std::string name) { std::size_t funcIndex = m_nextFuncIndex++; return funcIndex; } - inline std::size_t TransformVisitor::RegisterStruct(std::string name, StructDescription description) + inline std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription description) { std::size_t structIndex = m_structs.size(); m_structs.emplace_back(std::move(description)); @@ -44,9 +43,10 @@ namespace Nz::ShaderAst return structIndex; } - inline std::size_t TransformVisitor::RegisterVariable(std::string name) + inline std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type) { - std::size_t varIndex = m_nextVarIndex++; + std::size_t varIndex = m_variables.size(); + m_variables.emplace_back(std::move(type)); m_identifiersInScope.push_back({ std::move(name), @@ -57,6 +57,12 @@ namespace Nz::ShaderAst return varIndex; } + + StatementPtr Sanitize(StatementPtr& ast, std::string* error) + { + SanitizeVisitor sanitizer; + return sanitizer.Sanitize(ast, error); + } } #include diff --git a/include/Nazara/Shader/Ast/TransformVisitor.hpp b/include/Nazara/Shader/Ast/TransformVisitor.hpp deleted file mode 100644 index d2483a095..000000000 --- a/include/Nazara/Shader/Ast/TransformVisitor.hpp +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADERAST_TRANSFORMVISITOR_HPP -#define NAZARA_SHADERAST_TRANSFORMVISITOR_HPP - -#include -#include -#include -#include - -namespace Nz::ShaderAst -{ - class NAZARA_SHADER_API TransformVisitor : AstCloner - { - public: - inline TransformVisitor(); - TransformVisitor(const TransformVisitor&) = delete; - TransformVisitor(TransformVisitor&&) = delete; - ~TransformVisitor() = default; - - StatementPtr Transform(StatementPtr& statement); - - TransformVisitor& operator=(const TransformVisitor&) = delete; - TransformVisitor& operator=(TransformVisitor&&) = delete; - - private: - struct Identifier; - - ExpressionPtr Clone(AccessMemberIdentifierExpression& node) override; - ExpressionPtr Clone(CastExpression& node) override; - ExpressionPtr Clone(IdentifierExpression& node) override; - ExpressionPtr CloneExpression(ExpressionPtr& expr) override; - - inline const Identifier* FindIdentifier(const std::string_view& identifierName) const; - - void PushScope(); - void PopScope(); - - inline std::size_t RegisterFunction(std::string name); - inline std::size_t RegisterStruct(std::string name, StructDescription description); - inline std::size_t RegisterVariable(std::string name); - - ExpressionType ResolveType(const ExpressionType& exprType); - - using AstCloner::Visit; - void Visit(BranchStatement& node) override; - void Visit(ConditionalStatement& node) override; - void Visit(DeclareExternalStatement& node) override; - void Visit(DeclareFunctionStatement& node) override; - void Visit(DeclareStructStatement& node) override; - void Visit(DeclareVariableStatement& node) override; - void Visit(MultiStatement& node) override; - - struct Alias - { - std::variant value; - }; - - struct Struct - { - std::size_t structIndex; - }; - - struct Variable - { - std::size_t varIndex; - }; - - struct Identifier - { - std::string name; - std::variant value; - }; - - private: - std::size_t m_nextFuncIndex; - std::size_t m_nextVarIndex; - std::vector m_identifiersInScope; - std::vector m_structs; - std::vector m_scopeSizes; - }; -} - -#include - -#endif diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 99e68d9f0..faea46ef7 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -24,12 +24,12 @@ namespace Nz struct Environment; using ExtSupportCallback = std::function; - GlslWriter(); + inline GlslWriter(); GlslWriter(const GlslWriter&) = delete; GlslWriter(GlslWriter&&) = delete; ~GlslWriter() = default; - std::string Generate(ShaderAst::StatementPtr& shader, const States& conditions = {}); + inline std::string Generate(ShaderAst::StatementPtr& shader, const States& conditions = {}); std::string Generate(std::optional shaderStage, ShaderAst::StatementPtr& shader, const States& conditions = {}); void SetEnv(Environment environment); diff --git a/include/Nazara/Shader/GlslWriter.inl b/include/Nazara/Shader/GlslWriter.inl index 1ecd13aee..e8c4d052c 100644 --- a/include/Nazara/Shader/GlslWriter.inl +++ b/include/Nazara/Shader/GlslWriter.inl @@ -7,6 +7,15 @@ namespace Nz { + inline GlslWriter::GlslWriter() : + m_currentState(nullptr) + { + } + + inline std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions) + { + return Generate(std::nullopt, shader, conditions); + } } #include diff --git a/include/Nazara/Shader/ShaderAstCloner.hpp b/include/Nazara/Shader/ShaderAstCloner.hpp index 8bd43c4fb..0d034cab9 100644 --- a/include/Nazara/Shader/ShaderAstCloner.hpp +++ b/include/Nazara/Shader/ShaderAstCloner.hpp @@ -30,9 +30,11 @@ namespace Nz::ShaderAst AstCloner& operator=(AstCloner&&) = delete; protected: - virtual ExpressionPtr CloneExpression(ExpressionPtr& expr); - virtual StatementPtr CloneStatement(StatementPtr& statement); + inline ExpressionPtr CloneExpression(ExpressionPtr& expr); + inline StatementPtr CloneStatement(StatementPtr& statement); + virtual ExpressionPtr CloneExpression(Expression& expr); + virtual StatementPtr CloneStatement(Statement& statement); virtual ExpressionPtr Clone(AccessMemberIdentifierExpression& node); virtual ExpressionPtr Clone(AccessMemberIndexExpression& node); diff --git a/include/Nazara/Shader/ShaderAstCloner.inl b/include/Nazara/Shader/ShaderAstCloner.inl index 20a829343..9d099bf63 100644 --- a/include/Nazara/Shader/ShaderAstCloner.inl +++ b/include/Nazara/Shader/ShaderAstCloner.inl @@ -7,6 +7,22 @@ namespace Nz::ShaderAst { + ExpressionPtr AstCloner::CloneExpression(ExpressionPtr& expr) + { + if (!expr) + return nullptr; + + return CloneExpression(*expr); + } + + StatementPtr AstCloner::CloneStatement(StatementPtr& statement) + { + if (!statement) + return nullptr; + + return CloneStatement(*statement); + } + inline ExpressionPtr Clone(ExpressionPtr& node) { AstCloner cloner; diff --git a/include/Nazara/Shader/ShaderAstScopedVisitor.hpp b/include/Nazara/Shader/ShaderAstScopedVisitor.hpp deleted file mode 100644 index e212a30b5..000000000 --- a/include/Nazara/Shader/ShaderAstScopedVisitor.hpp +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADER_SCOPED_VISITOR_HPP -#define NAZARA_SHADER_SCOPED_VISITOR_HPP - -#include -#include -#include - -namespace Nz::ShaderAst -{ - class NAZARA_SHADER_API AstScopedVisitor : public AstRecursiveVisitor - { - public: - struct Identifier; - - AstScopedVisitor() = default; - ~AstScopedVisitor() = default; - - inline const Identifier* FindIdentifier(const std::string_view& identifierName) const; - - void ScopedVisit(StatementPtr& nodePtr); - - using AstRecursiveVisitor::Visit; - void Visit(BranchStatement& node) override; - void Visit(ConditionalStatement& node) override; - void Visit(DeclareExternalStatement& node) override; - void Visit(DeclareFunctionStatement& node) override; - void Visit(DeclareStructStatement& node) override; - void Visit(DeclareVariableStatement& node) override; - void Visit(MultiStatement& node) override; - - struct Alias - { - std::variant value; - }; - - struct Variable - { - ExpressionType type; - }; - - struct Identifier - { - std::string name; - std::variant value; - }; - - protected: - void PushScope(); - void PopScope(); - - inline void RegisterStruct(StructDescription structDesc); - inline void RegisterVariable(std::string name, ExpressionType type); - - private: - std::vector m_identifiersInScope; - std::vector m_scopeSizes; - }; -} - -#include - -#endif diff --git a/include/Nazara/Shader/ShaderAstScopedVisitor.inl b/include/Nazara/Shader/ShaderAstScopedVisitor.inl deleted file mode 100644 index 67042069a..000000000 --- a/include/Nazara/Shader/ShaderAstScopedVisitor.inl +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include - -namespace Nz::ShaderAst -{ - inline auto AstScopedVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* - { - auto it = std::find_if(m_identifiersInScope.rbegin(), m_identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); - if (it == m_identifiersInScope.rend()) - return nullptr; - - return &*it; - } - - inline void AstScopedVisitor::RegisterStruct(StructDescription structDesc) - { - std::string name = structDesc.name; - - m_identifiersInScope.push_back({ - std::move(name), - std::move(structDesc) - }); - } - - inline void AstScopedVisitor::RegisterVariable(std::string name, ExpressionType type) - { - m_identifiersInScope.push_back({ - std::move(name), - Variable { std::move(type) } - }); - } -} - -#include diff --git a/include/Nazara/Shader/ShaderAstValidator.hpp b/include/Nazara/Shader/ShaderAstValidator.hpp deleted file mode 100644 index 14a491682..000000000 --- a/include/Nazara/Shader/ShaderAstValidator.hpp +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADERVALIDATOR_HPP -#define NAZARA_SHADERVALIDATOR_HPP - -#include -#include -#include -#include - -namespace Nz::ShaderAst -{ - class NAZARA_SHADER_API AstValidator final : public AstScopedVisitor - { - public: - inline AstValidator(); - AstValidator(const AstValidator&) = delete; - AstValidator(AstValidator&&) = delete; - ~AstValidator() = default; - - bool Validate(StatementPtr& node, std::string* error = nullptr); - - private: - const ExpressionType& GetExpressionType(Expression& expression); - Expression& MandatoryExpr(ExpressionPtr& node); - Statement& MandatoryStatement(StatementPtr& node); - void TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right); - void TypeMustMatch(const ExpressionType& left, const ExpressionType& right); - - ExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); - const ExpressionType& ResolveAlias(const ExpressionType& expressionType); - - void Visit(AccessMemberIdentifierExpression& node) override; - void Visit(AssignExpression& node) override; - void Visit(BinaryExpression& node) override; - void Visit(CastExpression& node) override; - void Visit(ConstantExpression& node) override; - void Visit(ConditionalExpression& node) override; - void Visit(IdentifierExpression& node) override; - void Visit(IntrinsicExpression& node) override; - void Visit(SwizzleExpression& node) override; - - void Visit(BranchStatement& node) override; - void Visit(ConditionalStatement& node) override; - void Visit(DeclareExternalStatement& node) override; - void Visit(DeclareFunctionStatement& node) override; - void Visit(DeclareStructStatement& node) override; - void Visit(DeclareVariableStatement& node) override; - void Visit(ExpressionStatement& node) override; - void Visit(MultiStatement& node) override; - - struct Context; - - Context* m_context; - }; - - NAZARA_SHADER_API bool ValidateAst(StatementPtr& node, std::string* error = nullptr); -} - -#include - -#endif diff --git a/include/Nazara/Shader/ShaderAstValidator.inl b/include/Nazara/Shader/ShaderAstValidator.inl deleted file mode 100644 index 2020badd4..000000000 --- a/include/Nazara/Shader/ShaderAstValidator.inl +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include - -namespace Nz::ShaderAst -{ - AstValidator::AstValidator() : - m_context(nullptr) - { - } -} - -#include diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp new file mode 100644 index 000000000..6da7ef18a --- /dev/null +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -0,0 +1,832 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + namespace + { + struct AstError + { + std::string errMsg; + }; + + template + std::unique_ptr static_unique_pointer_cast(std::unique_ptr&& ptr) + { + return std::unique_ptr(static_cast(ptr.release())); + } + } + + struct SanitizeVisitor::Context + { + std::array entryFunctions = {}; + std::unordered_set declaredExternalVar; + std::unordered_set usedBindingIndexes; + }; + + StatementPtr SanitizeVisitor::Sanitize(StatementPtr& nodePtr, std::string* error) + { + StatementPtr clone; + + Context currentContext; + + m_context = ¤tContext; + CallOnExit resetContext([&] { m_context = nullptr; }); + + PushScope(); //< Global scope + { + try + { + clone = AstCloner::Clone(nodePtr); + } + catch (const AstError& err) + { + if (!error) + throw std::runtime_error(err.errMsg); + + *error = err.errMsg; + } + } + PopScope(); + + return clone; + } + + const ExpressionType& SanitizeVisitor::CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices) + { + std::size_t structIndex = ResolveStruct(structType); + + *structIndices++ = structIndex; + + assert(structIndex < m_structs.size()); + const StructDescription& s = m_structs[structIndex]; + + auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; }); + if (memberIt == s.members.end()) + throw AstError{ "unknown field " + memberIdentifier[0] }; + + const auto& member = *memberIt; + + if (remainingMembers > 1) + return CheckField(member.type, memberIdentifier + 1, remainingMembers - 1, structIndices); + else + return member.type; + } + + ExpressionPtr SanitizeVisitor::Clone(AccessMemberIdentifierExpression& node) + { + auto structExpr = CloneExpression(MandatoryExpr(node.structExpr)); + + const ExpressionType& exprType = GetExpressionType(*structExpr); + + // Transform to AccessMemberIndexExpression + auto accessMemberIndex = std::make_unique(); + accessMemberIndex->structExpr = std::move(structExpr); + + StackArray structIndices = NazaraStackArrayNoInit(std::size_t, node.memberIdentifiers.size()); + + accessMemberIndex->cachedExpressionType = ResolveType(CheckField(exprType, node.memberIdentifiers.data(), node.memberIdentifiers.size(), structIndices.data())); + + accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size()); + for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i) + { + std::size_t structIndex = structIndices[i]; + assert(structIndex < m_structs.size()); + const StructDescription& structDesc = m_structs[structIndex]; + + auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; }); + assert(it != structDesc.members.end()); + + accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it); + } + + return accessMemberIndex; + } + + ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node) + { + MandatoryExpr(node.left); + MandatoryExpr(node.right); + + if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) + throw AstError{ "Assignation is only possible with a l-value" }; + + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + TypeMustMatch(clone->left, clone->right); + clone->cachedExpressionType = GetExpressionType(*clone->right); + + return clone; + } + + ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node) + { + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + const ExpressionType& leftExprType = GetExpressionType(MandatoryExpr(clone->left)); + if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType)) + throw AstError{ "left expression type does not support binary operation" }; + + const ExpressionType& rightExprType = GetExpressionType(MandatoryExpr(clone->right)); + if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType)) + throw AstError{ "right expression type does not support binary operation" }; + + if (IsPrimitiveType(leftExprType)) + { + PrimitiveType leftType = std::get(leftExprType); + switch (clone->op) + { + case BinaryType::CompGe: + case BinaryType::CompGt: + case BinaryType::CompLe: + case BinaryType::CompLt: + if (leftType == PrimitiveType::Boolean) + throw AstError{ "this operation is not supported for booleans" }; + + TypeMustMatch(clone->left, clone->right); + + clone->cachedExpressionType = PrimitiveType::Boolean; + break; + + case BinaryType::Add: + case BinaryType::CompEq: + case BinaryType::CompNe: + case BinaryType::Subtract: + TypeMustMatch(clone->left, clone->right); + + clone->cachedExpressionType = leftExprType; + break; + + case BinaryType::Multiply: + case BinaryType::Divide: + { + switch (leftType) + { + case PrimitiveType::Float32: + case PrimitiveType::Int32: + case PrimitiveType::UInt32: + { + if (IsMatrixType(rightExprType)) + { + TypeMustMatch(leftType, std::get(rightExprType).type); + clone->cachedExpressionType = rightExprType; + } + else if (IsPrimitiveType(rightExprType)) + { + TypeMustMatch(leftType, rightExprType); + clone->cachedExpressionType = leftExprType; + } + else if (IsVectorType(rightExprType)) + { + TypeMustMatch(leftType, std::get(rightExprType).type); + clone->cachedExpressionType = rightExprType; + } + else + throw AstError{ "incompatible types" }; + + break; + } + + case PrimitiveType::Boolean: + throw AstError{ "this operation is not supported for booleans" }; + + default: + throw AstError{ "incompatible types" }; + } + } + } + } + else if (IsMatrixType(leftExprType)) + { + const MatrixType& leftType = std::get(leftExprType); + switch (clone->op) + { + case BinaryType::CompGe: + case BinaryType::CompGt: + case BinaryType::CompLe: + case BinaryType::CompLt: + case BinaryType::CompEq: + case BinaryType::CompNe: + TypeMustMatch(clone->left, clone->right); + clone->cachedExpressionType = PrimitiveType::Boolean; + break; + + case BinaryType::Add: + case BinaryType::Subtract: + TypeMustMatch(clone->left, clone->right); + clone->cachedExpressionType = leftExprType; + break; + + case BinaryType::Multiply: + case BinaryType::Divide: + { + if (IsMatrixType(rightExprType)) + { + TypeMustMatch(leftExprType, rightExprType); + clone->cachedExpressionType = leftExprType; //< FIXME + } + else if (IsPrimitiveType(rightExprType)) + { + TypeMustMatch(leftType.type, rightExprType); + clone->cachedExpressionType = leftExprType; + } + else if (IsVectorType(rightExprType)) + { + const VectorType& rightType = std::get(rightExprType); + TypeMustMatch(leftType.type, rightType.type); + + if (leftType.columnCount != rightType.componentCount) + throw AstError{ "incompatible types" }; + + clone->cachedExpressionType = rightExprType; + } + else + throw AstError{ "incompatible types" }; + } + } + } + else if (IsVectorType(leftExprType)) + { + const VectorType& leftType = std::get(leftExprType); + switch (clone->op) + { + case BinaryType::CompGe: + case BinaryType::CompGt: + case BinaryType::CompLe: + case BinaryType::CompLt: + case BinaryType::CompEq: + case BinaryType::CompNe: + TypeMustMatch(clone->left, clone->right); + clone->cachedExpressionType = PrimitiveType::Boolean; + break; + + case BinaryType::Add: + case BinaryType::Subtract: + TypeMustMatch(clone->left, clone->right); + clone->cachedExpressionType = leftExprType; + break; + + case BinaryType::Multiply: + case BinaryType::Divide: + { + if (IsPrimitiveType(rightExprType)) + { + TypeMustMatch(leftType.type, rightExprType); + clone->cachedExpressionType = rightExprType; + } + else if (IsVectorType(rightExprType)) + { + TypeMustMatch(leftType, rightExprType); + clone->cachedExpressionType = rightExprType; + } + else + throw AstError{ "incompatible types" }; + } + } + } + + return clone; + } + + ExpressionPtr SanitizeVisitor::Clone(CastExpression& node) + { + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t + { + if (IsVectorType(exprType)) + return std::get(exprType).componentCount; + else + { + assert(IsPrimitiveType(exprType)); + return 1; + } + }; + + std::size_t componentCount = 0; + std::size_t requiredComponents = GetComponentCount(clone->targetType); + + for (auto& exprPtr : clone->expressions) + { + if (!exprPtr) + break; + + const ExpressionType& exprType = GetExpressionType(*exprPtr); + if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) + throw AstError{ "incompatible type" }; + + componentCount += GetComponentCount(exprType); + } + + if (componentCount != requiredComponents) + throw AstError{ "component count doesn't match required component count" }; + + clone->targetType = ResolveType(clone->targetType); + clone->cachedExpressionType = clone->targetType; + + return clone; + } + + ExpressionPtr SanitizeVisitor::Clone(ConditionalExpression& node) + { + MandatoryExpr(node.truePath); + MandatoryExpr(node.falsePath); + + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + const ExpressionType& leftExprType = GetExpressionType(*clone->truePath); + if (leftExprType != GetExpressionType(*clone->falsePath)) + throw AstError{ "true path type must match false path type" }; + + clone->cachedExpressionType = leftExprType; + + return clone; + } + + ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) + { + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + clone->cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + return PrimitiveType::Boolean; + else if constexpr (std::is_same_v) + return PrimitiveType::Float32; + else if constexpr (std::is_same_v) + return PrimitiveType::Int32; + else if constexpr (std::is_same_v) + return PrimitiveType::UInt32; + else if constexpr (std::is_same_v) + return VectorType{ 2, PrimitiveType::Float32 }; + else if constexpr (std::is_same_v) + return VectorType{ 3, PrimitiveType::Float32 }; + else if constexpr (std::is_same_v) + return VectorType{ 4, PrimitiveType::Float32 }; + else if constexpr (std::is_same_v) + return VectorType{ 2, PrimitiveType::Int32 }; + else if constexpr (std::is_same_v) + return VectorType{ 3, PrimitiveType::Int32 }; + else if constexpr (std::is_same_v) + return VectorType{ 4, PrimitiveType::Int32 }; + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, clone->value); + + return clone; + } + + ExpressionPtr SanitizeVisitor::Clone(IdentifierExpression& node) + { + assert(m_context); + + const Identifier* identifier = FindIdentifier(node.identifier); + if (!identifier) + throw AstError{ "unknown identifier " + node.identifier }; + + if (!std::holds_alternative(identifier->value)) + throw AstError{ "expected variable identifier" }; + + const Variable& variable = std::get(identifier->value); + + // Replace IdentifierExpression by VariableExpression + auto varExpr = std::make_unique(); + varExpr->cachedExpressionType = m_variables[variable.varIndex]; + varExpr->variableId = variable.varIndex; + + return varExpr; + } + + ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node) + { + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + // Parameter validation + switch (clone->intrinsic) + { + case IntrinsicType::CrossProduct: + case IntrinsicType::DotProduct: + { + if (clone->parameters.size() != 2) + throw AstError { "Expected two parameters" }; + + for (auto& param : clone->parameters) + MandatoryExpr(param); + + const ExpressionType& type = GetExpressionType(*clone->parameters.front()); + + for (std::size_t i = 1; i < clone->parameters.size(); ++i) + { + if (type != GetExpressionType(*clone->parameters[i])) + throw AstError{ "All type must match" }; + } + + break; + } + + case IntrinsicType::SampleTexture: + { + if (clone->parameters.size() != 2) + throw AstError{ "Expected two parameters" }; + + for (auto& param : clone->parameters) + MandatoryExpr(param); + + if (!IsSamplerType(GetExpressionType(*clone->parameters[0]))) + throw AstError{ "First parameter must be a sampler" }; + + if (!IsVectorType(GetExpressionType(*clone->parameters[1]))) + throw AstError{ "Second parameter must be a vector" }; + } + } + + // Return type attribution + switch (clone->intrinsic) + { + case IntrinsicType::CrossProduct: + { + const ExpressionType& type = GetExpressionType(*clone->parameters.front()); + if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) + throw AstError{ "CrossProduct only works with vec3 expressions" }; + + clone->cachedExpressionType = type; + break; + } + + case IntrinsicType::DotProduct: + { + ExpressionType type = GetExpressionType(*clone->parameters.front()); + if (!IsVectorType(type)) + throw AstError{ "DotProduct expects vector types" }; + + clone->cachedExpressionType = std::get(type).type; + break; + } + + case IntrinsicType::SampleTexture: + { + clone->cachedExpressionType = VectorType{ 4, std::get(GetExpressionType(*clone->parameters.front())).sampledType }; + break; + } + } + + return clone; + } + + ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node) + { + if (node.componentCount > 4) + throw AstError{ "Cannot swizzle more than four elements" }; + + MandatoryExpr(node.expression); + + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + const ExpressionType& exprType = GetExpressionType(*clone->expression); + if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) + throw AstError{ "Cannot swizzle this type" }; + + PrimitiveType baseType; + if (IsPrimitiveType(exprType)) + baseType = std::get(exprType); + else + baseType = std::get(exprType).type; + + if (clone->componentCount > 1) + { + clone->cachedExpressionType = VectorType{ + clone->componentCount, + baseType + }; + } + else + clone->cachedExpressionType = baseType; + + return clone; + } + + StatementPtr SanitizeVisitor::Clone(BranchStatement& node) + { + auto clone = std::make_unique(); + clone->condStatements.reserve(node.condStatements.size()); + + for (auto& cond : node.condStatements) + { + PushScope(); + + auto& condStatement = clone->condStatements.emplace_back(); + condStatement.condition = CloneExpression(MandatoryExpr(cond.condition)); + + const ExpressionType& condType = GetExpressionType(*condStatement.condition); + if (!IsPrimitiveType(condType) || std::get(condType) != PrimitiveType::Boolean) + throw AstError{ "branch expressions must resolve to boolean type" }; + + condStatement.statement = CloneStatement(MandatoryStatement(cond.statement)); + + PopScope(); + } + + if (node.elseStatement) + { + PushScope(); + clone->elseStatement = CloneStatement(node.elseStatement); + PopScope(); + } + + return clone; + } + + StatementPtr SanitizeVisitor::Clone(ConditionalStatement& node) + { + MandatoryStatement(node.statement); + + PushScope(); + + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + PopScope(); + + return clone; + } + + StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node) + { + assert(m_context); + + for (const auto& extVar : node.externalVars) + { + if (extVar.bindingIndex) + { + unsigned int bindingIndex = extVar.bindingIndex.value(); + if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end()) + throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" }; + + m_context->usedBindingIndexes.insert(bindingIndex); + } + + if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end()) + throw AstError{ "External variable " + extVar.name + " is already declared" }; + + m_context->declaredExternalVar.insert(extVar.name); + } + + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + for (auto& extVar : clone->externalVars) + { + extVar.type = ResolveType(extVar.type); + + ExpressionType varType = extVar.type; + if (IsUniformType(extVar.type)) + varType = std::get(std::get(varType).containedType); + + std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType)); + if (!clone->varIndex) + clone->varIndex = varIndex; //< First external variable index is node variable index + } + + return clone; + } + + StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node) + { + if (node.entryStage) + { + ShaderStageType stageType = *node.entryStage; + + if (m_context->entryFunctions[UnderlyingCast(stageType)]) + throw AstError{ "the same entry type has been defined multiple times" }; + + m_context->entryFunctions[UnderlyingCast(stageType)] = &node; + + if (node.parameters.size() > 1) + throw AstError{ "entry functions can either take one struct parameter or no parameter" }; + } + + auto clone = std::make_unique(); + clone->entryStage = node.entryStage; + clone->name = node.name; + clone->funcIndex = m_nextFuncIndex++; + clone->parameters = node.parameters; + clone->returnType = ResolveType(node.returnType); + + PushScope(); + { + for (auto& parameter : clone->parameters) + { + parameter.type = ResolveType(parameter.type); + std::size_t varIndex = RegisterVariable(parameter.name, parameter.type); + if (!clone->varIndex) + clone->varIndex = varIndex; //< First parameter variable index is node variable index + } + + clone->statements.reserve(node.statements.size()); + for (auto& statement : node.statements) + clone->statements.push_back(CloneStatement(MandatoryStatement(statement))); + } + PopScope(); + + return clone; + } + + StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node) + { + std::unordered_set declaredMembers; + + for (auto& member : node.description.members) + { + if (declaredMembers.find(member.name) != declaredMembers.end()) + throw AstError{ "struct member " + member.name + " found multiple time" }; + + declaredMembers.insert(member.name); + } + + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + for (auto& member : clone->description.members) + member.type = ResolveType(member.type); + + clone->structIndex = RegisterStruct(clone->description.name, clone->description); + + return clone; + } + + StatementPtr SanitizeVisitor::Clone(DeclareVariableStatement& node) + { + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + if (IsNoType(clone->varType)) + { + if (!clone->initialExpression) + throw AstError{ "variable must either have a type or an initial value" }; + + clone->varType = ResolveType(GetExpressionType(*clone->initialExpression)); + } + else + clone->varType = ResolveType(clone->varType); + + clone->varIndex = RegisterVariable(clone->varName, clone->varType); + + return clone; + } + + StatementPtr SanitizeVisitor::Clone(ExpressionStatement& node) + { + MandatoryExpr(node.expression); + + return AstCloner::Clone(node); + } + + StatementPtr SanitizeVisitor::Clone(MultiStatement& node) + { + for (auto& statement : node.statements) + MandatoryStatement(statement); + + PushScope(); + + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + PopScope(); + + return clone; + } + + Expression& SanitizeVisitor::MandatoryExpr(ExpressionPtr& node) + { + if (!node) + throw AstError{ "Invalid expression" }; + + return *node; + } + + Statement& SanitizeVisitor::MandatoryStatement(StatementPtr& node) + { + if (!node) + throw AstError{ "Invalid statement" }; + + return *node; + } + + void SanitizeVisitor::PushScope() + { + m_scopeSizes.push_back(m_identifiersInScope.size()); + } + + void SanitizeVisitor::PopScope() + { + assert(!m_scopeSizes.empty()); + m_identifiersInScope.resize(m_scopeSizes.back()); + m_scopeSizes.pop_back(); + } + + std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType) + { + return std::visit([&](auto&& arg) -> std::size_t + { + using T = std::decay_t; + + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) + return ResolveStruct(arg); + else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) + { + throw AstError{ "expression is not a structure" }; + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, exprType); + } + + std::size_t SanitizeVisitor::ResolveStruct(const IdentifierType& identifierType) + { + const Identifier* identifier = FindIdentifier(identifierType.name); + if (!identifier) + throw AstError{ "unknown identifier " + identifierType.name }; + + if (!std::holds_alternative(identifier->value)) + throw AstError{ identifierType.name + " is not a struct" }; + + return std::get(identifier->value).structIndex; + } + + std::size_t SanitizeVisitor::ResolveStruct(const StructType& structType) + { + return structType.structIndex; + } + + std::size_t SanitizeVisitor::ResolveStruct(const UniformType& uniformType) + { + return std::visit([&](auto&& arg) -> std::size_t + { + using T = std::decay_t; + + if constexpr (std::is_same_v || std::is_same_v) + return ResolveStruct(arg); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, uniformType.containedType); + } + + ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType) + { + return std::visit([&](auto&& arg) -> ExpressionType + { + using T = std::decay_t; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return exprType; + } + else if constexpr (std::is_same_v) + { + const Identifier* identifier = FindIdentifier(arg.name); + if (!identifier) + throw AstError{ "unknown identifier " + arg.name }; + + if (!std::holds_alternative(identifier->value)) + throw AstError{ "expected type identifier" }; + + return StructType{ std::get(identifier->value).structIndex }; + } + else if constexpr (std::is_same_v) + { + return std::visit([&](auto&& containedArg) + { + ExpressionType resolvedType = ResolveType(containedArg); + assert(std::holds_alternative(resolvedType)); + + return UniformType{ std::get(resolvedType) }; + }, arg.containedType); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, exprType); + } + + void SanitizeVisitor::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right) + { + return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right)); + } + + void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) + { + if (left != right) + throw AstError{ "Left expression type must match right expression type" }; + } +} diff --git a/src/Nazara/Shader/Ast/TransformVisitor.cpp b/src/Nazara/Shader/Ast/TransformVisitor.cpp deleted file mode 100644 index 5e8e660ab..000000000 --- a/src/Nazara/Shader/Ast/TransformVisitor.cpp +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include -#include - -namespace Nz::ShaderAst -{ - StatementPtr TransformVisitor::Transform(StatementPtr& nodePtr) - { - StatementPtr clone; - - PushScope(); //< Global scope - { - clone = AstCloner::Clone(nodePtr); - } - PopScope(); - - return clone; - } - - void TransformVisitor::Visit(BranchStatement& node) - { - for (auto& cond : node.condStatements) - { - PushScope(); - { - cond.condition->Visit(*this); - cond.statement->Visit(*this); - } - PopScope(); - } - - if (node.elseStatement) - { - PushScope(); - { - node.elseStatement->Visit(*this); - } - PopScope(); - } - } - - void TransformVisitor::Visit(ConditionalStatement& node) - { - PushScope(); - { - AstCloner::Visit(node); - } - PopScope(); - } - - ExpressionType TransformVisitor::ResolveType(const ExpressionType& exprType) - { - return std::visit([&](auto&& arg) -> ExpressionType - { - using T = std::decay_t; - - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v) - { - return exprType; - } - else if constexpr (std::is_same_v) - { - const Identifier* identifier = FindIdentifier(arg.name); - assert(identifier); - assert(std::holds_alternative(identifier->value)); - - return StructType{ std::get(identifier->value).structIndex }; - } - else if constexpr (std::is_same_v) - { - return std::visit([&](auto&& containedArg) - { - ExpressionType resolvedType = ResolveType(containedArg); - assert(std::holds_alternative(resolvedType)); - - return UniformType{ std::get(resolvedType) }; - }, arg.containedType); - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, exprType); - } - - void TransformVisitor::Visit(DeclareExternalStatement& node) - { - for (auto& extVar : node.externalVars) - { - extVar.type = ResolveType(extVar.type); - - std::size_t varIndex = RegisterVariable(extVar.name); - if (!node.varIndex) - node.varIndex = varIndex; - } - - AstCloner::Visit(node); - } - - void TransformVisitor::Visit(DeclareFunctionStatement& node) - { - node.funcIndex = m_nextFuncIndex++; - node.returnType = ResolveType(node.returnType); - for (auto& parameter : node.parameters) - parameter.type = ResolveType(parameter.type); - - PushScope(); - { - for (auto& parameter : node.parameters) - { - std::size_t varIndex = RegisterVariable(parameter.name); - if (!node.varIndex) - node.varIndex = varIndex; - } - - AstCloner::Visit(node); - } - PopScope(); - } - - void TransformVisitor::Visit(DeclareStructStatement& node) - { - for (auto& member : node.description.members) - member.type = ResolveType(member.type); - - node.structIndex = RegisterStruct(node.description.name, node.description); - - AstCloner::Visit(node); - } - - void TransformVisitor::Visit(DeclareVariableStatement& node) - { - node.varType = ResolveType(node.varType); - node.varIndex = RegisterVariable(node.varName); - - AstCloner::Visit(node); - } - - void TransformVisitor::Visit(MultiStatement& node) - { - PushScope(); - { - AstCloner::Visit(node); - } - PopScope(); - } - - ExpressionPtr TransformVisitor::Clone(AccessMemberIdentifierExpression& node) - { - auto accessMemberIndex = std::make_unique(); - accessMemberIndex->structExpr = CloneExpression(node.structExpr); - accessMemberIndex->cachedExpressionType = node.cachedExpressionType; - accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size()); - - ExpressionType exprType = GetExpressionType(*node.structExpr); - for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i) - { - exprType = ResolveType(exprType); - assert(std::holds_alternative(exprType)); - - std::size_t structIndex = std::get(exprType).structIndex; - assert(structIndex < m_structs.size()); - const StructDescription& structDesc = m_structs[structIndex]; - - auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; }); - assert(it != structDesc.members.end()); - - accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it); - exprType = it->type; - } - - return accessMemberIndex; - } - - ExpressionPtr TransformVisitor::Clone(CastExpression& node) - { - ExpressionPtr expr = AstCloner::Clone(node); - - CastExpression* castExpr = static_cast(expr.get()); - castExpr->targetType = ResolveType(castExpr->targetType); - - return expr; - } - - ExpressionPtr TransformVisitor::Clone(IdentifierExpression& node) - { - const Identifier* identifier = FindIdentifier(node.identifier); - assert(identifier); - assert(std::holds_alternative(identifier->value)); - - auto varExpr = std::make_unique(); - varExpr->cachedExpressionType = node.cachedExpressionType; - varExpr->variableId = std::get(identifier->value).varIndex; - - return varExpr; - } - - ExpressionPtr TransformVisitor::CloneExpression(ExpressionPtr& expr) - { - ExpressionPtr exprPtr = AstCloner::CloneExpression(expr); - if (exprPtr) - { - assert(exprPtr->cachedExpressionType); - *exprPtr->cachedExpressionType = ResolveType(*exprPtr->cachedExpressionType); - } - - return exprPtr; - } - - void TransformVisitor::PushScope() - { - m_scopeSizes.push_back(m_identifiersInScope.size()); - } - - void TransformVisitor::PopScope() - { - assert(!m_scopeSizes.empty()); - m_identifiersInScope.resize(m_scopeSizes.back()); - m_scopeSizes.pop_back(); - } -} diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index a509245e4..c354ffbd1 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -8,9 +8,9 @@ #include #include #include +#include #include -#include -#include +#include #include #include #include @@ -31,36 +31,29 @@ namespace Nz return it->second; } - struct PreVisitor : ShaderAst::AstCloner + struct PreVisitor : ShaderAst::AstRecursiveVisitor { - using AstCloner::Clone; + using AstRecursiveVisitor::Visit; - ShaderAst::StatementPtr Clone(ShaderAst::DeclareFunctionStatement& node) override + void Visit(ShaderAst::DeclareFunctionStatement& node) override { - auto clone = AstCloner::Clone(node); - assert(clone->GetType() == ShaderAst::NodeType::DeclareFunctionStatement); - - ShaderAst::DeclareFunctionStatement* func = static_cast(clone.get()); - - // Remove function if it's an entry point of another type than the one selected + // Dismiss function if it's an entry point of another type than the one selected if (selectedStage) { if (node.entryStage) { ShaderStageType stage = *node.entryStage; if (stage != *selectedStage) - return ShaderBuilder::NoOp(); + return; - entryPoint = func; + entryPoint = &node; } } else { assert(!entryPoint); - entryPoint = func; + entryPoint = &node; } - - return clone; } std::optional selectedStage; @@ -99,17 +92,6 @@ namespace Nz unsigned int indentLevel = 0; }; - - GlslWriter::GlslWriter() : - m_currentState(nullptr) - { - } - - std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions) - { - return Generate(std::nullopt, shader, conditions); - } - std::string GlslWriter::Generate(std::optional shaderStage, ShaderAst::StatementPtr& shader, const States& conditions) { State state; @@ -121,17 +103,11 @@ namespace Nz m_currentState = nullptr; }); - std::string error; - if (!ShaderAst::ValidateAst(shader, &error)) - throw std::runtime_error("Invalid shader AST: " + error); - - ShaderAst::TransformVisitor transformVisitor; - ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader); + ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader); PreVisitor previsitor; previsitor.selectedStage = shaderStage; - - ShaderAst::StatementPtr adaptedShader = previsitor.Clone(transformedShader); + sanitizedAst->Visit(previsitor); if (!previsitor.entryPoint) throw std::runtime_error("missing entry point"); @@ -140,7 +116,7 @@ namespace Nz AppendHeader(); - adaptedShader->Visit(*this); + sanitizedAst->Visit(*this); return state.stream.str(); } @@ -361,6 +337,9 @@ namespace Nz void GlslWriter::HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node) { + if (m_currentState->entryFunc != &node) + return; //< Ignore other entry points + HandleInOut(); AppendLine("void main()"); EnterScope(); @@ -712,11 +691,10 @@ namespace Nz { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); - if (m_currentState->entryFunc == &node) + if (node.entryStage) return HandleEntryPoint(node); - assert(node.varIndex); - std::size_t varIndex = *node.varIndex; + std::optional varIndexOpt = node.varIndex; Append(node.returnType); Append(" "); @@ -731,6 +709,8 @@ namespace Nz Append(" "); Append(node.parameters[i].name); + assert(varIndexOpt); + std::size_t& varIndex = *varIndexOpt; RegisterVariable(varIndex++, node.parameters[i].name); } Append(")\n"); diff --git a/src/Nazara/Shader/ShaderAstCloner.cpp b/src/Nazara/Shader/ShaderAstCloner.cpp index efa5be3a5..af1348c72 100644 --- a/src/Nazara/Shader/ShaderAstCloner.cpp +++ b/src/Nazara/Shader/ShaderAstCloner.cpp @@ -24,21 +24,15 @@ namespace Nz::ShaderAst return PopStatement(); } - ExpressionPtr AstCloner::CloneExpression(ExpressionPtr& expr) + ExpressionPtr AstCloner::CloneExpression(Expression& expr) { - if (!expr) - return nullptr; - - expr->Visit(*this); + expr.Visit(*this); return PopExpression(); } - StatementPtr AstCloner::CloneStatement(StatementPtr& statement) + StatementPtr AstCloner::CloneStatement(Statement& statement) { - if (!statement) - return nullptr; - - statement->Visit(*this); + statement.Visit(*this); return PopStatement(); } diff --git a/src/Nazara/Shader/ShaderAstScopedVisitor.cpp b/src/Nazara/Shader/ShaderAstScopedVisitor.cpp deleted file mode 100644 index 39e80470e..000000000 --- a/src/Nazara/Shader/ShaderAstScopedVisitor.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include - -namespace Nz::ShaderAst -{ - void AstScopedVisitor::ScopedVisit(StatementPtr& nodePtr) - { - PushScope(); //< Global scope - { - nodePtr->Visit(*this); - } - PopScope(); - } - - void AstScopedVisitor::Visit(BranchStatement& node) - { - for (auto& cond : node.condStatements) - { - PushScope(); - { - cond.condition->Visit(*this); - cond.statement->Visit(*this); - } - PopScope(); - } - - if (node.elseStatement) - { - PushScope(); - { - node.elseStatement->Visit(*this); - } - PopScope(); - } - } - - void AstScopedVisitor::Visit(ConditionalStatement& node) - { - PushScope(); - { - AstRecursiveVisitor::Visit(node); - } - PopScope(); - } - - void AstScopedVisitor::Visit(DeclareExternalStatement& node) - { - for (auto& extVar : node.externalVars) - { - ExpressionType subType = extVar.type; - if (IsUniformType(subType)) - subType = std::get(std::get(subType).containedType); - - RegisterVariable(extVar.name, std::move(subType)); - } - - AstRecursiveVisitor::Visit(node); - } - - void AstScopedVisitor::Visit(DeclareFunctionStatement& node) - { - PushScope(); - { - for (auto& parameter : node.parameters) - RegisterVariable(parameter.name, parameter.type); - - AstRecursiveVisitor::Visit(node); - } - PopScope(); - } - - void AstScopedVisitor::Visit(DeclareStructStatement& node) - { - RegisterStruct(node.description); - - AstRecursiveVisitor::Visit(node); - } - - void AstScopedVisitor::Visit(DeclareVariableStatement& node) - { - RegisterVariable(node.varName, node.varType); - - AstRecursiveVisitor::Visit(node); - } - - void AstScopedVisitor::Visit(MultiStatement& node) - { - PushScope(); - { - AstRecursiveVisitor::Visit(node); - } - PopScope(); - } - - void AstScopedVisitor::PushScope() - { - m_scopeSizes.push_back(m_identifiersInScope.size()); - } - - void AstScopedVisitor::PopScope() - { - assert(!m_scopeSizes.empty()); - m_identifiersInScope.resize(m_scopeSizes.back()); - m_scopeSizes.pop_back(); - } -} diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp deleted file mode 100644 index d83e0dbce..000000000 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ /dev/null @@ -1,642 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include -#include -#include -#include -#include - -namespace Nz::ShaderAst -{ - struct AstError - { - std::string errMsg; - }; - - struct AstValidator::Context - { - std::array entryFunctions = {}; - std::unordered_set declaredExternalVar; - std::unordered_set usedBindingIndexes; - }; - - bool AstValidator::Validate(StatementPtr& node, std::string* error) - { - try - { - Context currentContext; - - m_context = ¤tContext; - CallOnExit resetContext([&] { m_context = nullptr; }); - - ScopedVisit(node); - return true; - } - catch (const AstError& e) - { - if (error) - *error = e.errMsg; - - return false; - } - } - - const ExpressionType& AstValidator::GetExpressionType(Expression& expression) - { - assert(expression.cachedExpressionType); - return ResolveAlias(expression.cachedExpressionType.value()); - } - - Expression& AstValidator::MandatoryExpr(ExpressionPtr& node) - { - if (!node) - throw AstError{ "Invalid expression" }; - - return *node; - } - - Statement& AstValidator::MandatoryStatement(StatementPtr& node) - { - if (!node) - throw AstError{ "Invalid statement" }; - - return *node; - } - - void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right) - { - return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right)); - } - - void AstValidator::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) - { - if (left != right) - throw AstError{ "Left expression type must match right expression type" }; - } - - ExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) - { - const Identifier* identifier = FindIdentifier(structName); - if (!identifier) - throw AstError{ "unknown identifier " + structName }; - - if (!std::holds_alternative(identifier->value)) - throw AstError{ "identifier is not a struct" }; - - const StructDescription& s = std::get(identifier->value); - - auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; }); - if (memberIt == s.members.end()) - throw AstError{ "unknown field " + memberIdentifier[0]}; - - const auto& member = *memberIt; - - if (remainingMembers > 1) - return CheckField(std::get(member.type).name, memberIdentifier + 1, remainingMembers - 1); - else - return member.type; - } - - const ExpressionType& AstValidator::ResolveAlias(const ExpressionType& expressionType) - { - if (!IsIdentifierType(expressionType)) - return expressionType; - - const Identifier* identifier = FindIdentifier(std::get(expressionType).name); - if (identifier && std::holds_alternative(identifier->value)) - { - const Alias& alias = std::get(identifier->value); - return std::visit([&](auto&& arg) -> const ShaderAst::ExpressionType& - { - using T = std::decay_t; - - if constexpr (std::is_same_v) - return arg; - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, alias.value); - } - - return expressionType; - } - - void AstValidator::Visit(AccessMemberIdentifierExpression& node) - { - // Register expressions types - AstScopedVisitor::Visit(node); - - ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr)); - if (!IsIdentifierType(exprType)) - throw AstError{ "expression is not a structure" }; - - const std::string& structName = std::get(exprType).name; - - node.cachedExpressionType = CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()); - } - - void AstValidator::Visit(AssignExpression& node) - { - MandatoryExpr(node.left); - MandatoryExpr(node.right); - - // Register expressions types - AstScopedVisitor::Visit(node); - - TypeMustMatch(node.left, node.right); - - if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) - throw AstError { "Assignation is only possible with a l-value" }; - - node.cachedExpressionType = GetExpressionType(*node.right); - } - - void AstValidator::Visit(BinaryExpression& node) - { - // Register expression type - AstScopedVisitor::Visit(node); - - ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left)); - if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType)) - throw AstError{ "left expression type does not support binary operation" }; - - ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right)); - if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType)) - throw AstError{ "right expression type does not support binary operation" }; - - if (IsPrimitiveType(leftExprType)) - { - PrimitiveType leftType = std::get(leftExprType); - switch (node.op) - { - case BinaryType::CompGe: - case BinaryType::CompGt: - case BinaryType::CompLe: - case BinaryType::CompLt: - if (leftType == PrimitiveType::Boolean) - throw AstError{ "this operation is not supported for booleans" }; - - TypeMustMatch(node.left, node.right); - - node.cachedExpressionType = PrimitiveType::Boolean; - break; - - case BinaryType::Add: - case BinaryType::CompEq: - case BinaryType::CompNe: - case BinaryType::Subtract: - TypeMustMatch(node.left, node.right); - - node.cachedExpressionType = leftExprType; - break; - - case BinaryType::Multiply: - case BinaryType::Divide: - { - switch (leftType) - { - case PrimitiveType::Float32: - case PrimitiveType::Int32: - case PrimitiveType::UInt32: - { - if (IsMatrixType(rightExprType)) - { - TypeMustMatch(leftType, std::get(rightExprType).type); - node.cachedExpressionType = rightExprType; - } - else if (IsPrimitiveType(rightExprType)) - { - TypeMustMatch(leftType, rightExprType); - node.cachedExpressionType = leftExprType; - } - else if (IsVectorType(rightExprType)) - { - TypeMustMatch(leftType, std::get(rightExprType).type); - node.cachedExpressionType = rightExprType; - } - else - throw AstError{ "incompatible types" }; - - break; - } - - case PrimitiveType::Boolean: - throw AstError{ "this operation is not supported for booleans" }; - - default: - throw AstError{ "incompatible types" }; - } - } - } - } - else if (IsMatrixType(leftExprType)) - { - const MatrixType& leftType = std::get(leftExprType); - switch (node.op) - { - case BinaryType::CompGe: - case BinaryType::CompGt: - case BinaryType::CompLe: - case BinaryType::CompLt: - case BinaryType::CompEq: - case BinaryType::CompNe: - TypeMustMatch(node.left, node.right); - node.cachedExpressionType = PrimitiveType::Boolean; - break; - - case BinaryType::Add: - case BinaryType::Subtract: - TypeMustMatch(node.left, node.right); - node.cachedExpressionType = leftExprType; - break; - - case BinaryType::Multiply: - case BinaryType::Divide: - { - if (IsMatrixType(rightExprType)) - { - TypeMustMatch(leftExprType, rightExprType); - node.cachedExpressionType = leftExprType; //< FIXME - } - else if (IsPrimitiveType(rightExprType)) - { - TypeMustMatch(leftType.type, rightExprType); - node.cachedExpressionType = leftExprType; - } - else if (IsVectorType(rightExprType)) - { - const VectorType& rightType = std::get(rightExprType); - TypeMustMatch(leftType.type, rightType.type); - - if (leftType.columnCount != rightType.componentCount) - throw AstError{ "incompatible types" }; - - node.cachedExpressionType = rightExprType; - } - else - throw AstError{ "incompatible types" }; - } - } - } - else if (IsVectorType(leftExprType)) - { - const VectorType& leftType = std::get(leftExprType); - switch (node.op) - { - case BinaryType::CompGe: - case BinaryType::CompGt: - case BinaryType::CompLe: - case BinaryType::CompLt: - case BinaryType::CompEq: - case BinaryType::CompNe: - TypeMustMatch(node.left, node.right); - node.cachedExpressionType = PrimitiveType::Boolean; - break; - - case BinaryType::Add: - case BinaryType::Subtract: - TypeMustMatch(node.left, node.right); - node.cachedExpressionType = leftExprType; - break; - - case BinaryType::Multiply: - case BinaryType::Divide: - { - if (IsPrimitiveType(rightExprType)) - { - TypeMustMatch(leftType.type, rightExprType); - node.cachedExpressionType = rightExprType; - } - else if (IsVectorType(rightExprType)) - { - TypeMustMatch(leftType, rightExprType); - node.cachedExpressionType = rightExprType; - } - else - throw AstError{ "incompatible types" }; - } - } - } - } - - void AstValidator::Visit(CastExpression& node) - { - AstScopedVisitor::Visit(node); - - auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t - { - if (IsPrimitiveType(exprType)) - return 1; - else if (IsVectorType(exprType)) - return std::get(exprType).componentCount; - else - throw AstError{ "wut" }; - }; - - std::size_t componentCount = 0; - std::size_t requiredComponents = GetComponentCount(node.targetType); - - for (auto& exprPtr : node.expressions) - { - if (!exprPtr) - break; - - const ExpressionType& exprType = GetExpressionType(*exprPtr); - if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) - throw AstError{ "incompatible type" }; - - componentCount += GetComponentCount(exprType); - } - - if (componentCount != requiredComponents) - throw AstError{ "component count doesn't match required component count" }; - - node.cachedExpressionType = node.targetType; - } - - void AstValidator::Visit(ConstantExpression& node) - { - node.cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType - { - using T = std::decay_t; - - if constexpr (std::is_same_v) - return PrimitiveType::Boolean; - else if constexpr (std::is_same_v) - return PrimitiveType::Float32; - else if constexpr (std::is_same_v) - return PrimitiveType::Int32; - else if constexpr (std::is_same_v) - return PrimitiveType::UInt32; - else if constexpr (std::is_same_v) - return VectorType{ 2, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 3, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 4, PrimitiveType::Float32 }; - else if constexpr (std::is_same_v) - return VectorType{ 2, PrimitiveType::Int32 }; - else if constexpr (std::is_same_v) - return VectorType{ 3, PrimitiveType::Int32 }; - else if constexpr (std::is_same_v) - return VectorType{ 4, PrimitiveType::Int32 }; - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, node.value); - - } - - void AstValidator::Visit(ConditionalExpression& node) - { - MandatoryExpr(node.truePath); - MandatoryExpr(node.falsePath); - - AstScopedVisitor::Visit(node); - - ExpressionType leftExprType = GetExpressionType(*node.truePath); - if (leftExprType != GetExpressionType(*node.falsePath)) - throw AstError{ "true path type must match false path type" }; - - node.cachedExpressionType = leftExprType; - //if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) - // throw AstError{ "condition not found" }; - } - - void AstValidator::Visit(IdentifierExpression& node) - { - assert(m_context); - - const Identifier* identifier = FindIdentifier(node.identifier); - if (!identifier) - throw AstError{ "Unknown identifier " + node.identifier }; - - node.cachedExpressionType = ResolveAlias(std::get(identifier->value).type); - } - - void AstValidator::Visit(IntrinsicExpression& node) - { - AstScopedVisitor::Visit(node); - - switch (node.intrinsic) - { - case IntrinsicType::CrossProduct: - case IntrinsicType::DotProduct: - { - if (node.parameters.size() != 2) - throw AstError { "Expected two parameters" }; - - for (auto& param : node.parameters) - MandatoryExpr(param); - - ExpressionType type = GetExpressionType(*node.parameters.front()); - - for (std::size_t i = 1; i < node.parameters.size(); ++i) - { - if (type != GetExpressionType(MandatoryExpr(node.parameters[i]))) - throw AstError{ "All type must match" }; - } - - break; - } - - case IntrinsicType::SampleTexture: - { - if (node.parameters.size() != 2) - throw AstError{ "Expected two parameters" }; - - for (auto& param : node.parameters) - MandatoryExpr(param); - - if (!IsSamplerType(GetExpressionType(*node.parameters[0]))) - throw AstError{ "First parameter must be a sampler" }; - - if (!IsVectorType(GetExpressionType(*node.parameters[1]))) - throw AstError{ "Second parameter must be a vector" }; - } - } - - switch (node.intrinsic) - { - case IntrinsicType::CrossProduct: - { - ExpressionType type = GetExpressionType(*node.parameters.front()); - if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) - throw AstError{ "CrossProduct only works with vec3 expressions" }; - - node.cachedExpressionType = std::move(type); - break; - } - - case IntrinsicType::DotProduct: - { - ExpressionType type = GetExpressionType(*node.parameters.front()); - if (!IsVectorType(type)) - throw AstError{ "DotProduct expects vector types" }; - - node.cachedExpressionType = std::get(type).type; - break; - } - - case IntrinsicType::SampleTexture: - { - node.cachedExpressionType = VectorType{ 4, std::get(GetExpressionType(*node.parameters.front())).sampledType }; - break; - } - } - } - - void AstValidator::Visit(SwizzleExpression& node) - { - if (node.componentCount > 4) - throw AstError{ "Cannot swizzle more than four elements" }; - - MandatoryExpr(node.expression); - - AstScopedVisitor::Visit(node); - - ExpressionType exprType = GetExpressionType(*node.expression); - if (IsPrimitiveType(exprType) || IsVectorType(exprType)) - { - PrimitiveType baseType; - if (IsPrimitiveType(exprType)) - baseType = std::get(exprType); - else - baseType = std::get(exprType).type; - - if (node.componentCount > 1) - { - node.cachedExpressionType = VectorType{ - node.componentCount, - baseType - }; - } - else - node.cachedExpressionType = baseType; - } - else - throw AstError{ "Cannot swizzle this type" }; - } - - void AstValidator::Visit(BranchStatement& node) - { - for (auto& condStatement : node.condStatements) - { - ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition)); - if (!IsPrimitiveType(condType) || std::get(condType) != PrimitiveType::Boolean) - throw AstError{ "if expression must resolve to boolean type" }; - - MandatoryStatement(condStatement.statement); - } - - AstScopedVisitor::Visit(node); - } - - void AstValidator::Visit(ConditionalStatement& node) - { - MandatoryStatement(node.statement); - - AstScopedVisitor::Visit(node); - //if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) - // throw AstError{ "condition not found" }; - } - - void AstValidator::Visit(DeclareExternalStatement& node) - { - for (const auto& extVar : node.externalVars) - { - if (extVar.bindingIndex) - { - unsigned int bindingIndex = extVar.bindingIndex.value(); - if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end()) - throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" }; - - m_context->usedBindingIndexes.insert(bindingIndex); - } - - if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end()) - throw AstError{ "External variable " + extVar.name + " is already declared" }; - - m_context->declaredExternalVar.insert(extVar.name); - } - - AstScopedVisitor::Visit(node); - } - - void AstValidator::Visit(DeclareFunctionStatement& node) - { - if (node.entryStage) - { - ShaderStageType stageType = *node.entryStage; - - if (m_context->entryFunctions[UnderlyingCast(stageType)]) - throw AstError{ "the same entry type has been defined multiple times" }; - - m_context->entryFunctions[UnderlyingCast(stageType)] = &node; - - if (node.parameters.size() > 1) - throw AstError{ "entry functions can either take one struct parameter or no parameter" }; - } - - for (auto& statement : node.statements) - MandatoryStatement(statement); - - AstScopedVisitor::Visit(node); - } - - void AstValidator::Visit(DeclareStructStatement& node) - { - assert(m_context); - - std::unordered_set declaredMembers; - - for (auto& member : node.description.members) - { - if (declaredMembers.find(member.name) != declaredMembers.end()) - throw AstError{ "struct member " + member.name + " found multiple time" }; - - declaredMembers.insert(member.name); - } - - AstScopedVisitor::Visit(node); - } - - void AstValidator::Visit(DeclareVariableStatement& node) - { - if (IsNoType(node.varType)) - { - if (!node.initialExpression) - throw AstError{ "variable must either have a type or an initial value" }; - - node.initialExpression->Visit(*this); - - node.varType = GetExpressionType(*node.initialExpression); - } - - AstScopedVisitor::Visit(node); - } - - void AstValidator::Visit(ExpressionStatement& node) - { - MandatoryExpr(node.expression); - - AstScopedVisitor::Visit(node); - } - - void AstValidator::Visit(MultiStatement& node) - { - assert(m_context); - - for (auto& statement : node.statements) - MandatoryStatement(statement); - - AstScopedVisitor::Visit(node); - } - - bool ValidateAst(StatementPtr& node, std::string* error) - { - AstValidator validator; - return validator.Validate(node, error); - } -} diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index af994659a..61e3be2f2 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -6,13 +6,13 @@ #include #include #include -#include +#include #include #include #include #include #include -#include +#include #include #include #include @@ -38,7 +38,7 @@ namespace Nz { ShaderAst::BuiltinEntry::VertexPosition, { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } } }; - class PreVisitor : public ShaderAst::AstScopedVisitor + class PreVisitor : public ShaderAst::AstRecursiveVisitor { public: struct UniformVar @@ -107,7 +107,7 @@ namespace Nz m_constantCache.Register(*m_constantCache.BuildConstant(arg)); }, node.value); - AstScopedVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); } void Visit(ShaderAst::DeclareExternalStatement& node) override @@ -233,13 +233,13 @@ namespace Nz } m_funcIndex = funcIndex; - AstScopedVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); m_funcIndex.reset(); } void Visit(ShaderAst::DeclareStructStatement& node) override { - AstScopedVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); assert(node.structIndex); std::size_t structIndex = *node.structIndex; @@ -253,7 +253,7 @@ namespace Nz void Visit(ShaderAst::DeclareVariableStatement& node) override { - AstScopedVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); assert(m_funcIndex); auto& func = m_funcs[*m_funcIndex]; @@ -269,12 +269,12 @@ namespace Nz { m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value())); - AstScopedVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); } void Visit(ShaderAst::IntrinsicExpression& node) override { - AstScopedVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); switch (node.intrinsic) { @@ -285,6 +285,7 @@ namespace Nz // Part of SPIR-V core case ShaderAst::IntrinsicType::DotProduct: + case ShaderAst::IntrinsicType::SampleTexture: break; } @@ -293,7 +294,7 @@ namespace Nz void Visit(ShaderAst::SwizzleExpression& node) override { - AstScopedVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value())); } @@ -391,12 +392,7 @@ namespace Nz std::vector SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions) { - std::string error; - if (!ShaderAst::ValidateAst(shader, &error)) - throw std::runtime_error("Invalid shader AST: " + error); - - ShaderAst::TransformVisitor transformVisitor; - ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader); + ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader); m_context.states = &conditions; @@ -409,7 +405,7 @@ namespace Nz // Register all extended instruction sets PreVisitor preVisitor(conditions, state.constantTypeCache, state.funcs); - transformedShader->Visit(preVisitor); + sanitizedAst->Visit(preVisitor); m_currentState->preVisitor = &preVisitor; @@ -417,7 +413,7 @@ namespace Nz state.extensionInstructions[extInst] = AllocateResultId(); SpirvAstVisitor visitor(*this, state.instructions, state.funcs); - transformedShader->Visit(visitor); + sanitizedAst->Visit(visitor); AppendHeader();