diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index 0a14dff0a..b898efc4f 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -64,6 +64,7 @@ namespace Nz::ShaderAst virtual StatementPtr Clone(DeclareVariableStatement& node); virtual StatementPtr Clone(DiscardStatement& node); virtual StatementPtr Clone(ExpressionStatement& node); + virtual StatementPtr Clone(ForEachStatement& node); virtual StatementPtr Clone(MultiStatement& node); virtual StatementPtr Clone(NoOpStatement& node); virtual StatementPtr Clone(ReturnStatement& node); diff --git a/include/Nazara/Shader/Ast/AstCompare.hpp b/include/Nazara/Shader/Ast/AstCompare.hpp index 4a338abc6..239d6ba5c 100644 --- a/include/Nazara/Shader/Ast/AstCompare.hpp +++ b/include/Nazara/Shader/Ast/AstCompare.hpp @@ -54,6 +54,7 @@ namespace Nz::ShaderAst inline bool Compare(const DeclareVariableStatement& lhs, const DeclareVariableStatement& rhs); inline bool Compare(const DiscardStatement& lhs, const DiscardStatement& rhs); inline bool Compare(const ExpressionStatement& lhs, const ExpressionStatement& rhs); + inline bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs); inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs); inline bool Compare(const NoOpStatement& lhs, const NoOpStatement& rhs); inline bool Compare(const ReturnStatement& lhs, const ReturnStatement& rhs); diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index 1ee389c18..64ea853e5 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -458,6 +458,23 @@ namespace Nz::ShaderAst return true; } + bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs) + { + if (!Compare(lhs.isConst, rhs.isConst)) + return false; + + if (!Compare(lhs.varName, rhs.varName)) + return false; + + if (!Compare(lhs.expression, rhs.expression)) + return false; + + if (!Compare(lhs.statement, rhs.statement)) + return false; + + return true; + } + inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs) { if (!Compare(lhs.statements, rhs.statements)) diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index da14b20f1..37cb4450c 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -52,6 +52,7 @@ NAZARA_SHADERAST_STATEMENT(DeclareOptionStatement) NAZARA_SHADERAST_STATEMENT(DeclareStructStatement) NAZARA_SHADERAST_STATEMENT(DeclareVariableStatement) NAZARA_SHADERAST_STATEMENT(DiscardStatement) +NAZARA_SHADERAST_STATEMENT(ForEachStatement) NAZARA_SHADERAST_STATEMENT(ExpressionStatement) NAZARA_SHADERAST_STATEMENT(MultiStatement) NAZARA_SHADERAST_STATEMENT(NoOpStatement) diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index 186282a2a..45645a199 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -46,6 +46,7 @@ namespace Nz::ShaderAst void Visit(DeclareVariableStatement& node) override; void Visit(DiscardStatement& node) override; void Visit(ExpressionStatement& node) override; + void Visit(ForEachStatement& node) override; void Visit(MultiStatement& node) override; void Visit(NoOpStatement& node) override; void Visit(ReturnStatement& node) override; diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index 36dfd2747..e3751f89f 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -49,6 +49,7 @@ namespace Nz::ShaderAst void Serialize(DeclareVariableStatement& node); void Serialize(DiscardStatement& node); void Serialize(ExpressionStatement& node); + void Serialize(ForEachStatement& node); void Serialize(MultiStatement& node); void Serialize(NoOpStatement& node); void Serialize(ReturnStatement& node); diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index e66bef6cf..70e26c26d 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -340,6 +340,18 @@ namespace Nz::ShaderAst ExpressionPtr expression; }; + struct NAZARA_SHADER_API ForEachStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + std::optional varIndex; + std::string varName; + ExpressionPtr expression; + StatementPtr statement; + bool isConst = false; + }; + struct NAZARA_SHADER_API MultiStatement : Statement { NodeType GetType() const override; @@ -371,7 +383,12 @@ namespace Nz::ShaderAst StatementPtr body; }; +#define NAZARA_SHADERAST_NODE(X) using X##Ptr = std::unique_ptr; + +#include + inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr); + inline ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr); inline bool IsExpression(NodeType nodeType); inline bool IsStatement(NodeType nodeType); } diff --git a/include/Nazara/Shader/Ast/Nodes.inl b/include/Nazara/Shader/Ast/Nodes.inl index bc11eade6..aac844825 100644 --- a/include/Nazara/Shader/Ast/Nodes.inl +++ b/include/Nazara/Shader/Ast/Nodes.inl @@ -13,6 +13,12 @@ namespace Nz::ShaderAst return expr.cachedExpressionType.value(); } + ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr) + { + assert(expr.cachedExpressionType); + return expr.cachedExpressionType.value(); + } + inline bool IsExpression(NodeType nodeType) { switch (nodeType) diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index d3fb66a34..3175d8b44 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -40,8 +40,9 @@ namespace Nz::ShaderAst std::unordered_set reservedIdentifiers; std::unordered_map optionValues; bool makeVariableNameUnique = false; + bool reduceLoopsToWhile = false; bool removeCompoundAssignments = false; - bool removeOptionDeclaration = true; + bool removeOptionDeclaration = false; bool removeScalarSwizzling = false; bool splitMultipleBranches = false; }; @@ -77,6 +78,7 @@ namespace Nz::ShaderAst StatementPtr Clone(DeclareVariableStatement& node) override; StatementPtr Clone(DiscardStatement& node) override; StatementPtr Clone(ExpressionStatement& node) override; + StatementPtr Clone(ForEachStatement& node) override; StatementPtr Clone(MultiStatement& node) override; StatementPtr Clone(WhileStatement& node) override; @@ -117,6 +119,8 @@ namespace Nz::ShaderAst void SanitizeIdentifier(std::string& identifier); void Validate(AccessIndexExpression& node); + void Validate(AssignExpression& node); + void Validate(BinaryExpression& node); void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration); void Validate(CastExpression& node); void Validate(DeclareVariableStatement& node); diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index deb2ae505..5d019b807 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -71,6 +71,7 @@ namespace Nz void AppendLine(const std::string& txt = {}); template void AppendLine(Args&&... params); void AppendStatementList(std::vector& statements); + void AppendVariableDeclaration(const ShaderAst::ExpressionType& varType, const std::string& varName); void EnterScope(); void LeaveScope(bool skipLine = true); diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index 8264048f2..3ff7e664c 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -99,6 +99,7 @@ namespace Nz void Visit(ShaderAst::BranchStatement& node) override; void Visit(ShaderAst::ConditionalStatement& node) override; + void Visit(ShaderAst::DeclareConstStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override; void Visit(ShaderAst::DeclareOptionStatement& node) override; @@ -106,6 +107,7 @@ namespace Nz void Visit(ShaderAst::DeclareVariableStatement& node) override; void Visit(ShaderAst::DiscardStatement& node) override; void Visit(ShaderAst::ExpressionStatement& node) override; + void Visit(ShaderAst::ForEachStatement& node) override; void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 5035a3191..5a547e73d 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -19,7 +19,9 @@ namespace Nz::ShaderBuilder { struct AccessIndex { + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expr, Int32 index) const; inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expr, const std::vector& indexConstants) const; + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expr, ShaderAst::ExpressionPtr indexExpression) const; inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expr, std::vector indexExpressions) const; }; @@ -106,6 +108,12 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression) const; }; + template + struct ForEach + { + inline std::unique_ptr operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const; + }; + struct Identifier { inline std::unique_ptr operator()(std::string name) const; @@ -143,11 +151,16 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(ShaderAst::UnaryType op, ShaderAst::ExpressionPtr expression) const; }; + struct Variable + { + inline std::unique_ptr operator()(std::size_t variableId, ShaderAst::ExpressionType expressionType) const; + }; + struct While { inline std::unique_ptr operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const; }; - } +} constexpr Impl::AccessIndex AccessIndex; constexpr Impl::AccessMember AccessMember; @@ -160,6 +173,7 @@ namespace Nz::ShaderBuilder constexpr Impl::ConditionalStatement ConditionalStatement; constexpr Impl::Constant Constant; constexpr Impl::Branch ConstBranch; + constexpr Impl::ForEach ConstForEach; constexpr Impl::DeclareConst DeclareConst; constexpr Impl::DeclareFunction DeclareFunction; constexpr Impl::DeclareOption DeclareOption; @@ -167,6 +181,7 @@ namespace Nz::ShaderBuilder constexpr Impl::DeclareVariable DeclareVariable; constexpr Impl::ExpressionStatement ExpressionStatement; constexpr Impl::NoParam Discard; + constexpr Impl::ForEach ForEach; constexpr Impl::Identifier Identifier; constexpr Impl::Intrinsic Intrinsic; constexpr Impl::Multi MultiStatement; @@ -174,6 +189,7 @@ namespace Nz::ShaderBuilder constexpr Impl::Return Return; constexpr Impl::Swizzle Swizzle; constexpr Impl::Unary Unary; + constexpr Impl::Variable Variable; constexpr Impl::While While; } diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 0a1b626d1..b646b632c 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -16,6 +16,15 @@ namespace Nz::ShaderBuilder return accessMemberNode; } + inline std::unique_ptr Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, Int32 index) const + { + auto accessMemberNode = std::make_unique(); + accessMemberNode->expr = std::move(expr); + accessMemberNode->indices.push_back(ShaderBuilder::Constant(index)); + + return accessMemberNode; + } + inline std::unique_ptr Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, const std::vector& indexConstants) const { auto accessMemberNode = std::make_unique(); @@ -28,6 +37,15 @@ namespace Nz::ShaderBuilder return accessMemberNode; } + inline std::unique_ptr Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, ShaderAst::ExpressionPtr indexExpression) const + { + auto accessMemberNode = std::make_unique(); + accessMemberNode->expr = std::move(expr); + accessMemberNode->indices.push_back(std::move(indexExpression)); + + return accessMemberNode; + } + inline std::unique_ptr Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, std::vector indexExpressions) const { auto accessMemberNode = std::make_unique(); @@ -136,6 +154,7 @@ namespace Nz::ShaderBuilder { auto constantNode = std::make_unique(); constantNode->value = std::move(value); + constantNode->cachedExpressionType = ShaderAst::GetExpressionType(constantNode->value); return constantNode; } @@ -250,6 +269,18 @@ namespace Nz::ShaderBuilder return expressionStatementNode; } + template + std::unique_ptr Impl::ForEach::operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const + { + auto forEachNode = std::make_unique(); + forEachNode->isConst = Const; + forEachNode->expression = std::move(expression); + forEachNode->statement = std::move(statement); + forEachNode->varName = std::move(varName); + + return forEachNode; + } + inline std::unique_ptr Impl::Identifier::operator()(std::string name) const { auto identifierNode = std::make_unique(); @@ -327,6 +358,15 @@ namespace Nz::ShaderBuilder return unaryNode; } + inline std::unique_ptr Impl::Variable::operator()(std::size_t variableId, ShaderAst::ExpressionType expressionType) const + { + auto varNode = std::make_unique(); + varNode->variableId = variableId; + varNode->cachedExpressionType = std::move(expressionType); + + return varNode; + } + inline std::unique_ptr Impl::While::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const { auto whileNode = std::make_unique(); diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index 2bd89245e..ea78c2867 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -88,6 +88,7 @@ namespace Nz::ShaderLang ShaderAst::StatementPtr ParseConstStatement(); ShaderAst::StatementPtr ParseDiscardStatement(); ShaderAst::StatementPtr ParseExternalBlock(std::vector attributes = {}); + ShaderAst::StatementPtr ParseForDeclaration(); std::vector ParseFunctionBody(); ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector attributes = {}); ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter(); diff --git a/include/Nazara/Shader/ShaderLangTokenList.hpp b/include/Nazara/Shader/ShaderLangTokenList.hpp index 7c98ab3ec..7234d5509 100644 --- a/include/Nazara/Shader/ShaderLangTokenList.hpp +++ b/include/Nazara/Shader/ShaderLangTokenList.hpp @@ -31,6 +31,7 @@ NAZARA_SHADERLANG_TOKEN(Else) NAZARA_SHADERLANG_TOKEN(EndOfStream) NAZARA_SHADERLANG_TOKEN(External) NAZARA_SHADERLANG_TOKEN(FloatingPointValue) +NAZARA_SHADERLANG_TOKEN(For) NAZARA_SHADERLANG_TOKEN(FunctionDeclaration) NAZARA_SHADERLANG_TOKEN(FunctionReturn) NAZARA_SHADERLANG_TOKEN(GreaterThan) @@ -38,6 +39,7 @@ NAZARA_SHADERLANG_TOKEN(GreaterThanEqual) NAZARA_SHADERLANG_TOKEN(IntegerValue) NAZARA_SHADERLANG_TOKEN(Identifier) NAZARA_SHADERLANG_TOKEN(If) +NAZARA_SHADERLANG_TOKEN(In) NAZARA_SHADERLANG_TOKEN(LessThan) NAZARA_SHADERLANG_TOKEN(LessThanEqual) NAZARA_SHADERLANG_TOKEN(Let) diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index eee6add0c..a22da4a9a 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -42,7 +42,8 @@ namespace Nz struct Array { TypePtr elementType; - UInt32 length; + ConstantPtr length; + std::optional stride; }; struct Bool {}; @@ -129,7 +130,7 @@ namespace Nz struct ConstantScalar { - std::variant value; + std::variant value; }; using AnyConstant = std::variant; @@ -174,6 +175,7 @@ namespace Nz TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const; TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const; TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const; + TypePtr BuildType(const ShaderAst::ArrayType& type) const; TypePtr BuildType(const ShaderAst::ExpressionType& type) const; TypePtr BuildType(const ShaderAst::IdentifierType& type) const; TypePtr BuildType(const ShaderAst::MatrixType& type) const; diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index 2c793c006..2ab1f4263 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -170,6 +170,16 @@ namespace Nz::ShaderAst return clone; } + StatementPtr AstCloner::Clone(ForEachStatement& node) + { + auto clone = std::make_unique(); + clone->isConst = node.isConst; + clone->expression = CloneExpression(node.expression); + clone->statement = CloneStatement(node.statement); + + return clone; + } + StatementPtr AstCloner::Clone(MultiStatement& node) { auto clone = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index 61ef253ea..e6b1368bc 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -85,7 +85,8 @@ namespace Nz::ShaderAst void AstRecursiveVisitor::Visit(SwizzleExpression& node) { - node.expression->Visit(*this); + if (node.expression) + node.expression->Visit(*this); } void AstRecursiveVisitor::Visit(VariableExpression& /*node*/) @@ -95,7 +96,8 @@ namespace Nz::ShaderAst void AstRecursiveVisitor::Visit(UnaryExpression& node) { - node.expression->Visit(*this); + if (node.expression) + node.expression->Visit(*this); } void AstRecursiveVisitor::Visit(BranchStatement& node) @@ -159,6 +161,15 @@ namespace Nz::ShaderAst node.expression->Visit(*this); } + void AstRecursiveVisitor::Visit(ForEachStatement& node) + { + if (node.expression) + node.expression->Visit(*this); + + if (node.statement) + node.statement->Visit(*this); + } + void AstRecursiveVisitor::Visit(MultiStatement& node) { for (auto& statement : node.statements) diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index df15d3de2..ca78af1c4 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -301,6 +301,14 @@ namespace Nz::ShaderAst Node(node.expression); } + void AstSerializerBase::Serialize(ForEachStatement& node) + { + Value(node.isConst); + Value(node.varName); + Node(node.expression); + Node(node.statement); + } + void AstSerializerBase::Serialize(MultiStatement& node) { Container(node.statements); diff --git a/src/Nazara/Shader/Ast/ExpressionType.cpp b/src/Nazara/Shader/Ast/ExpressionType.cpp index 5d507fb4e..186bbc20b 100644 --- a/src/Nazara/Shader/Ast/ExpressionType.cpp +++ b/src/Nazara/Shader/Ast/ExpressionType.cpp @@ -13,7 +13,7 @@ namespace Nz::ShaderAst { assert(array.containedType); containedType = std::make_unique(*array.containedType); - length = Clone(length); + length = Clone(array.length); } ArrayType& ArrayType::operator=(const ArrayType& array) @@ -21,7 +21,7 @@ namespace Nz::ShaderAst assert(array.containedType); containedType = std::make_unique(*array.containedType); - length = Clone(length); + length = Clone(array.length); return *this; } diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index ce20348a5..ecdbce463 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -278,39 +278,7 @@ namespace Nz::ShaderAst MandatoryExpr(node.right); auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); - - if (GetExpressionCategory(*clone->left) != ExpressionCategory::LValue) - throw AstError{ "Assignation is only possible with a l-value" }; - - std::optional binaryType; - switch (clone->op) - { - case AssignType::Simple: - TypeMustMatch(clone->left, clone->right); - break; - - case AssignType::CompoundAdd: binaryType = BinaryType::Add; break; - case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break; - case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break; - case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break; - case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break; - case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break; - } - - if (binaryType) - { - ExpressionType expressionType = ValidateBinaryOp(*binaryType, clone->left, clone->right); - TypeMustMatch(GetExpressionType(*clone->left), expressionType); - - if (m_context->options.removeCompoundAssignments) - { - clone->op = AssignType::Simple; - clone->right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*clone->left), std::move(clone->right)); - clone->right->cachedExpressionType = std::move(expressionType); - } - } - - clone->cachedExpressionType = GetExpressionType(*clone->left); + Validate(*clone); return clone; } @@ -318,7 +286,7 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); - clone->cachedExpressionType = ValidateBinaryOp(clone->op, clone->left, clone->right); + Validate(*clone); return clone; } @@ -861,6 +829,119 @@ namespace Nz::ShaderAst return AstCloner::Clone(node); } + StatementPtr SanitizeVisitor::Clone(ForEachStatement& node) + { + auto expr = CloneExpression(node.expression); + + const ExpressionType& exprType = GetExpressionType(*expr); + ExpressionType innerType; + if (IsArrayType(exprType)) + { + const ArrayType& arrayType = std::get(exprType); + innerType = arrayType.containedType->type; + } + else + throw AstError{ "for-each is only supported on arrays and range expressions" }; + + if (node.isConst) + { + // Repeat code + auto multi = std::make_unique(); + if (IsArrayType(exprType)) + { + const ArrayType& arrayType = std::get(exprType); + UInt32 length = arrayType.length.GetResultingValue(); + + for (UInt32 i = 0; i < length; ++i) + { + auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i)); + Validate(*accessIndex); + + auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex)); + Validate(*elementVariable); + + multi->statements.emplace_back(std::move(elementVariable)); + multi->statements.emplace_back(CloneStatement(node.statement)); + } + } + + return multi; + } + + if (m_context->options.reduceLoopsToWhile) + { + PushScope(); + + auto multi = std::make_unique(); + + if (IsArrayType(exprType)) + { + const ArrayType& arrayType = std::get(exprType); + UInt32 length = arrayType.length.GetResultingValue(); + + multi->statements.reserve(2); + + // Counter variable + auto counterVariable = ShaderBuilder::DeclareVariable("i", ShaderBuilder::Constant(0u)); + Validate(*counterVariable); + + std::size_t counterVarIndex = counterVariable->varIndex.value(); + + multi->statements.emplace_back(std::move(counterVariable)); + + auto whileStatement = std::make_unique(); + + // While condition + auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(length)); + Validate(*condition); + whileStatement->condition = std::move(condition); + + // While body + auto body = std::make_unique(); + body->statements.reserve(3); + + auto accessIndex = ShaderBuilder::AccessIndex(std::move(expr), ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32)); + Validate(*accessIndex); + + auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex)); + Validate(*elementVariable); + body->statements.emplace_back(std::move(elementVariable)); + + body->statements.emplace_back(CloneStatement(node.statement)); + + auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(1u)); + Validate(*incrCounter); + + body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); + + whileStatement->body = std::move(body); + + multi->statements.emplace_back(std::move(whileStatement)); + } + + PopScope(); + + return multi; + } + else + { + auto clone = std::make_unique(); + clone->expression = std::move(expr); + clone->varName = node.varName; + + PushScope(); + { + clone->varIndex = RegisterVariable(node.varName, innerType); + clone->statement = CloneStatement(node.statement); + } + PopScope(); + + SanitizeIdentifier(node.varName); + + return clone; + } + } + StatementPtr SanitizeVisitor::Clone(MultiStatement& node) { PushScope(); @@ -1206,7 +1287,6 @@ namespace Nz::ShaderAst using T = std::decay_t; if constexpr (std::is_same_v || - std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -1215,6 +1295,22 @@ namespace Nz::ShaderAst { return exprType; } + else if constexpr (std::is_same_v) + { + ArrayType resolvedArrayType; + if (arg.length.IsExpression()) + { + resolvedArrayType.length = CloneExpression(arg.length.GetExpression()); + ComputeAttributeValue(resolvedArrayType.length); + } + else if (arg.length.IsResultingValue()) + resolvedArrayType.length = arg.length.GetResultingValue(); + + resolvedArrayType.containedType = std::make_unique(); + resolvedArrayType.containedType->type = ResolveType(arg.containedType->type); + + return resolvedArrayType; + } else if constexpr (std::is_same_v) { const Identifier* identifier = FindIdentifier(arg.name); @@ -1262,8 +1358,12 @@ namespace Nz::ShaderAst for (auto& index : node.indices) { const ShaderAst::ExpressionType& indexType = GetExpressionType(*index); - if (!IsPrimitiveType(indexType) || std::get(indexType) != PrimitiveType::Int32) - throw AstError{ "AccessIndex expects Int32 indices" }; + if (!IsPrimitiveType(indexType)) + throw AstError{ "AccessIndex expects integer indices" }; + + PrimitiveType primitiveIndexType = std::get(indexType); + if (primitiveIndexType != PrimitiveType::Int32 && primitiveIndexType != PrimitiveType::UInt32) + throw AstError{ "AccessIndex expects integer indices" }; } ExpressionType exprType = GetExpressionType(*node.expr); @@ -1272,8 +1372,8 @@ namespace Nz::ShaderAst if (IsArrayType(exprType)) { const ArrayType& arrayType = std::get(exprType); - - exprType = arrayType.containedType->type; + ExpressionType containedType = arrayType.containedType->type; //< Don't overwrite exprType directly since it contains arrayType + exprType = std::move(containedType); } else if (IsStructType(exprType)) { @@ -1294,7 +1394,7 @@ namespace Nz::ShaderAst else if (IsMatrixType(exprType)) { // Matrix index (ex: mat[2]) - const MatrixType& matrixType = std::get(exprType); + MatrixType matrixType = std::get(exprType); //TODO: Handle row-major matrices exprType = VectorType{ matrixType.rowCount, matrixType.type }; @@ -1302,7 +1402,7 @@ namespace Nz::ShaderAst else if (IsVectorType(exprType)) { // Swizzle expression with one component (ex: vec[2]) - const VectorType& swizzledVec = std::get(exprType); + VectorType swizzledVec = std::get(exprType); exprType = swizzledVec.type; } @@ -1313,6 +1413,47 @@ namespace Nz::ShaderAst node.cachedExpressionType = std::move(exprType); } + void SanitizeVisitor::Validate(AssignExpression& node) + { + if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) + throw AstError{ "Assignation is only possible with a l-value" }; + + std::optional binaryType; + switch (node.op) + { + case AssignType::Simple: + TypeMustMatch(node.left, node.right); + break; + + case AssignType::CompoundAdd: binaryType = BinaryType::Add; break; + case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break; + case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break; + case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break; + case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break; + case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break; + } + + if (binaryType) + { + ExpressionType expressionType = ValidateBinaryOp(*binaryType, node.left, node.right); + TypeMustMatch(GetExpressionType(*node.left), expressionType); + + if (m_context->options.removeCompoundAssignments) + { + node.op = AssignType::Simple; + node.right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*node.left), std::move(node.right)); + node.right->cachedExpressionType = std::move(expressionType); + } + } + + node.cachedExpressionType = GetExpressionType(*node.left); + } + + void SanitizeVisitor::Validate(BinaryExpression& node) + { + node.cachedExpressionType = ValidateBinaryOp(node.op, node.left, node.right); + } + void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration) { if (referenceDeclaration->entryStage.HasValue()) diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 86bad6d63..9a7b46e40 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -207,28 +207,23 @@ namespace Nz ShaderAst::SanitizeVisitor::Options options; options.optionValues = std::move(optionValues); options.makeVariableNameUnique = true; + options.reduceLoopsToWhile = true; options.removeCompoundAssignments = false; + options.removeOptionDeclaration = true; options.removeScalarSwizzling = true; options.reservedIdentifiers = { // All reserved GLSL keywords as of GLSL ES 3.2 "active", "asm", "atomic_uint", "attribute", "bool", "break", "buffer", "bvec2", "bvec3", "bvec4", "case", "cast", "centroid", "class", "coherent", "common", "const", "continue", "default", "discard", "dmat2", "dmat2x2", "dmat2x3", "dmat2x4", "dmat3", "dmat3x2", "dmat3x3", "dmat3x4", "dmat4", "dmat4x2", "dmat4x3", "dmat4x4", "do", "double", "dvec2", "dvec3", "dvec4", "else", "enum", "extern", "external", "false", "filter", "fixed", "flat", "float", "for", "fvec2", "fvec3", "fvec4", "goto", "half", "highp", "hvec2", "hvec3", "hvec4", "if", "iimage1D", "iimage1DArray", "iimage2D", "iimage2DArray", "iimage2DMS", "iimage2DMSArray", "iimage2DRect", "iimage3D", "iimageBuffer", "iimageCube", "iimageCubeArray", "image1D", "image1DArray", "image2D", "image2DArray", "image2DMS", "image2DMSArray", "image2DRect", "image3D", "imageBuffer", "imageCube", "imageCubeArray", "in", "inline", "inout", "input", "int", "interface", "invariant", "isampler1D", "isampler1DArray", "isampler2D", "isampler2DArray", "isampler2DMS", "isampler2DMSArray", "isampler2DRect", "isampler3D", "isamplerBuffer", "isamplerCube", "isamplerCubeArray", "isubpassInput", "isubpassInputMS", "itexture2D", "itexture2DArray", "itexture2DMS", "itexture2DMSArray", "itexture3D", "itextureBuffer", "itextureCube", "itextureCubeArray", "ivec2", "ivec3", "ivec4", "layout", "long", "lowp", "mat2", "mat2x2", "mat2x3", "mat2x4", "mat3", "mat3x2", "mat3x3", "mat3x4", "mat4", "mat4x2", "mat4x3", "mat4x4", "mediump", "namespace", "noinline", "noperspective", "out", "output", "partition", "patch", "precise", "precision", "public", "readonly", "resource", "restrict", "return", "sample", "sampler", "sampler1D", "sampler1DArray", "sampler1DArrayShadow", "sampler1DShadow", "sampler2D", "sampler2DArray", "sampler2DArrayShadow", "sampler2DMS", "sampler2DMSArray", "sampler2DRect", "sampler2DRectShadow", "sampler2DShadow", "sampler3D", "sampler3DRect", "samplerBuffer", "samplerCube", "samplerCubeArray", "samplerCubeArrayShadow", "samplerCubeShadow", "samplerShadow", "shared", "short", "sizeof", "smooth", "static", "struct", "subpassInput", "subpassInputMS", "subroutine", "superp", "switch", "template", "texture2D", "texture2DArray", "texture2DMS", "texture2DMSArray", "texture3D", "textureBuffer", "textureCube", "textureCubeArray", "this", "true", "typedef", "uimage1D", "uimage1DArray", "uimage2D", "uimage2DArray", "uimage2DMS", "uimage2DMSArray", "uimage2DRect", "uimage3D", "uimageBuffer", "uimageCube", "uimageCubeArray", "uint", "uniform", "union", "unsigned", "usampler1D", "usampler1DArray", "usampler2D", "usampler2DArray", "usampler2DMS", "usampler2DMSArray", "usampler2DRect", "usampler3D", "usamplerBuffer", "usamplerCube", "usamplerCubeArray", "using", "usubpassInput", "usubpassInputMS", "utexture2D", "utexture2DArray", "utexture2DMS", "utexture2DMSArray", "utexture3D", "utextureBuffer", "utextureCube", "utextureCubeArray", "uvec2", "uvec3", "uvec4", "varying", "vec2", "vec3", "vec4", "void", "volatile", "while", "writeonly" // GLSL functions - "cross", "dot", "length", "max", "min", "pow", "texture" + "cross", "dot", "exp", "length", "max", "min", "pow", "texture" }; return ShaderAst::Sanitize(ast, options, error); } - void GlslWriter::Append(const ShaderAst::ArrayType& type) + void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/) { - Append(type.containedType->type, "["); - - if (type.length.IsResultingValue()) - Append(type.length.GetResultingValue()); - else - type.length.GetExpression()->Visit(*this); - - Append("]"); + throw std::runtime_error("unexpected ArrayType"); } void GlslWriter::Append(const ShaderAst::ExpressionType& type) @@ -390,7 +385,7 @@ namespace Nz first = false; - Append(parameter.type, " ", parameter.name); + AppendVariableDeclaration(parameter.type, parameter.name); } AppendLine((forward) ? ");" : ")"); } @@ -538,6 +533,40 @@ namespace Nz } } + void GlslWriter::AppendVariableDeclaration(const ShaderAst::ExpressionType& varType, const std::string& varName) + { + if (ShaderAst::IsArrayType(varType)) + { + std::vector*> lengths; + + const ShaderAst::ExpressionType* exprType = &varType; + while (ShaderAst::IsArrayType(*exprType)) + { + const auto& arrayType = std::get(*exprType); + lengths.push_back(&arrayType.length); + + exprType = &arrayType.containedType->type; + } + + assert(!ShaderAst::IsArrayType(*exprType)); + Append(*exprType, " ", varName); + + for (const auto* lengthAttribute : lengths) + { + Append("["); + + if (lengthAttribute->IsResultingValue()) + Append(lengthAttribute->GetResultingValue()); + else + lengthAttribute->GetExpression()->Visit(*this); + + Append("]"); + } + } + else + Append(varType, " ", varName); + } + void GlslWriter::EnterScope() { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); @@ -632,13 +661,8 @@ namespace Nz { Append("layout(location = "); Append(member.locationIndex.GetResultingValue()); - Append(") "); - Append(keyword); - Append(" "); - Append(member.type); - Append(" "); - Append(targetPrefix); - Append(member.name); + Append(") ", keyword, " "); + AppendVariableDeclaration(member.type, targetPrefix + member.name); AppendLine(";"); fields.push_back({ @@ -824,8 +848,10 @@ namespace Nz throw std::runtime_error("invalid type (value expected)"); else if constexpr (std::is_same_v) Append((arg) ? "true" : "false"); - else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) Append(std::to_string(arg)); + else if constexpr (std::is_same_v) + Append(std::to_string(arg), "u"); else if constexpr (std::is_same_v || std::is_same_v) Append("vec2(" + std::to_string(arg.x) + ", " + std::to_string(arg.y) + ")"); else if constexpr (std::is_same_v || std::is_same_v) @@ -1033,19 +1059,18 @@ namespace Nz first = false; - Append(member.type); - Append(" "); - Append(member.name); + AppendVariableDeclaration(member.type, member.name); Append(";"); } } LeaveScope(false); + + Append(" "); + Append(externalVar.name); } else - Append(externalVar.type); + AppendVariableDeclaration(externalVar.type, externalVar.name); - Append(" "); - Append(externalVar.name); AppendLine(";"); if (IsUniformType(externalVar.type)) @@ -1127,9 +1152,7 @@ namespace Nz first = false; - Append(member.type); - Append(" "); - Append(member.name); + AppendVariableDeclaration(member.type, member.name); Append(";"); } } @@ -1142,7 +1165,7 @@ namespace Nz assert(node.varIndex); RegisterVariable(*node.varIndex, node.varName); - Append(node.varType, " ", node.varName); + AppendVariableDeclaration(node.varType, node.varName); if (node.initialExpression) { Append(" = "); diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index f9cd93029..69e1b9eae 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -170,7 +170,7 @@ namespace Nz case ShaderAst::PrimitiveType::Boolean: return Append("bool"); case ShaderAst::PrimitiveType::Float32: return Append("f32"); case ShaderAst::PrimitiveType::Int32: return Append("i32"); - case ShaderAst::PrimitiveType::UInt32: return Append("ui32"); + case ShaderAst::PrimitiveType::UInt32: return Append("u32"); } } @@ -185,7 +185,7 @@ namespace Nz case ImageType::E2D: Append("2D"); break; case ImageType::E2D_Array: Append("2DArray"); break; case ImageType::E3D: Append("3D"); break; - case ImageType::Cubemap: Append("Cube"); break; + case ImageType::Cubemap: Append("Cube"); break; } Append("<", samplerType.sampledType, ">"); @@ -653,6 +653,21 @@ namespace Nz node.statement->Visit(*this); } + void LangWriter::Visit(ShaderAst::DeclareConstStatement& node) + { + assert(node.constIndex); + RegisterConstant(*node.constIndex, node.name); + + Append("const ", node.name, ": ", node.type); + if (node.expression) + { + Append(" = "); + node.expression->Visit(*this); + } + + Append(";"); + } + void LangWriter::Visit(ShaderAst::ConstantValueExpression& node) { std::visit([&](auto&& arg) @@ -811,6 +826,20 @@ namespace Nz Append(";"); } + void LangWriter::Visit(ShaderAst::ForEachStatement& node) + { + assert(node.varIndex); + RegisterVariable(*node.varIndex, node.varName); + + Append("for ", node.varName, " in "); + node.expression->Visit(*this); + AppendLine(); + + EnterScope(); + node.statement->Visit(*this); + LeaveScope(); + } + void LangWriter::Visit(ShaderAst::IntrinsicExpression& node) { bool method = false; diff --git a/src/Nazara/Shader/ShaderLangLexer.cpp b/src/Nazara/Shader/ShaderLangLexer.cpp index c7ad71449..5015d3935 100644 --- a/src/Nazara/Shader/ShaderLangLexer.cpp +++ b/src/Nazara/Shader/ShaderLangLexer.cpp @@ -47,7 +47,9 @@ namespace Nz::ShaderLang { "external", TokenType::External }, { "false", TokenType::BoolFalse }, { "fn", TokenType::FunctionDeclaration }, + { "for", TokenType::For }, { "if", TokenType::If }, + { "in", TokenType::In }, { "let", TokenType::Let }, { "option", TokenType::Option }, { "return", TokenType::Return }, diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 8c800843d..d2622e76e 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include #include #include @@ -472,6 +473,14 @@ namespace Nz::ShaderLang switch (Peek().type) { + case TokenType::For: + { + auto forEach = ParseForDeclaration(); + SafeCast(*forEach).isConst = true; + + return forEach; + } + case TokenType::Identifier: { std::string constName; @@ -487,7 +496,7 @@ namespace Nz::ShaderLang case TokenType::If: { auto branch = ParseBranchStatement(); - static_cast(*branch).isConst = true; + SafeCast(*branch).isConst = true; return branch; } @@ -589,6 +598,21 @@ namespace Nz::ShaderLang return externalStatement; } + ShaderAst::StatementPtr Parser::ParseForDeclaration() + { + Expect(Advance(), TokenType::For); + + std::string varName = ParseIdentifierAsName(); + + Expect(Advance(), TokenType::In); + + ShaderAst::ExpressionPtr expr = ParseExpression(); + + ShaderAst::StatementPtr statement = ParseStatement(); + + return ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement)); + } + std::vector Parser::ParseFunctionBody() { return ParseStatementList(); @@ -734,6 +758,10 @@ namespace Nz::ShaderLang statement = ParseDiscardStatement(); break; + case TokenType::For: + statement = ParseForDeclaration(); + break; + case TokenType::Let: statement = ParseVariableDeclaration(); break; diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index 625cf6fa8..d3b9eb027 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -38,7 +38,7 @@ namespace Nz bool Compare(const Array& lhs, const Array& rhs) const { - return lhs.length == rhs.length && Compare(lhs.elementType, rhs.elementType); + return Compare(lhs.length, rhs.length) && Compare(lhs.elementType, rhs.elementType) && lhs.stride == rhs.stride; } bool Compare(const Bool& /*lhs*/, const Bool& /*rhs*/) const @@ -237,6 +237,8 @@ namespace Nz { assert(array.elementType); cache.Register(*array.elementType); + assert(array.length); + cache.Register(*array.length); } void Register(const Bool&) {} @@ -416,6 +418,7 @@ namespace Nz tsl::ordered_map structureSizes; StructCallback structCallback; UInt32& nextResultId; + bool isInBlockStruct = false; }; SpirvConstantCache::SpirvConstantCache(UInt32& resultId) @@ -493,26 +496,59 @@ namespace Nz auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const -> TypePtr { - return std::make_shared(Pointer{ + bool wasInblockStruct = m_internal->isInBlockStruct; + if (storageClass == SpirvStorageClass::Uniform) + m_internal->isInBlockStruct = true; + + auto typePtr = std::make_shared(Pointer{ BuildType(type), storageClass }); + + m_internal->isInBlockStruct = wasInblockStruct; + + return typePtr; } auto SpirvConstantCache::BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const -> TypePtr { - return std::make_shared(Pointer{ + bool wasInblockStruct = m_internal->isInBlockStruct; + if (storageClass == SpirvStorageClass::Uniform) + m_internal->isInBlockStruct = true; + + auto typePtr = std::make_shared(Pointer{ type, storageClass + }); + + m_internal->isInBlockStruct = wasInblockStruct; + + return typePtr; + } + + auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr + { + return std::make_shared(Array{ + BuildType(type.containedType->type), + BuildConstant(type.length.GetResultingValue()), + (m_internal->isInBlockStruct) ? std::make_optional(16) : std::nullopt }); } auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr { - return std::make_shared(Pointer{ + bool wasInblockStruct = m_internal->isInBlockStruct; + if (storageClass == SpirvStorageClass::Uniform) + m_internal->isInBlockStruct = true; + + auto typePtr = std::make_shared(Pointer{ BuildType(type), storageClass }); + + m_internal->isInBlockStruct = wasInblockStruct; + + return typePtr; } auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr @@ -614,6 +650,10 @@ namespace Nz sType.name = structDesc.name; sType.decorations = std::move(decorations); + bool wasInBlock = m_internal->isInBlockStruct; + if (!wasInBlock) + m_internal->isInBlockStruct = std::find(sType.decorations.begin(), sType.decorations.end(), SpirvDecoration::Block) != sType.decorations.end(); + for (const auto& member : structDesc.members) { if (member.cond.HasValue() && !member.cond.GetResultingValue()) @@ -624,6 +664,8 @@ namespace Nz sMembers.type = BuildType(member.type); } + m_internal->isInBlockStruct = wasInBlock; + return std::make_shared(std::move(sType)); } @@ -814,7 +856,11 @@ namespace Nz using T = std::decay_t; if constexpr (std::is_same_v) - constants.Append(SpirvOp::OpTypeArray, resultId, GetId(*arg.elementType), arg.length); + { + constants.Append(SpirvOp::OpTypeArray, resultId, GetId(*arg.elementType), GetId(*arg.length)); + if (arg.stride) + annotations.Append(SpirvOp::OpDecorate, resultId, SpirvDecoration::ArrayStride, *arg.stride); + } else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeBool, resultId); else if constexpr (std::is_same_v) @@ -908,8 +954,23 @@ namespace Nz if constexpr (std::is_same_v) { - // TODO - throw std::runtime_error("todo"); + assert(std::holds_alternative(arg.length->constant)); + const auto& scalar = std::get(arg.length->constant); + assert(std::holds_alternative(scalar.value)); + std::size_t length = std::get(scalar.value); + + if (!std::holds_alternative(arg.elementType->type)) + throw std::runtime_error("todo"); + + // FIXME: Virer cette implémentation du ghetto + + const Float& fData = std::get(arg.elementType->type); + switch (fData.width) + { + case 32: return structOffsets.AddFieldArray(StructFieldType::Float1, length); + case 64: return structOffsets.AddFieldArray(StructFieldType::Double1, length); + default: throw std::runtime_error("unexpected float width " + std::to_string(fData.width)); + } } else if constexpr (std::is_same_v) return structOffsets.AddField(StructFieldType::Bool1); diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 2e308704e..ce56a2b01 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -487,7 +487,9 @@ namespace Nz { ShaderAst::SanitizeVisitor::Options options; options.optionValues = states.optionValues; + options.reduceLoopsToWhile = true; options.removeCompoundAssignments = true; + options.removeOptionDeclaration = true; options.splitMultipleBranches = true; sanitizedAst = ShaderAst::Sanitize(shader, options); diff --git a/tests/Engine/Shader/Const.cpp b/tests/Engine/Shader/Const.cpp index 96d72c8c7..28cf29ef6 100644 --- a/tests/Engine/Shader/Const.cpp +++ b/tests/Engine/Shader/Const.cpp @@ -109,4 +109,55 @@ fn main() )"); } } + + WHEN("using const for-each") + { + std::string_view sourceCode = R"( +const LightCount = 3; + +[layout(std140)] +struct Light +{ + color: vec4 +} + +[layout(std140)] +struct LightData +{ + lights: [Light; LightCount] +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + let color = (0.0).xxxx; + const for light in data.lights + { + color += light.color; + } +} +)"; + + Nz::ShaderAst::StatementPtr shader; + REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); + + ExpectOutput(*shader, {}, R"( +[entry(frag)] +fn main() +{ + let color: vec4 = (0.000000).xxxx; + let light: Light = data.lights[0]; + color += light.color; + let light: Light = data.lights[1]; + color += light.color; + let light: Light = data.lights[2]; + color += light.color; +} +)"); + } } diff --git a/tests/Engine/Shader/Loops.cpp b/tests/Engine/Shader/Loops.cpp index b40aa58d3..6ba55d163 100644 --- a/tests/Engine/Shader/Loops.cpp +++ b/tests/Engine/Shader/Loops.cpp @@ -8,7 +8,9 @@ TEST_CASE("loops", "[Shader]") { - std::string_view nzslSource = R"( + WHEN("using a while") + { + std::string_view nzslSource = R"( struct inputStruct { value: f32 @@ -32,9 +34,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shader, R"( void main() { float value = 0.000000; @@ -48,7 +50,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shader, R"( [entry(frag)] fn main() { @@ -63,7 +65,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSpirV(*shader, R"( OpFunction OpLabel OpVariable @@ -87,4 +89,93 @@ OpBranch OpLabel OpReturn OpFunctionEnd)"); + } + + WHEN("using a for-each") + { + std::string_view nzslSource = R"( +struct inputStruct +{ + value: [f32; 10] +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + let x = 0.0; + for v in data.value + { + x += v; + } +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + + ExpectGLSL(*shader, R"( +void main() +{ + float x = 0.000000; + uint i = 0u; + while (i < (10u)) + { + float v = data.value[i]; + x += v; + i += 1u; + } + +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let x: f32 = 0.000000; + for v in data.value + { + x += v; + } + +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpVariable +OpVariable +OpStore +OpStore +OpBranch +OpLabel +OpLoad +OpULessThan +OpLoopMerge +OpBranchConditional +OpLabel +OpAccessChain +OpLoad +OpAccessChain +OpLoad +OpStore +OpLoad +OpLoad +OpFAdd +OpStore +OpLoad +OpIAdd +OpStore +OpBranch +OpLabel +OpReturn +OpFunctionEnd)"); + } } diff --git a/tests/Engine/Shader/Sanitizations.cpp b/tests/Engine/Shader/Sanitizations.cpp index 9a376c0a8..e28256adb 100644 --- a/tests/Engine/Shader/Sanitizations.cpp +++ b/tests/Engine/Shader/Sanitizations.cpp @@ -74,6 +74,55 @@ fn main() } +} +)"); + + } + + WHEN("reducing for-each to while") + { + std::string_view nzslSource = R"( +struct inputStruct +{ + value: [f32; 10] +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + let x = 0.0; + for v in data.value + { + x += v; + } +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + Nz::ShaderAst::SanitizeVisitor::Options options; + options.reduceLoopsToWhile = true; + + REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader, options)); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let x: f32 = 0.000000; + let i: u32 = 0; + while (i < (10)) + { + let v: f32 = data.value[i]; + x += v; + i += 1; + } + } )");