diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index e346509ca..0b1299c9a 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -67,6 +67,7 @@ namespace Nz::ShaderAst virtual StatementPtr Clone(MultiStatement& node); virtual StatementPtr Clone(NoOpStatement& node); virtual StatementPtr Clone(ReturnStatement& node); + virtual StatementPtr Clone(WhileStatement& node); #define NAZARA_SHADERAST_NODE(NodeType) void Visit(NodeType& node) override; #include diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index 927f36eb7..f82481594 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -55,7 +55,8 @@ NAZARA_SHADERAST_STATEMENT(DiscardStatement) NAZARA_SHADERAST_STATEMENT(ExpressionStatement) NAZARA_SHADERAST_STATEMENT(MultiStatement) NAZARA_SHADERAST_STATEMENT(NoOpStatement) -NAZARA_SHADERAST_STATEMENT_LAST(ReturnStatement) +NAZARA_SHADERAST_STATEMENT(ReturnStatement) +NAZARA_SHADERAST_STATEMENT_LAST(WhileStatement) #undef NAZARA_SHADERAST_EXPRESSION #undef NAZARA_SHADERAST_NODE diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index 765ac0d89..f771cb353 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -49,6 +49,7 @@ namespace Nz::ShaderAst void Visit(MultiStatement& node) override; void Visit(NoOpStatement& node) override; void Visit(ReturnStatement& node) override; + void Visit(WhileStatement& node) override; }; } diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index 6e4ae5cf4..06d990b24 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -52,6 +52,7 @@ namespace Nz::ShaderAst void Serialize(MultiStatement& node); void Serialize(NoOpStatement& node); void Serialize(ReturnStatement& node); + void Serialize(WhileStatement& node); protected: template void Attribute(AttributeValue& attribute); diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index d8219109d..7dfceec53 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -362,6 +362,15 @@ namespace Nz::ShaderAst ExpressionPtr returnExpr; }; + struct NAZARA_SHADER_API WhileStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + ExpressionPtr condition; + StatementPtr body; + }; + inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr); inline bool IsExpression(NodeType nodeType); inline bool IsStatement(NodeType nodeType); diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 22922b461..7750e3eea 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -74,6 +74,7 @@ namespace Nz::ShaderAst StatementPtr Clone(DiscardStatement& node) override; StatementPtr Clone(ExpressionStatement& node) override; StatementPtr Clone(MultiStatement& node) override; + StatementPtr Clone(WhileStatement& node) override; const Identifier* FindIdentifier(const std::string_view& identifierName) const; diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 504df4461..cd4e778d7 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -105,6 +105,7 @@ namespace Nz void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override; + void Visit(ShaderAst::WhileStatement& node) override; static bool HasExplicitBinding(ShaderAst::StatementPtr& shader); static bool HasExplicitLocation(ShaderAst::StatementPtr& shader); diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index a1446d3a9..cee67006c 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -108,6 +108,7 @@ namespace Nz void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override; + void Visit(ShaderAst::WhileStatement& node) override; struct State; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 12b8f0dfc..5e0a77400 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -140,6 +140,11 @@ namespace Nz::ShaderBuilder { inline std::unique_ptr operator()(ShaderAst::UnaryType op, ShaderAst::ExpressionPtr expression) const; }; + + struct While + { + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const; + }; } constexpr Impl::AccessIndex AccessIndex; @@ -167,6 +172,7 @@ namespace Nz::ShaderBuilder constexpr Impl::Return Return; constexpr Impl::Swizzle Swizzle; constexpr Impl::Unary Unary; + constexpr Impl::While While; } #include diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 70deabfd7..52188a99c 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -298,6 +298,15 @@ namespace Nz::ShaderBuilder return unaryNode; } + + inline std::unique_ptr Impl::While::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const + { + auto whileNode = std::make_unique(); + whileNode->condition = std::move(condition); + whileNode->body = std::move(body); + + return whileNode; + } } #include diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index ff40a5e84..cb6ebd471 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -98,6 +98,7 @@ namespace Nz::ShaderLang std::vector ParseStatementList(); ShaderAst::StatementPtr ParseStructDeclaration(std::vector attributes = {}); ShaderAst::StatementPtr ParseVariableDeclaration(); + ShaderAst::StatementPtr ParseWhileStatement(); // Expressions ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs); diff --git a/include/Nazara/Shader/ShaderLangTokenList.hpp b/include/Nazara/Shader/ShaderLangTokenList.hpp index e9a2ad288..c0797f4f6 100644 --- a/include/Nazara/Shader/ShaderLangTokenList.hpp +++ b/include/Nazara/Shader/ShaderLangTokenList.hpp @@ -60,6 +60,7 @@ NAZARA_SHADERLANG_TOKEN(Option) NAZARA_SHADERLANG_TOKEN(Semicolon) NAZARA_SHADERLANG_TOKEN(Return) NAZARA_SHADERLANG_TOKEN(Struct) +NAZARA_SHADERLANG_TOKEN(While) #undef NAZARA_SHADERLANG_TOKEN #undef NAZARA_SHADERLANG_TOKEN_LAST diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index c208db3af..991d95bd3 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -61,8 +61,9 @@ namespace Nz void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; - void Visit(ShaderAst::VariableExpression& node) override; void Visit(ShaderAst::UnaryExpression& node) override; + void Visit(ShaderAst::VariableExpression& node) override; + void Visit(ShaderAst::WhileStatement& node) override; SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete; SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete; diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index 0bb0a873a..80dccbb21 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -193,6 +193,15 @@ namespace Nz::ShaderAst return clone; } + StatementPtr AstCloner::Clone(WhileStatement& node) + { + auto clone = std::make_unique(); + clone->condition = CloneExpression(node.condition); + clone->body = CloneStatement(node.body); + + return clone; + } + ExpressionPtr AstCloner::Clone(AccessIdentifierExpression& node) { auto clone = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index 4ebcd980a..8ee1dda84 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -175,4 +175,13 @@ namespace Nz::ShaderAst if (node.returnExpr) node.returnExpr->Visit(*this); } + + void AstRecursiveVisitor::Visit(WhileStatement& node) + { + if (node.condition) + node.condition->Visit(*this); + + if (node.body) + node.body->Visit(*this); + } } diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 74d06b650..28cc21875 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -318,6 +318,12 @@ namespace Nz::ShaderAst Node(node.returnExpr); } + void AstSerializerBase::Serialize(WhileStatement& node) + { + Node(node.condition); + Node(node.body); + } + void ShaderAstSerializer::Serialize(StatementPtr& shader) { m_stream << s_magicNumber << s_currentVersion; diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index b099fa69a..94c50af48 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -460,12 +460,11 @@ namespace Nz::ShaderAst { case Identifier::Type::Constant: { - // Replace IdentifierExpression by ConstantExpression - auto constantExpr = std::make_unique(); - constantExpr->cachedExpressionType = GetExpressionType(m_context->constantValues[identifier->index]); - constantExpr->constantId = identifier->index; + // Replace IdentifierExpression by Constant(Value)Expression + ConstantExpression constantExpr; + constantExpr.constantId = identifier->index; - return constantExpr; + return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression } case Identifier::Type::Variable: @@ -951,6 +950,19 @@ namespace Nz::ShaderAst return clone; } + StatementPtr SanitizeVisitor::Clone(WhileStatement& node) + { + MandatoryExpr(node.condition); + MandatoryStatement(node.body); + + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + if (GetExpressionType(*clone->condition) != ExpressionType{ PrimitiveType::Boolean }) + throw AstError{ "expected a boolean value" }; + + return clone; + } + auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* { auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 7602c0da1..5bc853672 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -1179,6 +1179,17 @@ namespace Nz } } + void GlslWriter::Visit(ShaderAst::WhileStatement& node) + { + Append("while ("); + node.condition->Visit(*this); + AppendLine(")"); + + EnterScope(); + node.body->Visit(*this); + LeaveScope(); + } + void GlslWriter::Visit(ShaderAst::SwizzleExpression& node) { Visit(node.expression, true); diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 356b035bd..dcad1f098 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -930,6 +930,17 @@ namespace Nz node.expression->Visit(*this); } + void LangWriter::Visit(ShaderAst::WhileStatement& node) + { + Append("while ("); + node.condition->Visit(*this); + AppendLine(")"); + + EnterScope(); + node.body->Visit(*this); + LeaveScope(); + } + void LangWriter::AppendHeader() { // Nothing yet diff --git a/src/Nazara/Shader/ShaderLangLexer.cpp b/src/Nazara/Shader/ShaderLangLexer.cpp index b4cb0b578..6e6a3ccae 100644 --- a/src/Nazara/Shader/ShaderLangLexer.cpp +++ b/src/Nazara/Shader/ShaderLangLexer.cpp @@ -52,7 +52,8 @@ namespace Nz::ShaderLang { "option", TokenType::Option }, { "return", TokenType::Return }, { "struct", TokenType::Struct }, - { "true", TokenType::BoolTrue } + { "true", TokenType::BoolTrue }, + { "while", TokenType::While } }; std::size_t currentPos = 0; diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 4c2034b93..674de3995 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -742,6 +742,10 @@ namespace Nz::ShaderLang statement = ParseReturnStatement(); break; + case TokenType::While: + statement = ParseWhileStatement(); + break; + default: throw UnexpectedToken{}; } @@ -905,6 +909,21 @@ namespace Nz::ShaderLang return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression)); } + ShaderAst::StatementPtr Parser::ParseWhileStatement() + { + Expect(Advance(), TokenType::While); + + Expect(Advance(), TokenType::OpenParenthesis); + + ShaderAst::ExpressionPtr condition = ParseExpression(); + + Expect(Advance(), TokenType::ClosingParenthesis); + + ShaderAst::StatementPtr body = ParseStatement(); + + return ShaderBuilder::While(std::move(condition), std::move(body)); + } + ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs) { for (;;) diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 56bf2d65b..5cd21f9db 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -944,12 +944,6 @@ namespace Nz PushResultId(resultId); } - void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node) - { - SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock); - PushResultId(loadVisitor.Evaluate(node)); - } - void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node) { const ShaderAst::ExpressionType& resultType = GetExpressionType(node); @@ -1011,6 +1005,40 @@ namespace Nz PushResultId(resultId); } + void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node) + { + SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock); + PushResultId(loadVisitor.Evaluate(node)); + } + + void SpirvAstVisitor::Visit(ShaderAst::WhileStatement& node) + { + assert(node.condition); + assert(node.body); + + SpirvBlock headerBlock(m_writer); + SpirvBlock bodyBlock(m_writer); + SpirvBlock mergeBlock(m_writer); + + m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId()); + m_currentBlock = &headerBlock; + + UInt32 expressionId = EvaluateExpression(node.condition); + + m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), SpirvLoopControl::None); + m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId()); + + m_currentBlock = &bodyBlock; + node.body->Visit(*this); + + m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId()); + + m_functionBlocks.emplace_back(std::move(headerBlock)); + m_functionBlocks.emplace_back(std::move(bodyBlock)); + m_functionBlocks.emplace_back(std::move(mergeBlock)); + m_currentBlock = &m_functionBlocks.back(); + } + void SpirvAstVisitor::PushResultId(UInt32 value) { m_resultIds.push_back(value);