From 83d26e209e70ed598fb77425e0075012c4bc2bfa Mon Sep 17 00:00:00 2001 From: Lynix Date: Tue, 8 Mar 2022 01:30:48 +0100 Subject: [PATCH] WIP2 --- include/Nazara/Shader/Ast/AstCloner.hpp | 1 + include/Nazara/Shader/Ast/AstCompare.hpp | 1 + include/Nazara/Shader/Ast/AstCompare.inl | 11 + include/Nazara/Shader/Ast/AstNodeList.hpp | 1 + .../Nazara/Shader/Ast/AstRecursiveVisitor.hpp | 1 + include/Nazara/Shader/Ast/AstSerializer.hpp | 1 + include/Nazara/Shader/Ast/Nodes.hpp | 10 + include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 50 +-- include/Nazara/Shader/GlslWriter.hpp | 1 + include/Nazara/Shader/LangWriter.hpp | 1 + include/Nazara/Shader/ShaderBuilder.hpp | 6 + include/Nazara/Shader/ShaderBuilder.inl | 9 + include/Nazara/Shader/SpirvAstVisitor.hpp | 1 + src/Nazara/Shader/Ast/AstCloner.cpp | 10 + src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp | 6 + src/Nazara/Shader/Ast/AstSerializer.cpp | 7 + src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 287 +++++++++++------- src/Nazara/Shader/GlslWriter.cpp | 5 + src/Nazara/Shader/LangWriter.cpp | 14 +- src/Nazara/Shader/ShaderLangParser.cpp | 1 + src/Nazara/Shader/SpirvAstVisitor.cpp | 5 + tests/resources.cpp | 8 +- 22 files changed, 295 insertions(+), 142 deletions(-) diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index f4520c6d9..3dcc8e5ff 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -58,6 +58,7 @@ namespace Nz::ShaderAst virtual StatementPtr Clone(BranchStatement& node); virtual StatementPtr Clone(ConditionalStatement& node); + virtual StatementPtr Clone(DeclareAliasStatement& node); virtual StatementPtr Clone(DeclareConstStatement& node); virtual StatementPtr Clone(DeclareExternalStatement& node); virtual StatementPtr Clone(DeclareFunctionStatement& node); diff --git a/include/Nazara/Shader/Ast/AstCompare.hpp b/include/Nazara/Shader/Ast/AstCompare.hpp index e7c4dd2fc..e4cfde283 100644 --- a/include/Nazara/Shader/Ast/AstCompare.hpp +++ b/include/Nazara/Shader/Ast/AstCompare.hpp @@ -50,6 +50,7 @@ namespace Nz::ShaderAst inline bool Compare(const BranchStatement& lhs, const BranchStatement& rhs); inline bool Compare(const ConditionalStatement& lhs, const ConditionalStatement& rhs); + inline bool Compare(const DeclareAliasStatement& lhs, const DeclareAliasStatement& rhs); inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs); inline bool Compare(const DeclareExternalStatement& lhs, const DeclareExternalStatement& rhs); inline bool Compare(const DeclareFunctionStatement& lhs, const DeclareFunctionStatement& rhs); diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index e6ae5bb1a..e6502caff 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -419,6 +419,17 @@ namespace Nz::ShaderAst return true; } + bool Compare(const DeclareAliasStatement& lhs, const DeclareAliasStatement& rhs) + { + if (!Compare(lhs.name, rhs.name)) + return false; + + if (!Compare(lhs.expression, rhs.expression)) + return false; + + return true; + } + inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs) { if (!Compare(lhs.name, rhs.name)) diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index 8714f3436..b208750fc 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -45,6 +45,7 @@ NAZARA_SHADERAST_EXPRESSION(VariableExpression) NAZARA_SHADERAST_EXPRESSION(UnaryExpression) NAZARA_SHADERAST_STATEMENT(BranchStatement) NAZARA_SHADERAST_STATEMENT(ConditionalStatement) +NAZARA_SHADERAST_STATEMENT(DeclareAliasStatement) NAZARA_SHADERAST_STATEMENT(DeclareConstStatement) NAZARA_SHADERAST_STATEMENT(DeclareExternalStatement) NAZARA_SHADERAST_STATEMENT(DeclareFunctionStatement) diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index d85ca554e..99d25c0cf 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -38,6 +38,7 @@ namespace Nz::ShaderAst void Visit(BranchStatement& node) override; void Visit(ConditionalStatement& node) override; + void Visit(DeclareAliasStatement& node) override; void Visit(DeclareConstStatement& node) override; void Visit(DeclareExternalStatement& node) override; void Visit(DeclareFunctionStatement& node) override; diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index 3170bf479..6c87e8bb5 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -41,6 +41,7 @@ namespace Nz::ShaderAst void Serialize(BranchStatement& node); void Serialize(ConditionalStatement& node); + void Serialize(DeclareAliasStatement& node); void Serialize(DeclareConstStatement& node); void Serialize(DeclareExternalStatement& node); void Serialize(DeclareFunctionStatement& node); diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 47e4f2f95..c16835b8e 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -244,6 +244,16 @@ namespace Nz::ShaderAst StatementPtr statement; }; + struct NAZARA_SHADER_API DeclareAliasStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + std::optional aliasIndex; + std::string name; + ExpressionPtr expression; + }; + struct NAZARA_SHADER_API DeclareConstStatement : Statement { NodeType GetType() const override; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 343bb7bff..fb0a6ce8e 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -57,11 +57,13 @@ namespace Nz::ShaderAst }; private: + enum class IdentifierCategory; struct CurrentFunctionData; struct Environment; struct FunctionData; struct Identifier; - template struct IdentifierData; + struct IdentifierData; + template struct IdentifierList; struct Scope; using AstCloner::CloneExpression; @@ -84,6 +86,7 @@ namespace Nz::ShaderAst StatementPtr Clone(BranchStatement& node) override; StatementPtr Clone(ConditionalStatement& node) override; + StatementPtr Clone(DeclareAliasStatement& node) override; StatementPtr Clone(DeclareConstStatement& node) override; StatementPtr Clone(DeclareExternalStatement& node) override; StatementPtr Clone(DeclareFunctionStatement& node) override; @@ -99,10 +102,10 @@ namespace Nz::ShaderAst StatementPtr Clone(ScopedStatement& node) override; StatementPtr Clone(WhileStatement& node) override; - const Identifier* FindIdentifier(const std::string_view& identifierName) const; - template const Identifier* FindIdentifier(const std::string_view& identifierName, F&& functor) const; - const Identifier* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const; - template const Identifier* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const; + const IdentifierData* FindIdentifier(const std::string_view& identifierName) const; + template const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const; + const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const; + template const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const; TypeParameter FindTypeParameter(const std::string_view& identifierName) const; Expression& MandatoryExpr(const ExpressionPtr& node) const; @@ -120,6 +123,8 @@ namespace Nz::ShaderAst void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen); void RegisterBuiltin(); + + std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional index = {}); std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional index = {}); std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional index = {}); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); @@ -129,6 +134,7 @@ namespace Nz::ShaderAst std::size_t RegisterType(std::string name, PartialType partialType, std::optional index = {}); std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional index = {}); + const IdentifierData* ResolveAlias(const IdentifierData* identifier) const; void ResolveFunctions(); const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node); std::size_t ResolveStruct(const ExpressionType& exprType); @@ -146,6 +152,7 @@ namespace Nz::ShaderAst StatementPtr Unscope(StatementPtr node); + void Validate(DeclareAliasStatement& node); void Validate(WhileStatement& node); void Validate(AccessIndexExpression& node); @@ -160,6 +167,18 @@ namespace Nz::ShaderAst void Validate(VariableExpression& node); ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr); + enum class IdentifierCategory + { + Alias, + Constant, + Function, + Intrinsic, + Module, + Struct, + Type, + Variable + }; + struct FunctionData { Bitset<> calledByFunctions; @@ -167,23 +186,16 @@ namespace Nz::ShaderAst FunctionFlags flags; }; + struct IdentifierData + { + std::size_t index; + IdentifierCategory category; + }; + struct Identifier { - enum class Type - { - Alias, - Constant, - Function, - Intrinsic, - Module, - Struct, - Type, - Variable - }; - std::string name; - std::size_t index; - Type type; + IdentifierData data; }; struct Context; diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index e070be3fc..758087d5f 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -105,6 +105,7 @@ namespace Nz void Visit(ShaderAst::UnaryExpression& node) override; void Visit(ShaderAst::BranchStatement& node) override; + void Visit(ShaderAst::DeclareAliasStatement& node) override; void Visit(ShaderAst::DeclareConstStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override; diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index 522023877..d4862641d 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -110,6 +110,7 @@ namespace Nz void Visit(ShaderAst::BranchStatement& node) override; void Visit(ShaderAst::ConditionalStatement& node) override; + void Visit(ShaderAst::DeclareAliasStatement& node) override; void Visit(ShaderAst::DeclareConstStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index a8f130849..3972ebc52 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -76,6 +76,11 @@ namespace Nz::ShaderBuilder template std::unique_ptr operator()(ShaderAst::ExpressionType type, T value) const; }; + struct DeclareAlias + { + inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionPtr expression) const; + }; + struct DeclareConst { inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const; @@ -191,6 +196,7 @@ namespace Nz::ShaderBuilder constexpr Impl::ConditionalStatement ConditionalStatement; constexpr Impl::Constant Constant; constexpr Impl::Branch ConstBranch; + constexpr Impl::DeclareAlias DeclareAlias; constexpr Impl::DeclareConst DeclareConst; constexpr Impl::DeclareFunction DeclareFunction; constexpr Impl::DeclareOption DeclareOption; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index f1f5dae2d..bf603414a 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -195,6 +195,15 @@ namespace Nz::ShaderBuilder throw std::runtime_error("unexpected primitive type"); } + inline std::unique_ptr Impl::DeclareAlias::operator()(std::string name, ShaderAst::ExpressionPtr expression) const + { + auto declareAliasNode = std::make_unique(); + declareAliasNode->name = std::move(name); + declareAliasNode->expression = std::move(expression); + + return declareAliasNode; + } + inline std::unique_ptr Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const { auto declareConstNode = std::make_unique(); diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index d9768a1d5..d5a174dac 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -48,6 +48,7 @@ namespace Nz void Visit(ShaderAst::CallFunctionExpression& node) override; void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::ConstantValueExpression& node) override; + void Visit(ShaderAst::DeclareAliasStatement& node) override; void Visit(ShaderAst::DeclareConstStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override; diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index b40ab1ea0..a38183177 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -77,6 +77,16 @@ namespace Nz::ShaderAst return clone; } + StatementPtr AstCloner::Clone(DeclareAliasStatement& node) + { + auto clone = std::make_unique(); + clone->aliasIndex = node.aliasIndex; + clone->name = node.name; + clone->expression = CloneExpression(node.expression); + + return clone; + } + StatementPtr AstCloner::Clone(DeclareConstStatement& node) { auto clone = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index 613dea9c0..1545161a3 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -117,6 +117,12 @@ namespace Nz::ShaderAst node.statement->Visit(*this); } + void AstRecursiveVisitor::Visit(DeclareAliasStatement& node) + { + if (node.expression) + node.expression->Visit(*this); + } + void AstRecursiveVisitor::Visit(DeclareConstStatement& node) { if (node.expression) diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 86b7a7b31..208f63961 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -193,6 +193,13 @@ namespace Nz::ShaderAst Node(node.statement); } + void AstSerializerBase::Serialize(DeclareAliasStatement& node) + { + OptVal(node.aliasIndex); + Value(node.name); + Node(node.expression); + } + void AstSerializerBase::Serialize(DeclareExternalStatement& node) { ExprValue(node.bindingSet); diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index b3201b687..3d0dcdf51 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -46,7 +46,7 @@ namespace Nz::ShaderAst }; template - struct SanitizeVisitor::IdentifierData + struct SanitizeVisitor::IdentifierList { Bitset availableIndices; Bitset preregisteredIndices; @@ -117,13 +117,6 @@ namespace Nz::ShaderAst std::shared_ptr parentEnv; std::vector identifiersInScope; std::vector scopes; - IdentifierData constantValues; - IdentifierData functions; - IdentifierData intrinsics; - IdentifierData moduleIndices; - IdentifierData structs; - IdentifierData> types; - IdentifierData variableTypes; }; struct SanitizeVisitor::Context @@ -143,6 +136,14 @@ namespace Nz::ShaderAst std::unordered_set usedBindingIndexes; std::shared_ptr globalEnv; std::shared_ptr currentEnv; + IdentifierList constantValues; + IdentifierList functions; + IdentifierList aliases; + IdentifierList intrinsics; + IdentifierList moduleIndices; + IdentifierList structs; + IdentifierList> types; + IdentifierList variableTypes; Options options; CurrentFunctionData* currentFunction = nullptr; }; @@ -241,7 +242,7 @@ namespace Nz::ShaderAst else if (IsStructType(exprType)) { std::size_t structIndex = ResolveStruct(exprType); - const StructDescription* s = m_context->currentEnv->structs.Retrieve(structIndex); + const StructDescription* s = m_context->structs.Retrieve(structIndex); // Retrieve member index (not counting disabled fields) Int32 fieldIndex = 0; @@ -569,7 +570,7 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) { // Replace by constant value - auto constant = ShaderBuilder::Constant(m_context->currentEnv->constantValues.Retrieve(node.constantId)); + auto constant = ShaderBuilder::Constant(m_context->constantValues.Retrieve(node.constantId)); constant->cachedExpressionType = GetExpressionType(constant->value); return constant; @@ -579,32 +580,32 @@ namespace Nz::ShaderAst { assert(m_context); - const Identifier* identifier = FindIdentifier(node.identifier); - if (!identifier) + const IdentifierData* identifierData = FindIdentifier(node.identifier); + if (!identifierData) throw AstError{ "unknown identifier " + node.identifier }; - switch (identifier->type) + switch (identifierData->category) { - case Identifier::Type::Constant: + case IdentifierCategory::Constant: { // Replace IdentifierExpression by Constant(Value)Expression ConstantExpression constantExpr; - constantExpr.constantId = identifier->index; + constantExpr.constantId = identifierData->index; return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression } - case Identifier::Type::Function: + case IdentifierCategory::Function: { auto clone = AstCloner::Clone(node); - clone->cachedExpressionType = FunctionType{ identifier->index }; + clone->cachedExpressionType = FunctionType{ identifierData->index }; return clone; } - case Identifier::Type::Intrinsic: + case IdentifierCategory::Intrinsic: { - IntrinsicType intrinsicType = m_context->currentEnv->intrinsics.Retrieve(identifier->index); + IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index); auto clone = AstCloner::Clone(node); clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; @@ -612,28 +613,28 @@ namespace Nz::ShaderAst return clone; } - case Identifier::Type::Struct: + case IdentifierCategory::Struct: { auto clone = AstCloner::Clone(node); - clone->cachedExpressionType = StructType{ identifier->index }; + clone->cachedExpressionType = StructType{ identifierData->index }; return clone; } - case Identifier::Type::Type: + case IdentifierCategory::Type: { auto clone = AstCloner::Clone(node); - clone->cachedExpressionType = Type{ identifier->index }; + clone->cachedExpressionType = Type{ identifierData->index }; return clone; } - case Identifier::Type::Variable: + case IdentifierCategory::Variable: { // Replace IdentifierExpression by VariableExpression auto varExpr = std::make_unique(); - varExpr->cachedExpressionType = m_context->currentEnv->variableTypes.Retrieve(identifier->index); - varExpr->variableId = identifier->index; + varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index); + varExpr->variableId = identifierData->index; return varExpr; } @@ -763,6 +764,14 @@ namespace Nz::ShaderAst return ShaderBuilder::NoOp(); } + StatementPtr SanitizeVisitor::Clone(DeclareAliasStatement& node) + { + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + Validate(*clone); + + return clone; + } + StatementPtr SanitizeVisitor::Clone(DeclareConstStatement& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); @@ -987,7 +996,7 @@ namespace Nz::ShaderAst else if (IsStructType(resolvedType)) { std::size_t structIndex = std::get(resolvedType).structIndex; - const StructDescription* desc = m_context->currentEnv->structs.Retrieve(structIndex); + const StructDescription* desc = m_context->structs.Retrieve(structIndex); if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) throw AstError{ "inner struct layout mismatch" }; } @@ -1353,8 +1362,7 @@ namespace Nz::ShaderAst targetModule->rootNode->sectionName = "Module " + targetModule->metadata->moduleId.ToString(); - m_context->currentEnv = m_context->moduleEnvironments.emplace_back(); - m_context->currentEnv->parentEnv = m_context->globalEnv; + m_context->currentEnv = m_context->moduleEnvironments.emplace_back(std::make_shared()); CallOnExit restoreEnvOnExit([&] { m_context->currentEnv = m_context->globalEnv; }); ModulePtr sanitizedModule = std::make_shared(targetModule->metadata); @@ -1373,6 +1381,8 @@ namespace Nz::ShaderAst DependencyCheckerVisitor::UsageSet exportedSet; + MultiStatementPtr aliasBlock = std::make_unique(); + AstExportVisitor::Callbacks callbacks; callbacks.onExportedStruct = [&](DeclareStructStatement& node) { @@ -1380,6 +1390,12 @@ namespace Nz::ShaderAst moduleDependencies.MarkStructAsUsed(*node.structIndex); exportedSet.usedStructs.UnboundedSet(*node.structIndex); + + auto alias = Clone(node); + // TODO: DeclareAlias + + aliasBlock->statements.emplace_back(std::move(alias)); + }; AstExportVisitor exportVisitor; @@ -1387,46 +1403,13 @@ namespace Nz::ShaderAst moduleDependencies.Resolve(); - auto statementPtr = EliminateUnusedPass(*sanitizedModule->rootNode, moduleDependencies.GetUsage()); - - DependencyCheckerVisitor::UsageSet remappedExportedSet; - - IndexRemapperVisitor::Callbacks remapCallbacks; - remapCallbacks.constIndexGenerator = [this](std::size_t previousIndex) { return m_context->currentEnv->constantValues.RegisterNewIndex(true); }; - remapCallbacks.funcIndexGenerator = [&](std::size_t previousIndex) - { - std::size_t newIndex = m_context->currentEnv->functions.RegisterNewIndex(true); - if (exportedSet.usedFunctions.Test(previousIndex)) - remappedExportedSet.usedFunctions.UnboundedSet(newIndex); - - return newIndex; - }; - - remapCallbacks.structIndexGenerator = [&](std::size_t previousIndex) - { - std::size_t newIndex = m_context->currentEnv->structs.RegisterNewIndex(true); - if (exportedSet.usedStructs.Test(previousIndex)) - remappedExportedSet.usedStructs.UnboundedSet(newIndex); - - return newIndex; - }; - - remapCallbacks.varIndexGenerator = [&](std::size_t previousIndex) - { - std::size_t newIndex = m_context->currentEnv->variableTypes.RegisterNewIndex(true); - if (exportedSet.usedVariables.Test(previousIndex)) - remappedExportedSet.usedVariables.UnboundedSet(newIndex); - - return newIndex; - }; - - statementPtr = RemapIndices(*statementPtr, remapCallbacks); + //m_context-> // Register exported variables (FIXME: This shouldn't be necessary and could be handled by the IndexRemapperVisitor) //m_context->importUsage = remappedExportedSet; //CallOnExit restoreImportOnExit([&] { m_context->importUsage.reset(); }); - return AstCloner::Clone(*statementPtr); + return aliasBlock; } StatementPtr SanitizeVisitor::Clone(MultiStatement& node) @@ -1478,18 +1461,18 @@ namespace Nz::ShaderAst return clone; } - auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* + auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const IdentifierData* { return FindIdentifier(*m_context->currentEnv, identifierName); } template - auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName, F&& functor) const -> const Identifier* + auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName, F&& functor) const -> const IdentifierData* { return FindIdentifier(*m_context->currentEnv, identifierName, std::forward(functor)); } - auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName) const -> const Identifier* + auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName) const -> const IdentifierData* { auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); if (it == environment.identifiersInScope.rend()) @@ -1500,15 +1483,21 @@ namespace Nz::ShaderAst return nullptr; } - return &*it; + return ResolveAlias(&it->data); } template - auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const -> const Identifier* + auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const -> const IdentifierData* { auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier) { - return identifier.name == identifierName && functor(identifier); + if (identifier.name == identifierName) + { + if (functor(*ResolveAlias(&identifier.data))) + return true; + } + + return false; }); if (it == environment.identifiersInScope.rend()) { @@ -1518,7 +1507,7 @@ namespace Nz::ShaderAst return nullptr; } - return &*it; + return ResolveAlias(&it->data); } TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const @@ -1527,30 +1516,43 @@ namespace Nz::ShaderAst if (!identifier) throw std::runtime_error("identifier " + std::string(identifierName) + " not found"); - switch (identifier->type) + switch (identifier->category) { - case Identifier::Type::Constant: - return m_context->currentEnv->constantValues.Retrieve(identifier->index); + case IdentifierCategory::Constant: + return m_context->constantValues.Retrieve(identifier->index); - case Identifier::Type::Struct: + case IdentifierCategory::Struct: return StructType{ identifier->index }; - case Identifier::Type::Type: + case IdentifierCategory::Type: return std::visit([&](auto&& arg) -> TypeParameter { return arg; - }, m_context->currentEnv->types.Retrieve(identifier->index)); + }, m_context->types.Retrieve(identifier->index)); - case Identifier::Type::Alias: - throw std::runtime_error("TODO"); + case IdentifierCategory::Alias: + { + IdentifierCategory category; + std::size_t index; + do + { + const auto& aliasData = m_context->aliases.Retrieve(identifier->index); + category = aliasData.category; + index = aliasData.index; + } + while (category == IdentifierCategory::Alias); + } - case Identifier::Type::Function: + case IdentifierCategory::Function: throw std::runtime_error("unexpected function identifier"); - case Identifier::Type::Intrinsic: + case IdentifierCategory::Intrinsic: throw std::runtime_error("unexpected intrinsic identifier"); - case Identifier::Type::Variable: + case IdentifierCategory::Module: + throw std::runtime_error("unexpected module identifier"); + + case IdentifierCategory::Variable: throw std::runtime_error("unexpected variable identifier"); } @@ -1589,7 +1591,7 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression) { - // No need to cache LValues (variables/constants) (TODO: Improve this, as constants doens't need to be cached as well) + // No need to cache LValues (variables/constants) (TODO: Improve this, as constants doesn't need to be cached as well) if (GetExpressionCategory(*expression) == ExpressionCategory::LValue) return expression; @@ -1652,7 +1654,7 @@ namespace Nz::ShaderAst AstConstantPropagationVisitor::Options optimizerOptions; optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue& { - return m_context->currentEnv->constantValues.Retrieve(constantId); + return m_context->constantValues.Retrieve(constantId); }; // Run optimizer on constant value to hopefully retrieve a single constant value @@ -1661,7 +1663,7 @@ namespace Nz::ShaderAst void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen) { - auto& funcData = m_context->currentEnv->functions.Retrieve(funcIndex); + auto& funcData = m_context->functions.Retrieve(funcIndex); funcData.flags |= flags; for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i)) @@ -1832,16 +1834,31 @@ namespace Nz::ShaderAst RegisterIntrinsic("reflect", IntrinsicType::Reflect); } + std::size_t SanitizeVisitor::RegisterAlias(std::string name, IdentifierData aliasData, std::optional index) + { + if (FindIdentifier(name)) + throw AstError{ name + " is already used" }; + + std::size_t aliasIndex = m_context->aliases.Register(std::move(aliasData), index); + m_context->currentEnv->identifiersInScope.push_back({ + std::move(name), + aliasIndex, + IdentifierCategory::Alias + }); + + return aliasIndex; + } + std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t constantIndex = m_context->currentEnv->constantValues.Register(std::move(value), index); + std::size_t constantIndex = m_context->constantValues.Register(std::move(value), index); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), constantIndex, - Identifier::Type::Constant + IdentifierCategory::Constant }); return constantIndex; @@ -1854,23 +1871,23 @@ namespace Nz::ShaderAst bool duplicate = true; // Functions cannot be declared twice, except for entry ones if their stages are different - if (funcData.node->entryStage.HasValue() && identifier->type == Identifier::Type::Function) + if (funcData.node->entryStage.HasValue() && identifier->category == IdentifierCategory::Function) { - auto& otherFunction = m_context->currentEnv->functions.Retrieve(identifier->index); + auto& otherFunction = m_context->functions.Retrieve(identifier->index); if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue()) duplicate = false; } if (duplicate) - throw AstError{ funcData.node->name + " is already used" }; + throw AstError{ name + " is already used" }; } - std::size_t functionIndex = m_context->currentEnv->functions.Register(std::move(funcData), index); + std::size_t functionIndex = m_context->functions.Register(std::move(funcData), index); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), functionIndex, - Identifier::Type::Function + IdentifierCategory::Function }); return functionIndex; @@ -1881,12 +1898,12 @@ namespace Nz::ShaderAst if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t intrinsicIndex = m_context->currentEnv->intrinsics.Register(std::move(type)); + std::size_t intrinsicIndex = m_context->intrinsics.Register(std::move(type)); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), intrinsicIndex, - Identifier::Type::Intrinsic + IdentifierCategory::Intrinsic }); return intrinsicIndex; @@ -1894,7 +1911,18 @@ namespace Nz::ShaderAst std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex) { - return std::size_t(); + if (FindIdentifier(moduleIdentifier)) + throw AstError{ moduleIdentifier + " is already used" }; + + std::size_t intrinsicIndex = m_context->moduleIndices.Register(moduleIndex); + + m_context->currentEnv->identifiersInScope.push_back({ + std::move(moduleIdentifier), + intrinsicIndex, + IdentifierCategory::Module + }); + + return intrinsicIndex; } std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional index) @@ -1902,12 +1930,12 @@ namespace Nz::ShaderAst if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t structIndex = m_context->currentEnv->structs.Register(description, index); + std::size_t structIndex = m_context->structs.Register(description, index); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), structIndex, - Identifier::Type::Struct + IdentifierCategory::Struct }); return structIndex; @@ -1918,12 +1946,12 @@ namespace Nz::ShaderAst if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t typeIndex = m_context->currentEnv->types.Register(std::move(expressionType), index); + std::size_t typeIndex = m_context->types.Register(std::move(expressionType), index); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), typeIndex, - Identifier::Type::Type + IdentifierCategory::Type }); return typeIndex; @@ -1934,12 +1962,12 @@ namespace Nz::ShaderAst if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t typeIndex = m_context->currentEnv->types.Register(std::move(partialType), index); + std::size_t typeIndex = m_context->types.Register(std::move(partialType), index); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), typeIndex, - Identifier::Type::Type + IdentifierCategory::Type }); return typeIndex; @@ -1950,21 +1978,29 @@ namespace Nz::ShaderAst if (auto* identifier = FindIdentifier(name)) { // Allow variable shadowing - if (identifier->type != Identifier::Type::Variable) + if (identifier->category != IdentifierCategory::Variable) throw AstError{ name + " is already used" }; } - std::size_t varIndex = m_context->currentEnv->variableTypes.Register(std::move(type), index); + std::size_t varIndex = m_context->variableTypes.Register(std::move(type), index); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), varIndex, - Identifier::Type::Variable + IdentifierCategory::Variable }); return varIndex; } + auto SanitizeVisitor::ResolveAlias(const IdentifierData* identifier) const -> const IdentifierData* + { + while (identifier->category == IdentifierCategory::Alias) + identifier = &m_context->aliases.Retrieve(identifier->index); + + return identifier; + } + void SanitizeVisitor::ResolveFunctions() { // Once every function is known, we can evaluate function content @@ -1997,7 +2033,7 @@ namespace Nz::ShaderAst std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex; for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) { - auto& targetFunc = m_context->currentEnv->functions.Retrieve(i); + auto& targetFunc = m_context->functions.Retrieve(i); targetFunc.calledByFunctions.UnboundedSet(funcIndex); } @@ -2006,13 +2042,13 @@ namespace Nz::ShaderAst m_context->pendingFunctions.clear(); Bitset<> seen; - for (const auto& [funcIndex, funcData] : m_context->currentEnv->functions.values) + for (const auto& [funcIndex, funcData] : m_context->functions.values) { PropagateFunctionFlags(funcIndex, funcData.flags, seen); seen.Clear(); } - for (const auto& [funcIndex, funcData] : m_context->currentEnv->functions.values) + for (const auto& [funcIndex, funcData] : m_context->functions.values) { if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment) throw AstError{ "discard can only be used in the fragment stage" }; @@ -2064,14 +2100,14 @@ namespace Nz::ShaderAst std::size_t SanitizeVisitor::ResolveStruct(const IdentifierType& identifierType) { - const Identifier* identifier = FindIdentifier(identifierType.name); - if (!identifier) + const IdentifierData* identifierData = FindIdentifier(identifierType.name); + if (!identifierData) throw AstError{ "unknown identifier " + identifierType.name }; - if (identifier->type != Identifier::Type::Struct) + if (identifierData->category != IdentifierCategory::Struct) throw AstError{ identifierType.name + " is not a struct" }; - return identifier->index; + return identifierData->index; } std::size_t SanitizeVisitor::ResolveStruct(const StructType& structType) @@ -2091,7 +2127,7 @@ namespace Nz::ShaderAst std::size_t typeIndex = std::get(exprType).typeIndex; - const auto& type = m_context->currentEnv->types.Retrieve(typeIndex); + const auto& type = m_context->types.Retrieve(typeIndex); if (std::holds_alternative(type)) throw AstError{ "full type expected" }; @@ -2186,6 +2222,21 @@ namespace Nz::ShaderAst return node; } + void SanitizeVisitor::Validate(DeclareAliasStatement& node) + { + if (node.name.empty()) + throw std::runtime_error("invalid alias name"); + + ExpressionType exprType = GetExpressionType(*node.expression); + if (IsStructType(exprType)) + { + std::size_t structIndex = ResolveStruct(exprType); + node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex); + } + else + throw AstError{ "for now, only structs can be aliased" }; + } + void SanitizeVisitor::Validate(WhileStatement& node) { if (GetExpressionType(*node.condition) != ExpressionType{ PrimitiveType::Boolean }) @@ -2198,7 +2249,7 @@ namespace Nz::ShaderAst if (IsTypeExpression(exprType)) { std::size_t typeIndex = std::get(exprType).typeIndex; - const auto& type = m_context->currentEnv->types.Retrieve(typeIndex); + const auto& type = m_context->types.Retrieve(typeIndex); if (!std::holds_alternative(type)) throw std::runtime_error("only partial types can be specialized"); @@ -2291,7 +2342,7 @@ namespace Nz::ShaderAst Int32 index = std::get(constantExpr.value); std::size_t structIndex = ResolveStruct(exprType); - const StructDescription* s = m_context->currentEnv->structs.Retrieve(structIndex); + const StructDescription* s = m_context->structs.Retrieve(structIndex); exprType = ResolveType(s->members[index].type); } @@ -2365,7 +2416,7 @@ namespace Nz::ShaderAst assert(std::holds_alternative(targetFuncType)); std::size_t targetFuncIndex = std::get(targetFuncType).funcIndex; - auto& funcData = m_context->currentEnv->functions.Retrieve(targetFuncIndex); + auto& funcData = m_context->functions.Retrieve(targetFuncIndex); const DeclareFunctionStatement* referenceDeclaration = funcData.node; @@ -2482,9 +2533,9 @@ namespace Nz::ShaderAst if (m_context->options.makeVariableNameUnique) { // Since we are registered, FindIdentifier will find us - auto IgnoreOurself = [varIndex = *node.varIndex](const Identifier& identifier) + auto IgnoreOurself = [varIndex = *node.varIndex](const IdentifierData& identifierData) { - if (identifier.type == Identifier::Type::Variable && identifier.index == varIndex) + if (identifierData.category == IdentifierCategory::Variable && identifierData.index == varIndex) return false; return true; @@ -2728,7 +2779,7 @@ namespace Nz::ShaderAst void SanitizeVisitor::Validate(VariableExpression& node) { - node.cachedExpressionType = m_context->currentEnv->variableTypes.Retrieve(node.variableId); + node.cachedExpressionType = m_context->variableTypes.Retrieve(node.variableId); } ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr) diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 33b22d7eb..faf9f88ca 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -1017,6 +1017,11 @@ namespace Nz } } + void GlslWriter::Visit(ShaderAst::DeclareAliasStatement& /*node*/) + { + /* nothing to do */ + } + void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/) { /* nothing to do */ diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 5f3c01a65..89d313ea1 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -768,6 +768,18 @@ namespace Nz node.statement->Visit(*this); } + void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node) + { + throw std::runtime_error("TODO"); //< missing registering + + assert(node.aliasIndex); + + Append("alias ", node.name, " = "); + assert(node.expression); + node.expression->Visit(*this); + AppendLine(";"); + } + void LangWriter::Visit(ShaderAst::DeclareConstStatement& node) { assert(node.constIndex); @@ -780,7 +792,7 @@ namespace Nz node.expression->Visit(*this); } - Append(";"); + AppendLine(";"); } void LangWriter::Visit(ShaderAst::ConstantValueExpression& node) diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 7e215ed00..89c7fd8e9 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -347,6 +347,7 @@ namespace Nz::ShaderLang auto& importedModule = m_context->module->importedModules.emplace_back(); importedModule.module = std::move(module); + importedModule.identifier = identifier; } else { diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index ba347f801..be89238b9 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -598,6 +598,11 @@ namespace Nz }, node.value); } + void SpirvAstVisitor::Visit(ShaderAst::DeclareAliasStatement& /*node*/) + { + /* nothing to do */ + } + void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/) { /* nothing to do */ diff --git a/tests/resources.cpp b/tests/resources.cpp index 8a71d9249..01ddc41ab 100644 --- a/tests/resources.cpp +++ b/tests/resources.cpp @@ -4,11 +4,11 @@ std::filesystem::path GetResourceDir() { static std::filesystem::path resourceDir = [] { - std::filesystem::path resourceDir = "resources"; - if (!std::filesystem::is_directory(resourceDir) && std::filesystem::is_directory(".." / resourceDir)) - return ".." / resourceDir; + std::filesystem::path dir = "resources"; + if (!std::filesystem::is_directory(dir) && std::filesystem::is_directory(".." / dir)) + return ".." / dir; else - return resourceDir; + return dir; }(); return resourceDir;