From be9bdc4705b0320ba6d55310ffc7eade2f19b46d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Tue, 8 Mar 2022 20:26:02 +0100 Subject: [PATCH] Modules are workings \o/ --- examples/RenderTest/main.cpp | 31 +- include/Nazara/Shader/Ast/AstCloner.hpp | 3 + include/Nazara/Shader/Ast/AstCompare.hpp | 3 + include/Nazara/Shader/Ast/AstCompare.inl | 24 ++ include/Nazara/Shader/Ast/AstNodeList.hpp | 3 + .../Nazara/Shader/Ast/AstRecursiveVisitor.hpp | 3 + include/Nazara/Shader/Ast/AstSerializer.hpp | 3 + include/Nazara/Shader/Ast/AstUtils.hpp | 3 + .../Shader/Ast/EliminateUnusedPassVisitor.inl | 3 + .../Shader/Ast/IndexRemapperVisitor.hpp | 3 +- include/Nazara/Shader/Ast/Nodes.hpp | 24 ++ include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 2 + include/Nazara/Shader/GlslWriter.hpp | 2 +- include/Nazara/Shader/LangWriter.hpp | 26 +- include/Nazara/Shader/ShaderBuilder.hpp | 18 + include/Nazara/Shader/ShaderBuilder.inl | 31 +- include/Nazara/Shader/ShaderLangParser.hpp | 2 + include/Nazara/Shader/ShaderLangTokenList.hpp | 1 + src/Nazara/Shader/Ast/AstCloner.cpp | 30 ++ src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp | 15 + src/Nazara/Shader/Ast/AstSerializer.cpp | 15 + src/Nazara/Shader/Ast/AstUtils.cpp | 15 + .../Shader/Ast/IndexRemapperVisitor.cpp | 22 +- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 340 ++++++++++++------ src/Nazara/Shader/GlslWriter.cpp | 85 +++-- src/Nazara/Shader/LangWriter.cpp | 242 ++++++++----- src/Nazara/Shader/ShaderLangLexer.cpp | 1 + src/Nazara/Shader/ShaderLangParser.cpp | 40 ++- src/Nazara/Shader/SpirvWriter.cpp | 8 + 29 files changed, 742 insertions(+), 256 deletions(-) diff --git a/examples/RenderTest/main.cpp b/examples/RenderTest/main.cpp index d6ec21f97..4bfce1b37 100644 --- a/examples/RenderTest/main.cpp +++ b/examples/RenderTest/main.cpp @@ -9,16 +9,26 @@ NAZARA_REQUEST_DEDICATED_GPU() -const char moduleSource[] = R"( +const char barModuleSource[] = R"( [nzsl_version("1.0")] +[uuid("4BB09DEE-F70A-442E-859F-E8F2F3F8583D")] module; fn dummy() {} +[export] [layout(std140)] struct Bar { } +)"; + +const char dataModuleSource[] = R"( +[nzsl_version("1.0")] +[uuid("E49DC9AD-469C-462C-9719-A6F012372029")] +module; + +import Test/Bar; struct Foo {} @@ -37,14 +47,15 @@ const char shaderSource[] = R"( [nzsl_version("1.0")] module; -import Test/module_test; +import Test/Data; +import Test/Bar; option red: bool = false; [set(0)] external { - [binding(0)] viewerData: uniform[Data], + [binding(0)] viewerData: uniform[Data] } [set(1)] @@ -138,10 +149,12 @@ int main() if (modulePath[0] != "Test") return {}; - if (modulePath[1] != "module_test") + if (modulePath[1] == "Bar") + return Nz::ShaderLang::Parse(std::string_view(barModuleSource, sizeof(barModuleSource))); + else if (modulePath[1] == "Data") + return Nz::ShaderLang::Parse(std::string_view(dataModuleSource, sizeof(dataModuleSource))); + else return {}; - - return Nz::ShaderLang::Parse(std::string_view(moduleSource, sizeof(moduleSource))); }; shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt); @@ -152,10 +165,12 @@ int main() } Nz::LangWriter langWriter; - std::cout << langWriter.Generate(*shaderModule) << std::endl; + std::string output = langWriter.Generate(*shaderModule); + std::cout << output << std::endl; + assert(Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(output))); Nz::ShaderWriter::States states; - states.optimize = true; + states.optimize = false; auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states); if (!fragVertShader) diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index 3dcc8e5ff..bf1eed77f 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -50,8 +50,11 @@ namespace Nz::ShaderAst virtual ExpressionPtr Clone(ConditionalExpression& node); virtual ExpressionPtr Clone(ConstantExpression& node); virtual ExpressionPtr Clone(ConstantValueExpression& node); + virtual ExpressionPtr Clone(FunctionExpression& node); virtual ExpressionPtr Clone(IdentifierExpression& node); virtual ExpressionPtr Clone(IntrinsicExpression& node); + virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node); + virtual ExpressionPtr Clone(StructTypeExpression& node); virtual ExpressionPtr Clone(SwizzleExpression& node); virtual ExpressionPtr Clone(VariableExpression& node); virtual ExpressionPtr Clone(UnaryExpression& node); diff --git a/include/Nazara/Shader/Ast/AstCompare.hpp b/include/Nazara/Shader/Ast/AstCompare.hpp index e4cfde283..8f73dd91f 100644 --- a/include/Nazara/Shader/Ast/AstCompare.hpp +++ b/include/Nazara/Shader/Ast/AstCompare.hpp @@ -42,8 +42,11 @@ namespace Nz::ShaderAst inline bool Compare(const ConditionalExpression& lhs, const ConditionalExpression& rhs); inline bool Compare(const ConstantExpression& lhs, const ConstantExpression& rhs); inline bool Compare(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs); + inline bool Compare(const FunctionExpression& lhs, const FunctionExpression& rhs); inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs); inline bool Compare(const IntrinsicExpression& lhs, const IntrinsicExpression& rhs); + inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs); + inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs); inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs); inline bool Compare(const VariableExpression& lhs, const VariableExpression& rhs); inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs); diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index e6502caff..d8e3aad4c 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -342,6 +342,14 @@ namespace Nz::ShaderAst return true; } + inline bool Compare(const FunctionExpression& lhs, const FunctionExpression& rhs) + { + if (!Compare(lhs.funcId, rhs.funcId)) + return false; + + return true; + } + inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs) { if (!Compare(lhs.identifier, rhs.identifier)) @@ -361,6 +369,22 @@ namespace Nz::ShaderAst return true; } + inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs) + { + if (!Compare(lhs.intrinsicId, rhs.intrinsicId)) + return false; + + return true; + } + + inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs) + { + if (!Compare(lhs.structTypeId, rhs.structTypeId)) + return false; + + return true; + } + inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs) { if (!Compare(lhs.componentCount, rhs.componentCount)) diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index b208750fc..442b5d96d 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -38,8 +38,11 @@ NAZARA_SHADERAST_EXPRESSION(CastExpression) NAZARA_SHADERAST_EXPRESSION(ConditionalExpression) NAZARA_SHADERAST_EXPRESSION(ConstantExpression) NAZARA_SHADERAST_EXPRESSION(ConstantValueExpression) +NAZARA_SHADERAST_EXPRESSION(FunctionExpression) NAZARA_SHADERAST_EXPRESSION(IdentifierExpression) NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression) +NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression) +NAZARA_SHADERAST_EXPRESSION(StructTypeExpression) NAZARA_SHADERAST_EXPRESSION(SwizzleExpression) NAZARA_SHADERAST_EXPRESSION(VariableExpression) NAZARA_SHADERAST_EXPRESSION(UnaryExpression) diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index 99d25c0cf..05c92d7de 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -30,8 +30,11 @@ namespace Nz::ShaderAst void Visit(ConditionalExpression& node) override; void Visit(ConstantValueExpression& node) override; void Visit(ConstantExpression& node) override; + void Visit(FunctionExpression& node) override; void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; + void Visit(IntrinsicFunctionExpression& node) override; + void Visit(StructTypeExpression& node) override; void Visit(SwizzleExpression& node) override; void Visit(VariableExpression& node) override; void Visit(UnaryExpression& node) override; diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index 6c87e8bb5..1569182b3 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -33,8 +33,11 @@ namespace Nz::ShaderAst void Serialize(ConstantExpression& node); void Serialize(ConditionalExpression& node); void Serialize(ConstantValueExpression& node); + void Serialize(FunctionExpression& node); void Serialize(IdentifierExpression& node); void Serialize(IntrinsicExpression& node); + void Serialize(IntrinsicFunctionExpression& node); + void Serialize(StructTypeExpression& node); void Serialize(SwizzleExpression& node); void Serialize(VariableExpression& node); void Serialize(UnaryExpression& node); diff --git a/include/Nazara/Shader/Ast/AstUtils.hpp b/include/Nazara/Shader/Ast/AstUtils.hpp index eafc3c088..4b2798515 100644 --- a/include/Nazara/Shader/Ast/AstUtils.hpp +++ b/include/Nazara/Shader/Ast/AstUtils.hpp @@ -41,8 +41,11 @@ namespace Nz::ShaderAst void Visit(ConditionalExpression& node) override; void Visit(ConstantValueExpression& node) override; void Visit(ConstantExpression& node) override; + void Visit(FunctionExpression& node) override; void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; + void Visit(IntrinsicFunctionExpression& node) override; + void Visit(StructTypeExpression& node) override; void Visit(SwizzleExpression& node) override; void Visit(VariableExpression& node) override; void Visit(UnaryExpression& node) override; diff --git a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl index 61eb50cb1..d7642b8d7 100644 --- a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl +++ b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl @@ -16,6 +16,9 @@ namespace Nz::ShaderAst inline ModulePtr EliminateUnusedPass(const Module& shaderModule, const DependencyCheckerVisitor::Config& config) { DependencyCheckerVisitor dependencyVisitor; + for (const auto& importedModule : shaderModule.importedModules) + dependencyVisitor.Process(*importedModule.module->rootNode, config); + dependencyVisitor.Process(*shaderModule.rootNode, config); dependencyVisitor.Resolve(); diff --git a/include/Nazara/Shader/Ast/IndexRemapperVisitor.hpp b/include/Nazara/Shader/Ast/IndexRemapperVisitor.hpp index cb24e3fdc..24d2e362a 100644 --- a/include/Nazara/Shader/Ast/IndexRemapperVisitor.hpp +++ b/include/Nazara/Shader/Ast/IndexRemapperVisitor.hpp @@ -45,7 +45,8 @@ namespace Nz::ShaderAst StatementPtr Clone(DeclareStructStatement& node) override; StatementPtr Clone(DeclareVariableStatement& node) override; - ExpressionPtr Clone(CallFunctionExpression& node) override; + ExpressionPtr Clone(FunctionExpression& node) override; + ExpressionPtr Clone(StructTypeExpression& node) override; ExpressionPtr Clone(VariableExpression& node) override; void HandleType(ExpressionValue& exprType); diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index c16835b8e..aff06c1f6 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -156,6 +156,14 @@ namespace Nz::ShaderAst ShaderAst::ConstantValue value; }; + struct NAZARA_SHADER_API FunctionExpression : Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::size_t funcId; + }; + struct NAZARA_SHADER_API IdentifierExpression : Expression { NodeType GetType() const override; @@ -173,6 +181,22 @@ namespace Nz::ShaderAst IntrinsicType intrinsic; }; + struct NAZARA_SHADER_API IntrinsicFunctionExpression : Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::size_t intrinsicId; + }; + + struct NAZARA_SHADER_API StructTypeExpression : Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::size_t structTypeId; + }; + struct NAZARA_SHADER_API SwizzleExpression : Expression { NodeType GetType() const override; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index fb0a6ce8e..ef1118931 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -108,6 +108,8 @@ namespace Nz::ShaderAst template const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const; TypeParameter FindTypeParameter(const std::string_view& identifierName) const; + ExpressionPtr HandleIdentifier(const IdentifierData* identifierData); + Expression& MandatoryExpr(const ExpressionPtr& node) const; Statement& MandatoryStatement(const StatementPtr& node) const; diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 758087d5f..fe03238cc 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -85,7 +85,7 @@ namespace Nz void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node); void HandleInOut(); - void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc); + void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc, std::string structName); void RegisterVariable(std::size_t varIndex, std::string varName); void ScopeVisit(ShaderAst::Statement& node); diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index d4862641d..453c2310f 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -43,10 +43,12 @@ namespace Nz struct DepthWriteAttribute; struct EarlyFragmentTestsAttribute; struct EntryAttribute; + struct LangVersionAttribute; struct LayoutAttribute; struct LocationAttribute; struct SetAttribute; struct UnrollAttribute; + struct UuidAttribute; void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ExpressionType& type); @@ -68,18 +70,21 @@ namespace Nz template void AppendAttributes(bool appendLine, Args&&... params); template void AppendAttributesInternal(bool& first, const T& param); template void AppendAttributesInternal(bool& first, const T1& firstParam, const T2& secondParam, Rest&&... params); - void AppendAttribute(BindingAttribute binding); - void AppendAttribute(BuiltinAttribute builtin); - void AppendAttribute(DepthWriteAttribute depthWrite); - void AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests); - void AppendAttribute(EntryAttribute entry); - void AppendAttribute(LayoutAttribute layout); - void AppendAttribute(LocationAttribute location); - void AppendAttribute(SetAttribute set); - void AppendAttribute(UnrollAttribute unroll); + void AppendAttribute(BindingAttribute attribute); + void AppendAttribute(BuiltinAttribute attribute); + void AppendAttribute(DepthWriteAttribute attribute); + void AppendAttribute(EarlyFragmentTestsAttribute attribute); + void AppendAttribute(EntryAttribute attribute); + void AppendAttribute(LangVersionAttribute attribute); + void AppendAttribute(LayoutAttribute attribute); + void AppendAttribute(LocationAttribute attribute); + void AppendAttribute(SetAttribute seattributet); + void AppendAttribute(UnrollAttribute attribute); + void AppendAttribute(UuidAttribute attribute); void AppendComment(const std::string& section); void AppendCommentSection(const std::string& section); void AppendHeader(); + template void AppendIdentifier(const T& map, std::size_t id); void AppendLine(const std::string& txt = {}); template void AppendLine(Args&&... params); void AppendStatementList(std::vector& statements); @@ -88,7 +93,7 @@ namespace Nz void LeaveScope(bool skipLine = true); void RegisterConstant(std::size_t constantIndex, std::string constantName); - void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc); + void RegisterStruct(std::size_t structIndex, std::string structName); void RegisterVariable(std::size_t varIndex, std::string varName); void ScopeVisit(ShaderAst::Statement& node); @@ -104,6 +109,7 @@ namespace Nz void Visit(ShaderAst::ConstantValueExpression& node) override; void Visit(ShaderAst::ConstantExpression& node) override; void Visit(ShaderAst::IntrinsicExpression& node) override; + void Visit(ShaderAst::StructTypeExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; void Visit(ShaderAst::VariableExpression& node) override; void Visit(ShaderAst::UnaryExpression& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 3972ebc52..240504c46 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -127,6 +127,11 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const; }; + struct Function + { + inline std::unique_ptr operator()(std::size_t funcId) const; + }; + struct Identifier { inline std::unique_ptr operator()(std::string name) const; @@ -142,6 +147,11 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(ShaderAst::IntrinsicType intrinsicType, std::vector parameters) const; }; + struct IntrinsicFunction + { + inline std::unique_ptr operator()(std::size_t intrinsicFunctionId, ShaderAst::IntrinsicType intrinsicType) const; + }; + struct Multi { inline std::unique_ptr operator()(std::vector statements = {}) const; @@ -163,6 +173,11 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(ShaderAst::StatementPtr statement) const; }; + struct StructType + { + inline std::unique_ptr operator()(std::size_t structTypeId) const; + }; + struct Swizzle { inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression, std::array swizzleComponents, std::size_t componentCount) const; @@ -206,13 +221,16 @@ namespace Nz::ShaderBuilder constexpr Impl::NoParam Discard; constexpr Impl::For For; constexpr Impl::ForEach ForEach; + constexpr Impl::Function Function; constexpr Impl::Identifier Identifier; + constexpr Impl::IntrinsicFunction IntrinsicFunction; constexpr Impl::Import Import; constexpr Impl::Intrinsic Intrinsic; constexpr Impl::Multi MultiStatement; constexpr Impl::NoParam NoOp; constexpr Impl::Return Return; constexpr Impl::Scoped Scoped; + constexpr Impl::StructType StructType; constexpr Impl::Swizzle Swizzle; constexpr Impl::Unary Unary; constexpr Impl::Variable Variable; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index bf603414a..7936545c6 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -315,7 +315,7 @@ namespace Nz::ShaderBuilder return expressionStatementNode; } - inline std::unique_ptr Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const + inline std::unique_ptr Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const { auto forNode = std::make_unique(); forNode->fromExpr = std::move(fromExpression); @@ -326,7 +326,7 @@ namespace Nz::ShaderBuilder return forNode; } - inline std::unique_ptr Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const + inline std::unique_ptr Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const { auto forNode = std::make_unique(); forNode->fromExpr = std::move(fromExpression); @@ -348,6 +348,15 @@ namespace Nz::ShaderBuilder return forEachNode; } + inline std::unique_ptr Impl::Function::operator()(std::size_t funcId) const + { + auto intrinsicTypeExpr = std::make_unique(); + intrinsicTypeExpr->cachedExpressionType = ShaderAst::FunctionType{ funcId }; + intrinsicTypeExpr->funcId = funcId; + + return intrinsicTypeExpr; + } + inline std::unique_ptr Impl::Identifier::operator()(std::string name) const { auto identifierNode = std::make_unique(); @@ -373,6 +382,15 @@ namespace Nz::ShaderBuilder return intrinsicExpression; } + inline std::unique_ptr Impl::IntrinsicFunction::operator()(std::size_t intrinsicFunctionId, ShaderAst::IntrinsicType intrinsicType) const + { + auto intrinsicTypeExpr = std::make_unique(); + intrinsicTypeExpr->cachedExpressionType = ShaderAst::IntrinsicFunctionType{ intrinsicType }; + intrinsicTypeExpr->intrinsicId = intrinsicFunctionId; + + return intrinsicTypeExpr; + } + inline std::unique_ptr Impl::Multi::operator()(std::vector statements) const { auto multiStatement = std::make_unique(); @@ -403,6 +421,15 @@ namespace Nz::ShaderBuilder return scopedNode; } + inline std::unique_ptr Impl::StructType::operator()(std::size_t structTypeId) const + { + auto structTypeExpr = std::make_unique(); + structTypeExpr->cachedExpressionType = ShaderAst::StructType{ structTypeId }; + structTypeExpr->structTypeId = structTypeId; + + return structTypeExpr; + } + inline std::unique_ptr Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::array swizzleComponents, std::size_t componentCount) const { assert(componentCount > 0); diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index 1a7a6db44..d135b31f4 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -86,6 +86,7 @@ namespace Nz::ShaderLang void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue& type, ShaderAst::ExpressionPtr& initialValue); // Statements + ShaderAst::StatementPtr ParseAliasDeclaration(); ShaderAst::StatementPtr ParseBranchStatement(); ShaderAst::StatementPtr ParseConstStatement(); ShaderAst::StatementPtr ParseDiscardStatement(); @@ -130,6 +131,7 @@ namespace Nz::ShaderLang std::size_t tokenIndex = 0; ShaderAst::ModulePtr module; const Token* tokens; + bool parsingImportedModule = false; }; Context* m_context; diff --git a/include/Nazara/Shader/ShaderLangTokenList.hpp b/include/Nazara/Shader/ShaderLangTokenList.hpp index 4596050ee..a94e359eb 100644 --- a/include/Nazara/Shader/ShaderLangTokenList.hpp +++ b/include/Nazara/Shader/ShaderLangTokenList.hpp @@ -12,6 +12,7 @@ #define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X) #endif +NAZARA_SHADERLANG_TOKEN(Alias) NAZARA_SHADERLANG_TOKEN(Arrow) NAZARA_SHADERLANG_TOKEN(Assign) NAZARA_SHADERLANG_TOKEN(BoolFalse) diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index a38183177..10157d21e 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -405,6 +405,16 @@ namespace Nz::ShaderAst return clone; } + ExpressionPtr AstCloner::Clone(FunctionExpression& node) + { + auto clone = std::make_unique(); + clone->funcId = node.funcId; + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + ExpressionPtr AstCloner::Clone(IdentifierExpression& node) { auto clone = std::make_unique(); @@ -429,6 +439,26 @@ namespace Nz::ShaderAst return clone; } + ExpressionPtr AstCloner::Clone(IntrinsicFunctionExpression& node) + { + auto clone = std::make_unique(); + clone->intrinsicId = node.intrinsicId; + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + + ExpressionPtr AstCloner::Clone(StructTypeExpression& node) + { + auto clone = std::make_unique(); + clone->structTypeId = node.structTypeId; + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + ExpressionPtr AstCloner::Clone(SwizzleExpression& node) { auto clone = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index 1545161a3..06eb211c7 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -72,6 +72,11 @@ namespace Nz::ShaderAst /* Nothing to do */ } + void AstRecursiveVisitor::Visit(FunctionExpression& /*node*/) + { + /* Nothing to do */ + } + void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/) { /* Nothing to do */ @@ -83,6 +88,16 @@ namespace Nz::ShaderAst param->Visit(*this); } + void AstRecursiveVisitor::Visit(IntrinsicFunctionExpression& /*node*/) + { + /* Nothing to do */ + } + + void AstRecursiveVisitor::Visit(StructTypeExpression& /*node*/) + { + /* Nothing to do */ + } + void AstRecursiveVisitor::Visit(SwizzleExpression& node) { if (node.expression) diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 208f63961..b1a1cd04e 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -153,6 +153,21 @@ namespace Nz::ShaderAst Node(param); } + void AstSerializerBase::Serialize(IntrinsicFunctionExpression& node) + { + SizeT(node.intrinsicId); + } + + void AstSerializerBase::Serialize(StructTypeExpression& node) + { + SizeT(node.structTypeId); + } + + void AstSerializerBase::Serialize(FunctionExpression& node) + { + SizeT(node.funcId); + } + void AstSerializerBase::Serialize(SwizzleExpression& node) { SizeT(node.componentCount); diff --git a/src/Nazara/Shader/Ast/AstUtils.cpp b/src/Nazara/Shader/Ast/AstUtils.cpp index eedfedd6a..96006b552 100644 --- a/src/Nazara/Shader/Ast/AstUtils.cpp +++ b/src/Nazara/Shader/Ast/AstUtils.cpp @@ -72,6 +72,11 @@ namespace Nz::ShaderAst m_expressionCategory = ExpressionCategory::LValue; } + void ShaderAstValueCategory::Visit(FunctionExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::LValue; + } + void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/) { m_expressionCategory = ExpressionCategory::LValue; @@ -82,6 +87,16 @@ namespace Nz::ShaderAst m_expressionCategory = ExpressionCategory::RValue; } + void ShaderAstValueCategory::Visit(IntrinsicFunctionExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::LValue; + } + + void ShaderAstValueCategory::Visit(StructTypeExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::LValue; + } + void ShaderAstValueCategory::Visit(SwizzleExpression& node) { if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1) diff --git a/src/Nazara/Shader/Ast/IndexRemapperVisitor.cpp b/src/Nazara/Shader/Ast/IndexRemapperVisitor.cpp index b7d62266b..86ae274e7 100644 --- a/src/Nazara/Shader/Ast/IndexRemapperVisitor.cpp +++ b/src/Nazara/Shader/Ast/IndexRemapperVisitor.cpp @@ -137,20 +137,22 @@ namespace Nz::ShaderAst return clone; } - ExpressionPtr IndexRemapperVisitor::Clone(CallFunctionExpression& node) + ExpressionPtr IndexRemapperVisitor::Clone(FunctionExpression& node) { - CallFunctionExpressionPtr clone = static_unique_pointer_cast(AstCloner::Clone(node)); + FunctionExpressionPtr clone = static_unique_pointer_cast(AstCloner::Clone(node)); - const auto& targetFuncType = GetExpressionType(*node.targetFunction); - if (std::holds_alternative(targetFuncType)) - { - const auto& funcType = std::get(targetFuncType); + assert(clone->funcId); + clone->funcId = Retrieve(m_context->newFuncIndices, clone->funcId); - FunctionType newFunc; - newFunc.funcIndex = Retrieve(m_context->newFuncIndices, funcType.funcIndex); - clone->cachedExpressionType = ExpressionType{ newFunc }; //< FIXME We should add FunctionExpression like VariableExpression to handle this - } + return clone; + } + ExpressionPtr IndexRemapperVisitor::Clone(StructTypeExpression& node) + { + StructTypeExpressionPtr clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + assert(clone->structTypeId); + clone->structTypeId = Retrieve(m_context->newStructIndices, clone->structTypeId); return clone; } diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 3d0dcdf51..86dace007 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -114,6 +115,7 @@ namespace Nz::ShaderAst struct SanitizeVisitor::Environment { + Uuid moduleId; std::shared_ptr parentEnv; std::vector identifiersInScope; std::vector scopes; @@ -121,21 +123,31 @@ namespace Nz::ShaderAst struct SanitizeVisitor::Context { + struct ModuleData + { + std::unordered_map exportedSetByModule; + std::shared_ptr environment; + std::unique_ptr dependenciesVisitor; + }; + struct PendingFunction { DeclareFunctionStatement* cloneNode; const DeclareFunctionStatement* node; }; + static constexpr std::size_t ModuleIdSentinel = std::numeric_limits::max(); + std::array entryFunctions = {}; - std::vector> moduleEnvironments; + std::vector modules; std::vector pendingFunctions; std::vector* currentStatementList = nullptr; std::unordered_map moduleByUuid; std::unordered_set declaredExternalVar; std::unordered_set usedBindingIndexes; - std::shared_ptr globalEnv; + std::shared_ptr builtinEnv; std::shared_ptr currentEnv; + std::shared_ptr globalEnv; IdentifierList constantValues; IdentifierList functions; IdentifierList aliases; @@ -144,22 +156,57 @@ namespace Nz::ShaderAst IdentifierList structs; IdentifierList> types; IdentifierList variableTypes; + ModulePtr currentModule; Options options; CurrentFunctionData* currentFunction = nullptr; }; ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error) { - Context currentContext; - currentContext.options = options; - ModulePtr clone = std::make_shared(module.metadata); clone->importedModules = module.importedModules; + Context currentContext; + currentContext.options = options; + currentContext.currentModule = clone; + m_context = ¤tContext; CallOnExit resetContext([&] { m_context = nullptr; }); + // Register builtin env + m_context->builtinEnv = std::make_shared(); + m_context->currentEnv = m_context->builtinEnv; + RegisterBuiltin(); + m_context->globalEnv = std::make_shared(); + m_context->globalEnv->moduleId = clone->metadata->moduleId; + m_context->globalEnv->parentEnv = m_context->builtinEnv; + + for (std::size_t moduleId = 0; moduleId < clone->importedModules.size(); ++moduleId) + { + auto moduleEnv = std::make_shared(); + moduleEnv->moduleId = clone->importedModules[moduleId].module->metadata->moduleId; + moduleEnv->parentEnv = m_context->builtinEnv; + + m_context->currentEnv = moduleEnv; + + // Previous modules are visibles + for (std::size_t previousModuleId = 0; previousModuleId < moduleId; ++previousModuleId) + RegisterModule(clone->importedModules[previousModuleId].identifier, previousModuleId); + + auto& importedModule = clone->importedModules[moduleId]; + importedModule.module->rootNode = SanitizeInternal(*importedModule.module->rootNode, error); + if (!importedModule.module->rootNode) + return {}; + + m_context->moduleByUuid[importedModule.module->metadata->moduleId] = moduleId; + auto& moduleData = m_context->modules.emplace_back(); + moduleData.environment = std::move(moduleEnv); + + m_context->currentEnv = m_context->globalEnv; + RegisterModule(importedModule.identifier, moduleId); + } + m_context->currentEnv = m_context->globalEnv; clone->rootNode = SanitizeInternal(*module.rootNode, error); @@ -211,7 +258,25 @@ namespace Nz::ShaderAst if (node.identifiers.empty()) throw AstError{ "AccessIdentifierExpression must have at least one identifier" }; - ExpressionPtr indexedExpr = CloneExpression(MandatoryExpr(node.expr)); + MandatoryExpr(node.expr); + + // Handle module access (TODO: Add namespace expression?) + if (node.expr->GetType() == NodeType::IdentifierExpression && node.identifiers.size() == 1) + { + auto& identifierExpr = static_cast(*node.expr); + const IdentifierData* identifierData = FindIdentifier(identifierExpr.identifier); + if (identifierData && identifierData->category == IdentifierCategory::Module) + { + std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index); + + const auto& env = *m_context->modules[moduleIndex].environment; + identifierData = FindIdentifier(env, node.identifiers.front()); + if (identifierData) + return HandleIdentifier(identifierData); + } + } + + ExpressionPtr indexedExpr = CloneExpression(node.expr); for (const std::string& identifier : node.identifiers) { if (identifier.empty()) @@ -393,7 +458,10 @@ namespace Nz::ShaderAst if (!m_context->currentFunction) throw AstError{ "function calls must happen inside a function" }; - std::size_t targetFuncIndex = std::get(targetExprType).funcIndex; + if (targetExpr->GetType() != NodeType::FunctionExpression) + throw AstError{ "expected function expression" }; + + std::size_t targetFuncIndex = static_cast(*targetExpr).funcId; auto clone = std::make_unique(); clone->targetFunction = std::move(targetExpr); @@ -410,13 +478,18 @@ namespace Nz::ShaderAst } else if (IsIntrinsicFunctionType(targetExprType)) { + if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression) + throw AstError{ "expected intrinsic function expression" }; + + std::size_t targetIntrinsicId = static_cast(*targetExpr).intrinsicId; + std::vector parameters; parameters.reserve(node.parameters.size()); for (const auto& param : node.parameters) parameters.push_back(CloneExpression(param)); - auto intrinsic = ShaderBuilder::Intrinsic(std::get(targetExprType).intrinsic, std::move(parameters)); + auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics.Retrieve(targetIntrinsicId), std::move(parameters)); Validate(*intrinsic); return intrinsic; @@ -584,64 +657,7 @@ namespace Nz::ShaderAst if (!identifierData) throw AstError{ "unknown identifier " + node.identifier }; - switch (identifierData->category) - { - case IdentifierCategory::Constant: - { - // Replace IdentifierExpression by Constant(Value)Expression - ConstantExpression constantExpr; - constantExpr.constantId = identifierData->index; - - return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression - } - - case IdentifierCategory::Function: - { - auto clone = AstCloner::Clone(node); - clone->cachedExpressionType = FunctionType{ identifierData->index }; - - return clone; - } - - case IdentifierCategory::Intrinsic: - { - IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index); - - auto clone = AstCloner::Clone(node); - clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; - - return clone; - } - - case IdentifierCategory::Struct: - { - auto clone = AstCloner::Clone(node); - clone->cachedExpressionType = StructType{ identifierData->index }; - - return clone; - } - - case IdentifierCategory::Type: - { - auto clone = AstCloner::Clone(node); - clone->cachedExpressionType = Type{ identifierData->index }; - - return clone; - } - - case IdentifierCategory::Variable: - { - // Replace IdentifierExpression by VariableExpression - auto varExpr = std::make_unique(); - varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index); - varExpr->variableId = identifierData->index; - - return varExpr; - } - - default: - throw AstError{ "unexpected identifier" }; - } + return HandleIdentifier(identifierData); } ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node) @@ -872,6 +888,8 @@ namespace Nz::ShaderAst if (node.returnType.HasValue()) clone->returnType = ResolveType(node.returnType); + else + clone->returnType = ExpressionType{ NoType{} }; if (node.depthWrite.HasValue()) clone->depthWrite = ComputeExprValue(node.depthWrite); @@ -1360,54 +1378,99 @@ namespace Nz::ShaderAst if (!targetModule) throw AstError{ "module " + ModulePathAsString() + " not found" }; - targetModule->rootNode->sectionName = "Module " + targetModule->metadata->moduleId.ToString(); + std::size_t moduleIndex; - m_context->currentEnv = m_context->moduleEnvironments.emplace_back(std::make_shared()); - CallOnExit restoreEnvOnExit([&] { m_context->currentEnv = m_context->globalEnv; }); + const Uuid& moduleUuid = targetModule->metadata->moduleId; + auto it = m_context->moduleByUuid.find(moduleUuid); + if (it == m_context->moduleByUuid.end()) + { + m_context->moduleByUuid[moduleUuid] = Context::ModuleIdSentinel; - ModulePtr sanitizedModule = std::make_shared(targetModule->metadata); + // Generate module identifier (based on UUID) + const auto& moduleUuidBytes = moduleUuid.ToArray(); - std::string error; - sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error); - if (!sanitizedModule) - throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error }; + SHA256Hash hasher; + hasher.Begin(); + hasher.Append(moduleUuidBytes.data(), moduleUuidBytes.size()); + hasher.End(); + + std::string identifier = "__" + hasher.End().ToHex().substr(0, 8); + + // Load new module + auto moduleEnvironment = std::make_shared(); + moduleEnvironment->parentEnv = m_context->builtinEnv; + + auto previousEnv = m_context->currentEnv; + m_context->currentEnv = moduleEnvironment; + + ModulePtr sanitizedModule = std::make_shared(targetModule->metadata); + + std::string error; + sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error); + if (!sanitizedModule) + throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error }; + + moduleIndex = m_context->modules.size(); + + assert(m_context->modules.size() == moduleIndex); + auto& moduleData = m_context->modules.emplace_back(); + moduleData.dependenciesVisitor = std::make_unique(); + moduleData.dependenciesVisitor->Process(*sanitizedModule->rootNode); + moduleData.environment = std::move(moduleEnvironment); + + assert(m_context->currentModule->importedModules.size() == moduleIndex); + auto& importedModule = m_context->currentModule->importedModules.emplace_back(); + importedModule.identifier = identifier; + importedModule.module = std::move(sanitizedModule); + + m_context->currentEnv = std::move(previousEnv); + + RegisterModule(identifier, moduleIndex); + + m_context->moduleByUuid[moduleUuid] = moduleIndex; + } + else + { + // Module has already been imported + moduleIndex = it->second; + if (moduleIndex == Context::ModuleIdSentinel) + throw AstError{ "circular import detected" }; + } + + auto& moduleData = m_context->modules[moduleIndex]; + + auto& exportedSet = moduleData.exportedSetByModule[m_context->currentEnv->moduleId]; // Extract exported nodes and their dependencies - DependencyCheckerVisitor::Config depConfig; - depConfig.usedShaderStages.Clear(); - - DependencyCheckerVisitor moduleDependencies; - moduleDependencies.Process(*sanitizedModule->rootNode, depConfig); - - DependencyCheckerVisitor::UsageSet exportedSet; - - MultiStatementPtr aliasBlock = std::make_unique(); + std::vector aliasStatements; AstExportVisitor::Callbacks callbacks; callbacks.onExportedStruct = [&](DeclareStructStatement& node) { assert(node.structIndex); - moduleDependencies.MarkStructAsUsed(*node.structIndex); - exportedSet.usedStructs.UnboundedSet(*node.structIndex); - - auto alias = Clone(node); - // TODO: DeclareAlias - - aliasBlock->statements.emplace_back(std::move(alias)); + moduleData.dependenciesVisitor->MarkStructAsUsed(*node.structIndex); + if (!exportedSet.usedStructs.UnboundedTest(*node.structIndex)) + { + exportedSet.usedStructs.UnboundedSet(*node.structIndex); + aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.description.name, ShaderBuilder::StructType(*node.structIndex))); + } }; AstExportVisitor exportVisitor; - exportVisitor.Visit(*sanitizedModule->rootNode, callbacks); + exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks); - moduleDependencies.Resolve(); + if (aliasStatements.empty()) + return ShaderBuilder::NoOp(); - //m_context-> - - // Register exported variables (FIXME: This shouldn't be necessary and could be handled by the IndexRemapperVisitor) - //m_context->importUsage = remappedExportedSet; - //CallOnExit restoreImportOnExit([&] { m_context->importUsage.reset(); }); + // Register module and aliases + MultiStatementPtr aliasBlock = std::make_unique(); + for (auto& aliasPtr : aliasStatements) + { + Validate(*aliasPtr); + aliasBlock->statements.push_back(std::move(aliasPtr)); + } return aliasBlock; } @@ -1559,6 +1622,74 @@ namespace Nz::ShaderAst throw std::runtime_error("internal error"); } + ExpressionPtr SanitizeVisitor::HandleIdentifier(const IdentifierData* identifierData) + { + switch (identifierData->category) + { + case IdentifierCategory::Constant: + { + // Replace IdentifierExpression by Constant(Value)Expression + ConstantExpression constantExpr; + constantExpr.constantId = identifierData->index; + + return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression + } + + case IdentifierCategory::Function: + { + // Replace IdentifierExpression by FunctionExpression + auto funcExpr = std::make_unique(); + funcExpr->cachedExpressionType = FunctionType{ identifierData->index }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type + funcExpr->funcId = identifierData->index; + + return funcExpr; + } + + case IdentifierCategory::Intrinsic: + { + IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index); + + // Replace IdentifierExpression by IntrinsicFunctionExpression + auto intrinsicExpr = std::make_unique(); + intrinsicExpr->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type + intrinsicExpr->intrinsicId = identifierData->index; + + return intrinsicExpr; + } + + case IdentifierCategory::Struct: + { + // Replace IdentifierExpression by StructTypeExpression + auto structExpr = std::make_unique(); + structExpr->cachedExpressionType = StructType{ identifierData->index }; + structExpr->structTypeId = identifierData->index; + + return structExpr; + } + + case IdentifierCategory::Type: + { + auto clone = ShaderBuilder::Identifier("dummy"); + clone->cachedExpressionType = Type{ identifierData->index }; + + return clone; + } + + case IdentifierCategory::Variable: + { + // Replace IdentifierExpression by VariableExpression + auto varExpr = std::make_unique(); + varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index); + varExpr->variableId = identifierData->index; + + return varExpr; + } + + default: + throw AstError{ "unexpected identifier" }; + } + } + Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node) const { if (!node) @@ -1909,20 +2040,20 @@ namespace Nz::ShaderAst return intrinsicIndex; } - std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex) + std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t index) { if (FindIdentifier(moduleIdentifier)) throw AstError{ moduleIdentifier + " is already used" }; - std::size_t intrinsicIndex = m_context->moduleIndices.Register(moduleIndex); + std::size_t moduleIndex = m_context->moduleIndices.Register(index); m_context->currentEnv->identifiersInScope.push_back({ std::move(moduleIdentifier), - intrinsicIndex, + moduleIndex, IdentifierCategory::Module }); - return intrinsicIndex; + return moduleIndex; } std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional index) @@ -2169,11 +2300,7 @@ namespace Nz::ShaderAst MultiStatementPtr SanitizeVisitor::SanitizeInternal(MultiStatement& rootNode, std::string* error) { MultiStatementPtr output; - - PushScope(); //< Global scope { - RegisterBuiltin(); - // First pass, evaluate everything except function code try { @@ -2196,7 +2323,6 @@ namespace Nz::ShaderAst ResolveFunctions(); } - PopScope(); return output; } @@ -2261,7 +2387,7 @@ namespace Nz::ShaderAst StackVector parameters = NazaraStackVector(TypeParameter, partialType.parameters.size()); for (std::size_t i = 0; i < partialType.parameters.size(); ++i) { - ExpressionPtr indexExpr = CloneExpression(node.indices[i]); + const ExpressionPtr& indexExpr = node.indices[i]; switch (partialType.parameters[i]) { case TypeParameterCategory::ConstantValue: diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index faf9f88ca..79e274d7f 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -137,9 +137,16 @@ namespace Nz std::string targetName; }; + struct StructData + { + std::string nameOverride; + const ShaderAst::StructDescription* desc; + }; + std::optional stage; + std::string moduleSuffix; std::stringstream stream; - std::unordered_map structs; + std::unordered_map structs; std::unordered_map variableNames; std::vector inputFields; std::vector outputFields; @@ -172,6 +179,7 @@ namespace Nz else targetAst = module.rootNode.get(); + const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module; ShaderAst::StatementPtr optimizedAst; if (states.optimize) @@ -193,6 +201,15 @@ namespace Nz AppendHeader(); + for (const auto& importedModule : targetModule.importedModules) + { + m_currentState->moduleSuffix = importedModule.identifier; + + importedModule.module->rootNode->Visit(state.previsitor); + importedModule.module->rootNode->Visit(*this); + } + + m_currentState->moduleSuffix = {}; targetAst->Visit(*this); return state.stream.str(); @@ -342,8 +359,8 @@ namespace Nz void GlslWriter::Append(const ShaderAst::StructType& structType) { - ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structType.structIndex); - Append(structDesc->name); + const auto& structData = Retrieve(m_currentState->structs, structType.structIndex); + Append(structData.nameOverride); } void GlslWriter::Append(const ShaderAst::Type& /*type*/) @@ -630,9 +647,9 @@ namespace Nz assert(IsStructType(parameter.type.GetResultingValue())); std::size_t structIndex = std::get(parameter.type.GetResultingValue()).structIndex; - const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); + const auto& structData = Retrieve(m_currentState->structs, structIndex); - AppendLine(structDesc->name, " ", varName, ";"); + AppendLine(structData.nameOverride, " ", varName, ";"); for (const auto& [memberName, targetName] : m_currentState->inputFields) AppendLine(varName, ".", memberName, " = ", targetName, ";"); @@ -651,9 +668,9 @@ namespace Nz void GlslWriter::HandleInOut() { - auto AppendInOut = [this](const ShaderAst::StructDescription& structDesc, std::vector& fields, const char* keyword, const char* targetPrefix) + auto AppendInOut = [this](const State::StructData& structData, std::vector& fields, const char* keyword, const char* targetPrefix) { - for (const auto& member : structDesc.members) + for (const auto& member : structData.desc->members) { if (member.cond.HasValue() && !member.cond.GetResultingValue()) continue; @@ -691,8 +708,6 @@ namespace Nz const ShaderAst::DeclareFunctionStatement& node = *m_currentState->previsitor.entryPoint; - const ShaderAst::StructDescription* inputStruct = nullptr; - if (!node.parameters.empty()) { assert(node.parameters.size() == 1); @@ -700,10 +715,10 @@ namespace Nz assert(std::holds_alternative(parameter.type.GetResultingValue())); std::size_t inputStructIndex = std::get(parameter.type.GetResultingValue()).structIndex; - inputStruct = Retrieve(m_currentState->structs, inputStructIndex); + const auto& inputStruct = Retrieve(m_currentState->structs, inputStructIndex); AppendCommentSection("Inputs"); - AppendInOut(*inputStruct, m_currentState->inputFields, "in", s_inputPrefix); + AppendInOut(inputStruct, m_currentState->inputFields, "in", s_inputPrefix); } if (m_currentState->stage == ShaderStageType::Vertex && m_environment.flipYPosition) @@ -717,17 +732,21 @@ namespace Nz assert(std::holds_alternative(node.returnType.GetResultingValue())); std::size_t outputStructIndex = std::get(node.returnType.GetResultingValue()).structIndex; - const ShaderAst::StructDescription* outputStruct = Retrieve(m_currentState->structs, outputStructIndex); + const auto& outputStruct = Retrieve(m_currentState->structs, outputStructIndex); AppendCommentSection("Outputs"); - AppendInOut(*outputStruct, m_currentState->outputFields, "out", s_outputPrefix); + AppendInOut(outputStruct, m_currentState->outputFields, "out", s_outputPrefix); } } - void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc) + void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc, std::string structName) { assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); - m_currentState->structs.emplace(structIndex, desc); + State::StructData structData; + structData.desc = desc; + structData.nameOverride = std::move(structName); + + m_currentState->structs.emplace(structIndex, std::move(structData)); } void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName) @@ -1035,11 +1054,13 @@ namespace Nz if (IsUniformType(externalVar.type.GetResultingValue())) { auto& uniform = std::get(externalVar.type.GetResultingValue()); - ShaderAst::StructDescription* structInfo = Retrieve(m_currentState->structs, uniform.containedType.structIndex); - if (structInfo->layout.HasValue()) - isStd140 = structInfo->layout.GetResultingValue() == StructLayout::Std140; + const auto& structInfo = Retrieve(m_currentState->structs, uniform.containedType.structIndex); + if (structInfo.desc->layout.HasValue()) + isStd140 = structInfo.desc->layout.GetResultingValue() == StructLayout::Std140; } + std::string varName = externalVar.name + m_currentState->moduleSuffix; + if (!m_currentState->bindingMapping.empty() || isStd140) Append("layout("); @@ -1074,15 +1095,15 @@ namespace Nz if (IsUniformType(externalVar.type.GetResultingValue())) { Append("_NzBinding_"); - AppendLine(externalVar.name); + AppendLine(varName); EnterScope(); { - auto& uniform = std::get(externalVar.type.GetResultingValue()); - auto& structDesc = Retrieve(m_currentState->structs, uniform.containedType.structIndex); + const auto& uniform = std::get(externalVar.type.GetResultingValue()); + const auto& structData = Retrieve(m_currentState->structs, uniform.containedType.structIndex); bool first = true; - for (const auto& member : structDesc->members) + for (const auto& member : structData.desc->members) { if (member.cond.HasValue() && !member.cond.GetResultingValue()) continue; @@ -1099,10 +1120,10 @@ namespace Nz LeaveScope(false); Append(" "); - Append(externalVar.name); + Append(varName); } else - AppendVariableDeclaration(externalVar.type.GetResultingValue(), externalVar.name); + AppendVariableDeclaration(externalVar.type.GetResultingValue(), varName); AppendLine(";"); @@ -1110,7 +1131,7 @@ namespace Nz AppendLine(); assert(externalVar.varIndex); - RegisterVariable(*externalVar.varIndex, externalVar.name); + RegisterVariable(*externalVar.varIndex, varName); } } @@ -1168,11 +1189,13 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node) { + std::string structName = node.description.name + m_currentState->moduleSuffix; + assert(node.structIndex); - RegisterStruct(*node.structIndex, &node.description); + RegisterStruct(*node.structIndex, &node.description, structName); Append("struct "); - AppendLine(node.description.name); + AppendLine(structName); EnterScope(); { bool first = true; @@ -1224,9 +1247,9 @@ namespace Nz Append(";"); } - void GlslWriter::Visit(ShaderAst::ImportStatement& /*node*/) + void GlslWriter::Visit(ShaderAst::ImportStatement& node) { - /* nothing to do */ + throw std::runtime_error("unexpected import statement, is the shader sanitized properly?"); } void GlslWriter::Visit(ShaderAst::MultiStatement& node) @@ -1254,7 +1277,7 @@ namespace Nz const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr); assert(IsStructType(returnType)); std::size_t structIndex = std::get(returnType).structIndex; - const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); + const auto& structData = Retrieve(m_currentState->structs, structIndex); std::string outputStructVarName; if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression) @@ -1262,7 +1285,7 @@ namespace Nz else { AppendLine(); - Append(structDesc->name, " ", s_outputVarName, " = "); + Append(structData.nameOverride, " ", s_outputVarName, " = "); node.returnExpr->Visit(*this); AppendLine(";"); diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 89d313ea1..93ba2c960 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -31,73 +31,95 @@ namespace Nz { const ShaderAst::ExpressionValue& bindingIndex; - inline bool HasValue() const { return bindingIndex.HasValue(); } + bool HasValue() const { return bindingIndex.HasValue(); } }; struct LangWriter::BuiltinAttribute { const ShaderAst::ExpressionValue& builtin; - inline bool HasValue() const { return builtin.HasValue(); } + bool HasValue() const { return builtin.HasValue(); } }; struct LangWriter::DepthWriteAttribute { const ShaderAst::ExpressionValue& writeMode; - inline bool HasValue() const { return writeMode.HasValue(); } + bool HasValue() const { return writeMode.HasValue(); } }; struct LangWriter::EarlyFragmentTestsAttribute { const ShaderAst::ExpressionValue& earlyFragmentTests; - inline bool HasValue() const { return earlyFragmentTests.HasValue(); } + bool HasValue() const { return earlyFragmentTests.HasValue(); } }; struct LangWriter::EntryAttribute { const ShaderAst::ExpressionValue& stageType; - inline bool HasValue() const { return stageType.HasValue(); } + bool HasValue() const { return stageType.HasValue(); } + }; + + struct LangWriter::LangVersionAttribute + { + UInt32 version; + + bool HasValue() const { return true; } }; struct LangWriter::LayoutAttribute { const ShaderAst::ExpressionValue& layout; - inline bool HasValue() const { return layout.HasValue(); } + bool HasValue() const { return layout.HasValue(); } }; struct LangWriter::LocationAttribute { const ShaderAst::ExpressionValue& locationIndex; - inline bool HasValue() const { return locationIndex.HasValue(); } + bool HasValue() const { return locationIndex.HasValue(); } }; struct LangWriter::SetAttribute { const ShaderAst::ExpressionValue& setIndex; - inline bool HasValue() const { return setIndex.HasValue(); } + bool HasValue() const { return setIndex.HasValue(); } }; struct LangWriter::UnrollAttribute { const ShaderAst::ExpressionValue& unroll; - inline bool HasValue() const { return unroll.HasValue(); } + bool HasValue() const { return unroll.HasValue(); } + }; + + struct LangWriter::UuidAttribute + { + Uuid uuid; + + bool HasValue() const { return true; } }; struct LangWriter::State { + struct Identifier + { + std::size_t moduleIndex; + std::string name; + }; + const States* states = nullptr; ShaderAst::Module* module; + std::size_t currentModuleIndex; std::stringstream stream; - std::unordered_map constantNames; - std::unordered_map structs; - std::unordered_map variableNames; + std::unordered_map constantNames; + std::unordered_map structs; + std::unordered_map variableNames; + std::vector moduleNames; bool isInEntryPoint = false; unsigned int indentLevel = 0; }; @@ -116,6 +138,22 @@ namespace Nz AppendHeader(); + // Register imported modules + m_currentState->currentModuleIndex = 0; + for (const auto& importedModule : sanitizedModule->importedModules) + { + AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion }); + AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId }); + AppendLine("module ", importedModule.identifier); + EnterScope(); + importedModule.module->rootNode->Visit(*this); + LeaveScope(true); + + m_currentState->currentModuleIndex++; + m_currentState->moduleNames.push_back(importedModule.identifier); + } + + m_currentState->currentModuleIndex = std::numeric_limits::max(); sanitizedModule->rootNode->Visit(*this); return state.stream.str(); @@ -213,8 +251,7 @@ namespace Nz void LangWriter::Append(const ShaderAst::StructType& structType) { - ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structType.structIndex); - Append(structDesc->name); + AppendIdentifier(m_currentState->structs, structType.structIndex); } void LangWriter::Append(const ShaderAst::Type& /*type*/) @@ -292,31 +329,31 @@ namespace Nz AppendAttributesInternal(first, secondParam, std::forward(params)...); } - void LangWriter::AppendAttribute(BindingAttribute binding) + void LangWriter::AppendAttribute(BindingAttribute attribute) { - if (!binding.HasValue()) + if (!attribute.HasValue()) return; Append("binding("); - if (binding.bindingIndex.IsResultingValue()) - Append(binding.bindingIndex.GetResultingValue()); + if (attribute.bindingIndex.IsResultingValue()) + Append(attribute.bindingIndex.GetResultingValue()); else - binding.bindingIndex.GetExpression()->Visit(*this); + attribute.bindingIndex.GetExpression()->Visit(*this); Append(")"); } - void LangWriter::AppendAttribute(BuiltinAttribute builtin) + void LangWriter::AppendAttribute(BuiltinAttribute attribute) { - if (!builtin.HasValue()) + if (!attribute.HasValue()) return; Append("builtin("); - if (builtin.builtin.IsResultingValue()) + if (attribute.builtin.IsResultingValue()) { - switch (builtin.builtin.GetResultingValue()) + switch (attribute.builtin.GetResultingValue()) { case ShaderAst::BuiltinEntry::FragCoord: Append("fragcoord"); @@ -332,21 +369,21 @@ namespace Nz } } else - builtin.builtin.GetExpression()->Visit(*this); + attribute.builtin.GetExpression()->Visit(*this); Append(")"); } - void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite) + void LangWriter::AppendAttribute(DepthWriteAttribute attribute) { - if (!depthWrite.HasValue()) + if (!attribute.HasValue()) return; Append("depth_write("); - if (depthWrite.writeMode.IsResultingValue()) + if (attribute.writeMode.IsResultingValue()) { - switch (depthWrite.writeMode.GetResultingValue()) + switch (attribute.writeMode.GetResultingValue()) { case ShaderAst::DepthWriteMode::Greater: Append("greater"); @@ -366,41 +403,41 @@ namespace Nz } } else - depthWrite.writeMode.GetExpression()->Visit(*this); + attribute.writeMode.GetExpression()->Visit(*this); Append(")"); } - void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests) + void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute attribute) { - if (!earlyFragmentTests.HasValue()) + if (!attribute.HasValue()) return; Append("early_fragment_tests("); - if (earlyFragmentTests.earlyFragmentTests.IsResultingValue()) + if (attribute.earlyFragmentTests.IsResultingValue()) { - if (earlyFragmentTests.earlyFragmentTests.GetResultingValue()) + if (attribute.earlyFragmentTests.GetResultingValue()) Append("true"); else Append("false"); } else - earlyFragmentTests.earlyFragmentTests.GetExpression()->Visit(*this); + attribute.earlyFragmentTests.GetExpression()->Visit(*this); Append(")"); } - void LangWriter::AppendAttribute(EntryAttribute entry) + void LangWriter::AppendAttribute(EntryAttribute attribute) { - if (!entry.HasValue()) + if (!attribute.HasValue()) return; Append("entry("); - if (entry.stageType.IsResultingValue()) + if (attribute.stageType.IsResultingValue()) { - switch (entry.stageType.GetResultingValue()) + switch (attribute.stageType.GetResultingValue()) { case ShaderStageType::Fragment: Append("frag"); @@ -412,20 +449,39 @@ namespace Nz } } else - entry.stageType.GetExpression()->Visit(*this); + attribute.stageType.GetExpression()->Visit(*this); Append(")"); } - void LangWriter::AppendAttribute(LayoutAttribute entry) + void LangWriter::AppendAttribute(LangVersionAttribute attribute) { - if (!entry.HasValue()) + UInt32 shaderLangVersion = attribute.version; + UInt32 majorVersion = shaderLangVersion / 100; + shaderLangVersion -= majorVersion * 100; + + UInt32 minorVersion = shaderLangVersion / 10; + shaderLangVersion -= minorVersion * 100; + + UInt32 patchVersion = shaderLangVersion; + + // nzsl_version + Append("nzsl_version(\"", majorVersion, ".", minorVersion); + if (patchVersion != 0) + Append(".", patchVersion); + + Append("\")"); + } + + void LangWriter::AppendAttribute(LayoutAttribute attribute) + { + if (!attribute.HasValue()) return; Append("layout("); - if (entry.layout.IsResultingValue()) + if (attribute.layout.IsResultingValue()) { - switch (entry.layout.GetResultingValue()) + switch (attribute.layout.GetResultingValue()) { case StructLayout::Packed: Append("packed"); @@ -437,50 +493,50 @@ namespace Nz } } else - entry.layout.GetExpression()->Visit(*this); + attribute.layout.GetExpression()->Visit(*this); Append(")"); } - void LangWriter::AppendAttribute(LocationAttribute location) + void LangWriter::AppendAttribute(LocationAttribute attribute) { - if (!location.HasValue()) + if (!attribute.HasValue()) return; Append("location("); - if (location.locationIndex.IsResultingValue()) - Append(location.locationIndex.GetResultingValue()); + if (attribute.locationIndex.IsResultingValue()) + Append(attribute.locationIndex.GetResultingValue()); else - location.locationIndex.GetExpression()->Visit(*this); + attribute.locationIndex.GetExpression()->Visit(*this); Append(")"); } - void LangWriter::AppendAttribute(SetAttribute set) + void LangWriter::AppendAttribute(SetAttribute attribute) { - if (!set.HasValue()) + if (!attribute.HasValue()) return; Append("set("); - if (set.setIndex.IsResultingValue()) - Append(set.setIndex.GetResultingValue()); + if (attribute.setIndex.IsResultingValue()) + Append(attribute.setIndex.GetResultingValue()); else - set.setIndex.GetExpression()->Visit(*this); + attribute.setIndex.GetExpression()->Visit(*this); Append(")"); } - void LangWriter::AppendAttribute(UnrollAttribute unroll) + void LangWriter::AppendAttribute(UnrollAttribute attribute) { - if (!unroll.HasValue()) + if (!attribute.HasValue()) return; Append("unroll("); - if (unroll.unroll.IsResultingValue()) + if (attribute.unroll.IsResultingValue()) { - switch (unroll.unroll.GetResultingValue()) + switch (attribute.unroll.GetResultingValue()) { case ShaderAst::LoopUnroll::Always: Append("always"); @@ -499,7 +555,12 @@ namespace Nz } } else - unroll.unroll.GetExpression()->Visit(*this); + attribute.unroll.GetExpression()->Visit(*this); + } + + void LangWriter::AppendAttribute(UuidAttribute attribute) + { + Append("uuid(\"", attribute.uuid.ToString(), "\")"); } void LangWriter::AppendComment(const std::string& section) @@ -539,6 +600,16 @@ namespace Nz m_currentState->stream << txt << '\n' << std::string(m_currentState->indentLevel, '\t'); } + template + void LangWriter::AppendIdentifier(const T& map, std::size_t id) + { + const auto& structIdentifier = Retrieve(map, id); + if (structIdentifier.moduleIndex != m_currentState->currentModuleIndex) + Append(m_currentState->moduleNames[structIdentifier.moduleIndex], '.'); + + Append(structIdentifier.name); + } + template void LangWriter::AppendLine(Args&&... params) { @@ -586,20 +657,32 @@ namespace Nz void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName) { + State::Identifier identifier; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(constantName); + assert(m_currentState->constantNames.find(constantIndex) == m_currentState->constantNames.end()); - m_currentState->constantNames.emplace(constantIndex, std::move(constantName)); + m_currentState->constantNames.emplace(constantIndex, std::move(identifier)); } - void LangWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc) + void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName) { + State::Identifier identifier; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(structName); + assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); - m_currentState->structs.emplace(structIndex, desc); + m_currentState->structs.emplace(structIndex, std::move(identifier)); } void LangWriter::RegisterVariable(std::size_t varIndex, std::string varName) { + State::Identifier identifier; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(varName); + assert(m_currentState->variableNames.find(varIndex) == m_currentState->variableNames.end()); - m_currentState->variableNames.emplace(varIndex, std::move(varName)); + m_currentState->variableNames.emplace(varIndex, std::move(identifier)); } void LangWriter::ScopeVisit(ShaderAst::Statement& node) @@ -770,7 +853,7 @@ namespace Nz void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node) { - throw std::runtime_error("TODO"); //< missing registering + //throw std::runtime_error("TODO"); //< missing registering assert(node.aliasIndex); @@ -828,7 +911,7 @@ namespace Nz void LangWriter::Visit(ShaderAst::ConstantExpression& node) { - Append(Retrieve(m_currentState->constantNames, node.constantId)); + AppendIdentifier(m_currentState->constantNames, node.constantId); } void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node) @@ -908,7 +991,7 @@ namespace Nz void LangWriter::Visit(ShaderAst::DeclareStructStatement& node) { assert(node.structIndex); - RegisterStruct(*node.structIndex, &node.description); + RegisterStruct(*node.structIndex, node.description.name); AppendAttributes(true, LayoutAttribute{ node.description.layout }); Append("struct "); @@ -1068,6 +1151,11 @@ namespace Nz Append(")"); } + void LangWriter::Visit(ShaderAst::StructTypeExpression& node) + { + AppendIdentifier(m_currentState->structs, node.structTypeId); + } + void LangWriter::Visit(ShaderAst::MultiStatement& node) { if (!node.sectionName.empty()) @@ -1100,7 +1188,7 @@ namespace Nz { EnterScope(); node.statement->Visit(*this); - LeaveScope(true); + LeaveScope(); } void LangWriter::Visit(ShaderAst::SwizzleExpression& node) @@ -1115,8 +1203,7 @@ namespace Nz void LangWriter::Visit(ShaderAst::VariableExpression& node) { - const std::string& varName = Retrieve(m_currentState->variableNames, node.variableId); - Append(varName); + AppendIdentifier(m_currentState->variableNames, node.variableId); } void LangWriter::Visit(ShaderAst::UnaryExpression& node) @@ -1150,22 +1237,7 @@ namespace Nz void LangWriter::AppendHeader() { - UInt32 shaderLangVersion = m_currentState->module->metadata->shaderLangVersion; - UInt32 majorVersion = shaderLangVersion / 100; - shaderLangVersion -= majorVersion * 100; - - UInt32 minorVersion = shaderLangVersion / 10; - shaderLangVersion -= minorVersion * 100; - - UInt32 patchVersion = shaderLangVersion; - - // nzsl_version - Append("[nzsl_version(\"", majorVersion, ".", minorVersion); - if (patchVersion != 0) - Append(".", patchVersion); - - AppendLine("\")]"); - + AppendAttributes(true, LangVersionAttribute{ m_currentState->module->metadata->shaderLangVersion }); AppendLine("module;"); AppendLine(); } diff --git a/src/Nazara/Shader/ShaderLangLexer.cpp b/src/Nazara/Shader/ShaderLangLexer.cpp index 0523e47a8..12329ef62 100644 --- a/src/Nazara/Shader/ShaderLangLexer.cpp +++ b/src/Nazara/Shader/ShaderLangLexer.cpp @@ -40,6 +40,7 @@ namespace Nz::ShaderLang ForceCLocale forceCLocale; std::unordered_map reservedKeywords = { + { "alias", TokenType::Alias }, { "const", TokenType::Const }, { "const_select", TokenType::ConstSelect }, { "discard", TokenType::Discard }, diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 89c7fd8e9..c4d3f9c2d 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -246,6 +246,9 @@ namespace Nz::ShaderLang void Parser::ParseModuleStatement(std::vector attributes) { + if (m_context->parsingImportedModule) + throw UnexpectedToken{}; + Expect(Advance(), TokenType::Module); std::optional moduleVersion; @@ -343,7 +346,21 @@ namespace Nz::ShaderLang const std::string& identifier = std::get(Peek().data); Consume(); - module->rootNode->statements = ParseStatementList(); + Expect(Advance(), TokenType::OpenCurlyBracket); + + m_context->parsingImportedModule = true; + + while (Peek().type != TokenType::ClosingCurlyBracket) + { + ShaderAst::StatementPtr statement = ParseRootStatement(); + if (!statement) + throw UnexpectedToken{}; //< "unexpected end of file" + + module->rootNode->statements.push_back(std::move(statement)); + } + Consume(); //< Consume ClosingCurlyBracket + + m_context->parsingImportedModule = false; auto& importedModule = m_context->module->importedModules.emplace_back(); importedModule.module = std::move(module); @@ -380,6 +397,21 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::Semicolon); } + ShaderAst::StatementPtr Parser::ParseAliasDeclaration() + { + Expect(Advance(), TokenType::Alias); + + std::string name = ParseIdentifierAsName(); + + Expect(Advance(), TokenType::Assign); + + ShaderAst::ExpressionPtr expr = ParseExpression(); + + Expect(Advance(), TokenType::Semicolon); + + return ShaderBuilder::DeclareAlias(std::move(name), std::move(expr)); + } + ShaderAst::StatementPtr Parser::ParseBranchStatement() { std::unique_ptr branch = std::make_unique(); @@ -756,6 +788,12 @@ namespace Nz::ShaderLang const Token& nextToken = Peek(); switch (nextToken.type) { + case TokenType::Alias: + if (!attributes.empty()) + throw UnexpectedToken{}; + + return ParseAliasDeclaration(); + case TokenType::Const: if (!attributes.empty()) throw UnexpectedToken{}; diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 81823abbc..98000f756 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -517,6 +517,8 @@ namespace Nz else targetAst = module.rootNode.get(); + const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module; + ShaderAst::StatementPtr optimizedAst; if (states.optimize) { @@ -542,6 +544,9 @@ namespace Nz // Register all extended instruction sets PreVisitor preVisitor(state.constantTypeCache, state.funcs); + for (const auto& importedModule : targetModule.importedModules) + importedModule.module->rootNode->Visit(preVisitor); + targetAst->Visit(preVisitor); m_currentState->preVisitor = &preVisitor; @@ -554,6 +559,9 @@ namespace Nz func.funcId = AllocateResultId(); SpirvAstVisitor visitor(*this, state.instructions, state.funcs); + for (const auto& importedModule : targetModule.importedModules) + importedModule.module->rootNode->Visit(visitor); + targetAst->Visit(visitor); AppendHeader();