diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index 64ea853e5..c44753158 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -460,10 +460,10 @@ namespace Nz::ShaderAst bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs) { - if (!Compare(lhs.isConst, rhs.isConst)) + if (!Compare(lhs.varName, rhs.varName)) return false; - if (!Compare(lhs.varName, rhs.varName)) + if (!Compare(lhs.unroll, rhs.unroll)) return false; if (!Compare(lhs.expression, rhs.expression)) @@ -498,6 +498,9 @@ namespace Nz::ShaderAst inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs) { + if (!Compare(lhs.unroll, rhs.unroll)) + return false; + if (!Compare(lhs.condition, rhs.condition)) return false; diff --git a/include/Nazara/Shader/Ast/Enums.hpp b/include/Nazara/Shader/Ast/Enums.hpp index ec46ed856..e15b2ce1d 100644 --- a/include/Nazara/Shader/Ast/Enums.hpp +++ b/include/Nazara/Shader/Ast/Enums.hpp @@ -36,6 +36,7 @@ namespace Nz Layout, //< Struct layout (struct only) - has argument style Location, //< Location (struct member only) - has argument index Set, //< Binding set (external var only) - has argument index + Unroll, //< Unroll (for/for each only) - has argument mode }; enum class BinaryType @@ -106,6 +107,13 @@ namespace Nz SampleTexture = 2, }; + enum class LoopUnroll + { + Always, + Hint, + Never + }; + enum class MemoryLayout { Std140 diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 70e26c26d..43d515688 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -345,11 +345,11 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; + AttributeValue unroll; std::optional varIndex; std::string varName; ExpressionPtr expression; StatementPtr statement; - bool isConst = false; }; struct NAZARA_SHADER_API MultiStatement : Statement @@ -379,6 +379,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; + AttributeValue unroll; ExpressionPtr condition; StatementPtr body; }; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 3175d8b44..3c81c2590 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -118,6 +118,8 @@ namespace Nz::ShaderAst void SanitizeIdentifier(std::string& identifier); + void Validate(WhileStatement& node); + void Validate(AccessIndexExpression& node); void Validate(AssignExpression& node); void Validate(BinaryExpression& node); diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 5a547e73d..4a2ddd884 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -108,7 +108,6 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression) const; }; - template struct ForEach { inline std::unique_ptr operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const; @@ -173,7 +172,6 @@ namespace Nz::ShaderBuilder constexpr Impl::ConditionalStatement ConditionalStatement; constexpr Impl::Constant Constant; constexpr Impl::Branch ConstBranch; - constexpr Impl::ForEach ConstForEach; constexpr Impl::DeclareConst DeclareConst; constexpr Impl::DeclareFunction DeclareFunction; constexpr Impl::DeclareOption DeclareOption; @@ -181,7 +179,7 @@ namespace Nz::ShaderBuilder constexpr Impl::DeclareVariable DeclareVariable; constexpr Impl::ExpressionStatement ExpressionStatement; constexpr Impl::NoParam Discard; - constexpr Impl::ForEach ForEach; + constexpr Impl::ForEach ForEach; constexpr Impl::Identifier Identifier; constexpr Impl::Intrinsic Intrinsic; constexpr Impl::Multi MultiStatement; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index b646b632c..5e7b81ea2 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -269,11 +269,9 @@ namespace Nz::ShaderBuilder return expressionStatementNode; } - template - std::unique_ptr Impl::ForEach::operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const + std::unique_ptr Impl::ForEach::operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const { auto forEachNode = std::make_unique(); - forEachNode->isConst = Const; forEachNode->expression = std::move(expression); forEachNode->statement = std::move(statement); forEachNode->varName = std::move(varName); diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index ea78c2867..0133967bc 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -88,7 +88,7 @@ namespace Nz::ShaderLang ShaderAst::StatementPtr ParseConstStatement(); ShaderAst::StatementPtr ParseDiscardStatement(); ShaderAst::StatementPtr ParseExternalBlock(std::vector attributes = {}); - ShaderAst::StatementPtr ParseForDeclaration(); + ShaderAst::StatementPtr ParseForDeclaration(std::vector attributes = {}); std::vector ParseFunctionBody(); ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector attributes = {}); ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter(); @@ -99,7 +99,7 @@ namespace Nz::ShaderLang std::vector ParseStatementList(); ShaderAst::StatementPtr ParseStructDeclaration(std::vector attributes = {}); ShaderAst::StatementPtr ParseVariableDeclaration(); - ShaderAst::StatementPtr ParseWhileStatement(); + ShaderAst::StatementPtr ParseWhileStatement(std::vector attributes); // Expressions ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs); diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index 2ab1f4263..ce773886f 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -173,9 +173,9 @@ namespace Nz::ShaderAst StatementPtr AstCloner::Clone(ForEachStatement& node) { auto clone = std::make_unique(); - clone->isConst = node.isConst; clone->expression = CloneExpression(node.expression); clone->statement = CloneStatement(node.statement); + clone->unroll = Clone(node.unroll); return clone; } @@ -208,6 +208,7 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->condition = CloneExpression(node.condition); clone->body = CloneStatement(node.body); + clone->unroll = Clone(node.unroll); return clone; } diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index ca78af1c4..2b277bca7 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -303,7 +303,7 @@ namespace Nz::ShaderAst void AstSerializerBase::Serialize(ForEachStatement& node) { - Value(node.isConst); + Attribute(node.unroll); Value(node.varName); Node(node.expression); Node(node.statement); @@ -328,6 +328,7 @@ namespace Nz::ShaderAst void AstSerializerBase::Serialize(WhileStatement& node) { + Attribute(node.unroll); Node(node.condition); Node(node.body); } diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index ecdbce463..1933980d0 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -843,29 +843,34 @@ namespace Nz::ShaderAst else throw AstError{ "for-each is only supported on arrays and range expressions" }; - if (node.isConst) + AttributeValue unrollValue; + if (node.unroll.HasValue()) { - // Repeat code - auto multi = std::make_unique(); - if (IsArrayType(exprType)) + unrollValue = ComputeAttributeValue(node.unroll); + if (unrollValue.GetResultingValue() == LoopUnroll::Always) { - const ArrayType& arrayType = std::get(exprType); - UInt32 length = arrayType.length.GetResultingValue(); - - for (UInt32 i = 0; i < length; ++i) + // Repeat code + auto multi = std::make_unique(); + if (IsArrayType(exprType)) { - auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i)); - Validate(*accessIndex); + const ArrayType& arrayType = std::get(exprType); + UInt32 length = arrayType.length.GetResultingValue(); - auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex)); - Validate(*elementVariable); + for (UInt32 i = 0; i < length; ++i) + { + auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i)); + Validate(*accessIndex); - multi->statements.emplace_back(std::move(elementVariable)); - multi->statements.emplace_back(CloneStatement(node.statement)); + auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex)); + Validate(*elementVariable); + + multi->statements.emplace_back(std::move(elementVariable)); + multi->statements.emplace_back(CloneStatement(node.statement)); + } } - } - return multi; + return multi; + } } if (m_context->options.reduceLoopsToWhile) @@ -890,6 +895,7 @@ namespace Nz::ShaderAst multi->statements.emplace_back(std::move(counterVariable)); auto whileStatement = std::make_unique(); + whileStatement->unroll = std::move(unrollValue); // While condition auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(length)); @@ -928,6 +934,7 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->expression = std::move(expr); clone->varName = node.varName; + clone->unroll = std::move(unrollValue); PushScope(); { @@ -968,9 +975,15 @@ namespace Nz::ShaderAst MandatoryStatement(node.body); auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + Validate(*clone); - if (GetExpressionType(*clone->condition) != ExpressionType{ PrimitiveType::Boolean }) - throw AstError{ "expected a boolean value" }; + AttributeValue unrollValue; + if (node.unroll.HasValue()) + { + clone->unroll = ComputeAttributeValue(node.unroll); + if (clone->unroll.GetResultingValue() == LoopUnroll::Always) + throw AstError{ "unroll(always) is not yet supported on while" }; + } return clone; } @@ -1350,6 +1363,12 @@ namespace Nz::ShaderAst } } + void SanitizeVisitor::Validate(WhileStatement& node) + { + if (GetExpressionType(*node.condition) != ExpressionType{ PrimitiveType::Boolean }) + throw AstError{ "expected a boolean value" }; + } + void SanitizeVisitor::Validate(AccessIndexExpression& node) { if (node.indices.empty()) diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index d2622e76e..4e50d41fa 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -37,6 +37,7 @@ namespace Nz::ShaderLang { "layout", ShaderAst::AttributeType::Layout }, { "location", ShaderAst::AttributeType::Location }, { "set", ShaderAst::AttributeType::Set }, + { "unroll", ShaderAst::AttributeType::Unroll }, }; std::unordered_map s_entryPoints = { @@ -54,6 +55,12 @@ namespace Nz::ShaderLang { "std140", StructLayout::Std140 } }; + std::unordered_map s_unrollModes = { + { "always", ShaderAst::LoopUnroll::Always }, + { "hint", ShaderAst::LoopUnroll::Hint }, + { "never", ShaderAst::LoopUnroll::Never } + }; + template std::optional BoundCast(U val) { @@ -76,26 +83,33 @@ namespace Nz::ShaderLang } template - void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map& map, ShaderAst::AttributeValue& targetAttribute, ShaderAst::Attribute::Param&& param) + void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map& map, ShaderAst::AttributeValue& targetAttribute, ShaderAst::Attribute::Param&& param, std::optional defaultValue = {}) { if (targetAttribute.HasValue()) throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" }; //FIXME: This should be handled with global values at sanitization stage - if (!param) - throw AttributeError{ "attribute " + std::string(attributeName) + " requires a value" }; + if (param) + { + const ShaderAst::ExpressionPtr& expr = *param; + if (expr->GetType() != ShaderAst::NodeType::IdentifierExpression) + throw AttributeError{ "attribute " + std::string(attributeName) + " can only be an identifier for now" }; - const ShaderAst::ExpressionPtr& expr = *param; - if (expr->GetType() != ShaderAst::NodeType::IdentifierExpression) - throw AttributeError{ "attribute " + std::string(attributeName) + " can only be an identifier for now" }; + const std::string& exprStr = static_cast(*expr).identifier; - const std::string& exprStr = static_cast(*expr).identifier; + auto it = map.find(exprStr); + if (it == map.end()) + throw AttributeError{ ("invalid parameter " + exprStr + " for " + std::string(attributeName) + " attribute").c_str() }; - auto it = map.find(exprStr); - if (it == map.end()) - throw AttributeError{ ("invalid parameter " + exprStr + " for " + std::string(attributeName) + " attribute").c_str() }; + targetAttribute = it->second; + } + else + { + if (!defaultValue) + throw AttributeError{ "attribute " + std::string(attributeName) + " requires a value" }; - targetAttribute = it->second; + targetAttribute = defaultValue.value(); + } } } @@ -473,14 +487,6 @@ namespace Nz::ShaderLang switch (Peek().type) { - case TokenType::For: - { - auto forEach = ParseForDeclaration(); - SafeCast(*forEach).isConst = true; - - return forEach; - } - case TokenType::Identifier: { std::string constName; @@ -598,7 +604,7 @@ namespace Nz::ShaderLang return externalStatement; } - ShaderAst::StatementPtr Parser::ParseForDeclaration() + ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector attributes) { Expect(Advance(), TokenType::For); @@ -610,7 +616,22 @@ namespace Nz::ShaderLang ShaderAst::StatementPtr statement = ParseStatement(); - return ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement)); + auto forEach = ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement)); + + for (auto&& [attributeType, arg] : attributes) + { + switch (attributeType) + { + case ShaderAst::AttributeType::Unroll: + HandleUniqueStringAttribute("unroll", s_unrollModes, forEach->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always)); + break; + + default: + throw AttributeError{ "unhandled attribute for for-each" }; + } + } + + return forEach; } std::vector Parser::ParseFunctionBody() @@ -745,47 +766,74 @@ namespace Nz::ShaderLang ShaderAst::StatementPtr Parser::ParseSingleStatement() { - const Token& token = Peek(); - + std::vector attributes; ShaderAst::StatementPtr statement; - switch (token.type) + do { - case TokenType::Const: - statement = ParseConstStatement(); - break; + const Token& token = Peek(); + switch (token.type) + { + case TokenType::Const: + if (!attributes.empty()) + throw UnexpectedToken{}; - case TokenType::Discard: - statement = ParseDiscardStatement(); - break; + statement = ParseConstStatement(); + break; - case TokenType::For: - statement = ParseForDeclaration(); - break; + case TokenType::Discard: + if (!attributes.empty()) + throw UnexpectedToken{}; - case TokenType::Let: - statement = ParseVariableDeclaration(); - break; + statement = ParseDiscardStatement(); + break; - case TokenType::Identifier: - statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation()); - Expect(Advance(), TokenType::Semicolon); - break; + case TokenType::For: + statement = ParseForDeclaration(std::move(attributes)); + break; - case TokenType::If: - statement = ParseBranchStatement(); - break; + case TokenType::Let: + if (!attributes.empty()) + throw UnexpectedToken{}; - case TokenType::Return: - statement = ParseReturnStatement(); - break; + statement = ParseVariableDeclaration(); + break; - case TokenType::While: - statement = ParseWhileStatement(); - break; + case TokenType::Identifier: + if (!attributes.empty()) + throw UnexpectedToken{}; - default: - throw UnexpectedToken{}; + statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation()); + Expect(Advance(), TokenType::Semicolon); + break; + + case TokenType::If: + if (!attributes.empty()) + throw UnexpectedToken{}; + + statement = ParseBranchStatement(); + break; + + case TokenType::OpenSquareBracket: + assert(attributes.empty()); + attributes = ParseAttributes(); + break; + + case TokenType::Return: + if (!attributes.empty()) + throw UnexpectedToken{}; + + statement = ParseReturnStatement(); + break; + + case TokenType::While: + statement = ParseWhileStatement(std::move(attributes)); + break; + + default: + throw UnexpectedToken{}; + } } + while (!statement); //< small trick to repeat parsing once we got attributes return statement; } @@ -955,7 +1003,7 @@ namespace Nz::ShaderLang return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression)); } - ShaderAst::StatementPtr Parser::ParseWhileStatement() + ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector attributes) { Expect(Advance(), TokenType::While); @@ -967,7 +1015,22 @@ namespace Nz::ShaderLang ShaderAst::StatementPtr body = ParseStatement(); - return ShaderBuilder::While(std::move(condition), std::move(body)); + auto whileStatement = ShaderBuilder::While(std::move(condition), std::move(body)); + + for (auto&& [attributeType, arg] : attributes) + { + switch (attributeType) + { + case ShaderAst::AttributeType::Unroll: + HandleUniqueStringAttribute("unroll", s_unrollModes, whileStatement->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always)); + break; + + default: + throw AttributeError{ "unhandled attribute for while" }; + } + } + + return whileStatement; } ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs) diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 3f4db6c0f..bc7fdaa95 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -1030,7 +1030,28 @@ namespace Nz UInt32 expressionId = EvaluateExpression(node.condition); - m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), SpirvLoopControl::None); + SpirvLoopControl loopControl; + if (node.unroll.HasValue()) + { + switch (node.unroll.GetResultingValue()) + { + case ShaderAst::LoopUnroll::Always: + // it shouldn't be possible to have this attribute as the loop gets unrolled in the sanitizer + throw std::runtime_error("unexpected unroll attribute"); + + case ShaderAst::LoopUnroll::Hint: + loopControl = SpirvLoopControl::Unroll; + break; + + case ShaderAst::LoopUnroll::Never: + loopControl = SpirvLoopControl::DontUnroll; + break; + } + } + else + loopControl = SpirvLoopControl::None; + + m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), loopControl); m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId()); m_currentBlock = &bodyBlock; diff --git a/tests/Engine/Shader/Const.cpp b/tests/Engine/Shader/Const.cpp index 28cf29ef6..81a74b77e 100644 --- a/tests/Engine/Shader/Const.cpp +++ b/tests/Engine/Shader/Const.cpp @@ -110,7 +110,7 @@ fn main() } } - WHEN("using const for-each") + WHEN("using [unroll] attribute on for-each") { std::string_view sourceCode = R"( const LightCount = 3; @@ -136,7 +136,9 @@ external fn main() { let color = (0.0).xxxx; - const for light in data.lights + + [unroll] + for light in data.lights { color += light.color; }