diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index a1076c8da..9b1a0de64 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -42,6 +42,7 @@ namespace Nz::ShaderAst virtual ExpressionPtr Clone(AccessIdentifierExpression& node); virtual ExpressionPtr Clone(AccessIndexExpression& node); + virtual ExpressionPtr Clone(AliasValueExpression& node); virtual ExpressionPtr Clone(AssignExpression& node); virtual ExpressionPtr Clone(BinaryExpression& node); virtual ExpressionPtr Clone(CallFunctionExpression& node); diff --git a/include/Nazara/Shader/Ast/AstCompare.hpp b/include/Nazara/Shader/Ast/AstCompare.hpp index ec198d0a4..dc6d0dcac 100644 --- a/include/Nazara/Shader/Ast/AstCompare.hpp +++ b/include/Nazara/Shader/Ast/AstCompare.hpp @@ -34,6 +34,7 @@ namespace Nz::ShaderAst inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs); inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs); + inline bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs); inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs); inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs); inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs); diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index 30fa13245..a305a144f 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -248,6 +248,14 @@ namespace Nz::ShaderAst return true; } + bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs) + { + if (!Compare(lhs.aliasId, rhs.aliasId)) + return false; + + return true; + } + inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs) { if (!Compare(lhs.op, rhs.op)) diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index 505657722..ed23b3faa 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -30,6 +30,7 @@ NAZARA_SHADERAST_EXPRESSION(AccessIdentifierExpression) NAZARA_SHADERAST_EXPRESSION(AccessIndexExpression) +NAZARA_SHADERAST_EXPRESSION(AliasValueExpression) NAZARA_SHADERAST_EXPRESSION(AssignExpression) NAZARA_SHADERAST_EXPRESSION(BinaryExpression) NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression) diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index 55106b67a..23954126f 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -22,6 +22,7 @@ namespace Nz::ShaderAst void Visit(AccessIdentifierExpression& node) override; void Visit(AccessIndexExpression& node) override; + void Visit(AliasValueExpression& node) override; void Visit(AssignExpression& node) override; void Visit(BinaryExpression& node) override; void Visit(CallFunctionExpression& node) override; diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index d0ebde7a4..99b68f1c0 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -25,6 +25,7 @@ namespace Nz::ShaderAst void Serialize(AccessIdentifierExpression& node); void Serialize(AccessIndexExpression& node); + void Serialize(AliasValueExpression& node); void Serialize(AssignExpression& node); void Serialize(BinaryExpression& node); void Serialize(CallFunctionExpression& node); diff --git a/include/Nazara/Shader/Ast/AstUtils.hpp b/include/Nazara/Shader/Ast/AstUtils.hpp index 7134eb230..9e8271193 100644 --- a/include/Nazara/Shader/Ast/AstUtils.hpp +++ b/include/Nazara/Shader/Ast/AstUtils.hpp @@ -33,6 +33,7 @@ namespace Nz::ShaderAst void Visit(AccessIdentifierExpression& node) override; void Visit(AccessIndexExpression& node) override; + void Visit(AliasValueExpression& node) override; void Visit(AssignExpression& node) override; void Visit(BinaryExpression& node) override; void Visit(CallFunctionExpression& node) override; diff --git a/include/Nazara/Shader/Ast/ExpressionType.hpp b/include/Nazara/Shader/Ast/ExpressionType.hpp index 46fc3efbd..e699c3435 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.hpp +++ b/include/Nazara/Shader/Ast/ExpressionType.hpp @@ -20,6 +20,22 @@ namespace Nz::ShaderAst { struct ContainedType; + struct NAZARA_SHADER_API AliasType + { + AliasType() = default; + AliasType(const AliasType& alias); + AliasType(AliasType&&) noexcept = default; + + AliasType& operator=(const AliasType& alias); + AliasType& operator=(AliasType&&) noexcept = default; + + std::size_t aliasIndex; + std::unique_ptr targetType; + + bool operator==(const AliasType& rhs) const; + inline bool operator!=(const AliasType& rhs) const; + }; + struct NAZARA_SHADER_API ArrayType { ArrayType() = default; @@ -134,7 +150,7 @@ namespace Nz::ShaderAst inline bool operator!=(const VectorType& rhs) const; }; - using ExpressionType = std::variant; + using ExpressionType = std::variant; struct ContainedType { @@ -157,6 +173,7 @@ namespace Nz::ShaderAst std::vector members; }; + inline bool IsAliasType(const ExpressionType& type); inline bool IsArrayType(const ExpressionType& type); inline bool IsFunctionType(const ExpressionType& type); inline bool IsIdentifierType(const ExpressionType& type); diff --git a/include/Nazara/Shader/Ast/ExpressionType.inl b/include/Nazara/Shader/Ast/ExpressionType.inl index daefa9ec2..ed7be280a 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.inl +++ b/include/Nazara/Shader/Ast/ExpressionType.inl @@ -8,6 +8,11 @@ namespace Nz::ShaderAst { + inline bool AliasType::operator!=(const AliasType& rhs) const + { + return !operator==(rhs); + } + inline bool ArrayType::operator!=(const ArrayType& rhs) const { return !operator==(rhs); @@ -130,6 +135,11 @@ namespace Nz::ShaderAst } + inline bool IsAliasType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + inline bool IsArrayType(const ExpressionType& type) { return std::holds_alternative(type); diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 09c21f85b..95970687b 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -82,6 +82,14 @@ namespace Nz::ShaderAst ExpressionPtr expr; }; + struct NAZARA_SHADER_API AliasValueExpression : Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::size_t aliasId; + }; + struct NAZARA_SHADER_API AssignExpression : Expression { NodeType GetType() const override; @@ -153,7 +161,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; - ShaderAst::ConstantValue value; + ConstantValue value; }; struct NAZARA_SHADER_API FunctionExpression : Expression @@ -207,7 +215,6 @@ namespace Nz::ShaderAst ExpressionPtr expression; }; - struct NAZARA_SHADER_API VariableExpression : Expression struct NAZARA_SHADER_API VariableValueExpression : Expression { NodeType GetType() const override; @@ -455,10 +462,12 @@ namespace Nz::ShaderAst #include - inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr); - inline ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr); + inline const ExpressionType& GetExpressionType(Expression& expr); + inline ExpressionType& GetExpressionTypeMut(Expression& expr); inline bool IsExpression(NodeType nodeType); inline bool IsStatement(NodeType nodeType); + + inline const ExpressionType& ResolveAlias(const ExpressionType& exprType); } #include diff --git a/include/Nazara/Shader/Ast/Nodes.inl b/include/Nazara/Shader/Ast/Nodes.inl index aac844825..4d1ef740c 100644 --- a/include/Nazara/Shader/Ast/Nodes.inl +++ b/include/Nazara/Shader/Ast/Nodes.inl @@ -7,18 +7,29 @@ namespace Nz::ShaderAst { - const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr) + inline const ExpressionType& GetExpressionType(Expression& expr) { assert(expr.cachedExpressionType); return expr.cachedExpressionType.value(); } - ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr) + inline ExpressionType& GetExpressionTypeMut(Expression& expr) { assert(expr.cachedExpressionType); return expr.cachedExpressionType.value(); } + inline const ExpressionType& ResolveAlias(const ExpressionType& exprType) + { + if (IsAliasType(exprType)) + { + const AliasType& alias = std::get(exprType); + return alias.targetType->type; + } + else + return exprType; + } + inline bool IsExpression(NodeType nodeType) { switch (nodeType) diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 9810dd2c5..71ad082df 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -47,6 +47,7 @@ namespace Nz::ShaderAst std::unordered_map optionValues; bool makeVariableNameUnique = false; bool reduceLoopsToWhile = false; + bool removeAliases = false; bool removeConstDeclaration = false; bool removeCompoundAssignments = false; bool removeMatrixCast = false; @@ -71,6 +72,7 @@ namespace Nz::ShaderAst ExpressionPtr Clone(AccessIdentifierExpression& node) override; ExpressionPtr Clone(AccessIndexExpression& node) override; + ExpressionPtr Clone(AliasValueExpression& node) override; ExpressionPtr Clone(AssignExpression& node) override; ExpressionPtr Clone(BinaryExpression& node) override; ExpressionPtr Clone(CallFunctionExpression& node) override; @@ -136,15 +138,16 @@ namespace Nz::ShaderAst std::size_t RegisterType(std::string name, PartialType partialType, std::optional index = {}); std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional index = {}); - const IdentifierData* ResolveAlias(const IdentifierData* identifier) const; + const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const; void ResolveFunctions(); const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node); + std::size_t ResolveStruct(const AliasType& aliasType); std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const IdentifierType& identifierType); std::size_t ResolveStruct(const StructType& structType); std::size_t ResolveStruct(const UniformType& uniformType); - ExpressionType ResolveType(const ExpressionType& exprType); - ExpressionType ResolveType(const ExpressionValue& exprTypeValue); + ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false); + ExpressionType ResolveType(const ExpressionValue& exprTypeValue, bool resolveAlias = false); void SanitizeIdentifier(std::string& identifier); MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error); diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 8a5446985..99b47cc9e 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -51,6 +51,7 @@ namespace Nz static ShaderAst::ModulePtr Sanitize(const ShaderAst::Module& module, std::unordered_map optionValues, std::string* error = nullptr); private: + void Append(const ShaderAst::AliasType& aliasType); void Append(const ShaderAst::ArrayType& type); void Append(ShaderAst::BuiltinEntry builtin); void Append(const ShaderAst::ExpressionType& type); @@ -94,6 +95,7 @@ namespace Nz void Visit(ShaderAst::AccessIdentifierExpression& node) override; void Visit(ShaderAst::AccessIndexExpression& node) override; + void Visit(ShaderAst::AliasValueExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::CallFunctionExpression& node) override; diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index 4a4b957cd..441b65bd8 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -50,6 +50,7 @@ namespace Nz struct UnrollAttribute; struct UuidAttribute; + void Append(const ShaderAst::AliasType& type); void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ExpressionType& type); void Append(const ShaderAst::ExpressionValue& type); @@ -92,6 +93,7 @@ namespace Nz void EnterScope(); void LeaveScope(bool skipLine = true); + void RegisterAlias(std::size_t aliasIndex, std::string aliasName); void RegisterConstant(std::size_t constantIndex, std::string constantName); void RegisterStruct(std::size_t structIndex, std::string structName); void RegisterVariable(std::size_t varIndex, std::string varName); @@ -102,6 +104,7 @@ namespace Nz void Visit(ShaderAst::AccessIdentifierExpression& node) override; void Visit(ShaderAst::AccessIndexExpression& node) override; + void Visit(ShaderAst::AliasValueExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::CastExpression& node) override; diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index 82f04a035..967dccb89 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -48,8 +48,6 @@ namespace Nz void Visit(ShaderAst::CallFunctionExpression& node) override; void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::ConstantValueExpression& node) override; - void Visit(ShaderAst::DeclareAliasStatement& node) override; - void Visit(ShaderAst::DeclareConstStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override; void Visit(ShaderAst::DeclareOptionStatement& node) override; @@ -64,7 +62,7 @@ namespace Nz void Visit(ShaderAst::ScopedStatement& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; void Visit(ShaderAst::UnaryExpression& node) override; - void Visit(ShaderAst::VariableExpression& node) override; + void Visit(ShaderAst::VariableValueExpression& node) override; void Visit(ShaderAst::WhileStatement& node) override; SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete; diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index ad1924956..380403d4d 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -178,6 +178,7 @@ namespace Nz TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const; TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const; TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const; + TypePtr BuildType(const ShaderAst::AliasType& type) const; TypePtr BuildType(const ShaderAst::ArrayType& type) const; TypePtr BuildType(const ShaderAst::ExpressionType& type) const; TypePtr BuildType(const ShaderAst::IdentifierType& type) const; diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index 5c255ffa6..af4bb3677 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -300,6 +300,16 @@ namespace Nz::ShaderAst return clone; } + ExpressionPtr AstCloner::Clone(AliasValueExpression& node) + { + auto clone = std::make_unique(); + clone->aliasId = node.aliasId; + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + ExpressionPtr AstCloner::Clone(AssignExpression& node) { auto clone = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index e6aa679c8..a1bd3717b 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -19,6 +19,11 @@ namespace Nz::ShaderAst index->Visit(*this); } + void AstRecursiveVisitor::Visit(AliasValueExpression& /*node*/) + { + /* nothing to do */ + } + void AstRecursiveVisitor::Visit(AssignExpression& node) { node.left->Visit(*this); diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 0548e7c85..3cdcc9934 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -58,6 +58,11 @@ namespace Nz::ShaderAst Node(identifier); } + void AstSerializerBase::Serialize(AliasValueExpression& node) + { + SizeT(node.aliasId); + } + void AstSerializerBase::Serialize(AssignExpression& node) { Enum(node.op); @@ -485,6 +490,12 @@ namespace Nz::ShaderAst Type(arg.objectType->type); SizeT(arg.methodIndex); } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(13); + SizeT(arg.aliasIndex); + Type(arg.targetType->type); + } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, type); @@ -800,6 +811,22 @@ namespace Nz::ShaderAst type = std::move(methodType); } + case 13: //< AliasType + { + std::size_t aliasIndex; + ExpressionType containedType; + SizeT(aliasIndex); + Type(containedType); + + AliasType aliasType; + aliasType.aliasIndex = aliasIndex; + aliasType.targetType = std::make_unique(); + aliasType.targetType->type = std::move(containedType); + + type = std::move(aliasType); + break; + } + default: break; } diff --git a/src/Nazara/Shader/Ast/AstUtils.cpp b/src/Nazara/Shader/Ast/AstUtils.cpp index 8ae24f721..c5917a2ef 100644 --- a/src/Nazara/Shader/Ast/AstUtils.cpp +++ b/src/Nazara/Shader/Ast/AstUtils.cpp @@ -23,6 +23,11 @@ namespace Nz::ShaderAst node.expr->Visit(*this); } + void ShaderAstValueCategory::Visit(AliasValueExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::LValue; + } + void ShaderAstValueCategory::Visit(AssignExpression& /*node*/) { m_expressionCategory = ExpressionCategory::RValue; diff --git a/src/Nazara/Shader/Ast/ExpressionType.cpp b/src/Nazara/Shader/Ast/ExpressionType.cpp index 7bc68b492..24b438dcd 100644 --- a/src/Nazara/Shader/Ast/ExpressionType.cpp +++ b/src/Nazara/Shader/Ast/ExpressionType.cpp @@ -9,6 +9,37 @@ namespace Nz::ShaderAst { + AliasType::AliasType(const AliasType& alias) : + aliasIndex(alias.aliasIndex) + { + assert(alias.targetType); + targetType = std::make_unique(*alias.targetType); + } + + AliasType& AliasType::operator=(const AliasType& alias) + { + aliasIndex = alias.aliasIndex; + + assert(alias.targetType); + targetType = std::make_unique(*alias.targetType); + + return *this; + } + + bool AliasType::operator==(const AliasType& rhs) const + { + assert(targetType); + assert(rhs.targetType); + + if (aliasIndex != rhs.aliasIndex) + return false; + + if (targetType->type != rhs.targetType->type) + return false; + + return true; + } + ArrayType::ArrayType(const ArrayType& array) : length(array.length) { @@ -31,10 +62,10 @@ namespace Nz::ShaderAst assert(containedType); assert(rhs.containedType); - if (containedType->type != rhs.containedType->type) + if (length != rhs.length) return false; - if (length != rhs.length) + if (containedType->type != rhs.containedType->type) return false; return true; diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index ceba36f4c..538e8c07e 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -282,7 +282,7 @@ namespace Nz::ShaderAst if (identifier.empty()) throw AstError{ "empty identifier" }; - const ExpressionType& exprType = GetExpressionType(*indexedExpr); + const ExpressionType& exprType = ResolveAlias(GetExpressionType(*indexedExpr)); // TODO: Add proper support for methods if (IsSamplerType(exprType)) { @@ -429,6 +429,25 @@ namespace Nz::ShaderAst return clone; } + ExpressionPtr SanitizeVisitor::Clone(AliasValueExpression& node) + { + const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId)); + ExpressionPtr targetExpr = HandleIdentifier(targetIdentifier); + + if (m_context->options.removeAliases) + return targetExpr; + + AliasType aliasType; + aliasType.aliasIndex = node.aliasId; + aliasType.targetType = std::make_unique(); + aliasType.targetType->type = *targetExpr->cachedExpressionType; + + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + clone->cachedExpressionType = std::move(aliasType); + + return clone; + } + ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node) { MandatoryExpr(node.left); @@ -543,7 +562,7 @@ namespace Nz::ShaderAst { const MatrixType& targetMatrixType = std::get(targetType); - const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front()); + const ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front()); bool isMatrixCast = IsMatrixType(frontExprType); if (isMatrixCast && std::get(frontExprType) == targetMatrixType) { @@ -785,6 +804,9 @@ namespace Nz::ShaderAst auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); Validate(*clone); + if (m_context->options.removeAliases) + return ShaderBuilder::NoOp(); + return clone; } @@ -803,7 +825,7 @@ namespace Nz::ShaderAst ExpressionType expressionType = ResolveType(GetExpressionType(value)); - if (clone->type.HasValue() && ResolveType(clone->type) != expressionType) + if (clone->type.HasValue() && ResolveType(clone->type, true) != ResolveAlias(expressionType)) throw AstError{ "constant expression doesn't match type" }; clone->type = expressionType; @@ -852,12 +874,13 @@ namespace Nz::ShaderAst m_context->declaredExternalVar.insert(extVar.name); ExpressionType resolvedType = ResolveType(extVar.type); + const ExpressionType& targetType = ResolveAlias(resolvedType); ExpressionType varType; - if (IsUniformType(resolvedType)) - varType = std::get(resolvedType).containedType; - else if (IsSamplerType(resolvedType)) - varType = resolvedType; + if (IsUniformType(targetType)) + varType = std::get(targetType).containedType; + else if (IsSamplerType(targetType)) + varType = targetType; else throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; @@ -954,8 +977,9 @@ namespace Nz::ShaderAst throw AstError{ "empty option name" }; ExpressionType resolvedType = ResolveType(clone->optType); + const ExpressionType& targetType = ResolveAlias(resolvedType); - if (clone->defaultValue && resolvedType != GetExpressionType(*clone->defaultValue)) + if (clone->defaultValue && targetType != GetExpressionType(*clone->defaultValue)) throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" }; clone->optType = std::move(resolvedType); @@ -1009,11 +1033,13 @@ namespace Nz::ShaderAst ExpressionType resolvedType = ResolveType(member.type); if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140) { - if (IsPrimitiveType(resolvedType) && std::get(resolvedType) == PrimitiveType::Boolean) + const ExpressionType& targetType = ResolveAlias(resolvedType); + + if (IsPrimitiveType(targetType) && std::get(targetType) == PrimitiveType::Boolean) throw AstError{ "boolean type is not allowed in std140 layout" }; - else if (IsStructType(resolvedType)) + else if (IsStructType(targetType)) { - std::size_t structIndex = std::get(resolvedType).structIndex; + std::size_t structIndex = std::get(targetType).structIndex; const StructDescription* desc = m_context->structs.Retrieve(structIndex); if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) throw AstError{ "inner struct layout mismatch" }; @@ -1461,7 +1487,7 @@ namespace Nz::ShaderAst AstExportVisitor exportVisitor; exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks); - if (aliasStatements.empty()) + if (aliasStatements.empty() || m_context->options.removeAliases) return ShaderBuilder::NoOp(); // Register module and aliases @@ -1546,7 +1572,7 @@ namespace Nz::ShaderAst return nullptr; } - return ResolveAlias(&it->data); + return &it->data; } template @@ -1556,7 +1582,7 @@ namespace Nz::ShaderAst { if (identifier.name == identifierName) { - if (functor(*ResolveAlias(&identifier.data))) + if (functor(identifier.data)) return true; } @@ -1570,7 +1596,7 @@ namespace Nz::ShaderAst return nullptr; } - return ResolveAlias(&it->data); + return &it->data; } TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const @@ -1626,6 +1652,14 @@ namespace Nz::ShaderAst { switch (identifierData->category) { + case IdentifierCategory::Alias: + { + AliasValueExpression aliasValue; + aliasValue.aliasId = identifierData->index; + + return Clone(aliasValue); + } + case IdentifierCategory::Constant: { // Replace IdentifierExpression by Constant(Value)Expression @@ -2124,7 +2158,7 @@ namespace Nz::ShaderAst return varIndex; } - auto SanitizeVisitor::ResolveAlias(const IdentifierData* identifier) const -> const IdentifierData* + auto SanitizeVisitor::ResolveAliasIdentifier(const IdentifierData* identifier) const -> const IdentifierData* { while (identifier->category == IdentifierCategory::Alias) identifier = &m_context->aliases.Retrieve(identifier->index); @@ -2181,7 +2215,7 @@ namespace Nz::ShaderAst for (const auto& [funcIndex, funcData] : m_context->functions.values) { - if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment) + if (funcData.flags.Test(FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment) throw AstError{ "discard can only be used in the fragment stage" }; } } @@ -2203,13 +2237,18 @@ namespace Nz::ShaderAst } + std::size_t SanitizeVisitor::ResolveStruct(const AliasType& aliasType) + { + return ResolveStruct(aliasType.targetType->type); + } + std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType) { return std::visit([&](auto&& arg) -> std::size_t { using T = std::decay_t; - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) return ResolveStruct(arg); else if constexpr (std::is_same_v || std::is_same_v || @@ -2251,10 +2290,15 @@ namespace Nz::ShaderAst return uniformType.containedType.structIndex; } - ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType) + ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias) { if (!IsTypeExpression(exprType)) - return exprType; + { + if (resolveAlias || m_context->options.removeAliases) + return ResolveAlias(exprType); + else + return exprType; + } std::size_t typeIndex = std::get(exprType).typeIndex; @@ -2265,13 +2309,13 @@ namespace Nz::ShaderAst return std::get(type); } - ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue& exprTypeValue) + ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue& exprTypeValue, bool resolveAlias) { if (!exprTypeValue.HasValue()) return {}; if (exprTypeValue.IsResultingValue()) - return ResolveType(exprTypeValue.GetResultingValue()); + return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias); assert(exprTypeValue.IsExpression()); ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression()); @@ -2281,7 +2325,7 @@ namespace Nz::ShaderAst //if (!IsTypeType(exprType)) // throw AstError{ "type expected" }; - return ResolveType(exprType); + return ResolveType(exprType, resolveAlias); } void SanitizeVisitor::SanitizeIdentifier(std::string& identifier) @@ -2334,7 +2378,7 @@ namespace Nz::ShaderAst void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const { - if (left != right) + if (ResolveAlias(left) != ResolveAlias(right)) throw AstError{ "Left expression type must match right expression type" }; } @@ -2359,6 +2403,11 @@ namespace Nz::ShaderAst std::size_t structIndex = ResolveStruct(exprType); node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex); } + else if (IsAliasType(exprType)) + { + const AliasType& alias = std::get(exprType); + node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex); + } else throw AstError{ "for now, only structs can be aliased" }; } @@ -2400,7 +2449,7 @@ namespace Nz::ShaderAst case TypeParameterCategory::PrimitiveType: case TypeParameterCategory::StructType: { - ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr)); + ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr), true); switch (partialType.parameters[i]) { @@ -2440,7 +2489,7 @@ namespace Nz::ShaderAst for (auto& index : node.indices) { - const ShaderAst::ExpressionType& indexType = GetExpressionType(*index); + const ExpressionType& indexType = GetExpressionType(*index); if (!IsPrimitiveType(indexType)) throw AstError{ "AccessIndex expects integer indices" }; @@ -2459,7 +2508,7 @@ namespace Nz::ShaderAst } else if (IsStructType(exprType)) { - const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr); + const ExpressionType& indexType = GetExpressionType(*indexExpr); if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 }) throw AstError{ "struct can only be accessed with constant i32 indices" }; @@ -2470,7 +2519,7 @@ namespace Nz::ShaderAst std::size_t structIndex = ResolveStruct(exprType); const StructDescription* s = m_context->structs.Retrieve(structIndex); - exprType = ResolveType(s->members[index].type); + exprType = ResolveType(s->members[index].type, true); } else if (IsMatrixType(exprType)) { @@ -2538,7 +2587,7 @@ namespace Nz::ShaderAst void SanitizeVisitor::Validate(CallFunctionExpression& node) { - const ShaderAst::ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction); + const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction); assert(std::holds_alternative(targetFuncType)); std::size_t targetFuncIndex = std::get(targetFuncType).funcIndex; @@ -2564,14 +2613,15 @@ namespace Nz::ShaderAst void SanitizeVisitor::Validate(CastExpression& node) { ExpressionType resolvedType = ResolveType(node.targetType); + const ExpressionType& targetType = ResolveAlias(resolvedType); const auto& firstExprPtr = node.expressions.front(); if (!firstExprPtr) throw AstError{ "expected at least one expression" }; - if (IsMatrixType(resolvedType)) + if (IsMatrixType(targetType)) { - const MatrixType& targetMatrixType = std::get(resolvedType); + const MatrixType& targetMatrixType = std::get(targetType); const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr); if (IsMatrixType(firstExprType)) @@ -2614,7 +2664,7 @@ namespace Nz::ShaderAst }; std::size_t componentCount = 0; - std::size_t requiredComponents = GetComponentCount(resolvedType); + std::size_t requiredComponents = GetComponentCount(targetType); for (auto& exprPtr : node.expressions) { @@ -2885,11 +2935,11 @@ namespace Nz::ShaderAst case UnaryType::Minus: case UnaryType::Plus: { - ShaderAst::PrimitiveType basicType; + PrimitiveType basicType; if (IsPrimitiveType(exprType)) - basicType = std::get(exprType); + basicType = std::get(exprType); else if (IsVectorType(exprType)) - basicType = std::get(exprType).type; + basicType = std::get(exprType).type; else throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" }; diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 2c69a9814..8e9062128 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -232,6 +232,7 @@ namespace Nz options.optionValues = std::move(optionValues); options.makeVariableNameUnique = true; options.reduceLoopsToWhile = true; + options.removeAliases = true; options.removeCompoundAssignments = false; options.removeConstDeclaration = true; options.removeOptionDeclaration = true; @@ -246,6 +247,11 @@ namespace Nz return ShaderAst::Sanitize(module, options, error); } + void GlslWriter::Append(const ShaderAst::AliasType& /*aliasType*/) + { + throw std::runtime_error("unexpected AliasType"); + } + void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/) { throw std::runtime_error("unexpected ArrayType"); @@ -689,11 +695,16 @@ namespace Nz builtin.identifier }); } - else if (member.locationIndex.HasValue()) + else { - Append("layout(location = "); - Append(member.locationIndex.GetResultingValue()); - Append(") ", keyword, " "); + if (member.locationIndex.HasValue()) + { + Append("layout(location = "); + Append(member.locationIndex.GetResultingValue()); + Append(") "); + } + + Append(keyword, " "); AppendVariableDeclaration(member.type.GetResultingValue(), targetPrefix + member.name); AppendLine(";"); @@ -805,6 +816,12 @@ namespace Nz Append("]"); } + void GlslWriter::Visit(ShaderAst::AliasValueExpression& /*node*/) + { + // all aliases should have been handled by sanitizer + throw std::runtime_error("unexpected alias value, is shader sanitized?"); + } + void GlslWriter::Visit(ShaderAst::AssignExpression& node) { node.left->Visit(*this); @@ -1038,12 +1055,14 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareAliasStatement& /*node*/) { - /* nothing to do */ + // all aliases should have been handled by sanitizer + throw std::runtime_error("unexpected alias declaration, is shader sanitized?"); } void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/) { - /* nothing to do */ + // all consts should have been handled by sanitizer + throw std::runtime_error("unexpected const declaration, is shader sanitized?"); } void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node) @@ -1184,7 +1203,8 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareOptionStatement& /*node*/) { - /* nothing to do */ + // all options should have been handled by sanitizer + throw std::runtime_error("unexpected option declaration, is shader sanitized?"); } void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node) @@ -1247,7 +1267,7 @@ namespace Nz Append(";"); } - void GlslWriter::Visit(ShaderAst::ImportStatement& node) + void GlslWriter::Visit(ShaderAst::ImportStatement& /*node*/) { throw std::runtime_error("unexpected import statement, is the shader sanitized properly?"); } @@ -1280,8 +1300,8 @@ namespace Nz const auto& structData = Retrieve(m_currentState->structs, structIndex); std::string outputStructVarName; - if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression) - outputStructVarName = Retrieve(m_currentState->variableNames, static_cast(*node.returnExpr).variableId); + if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableValueExpression) + outputStructVarName = Retrieve(m_currentState->variableNames, static_cast(*node.returnExpr).variableId); else { AppendLine(); diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 05b6d5421..866b6fa97 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -116,6 +116,7 @@ namespace Nz ShaderAst::Module* module; std::size_t currentModuleIndex; std::stringstream stream; + std::unordered_map aliases; std::unordered_map constants; std::unordered_map structs; std::unordered_map variables; @@ -164,6 +165,11 @@ namespace Nz m_environment = std::move(environment); } + void LangWriter::Append(const ShaderAst::AliasType& type) + { + AppendIdentifier(m_currentState->aliases, type.aliasIndex); + } + void LangWriter::Append(const ShaderAst::ArrayType& type) { Append("array[", type.containedType->type, ", ", type.length, "]"); @@ -655,6 +661,16 @@ namespace Nz Append("}"); } + void LangWriter::RegisterAlias(std::size_t aliasIndex, std::string aliasName) + { + State::Identifier identifier; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(aliasName); + + assert(m_currentState->aliases.find(aliasIndex) == m_currentState->aliases.end()); + m_currentState->aliases.emplace(aliasIndex, std::move(identifier)); + } + void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName) { State::Identifier identifier; @@ -714,7 +730,7 @@ namespace Nz { Visit(node.expr, true); - const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); + const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr)); assert(IsStructType(exprType)); for (const std::string& identifier : node.identifiers) @@ -725,7 +741,7 @@ namespace Nz { Visit(node.expr, true); - const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); + const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr)); assert(!IsStructType(exprType)); // Array access @@ -744,6 +760,11 @@ namespace Nz Append("]"); } + void LangWriter::Visit(ShaderAst::AliasValueExpression& node) + { + AppendIdentifier(m_currentState->aliases, node.aliasId); + } + void LangWriter::Visit(ShaderAst::AssignExpression& node) { node.left->Visit(*this); @@ -840,7 +861,13 @@ namespace Nz void LangWriter::Visit(ShaderAst::ConditionalExpression& node) { - throw std::runtime_error("fixme"); + Append("const_select("); + node.condition->Visit(*this); + Append(", "); + node.truePath->Visit(*this); + Append(", "); + node.falsePath->Visit(*this); + Append(")"); } void LangWriter::Visit(ShaderAst::ConditionalStatement& node) @@ -853,9 +880,8 @@ namespace Nz void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node) { - //throw std::runtime_error("TODO"); //< missing registering - assert(node.aliasIndex); + RegisterAlias(*node.aliasIndex, node.name); Append("alias ", node.name, " = "); assert(node.expression); diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index ff658154b..da71fd13f 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -598,16 +598,6 @@ namespace Nz }, node.value); } - void SpirvAstVisitor::Visit(ShaderAst::DeclareAliasStatement& /*node*/) - { - /* nothing to do */ - } - - void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/) - { - /* nothing to do */ - } - void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node) { for (auto&& extVar : node.externalVars) @@ -729,11 +719,6 @@ namespace Nz PopResultId(); } - void SpirvAstVisitor::Visit(ShaderAst::ImportStatement& node) - { - /* nothing to do */ - } - void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node) { switch (node.intrinsic) diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index 2b731cc8a..eaa7bd589 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -655,6 +655,28 @@ namespace Nz return typePtr; } + auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr + { + bool wasInblockStruct = m_internal->isInBlockStruct; + if (storageClass == SpirvStorageClass::Uniform) + m_internal->isInBlockStruct = true; + + auto typePtr = std::make_shared(Pointer{ + BuildType(type), + storageClass + }); + + m_internal->isInBlockStruct = wasInblockStruct; + + return typePtr; + } + + auto SpirvConstantCache::BuildType(const ShaderAst::AliasType& /*type*/) const -> TypePtr + { + // No AliasType is expected (as they should have been resolved by now) + throw std::runtime_error("unexpected alias"); + } + auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr { const auto& containedType = type.containedType->type; @@ -678,22 +700,6 @@ namespace Nz }); } - auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr - { - bool wasInblockStruct = m_internal->isInBlockStruct; - if (storageClass == SpirvStorageClass::Uniform) - m_internal->isInBlockStruct = true; - - auto typePtr = std::make_shared(Pointer{ - BuildType(type), - storageClass - }); - - m_internal->isInBlockStruct = wasInblockStruct; - - return typePtr; - } - auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr { return std::visit([&](auto&& arg) -> TypePtr diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 63c28911b..1fde66638 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -505,6 +505,7 @@ namespace Nz ShaderAst::SanitizeVisitor::Options options; options.optionValues = states.optionValues; options.reduceLoopsToWhile = true; + options.removeAliases = true; options.removeCompoundAssignments = true; options.removeMatrixCast = true; options.removeOptionDeclaration = true; diff --git a/tests/Engine/Shader/AccessMemberTest.cpp b/tests/Engine/Shader/AccessMember.cpp similarity index 100% rename from tests/Engine/Shader/AccessMemberTest.cpp rename to tests/Engine/Shader/AccessMember.cpp diff --git a/tests/Engine/Shader/Alias.cpp b/tests/Engine/Shader/Alias.cpp new file mode 100644 index 000000000..d85f28a9d --- /dev/null +++ b/tests/Engine/Shader/Alias.cpp @@ -0,0 +1,97 @@ +#include +#include +#include +#include +#include +#include +#include + +TEST_CASE("aliases", "[Shader]") +{ + SECTION("Alias of structs") + { + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +struct Data +{ + value: f32 +} + +alias ExtData = Data; + +external +{ + [binding(0)] extData: uniform[ExtData] +} + +struct Input +{ + value: f32 +} + +alias In = Input; + +struct Output +{ + [location(0)] value: f32 +} + +alias Out = Output; +alias FragOut = Out; + +[entry(frag)] +fn main(input: In) -> FragOut +{ + let output: Out; + output.value = extData.value * input.value; + return output; +} +)"; + + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shaderModule, R"( +void main() +{ + Input input_; + input_.value = _NzIn_value; + + Output output_; + output_.value = extData.value * input_.value; + + _NzOut_value = output_.value; + return; +} +)"); + + ExpectNZSL(*shaderModule, R"( +[entry(frag)] +fn main(input: In) -> FragOut +{ + let output: Out; + output.value = extData.value * input.value; + return output; +} +)"); + + ExpectSPIRV(*shaderModule, R"( +OpFunction +OpLabel +OpVariable +OpVariable +OpAccessChain +OpLoad +OpAccessChain +OpLoad +OpFMul +OpAccessChain +OpStore +OpLoad +OpCompositeExtract +OpStore +OpReturn +OpFunctionEnd)"); + } +} diff --git a/tests/Engine/Shader/Sanitizations.cpp b/tests/Engine/Shader/Sanitizations.cpp index abe810302..7745107c3 100644 --- a/tests/Engine/Shader/Sanitizations.cpp +++ b/tests/Engine/Shader/Sanitizations.cpp @@ -262,6 +262,47 @@ fn testMat4ToMat4(input: mat4[f32]) -> mat4[f32] { return input; } +)"); + + } + + WHEN("removing aliases") + { + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +struct inputStruct +{ + value: f32 +} + +alias Input = inputStruct; +alias In = Input; + +external +{ + [set(0), binding(0)] data: uniform[In] +} +)"; + + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); + + Nz::ShaderAst::SanitizeVisitor::Options options; + options.removeAliases = true; + + REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, options)); + + ExpectNZSL(*shaderModule, R"( +struct inputStruct +{ + value: f32 +} + +external +{ + [set(0), binding(0)] data: uniform[inputStruct] +} )"); } diff --git a/tests/Engine/Shader/ShaderUtils.cpp b/tests/Engine/Shader/ShaderUtils.cpp index e17626331..bab584ac2 100644 --- a/tests/Engine/Shader/ShaderUtils.cpp +++ b/tests/Engine/Shader/ShaderUtils.cpp @@ -142,8 +142,10 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput) Nz::ShaderAst::AstReflect reflectVisitor; reflectVisitor.Reflect(*shader.rootNode, callbacks); - INFO("no entry point found"); - REQUIRE(entryShaderStage.has_value()); + { + INFO("no entry point found"); + REQUIRE(entryShaderStage.has_value()); + } Nz::GlslWriter writer; std::string output = writer.Generate(entryShaderStage, shader);