From 756fd773a9fca857547eafeeba68d9cf6ba7d92e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Thu, 6 Jan 2022 20:38:55 +0100 Subject: [PATCH] Shader: Add support for numerical fors --- include/Nazara/Shader/Ast/AstCloner.hpp | 1 + include/Nazara/Shader/Ast/AstCompare.hpp | 1 + include/Nazara/Shader/Ast/AstCompare.inl | 23 ++ include/Nazara/Shader/Ast/AstNodeList.hpp | 1 + .../Nazara/Shader/Ast/AstRecursiveVisitor.hpp | 1 + include/Nazara/Shader/Ast/AstSerializer.hpp | 1 + include/Nazara/Shader/Ast/Nodes.hpp | 14 ++ include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 2 + include/Nazara/Shader/GlslWriter.hpp | 2 +- include/Nazara/Shader/LangWriter.hpp | 7 +- include/Nazara/Shader/ShaderBuilder.hpp | 7 + include/Nazara/Shader/ShaderBuilder.inl | 23 ++ include/Nazara/Shader/ShaderLangTokenList.hpp | 4 +- src/Nazara/Shader/Ast/AstCloner.cpp | 14 ++ src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp | 15 ++ src/Nazara/Shader/Ast/AstSerializer.cpp | 10 + src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 232 ++++++++++++++++-- src/Nazara/Shader/GlslWriter.cpp | 63 ++--- src/Nazara/Shader/LangWriter.cpp | 176 +++++++++---- src/Nazara/Shader/ShaderLangLexer.cpp | 2 +- src/Nazara/Shader/ShaderLangParser.cpp | 69 ++++-- src/Nazara/Shader/SpirvWriter.cpp | 1 + tests/Engine/Shader/Const.cpp | 57 +++++ tests/Engine/Shader/Loops.cpp | 154 ++++++++++++ 24 files changed, 746 insertions(+), 134 deletions(-) diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index b898efc4f..b24c97855 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -64,6 +64,7 @@ namespace Nz::ShaderAst virtual StatementPtr Clone(DeclareVariableStatement& node); virtual StatementPtr Clone(DiscardStatement& node); virtual StatementPtr Clone(ExpressionStatement& node); + virtual StatementPtr Clone(ForStatement& node); virtual StatementPtr Clone(ForEachStatement& node); virtual StatementPtr Clone(MultiStatement& node); virtual StatementPtr Clone(NoOpStatement& node); diff --git a/include/Nazara/Shader/Ast/AstCompare.hpp b/include/Nazara/Shader/Ast/AstCompare.hpp index 239d6ba5c..0683a9a7b 100644 --- a/include/Nazara/Shader/Ast/AstCompare.hpp +++ b/include/Nazara/Shader/Ast/AstCompare.hpp @@ -54,6 +54,7 @@ namespace Nz::ShaderAst inline bool Compare(const DeclareVariableStatement& lhs, const DeclareVariableStatement& rhs); inline bool Compare(const DiscardStatement& lhs, const DiscardStatement& rhs); inline bool Compare(const ExpressionStatement& lhs, const ExpressionStatement& rhs); + inline bool Compare(const ForStatement& lhs, const ForStatement& rhs); inline bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs); inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs); inline bool Compare(const NoOpStatement& lhs, const NoOpStatement& rhs); diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index c44753158..7c10bb09d 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -458,6 +458,29 @@ namespace Nz::ShaderAst return true; } + bool Compare(const ForStatement& lhs, const ForStatement& rhs) + { + if (!Compare(lhs.varName, rhs.varName)) + return false; + + if (!Compare(lhs.unroll, rhs.unroll)) + return false; + + if (!Compare(lhs.fromExpr, rhs.fromExpr)) + return false; + + if (!Compare(lhs.toExpr, rhs.toExpr)) + return false; + + if (!Compare(lhs.stepExpr, rhs.stepExpr)) + return false; + + if (!Compare(lhs.statement, rhs.statement)) + return false; + + return true; + } + bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs) { if (!Compare(lhs.varName, rhs.varName)) diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index 37cb4450c..d81e0d9ee 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -52,6 +52,7 @@ NAZARA_SHADERAST_STATEMENT(DeclareOptionStatement) NAZARA_SHADERAST_STATEMENT(DeclareStructStatement) NAZARA_SHADERAST_STATEMENT(DeclareVariableStatement) NAZARA_SHADERAST_STATEMENT(DiscardStatement) +NAZARA_SHADERAST_STATEMENT(ForStatement) NAZARA_SHADERAST_STATEMENT(ForEachStatement) NAZARA_SHADERAST_STATEMENT(ExpressionStatement) NAZARA_SHADERAST_STATEMENT(MultiStatement) diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index 45645a199..992d4233a 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -46,6 +46,7 @@ namespace Nz::ShaderAst void Visit(DeclareVariableStatement& node) override; void Visit(DiscardStatement& node) override; void Visit(ExpressionStatement& node) override; + void Visit(ForStatement& node) override; void Visit(ForEachStatement& node) override; void Visit(MultiStatement& node) override; void Visit(NoOpStatement& node) override; diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index e3751f89f..c1beb3edc 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -49,6 +49,7 @@ namespace Nz::ShaderAst void Serialize(DeclareVariableStatement& node); void Serialize(DiscardStatement& node); void Serialize(ExpressionStatement& node); + void Serialize(ForStatement& node); void Serialize(ForEachStatement& node); void Serialize(MultiStatement& node); void Serialize(NoOpStatement& node); diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 43d515688..2987cd980 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -340,6 +340,20 @@ namespace Nz::ShaderAst ExpressionPtr expression; }; + struct NAZARA_SHADER_API ForStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + AttributeValue unroll; + std::optional varIndex; + std::string varName; + ExpressionPtr fromExpr; + ExpressionPtr stepExpr; + ExpressionPtr toExpr; + StatementPtr statement; + }; + struct NAZARA_SHADER_API ForEachStatement : Statement { NodeType GetType() const override; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 3c81c2590..e9383f1ca 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -45,6 +45,7 @@ namespace Nz::ShaderAst bool removeOptionDeclaration = false; bool removeScalarSwizzling = false; bool splitMultipleBranches = false; + bool useIdentifierAccessesForStructs = true; }; private: @@ -78,6 +79,7 @@ namespace Nz::ShaderAst StatementPtr Clone(DeclareVariableStatement& node) override; StatementPtr Clone(DiscardStatement& node) override; StatementPtr Clone(ExpressionStatement& node) override; + StatementPtr Clone(ForStatement& node) override; StatementPtr Clone(ForEachStatement& node) override; StatementPtr Clone(MultiStatement& node) override; StatementPtr Clone(WhileStatement& node) override; diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 5d019b807..71f7ae2ce 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -66,7 +66,6 @@ namespace Nz template void Append(const T1& firstParam, const T2& secondParam, Args&&... params); void AppendCommentSection(const std::string& section); void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward = false); - void AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers); void AppendHeader(); void AppendLine(const std::string& txt = {}); template void AppendLine(Args&&... params); @@ -84,6 +83,7 @@ namespace Nz void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); + void Visit(ShaderAst::AccessIdentifierExpression& node) override; void Visit(ShaderAst::AccessIndexExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override; diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index 3ff7e664c..9c7f4123a 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -45,6 +45,7 @@ namespace Nz struct LayoutAttribute; struct LocationAttribute; struct SetAttribute; + struct UnrollAttribute; void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ExpressionType& type); @@ -68,9 +69,9 @@ namespace Nz void AppendAttribute(EntryAttribute entry); void AppendAttribute(LayoutAttribute layout); void AppendAttribute(LocationAttribute location); - void AppendAttribute(SetAttribute location); + void AppendAttribute(SetAttribute set); + void AppendAttribute(UnrollAttribute unroll); void AppendCommentSection(const std::string& section); - void AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers); void AppendHeader(); void AppendLine(const std::string& txt = {}); template void AppendLine(Args&&... params); @@ -85,6 +86,7 @@ namespace Nz void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); + void Visit(ShaderAst::AccessIdentifierExpression& node) override; void Visit(ShaderAst::AccessIndexExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override; @@ -107,6 +109,7 @@ namespace Nz void Visit(ShaderAst::DeclareVariableStatement& node) override; void Visit(ShaderAst::DiscardStatement& node) override; void Visit(ShaderAst::ExpressionStatement& node) override; + void Visit(ShaderAst::ForStatement& node) override; void Visit(ShaderAst::ForEachStatement& node) override; void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 4a2ddd884..84b121a95 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -108,6 +108,12 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression) const; }; + struct For + { + inline std::unique_ptr operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const; + inline std::unique_ptr operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const; + }; + struct ForEach { inline std::unique_ptr operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const; @@ -179,6 +185,7 @@ namespace Nz::ShaderBuilder constexpr Impl::DeclareVariable DeclareVariable; constexpr Impl::ExpressionStatement ExpressionStatement; constexpr Impl::NoParam Discard; + constexpr Impl::For For; constexpr Impl::ForEach ForEach; constexpr Impl::Identifier Identifier; constexpr Impl::Intrinsic Intrinsic; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 5e7b81ea2..0bfc90187 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -269,6 +269,29 @@ namespace Nz::ShaderBuilder return expressionStatementNode; } + inline std::unique_ptr Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const + { + auto forNode = std::make_unique(); + forNode->fromExpr = std::move(fromExpression); + forNode->statement = std::move(statement); + forNode->toExpr = std::move(toExpression); + forNode->varName = std::move(varName); + + return forNode; + } + + inline std::unique_ptr Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const + { + auto forNode = std::make_unique(); + forNode->fromExpr = std::move(fromExpression); + forNode->statement = std::move(statement); + forNode->stepExpr = std::move(stepExpression); + forNode->toExpr = std::move(toExpression); + forNode->varName = std::move(varName); + + return forNode; + } + std::unique_ptr Impl::ForEach::operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const { auto forEachNode = std::make_unique(); diff --git a/include/Nazara/Shader/ShaderLangTokenList.hpp b/include/Nazara/Shader/ShaderLangTokenList.hpp index 7234d5509..22a68ded3 100644 --- a/include/Nazara/Shader/ShaderLangTokenList.hpp +++ b/include/Nazara/Shader/ShaderLangTokenList.hpp @@ -12,6 +12,7 @@ #define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X) #endif +NAZARA_SHADERLANG_TOKEN(Arrow) NAZARA_SHADERLANG_TOKEN(Assign) NAZARA_SHADERLANG_TOKEN(BoolFalse) NAZARA_SHADERLANG_TOKEN(BoolTrue) @@ -33,7 +34,6 @@ NAZARA_SHADERLANG_TOKEN(External) NAZARA_SHADERLANG_TOKEN(FloatingPointValue) NAZARA_SHADERLANG_TOKEN(For) NAZARA_SHADERLANG_TOKEN(FunctionDeclaration) -NAZARA_SHADERLANG_TOKEN(FunctionReturn) NAZARA_SHADERLANG_TOKEN(GreaterThan) NAZARA_SHADERLANG_TOKEN(GreaterThanEqual) NAZARA_SHADERLANG_TOKEN(IntegerValue) @@ -59,8 +59,8 @@ NAZARA_SHADERLANG_TOKEN(OpenCurlyBracket) NAZARA_SHADERLANG_TOKEN(OpenSquareBracket) NAZARA_SHADERLANG_TOKEN(OpenParenthesis) NAZARA_SHADERLANG_TOKEN(Option) -NAZARA_SHADERLANG_TOKEN(Semicolon) NAZARA_SHADERLANG_TOKEN(Return) +NAZARA_SHADERLANG_TOKEN(Semicolon) NAZARA_SHADERLANG_TOKEN(Struct) NAZARA_SHADERLANG_TOKEN(While) diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index ce773886f..9b4ce5b30 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -170,12 +170,26 @@ namespace Nz::ShaderAst return clone; } + StatementPtr AstCloner::Clone(ForStatement& node) + { + auto clone = std::make_unique(); + clone->fromExpr = CloneExpression(node.fromExpr); + clone->stepExpr = CloneExpression(node.stepExpr); + clone->toExpr = CloneExpression(node.toExpr); + clone->statement = CloneStatement(node.statement); + clone->unroll = Clone(node.unroll); + clone->varName = node.varName; + + return clone; + } + StatementPtr AstCloner::Clone(ForEachStatement& node) { auto clone = std::make_unique(); clone->expression = CloneExpression(node.expression); clone->statement = CloneStatement(node.statement); clone->unroll = Clone(node.unroll); + clone->varName = node.varName; return clone; } diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index e6b1368bc..061ef2d52 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -161,6 +161,21 @@ namespace Nz::ShaderAst node.expression->Visit(*this); } + void AstRecursiveVisitor::Visit(ForStatement& node) + { + if (node.fromExpr) + node.fromExpr->Visit(*this); + + if (node.toExpr) + node.toExpr->Visit(*this); + + if (node.stepExpr) + node.stepExpr->Visit(*this); + + if (node.statement) + node.statement->Visit(*this); + } + void AstRecursiveVisitor::Visit(ForEachStatement& node) { if (node.expression) diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 2b277bca7..e5a453f75 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -301,6 +301,16 @@ namespace Nz::ShaderAst Node(node.expression); } + void AstSerializerBase::Serialize(ForStatement& node) + { + Attribute(node.unroll); + Value(node.varName); + Node(node.fromExpr); + Node(node.toExpr); + Node(node.stepExpr); + Node(node.statement); + } + void AstSerializerBase::Serialize(ForEachStatement& node) { Attribute(node.unroll); diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 1933980d0..31d8b9ff5 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -163,19 +163,6 @@ namespace Nz::ShaderAst const ExpressionType& exprType = GetExpressionType(*indexedExpr); if (IsStructType(exprType)) { - // Transform to AccessIndexExpression - AccessIndexExpression* accessIndexPtr; - if (indexedExpr->GetType() != NodeType::AccessIndexExpression) - { - std::unique_ptr accessIndex = std::make_unique(); - accessIndex->expr = std::move(indexedExpr); - - accessIndexPtr = accessIndex.get(); - indexedExpr = std::move(accessIndex); - } - else - accessIndexPtr = static_cast(indexedExpr.get()); - std::size_t structIndex = ResolveStruct(exprType); assert(structIndex < m_context->structs.size()); const StructDescription* s = m_context->structs[structIndex]; @@ -200,8 +187,42 @@ namespace Nz::ShaderAst if (!fieldPtr) throw AstError{ "unknown field " + identifier }; - accessIndexPtr->indices.push_back(ShaderBuilder::Constant(fieldIndex)); - accessIndexPtr->cachedExpressionType = ResolveType(fieldPtr->type); + if (m_context->options.useIdentifierAccessesForStructs) + { + // Use a AccessIdentifierExpression + AccessIdentifierExpression* accessIdentifierPtr; + if (indexedExpr->GetType() != NodeType::AccessIdentifierExpression) + { + std::unique_ptr accessIndex = std::make_unique(); + accessIndex->expr = std::move(indexedExpr); + + accessIdentifierPtr = accessIndex.get(); + indexedExpr = std::move(accessIndex); + } + else + accessIdentifierPtr = static_cast(indexedExpr.get()); + + accessIdentifierPtr->identifiers.push_back(s->members[fieldIndex].name); + accessIdentifierPtr->cachedExpressionType = ResolveType(fieldPtr->type); + } + else + { + // Transform to AccessIndexExpression + AccessIndexExpression* accessIndexPtr; + if (indexedExpr->GetType() != NodeType::AccessIndexExpression) + { + std::unique_ptr accessIndex = std::make_unique(); + accessIndex->expr = std::move(indexedExpr); + + accessIndexPtr = accessIndex.get(); + indexedExpr = std::move(accessIndex); + } + else + accessIndexPtr = static_cast(indexedExpr.get()); + + accessIndexPtr->indices.push_back(ShaderBuilder::Constant(fieldIndex)); + accessIndexPtr->cachedExpressionType = ResolveType(fieldPtr->type); + } } else if (IsPrimitiveType(exprType) || IsVectorType(exprType)) { @@ -269,6 +290,8 @@ namespace Nz::ShaderAst auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); Validate(*clone); + // TODO: Handle AccessIndex on structs with m_context->options.useIdentifierAccessesForStructs + return clone; } @@ -829,9 +852,180 @@ namespace Nz::ShaderAst return AstCloner::Clone(node); } + StatementPtr SanitizeVisitor::Clone(ForStatement& node) + { + if (node.varName.empty()) + throw AstError{ "numerical for variable name cannot be empty" }; + + auto fromExpr = CloneExpression(MandatoryExpr(node.fromExpr)); + auto stepExpr = CloneExpression(node.stepExpr); + auto toExpr = CloneExpression(MandatoryExpr(node.toExpr)); + + MandatoryStatement(node.statement); + + const ExpressionType& fromExprType = GetExpressionType(*fromExpr); + if (!IsPrimitiveType(fromExprType)) + throw AstError{ "numerical for from expression must be an integer or unsigned integer" }; + + PrimitiveType fromType = std::get(fromExprType); + if (fromType != PrimitiveType::Int32 && fromType != PrimitiveType::UInt32) + throw AstError{ "numerical for from expression must be an integer or unsigned integer" }; + + const ExpressionType& toExprType = GetExpressionType(*fromExpr); + if (toExprType != fromExprType) + throw AstError{ "numerical for to expression type must match from expression type" }; + + if (stepExpr) + { + const ExpressionType& stepExprType = GetExpressionType(*fromExpr); + if (stepExprType != fromExprType) + throw AstError{ "numerical for step expression type must match from expression type" }; + } + + + AttributeValue unrollValue; + if (node.unroll.HasValue()) + { + unrollValue = ComputeAttributeValue(node.unroll); + if (unrollValue.GetResultingValue() == LoopUnroll::Always) + { + PushScope(); + + auto multi = std::make_unique(); + + auto Unroll = [&](auto dummy) + { + using T = std::decay_t; + + T counter = std::get(ComputeConstantValue(*fromExpr)); + T to = std::get(ComputeConstantValue(*toExpr)); + T step = (stepExpr) ? std::get(ComputeConstantValue(*stepExpr)) : T(1); + + for (; counter < to; counter += step) + { + auto var = ShaderBuilder::DeclareVariable(node.varName, ShaderBuilder::Constant(counter)); + Validate(*var); + multi->statements.emplace_back(std::move(var)); + + multi->statements.emplace_back(CloneStatement(node.statement)); + } + }; + + switch (fromType) + { + case PrimitiveType::Int32: + Unroll(Int32{}); + break; + + case PrimitiveType::UInt32: + Unroll(UInt32{}); + break; + + default: + throw AstError{ "internal error" }; + } + + PopScope(); + + return multi; + } + } + + if (m_context->options.reduceLoopsToWhile) + { + PushScope(); + + auto multi = std::make_unique(); + + // Counter variable + auto counterVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(fromExpr)); + Validate(*counterVariable); + + std::size_t counterVarIndex = counterVariable->varIndex.value(); + multi->statements.emplace_back(std::move(counterVariable)); + + // Target variable + auto targetVariable = ShaderBuilder::DeclareVariable("to", std::move(toExpr)); + Validate(*targetVariable); + + std::size_t targetVarIndex = targetVariable->varIndex.value(); + multi->statements.emplace_back(std::move(targetVariable)); + + // Step variable + std::optional stepVarIndex; + + if (stepExpr) + { + auto stepVariable = ShaderBuilder::DeclareVariable("step", std::move(stepExpr)); + Validate(*stepVariable); + + stepVarIndex = stepVariable->varIndex; + multi->statements.emplace_back(std::move(stepVariable)); + } + + // While + auto whileStatement = std::make_unique(); + whileStatement->unroll = std::move(unrollValue); + + // While condition + auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, fromType), ShaderBuilder::Variable(targetVarIndex, fromType)); + Validate(*condition); + + whileStatement->condition = std::move(condition); + + // While body + auto body = std::make_unique(); + body->statements.reserve(2); + + body->statements.emplace_back(CloneStatement(node.statement)); + + ExpressionPtr incrExpr; + if (stepVarIndex) + incrExpr = ShaderBuilder::Variable(*stepVarIndex, fromType); + else + incrExpr = (fromType == PrimitiveType::Int32) ? ShaderBuilder::Constant(1) : ShaderBuilder::Constant(1u); + + auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, fromType), std::move(incrExpr)); + Validate(*incrCounter); + + body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); + + whileStatement->body = std::move(body); + + multi->statements.emplace_back(std::move(whileStatement)); + + PopScope(); + + return multi; + } + else + { + auto clone = std::make_unique(); + clone->fromExpr = std::move(fromExpr); + clone->stepExpr = std::move(stepExpr); + clone->toExpr = std::move(toExpr); + clone->varName = node.varName; + clone->unroll = std::move(unrollValue); + + PushScope(); + { + clone->varIndex = RegisterVariable(node.varName, fromExprType); + clone->statement = CloneStatement(node.statement); + } + PopScope(); + + SanitizeIdentifier(clone->varName); + + return clone; + } + } + StatementPtr SanitizeVisitor::Clone(ForEachStatement& node) { - auto expr = CloneExpression(node.expression); + auto expr = CloneExpression(MandatoryExpr(node.expression)); + + if (node.varName.empty()) + throw AstError{ "for-each variable name cannot be empty"}; const ExpressionType& exprType = GetExpressionType(*expr); ExpressionType innerType; @@ -849,6 +1043,8 @@ namespace Nz::ShaderAst unrollValue = ComputeAttributeValue(node.unroll); if (unrollValue.GetResultingValue() == LoopUnroll::Always) { + PushScope(); + // Repeat code auto multi = std::make_unique(); if (IsArrayType(exprType)) @@ -869,6 +1065,8 @@ namespace Nz::ShaderAst } } + PopScope(); + return multi; } } @@ -943,7 +1141,7 @@ namespace Nz::ShaderAst } PopScope(); - SanitizeIdentifier(node.varName); + SanitizeIdentifier(clone->varName); return clone; } diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 9a7b46e40..766a79ff4 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -390,41 +390,6 @@ namespace Nz AppendLine((forward) ? ");" : ")"); } - void GlslWriter::AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers) - { - ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); - - assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantValueExpression); - auto& constantValue = static_cast(**memberIndices); - Int32 index = std::get(constantValue.value); - assert(index >= 0); - - auto it = structDesc->members.begin(); - for (; it != structDesc->members.end(); ++it) - { - const auto& member = *it; - if (member.cond.HasValue() && !member.cond.GetResultingValue()) - continue; - - if (index == 0) - break; - - index--; - } - - assert(it != structDesc->members.end()); - const auto& member = *it; - - Append("."); - Append(member.name); - - if (remainingMembers > 1) - { - assert(IsStructType(member.type)); - AppendField(std::get(member.type).structIndex, memberIndices + 1, remainingMembers - 1); - } - } - void GlslWriter::AppendHeader() { unsigned int glslVersion; @@ -734,24 +699,30 @@ namespace Nz Append(")"); } + void GlslWriter::Visit(ShaderAst::AccessIdentifierExpression& node) + { + Visit(node.expr, true); + + const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); + assert(IsStructType(exprType)); + + for (const std::string& identifier : node.identifiers) + Append(".", identifier); + } + void GlslWriter::Visit(ShaderAst::AccessIndexExpression& node) { Visit(node.expr, true); const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); + assert(!IsStructType(exprType)); - // For structs, convert indices to field names - if (IsStructType(exprType)) - AppendField(std::get(exprType).structIndex, node.indices.data(), node.indices.size()); - else + // Array access + for (ShaderAst::ExpressionPtr& expr : node.indices) { - // Array access - for (ShaderAst::ExpressionPtr& expr : node.indices) - { - Append("["); - Visit(expr); - Append("]"); - } + Append("["); + Visit(expr); + Append("]"); } } diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 69e1b9eae..e6a486637 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -83,6 +83,13 @@ namespace Nz inline bool HasValue() const { return setIndex.HasValue(); } }; + struct LangWriter::UnrollAttribute + { + const ShaderAst::AttributeValue& unroll; + + inline bool HasValue() const { return unroll.HasValue(); } + }; + struct LangWriter::State { const States* states = nullptr; @@ -103,10 +110,7 @@ namespace Nz m_currentState = nullptr; }); - ShaderAst::SanitizeVisitor::Options options; - options.removeOptionDeclaration = false; - - ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader, options); + ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader); AppendHeader(); @@ -277,10 +281,14 @@ namespace Nz if (!binding.HasValue()) return; + Append("binding("); + if (binding.bindingIndex.IsResultingValue()) - Append("binding(", binding.bindingIndex.GetResultingValue(), ")"); + Append(binding.bindingIndex.GetResultingValue()); else binding.bindingIndex.GetExpression()->Visit(*this); + + Append(")"); } void LangWriter::AppendAttribute(BuiltinAttribute builtin) @@ -288,25 +296,29 @@ namespace Nz if (!builtin.HasValue()) return; + Append("builtin("); + if (builtin.builtin.IsResultingValue()) { switch (builtin.builtin.GetResultingValue()) { case ShaderAst::BuiltinEntry::FragCoord: - Append("builtin(fragcoord)"); + Append("fragcoord"); break; case ShaderAst::BuiltinEntry::FragDepth: - Append("builtin(fragdepth)"); + Append("fragdepth"); break; case ShaderAst::BuiltinEntry::VertexPosition: - Append("builtin(position)"); + Append("position"); break; } } else builtin.builtin.GetExpression()->Visit(*this); + + Append(")"); } void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite) @@ -314,29 +326,33 @@ namespace Nz if (!depthWrite.HasValue()) return; + Append("depth_write("); + if (depthWrite.writeMode.IsResultingValue()) { switch (depthWrite.writeMode.GetResultingValue()) { case ShaderAst::DepthWriteMode::Greater: - Append("depth_write(greater)"); + Append("greater"); break; case ShaderAst::DepthWriteMode::Less: - Append("depth_write(less)"); + Append("less"); break; case ShaderAst::DepthWriteMode::Replace: - Append("depth_write(replace)"); + Append("replace"); break; case ShaderAst::DepthWriteMode::Unchanged: - Append("depth_write(unchanged)"); + Append("unchanged"); break; } } else depthWrite.writeMode.GetExpression()->Visit(*this); + + Append(")"); } void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests) @@ -344,15 +360,19 @@ namespace Nz if (!earlyFragmentTests.HasValue()) return; + Append("early_fragment_tests("); + if (earlyFragmentTests.earlyFragmentTests.IsResultingValue()) { if (earlyFragmentTests.earlyFragmentTests.GetResultingValue()) - Append("early_fragment_tests(true)"); + Append("true"); else - Append("early_fragment_tests(false)"); + Append("false"); } else earlyFragmentTests.earlyFragmentTests.GetExpression()->Visit(*this); + + Append(")"); } void LangWriter::AppendAttribute(EntryAttribute entry) @@ -360,21 +380,25 @@ namespace Nz if (!entry.HasValue()) return; + Append("entry("); + if (entry.stageType.IsResultingValue()) { switch (entry.stageType.GetResultingValue()) { case ShaderStageType::Fragment: - Append("entry(frag)"); + Append("frag"); break; case ShaderStageType::Vertex: - Append("entry(vert)"); + Append("vert"); break; } } else entry.stageType.GetExpression()->Visit(*this); + + Append(")"); } void LangWriter::AppendAttribute(LayoutAttribute entry) @@ -382,17 +406,19 @@ namespace Nz if (!entry.HasValue()) return; + Append("layout("); if (entry.layout.IsResultingValue()) { switch (entry.layout.GetResultingValue()) { case StructLayout::Std140: - Append("layout(std140)"); + Append("std140"); break; } } else entry.layout.GetExpression()->Visit(*this); + Append(")"); } void LangWriter::AppendAttribute(LocationAttribute location) @@ -400,10 +426,14 @@ namespace Nz if (!location.HasValue()) return; + Append("location("); + if (location.locationIndex.IsResultingValue()) - Append("location(", location.locationIndex.GetResultingValue(), ")"); + Append(location.locationIndex.GetResultingValue()); else location.locationIndex.GetExpression()->Visit(*this); + + Append(")"); } void LangWriter::AppendAttribute(SetAttribute set) @@ -411,10 +441,45 @@ namespace Nz if (!set.HasValue()) return; + Append("set("); + if (set.setIndex.IsResultingValue()) - Append("set(", set.setIndex.GetResultingValue(), ")"); + Append(set.setIndex.GetResultingValue()); else set.setIndex.GetExpression()->Visit(*this); + + Append(")"); + } + + void LangWriter::AppendAttribute(UnrollAttribute unroll) + { + if (!unroll.HasValue()) + return; + + Append("unroll("); + + if (unroll.unroll.IsResultingValue()) + { + switch (unroll.unroll.GetResultingValue()) + { + case ShaderAst::LoopUnroll::Always: + Append("always"); + break; + + case ShaderAst::LoopUnroll::Hint: + Append("hint"); + break; + + case ShaderAst::LoopUnroll::Never: + Append("never"); + break; + + default: + break; + } + } + else + unroll.unroll.GetExpression()->Visit(*this); } void LangWriter::AppendCommentSection(const std::string& section) @@ -426,26 +491,6 @@ namespace Nz AppendLine(); } - void LangWriter::AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers) - { - ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); - - assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantValueExpression); - auto& constantValue = static_cast(**memberIndices); - Int32 index = std::get(constantValue.value); - - const auto& member = structDesc->members[index]; - - Append("."); - Append(member.name); - - if (remainingMembers > 1) - { - assert(IsStructType(member.type)); - AppendField(std::get(member.type).structIndex, memberIndices + 1, remainingMembers - 1); - } - } - void LangWriter::AppendLine(const std::string& txt) { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); @@ -526,24 +571,30 @@ namespace Nz Append(")"); } + void LangWriter::Visit(ShaderAst::AccessIdentifierExpression& node) + { + Visit(node.expr, true); + + const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); + assert(IsStructType(exprType)); + + for (const std::string& identifier : node.identifiers) + Append(".", identifier); + } + void LangWriter::Visit(ShaderAst::AccessIndexExpression& node) { Visit(node.expr, true); const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); + assert(!IsStructType(exprType)); - // For structs, convert indices to field names - if (IsStructType(exprType)) - AppendField(std::get(exprType).structIndex, node.indices.data(), node.indices.size()); - else + // Array access + for (ShaderAst::ExpressionPtr& expr : node.indices) { - // Array access - for (ShaderAst::ExpressionPtr& expr : node.indices) - { - Append("["); - Visit(expr); - Append("]"); - } + Append("["); + expr->Visit(*this); + Append("]"); } } @@ -826,11 +877,36 @@ namespace Nz Append(";"); } + void LangWriter::Visit(ShaderAst::ForStatement& node) + { + assert(node.varIndex); + RegisterVariable(*node.varIndex, node.varName); + + AppendAttributes(true, UnrollAttribute{ node.unroll }); + Append("for ", node.varName, " in "); + node.fromExpr->Visit(*this); + Append(" -> "); + node.toExpr->Visit(*this); + + if (node.stepExpr) + { + Append(" : "); + node.stepExpr->Visit(*this); + } + + AppendLine(); + + EnterScope(); + node.statement->Visit(*this); + LeaveScope(); + } + void LangWriter::Visit(ShaderAst::ForEachStatement& node) { assert(node.varIndex); RegisterVariable(*node.varIndex, node.varName); + AppendAttributes(true, UnrollAttribute{ node.unroll }); Append("for ", node.varName, " in "); node.expression->Visit(*this); AppendLine(); diff --git a/src/Nazara/Shader/ShaderLangLexer.cpp b/src/Nazara/Shader/ShaderLangLexer.cpp index 5015d3935..70955d6e4 100644 --- a/src/Nazara/Shader/ShaderLangLexer.cpp +++ b/src/Nazara/Shader/ShaderLangLexer.cpp @@ -113,7 +113,7 @@ namespace Nz::ShaderLang if (next == '>') { currentPos++; - tokenType = TokenType::FunctionReturn; + tokenType = TokenType::Arrow; break; } else if (next == '=') diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 4e50d41fa..210d20c01 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -614,24 +614,63 @@ namespace Nz::ShaderLang ShaderAst::ExpressionPtr expr = ParseExpression(); - ShaderAst::StatementPtr statement = ParseStatement(); - - auto forEach = ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement)); - - for (auto&& [attributeType, arg] : attributes) + if (Peek().type == TokenType::Arrow) { - switch (attributeType) + // Numerical for + Consume(); + + ShaderAst::ExpressionPtr toExpr = ParseExpression(); + + ShaderAst::ExpressionPtr stepExpr; + if (Peek().type == TokenType::Colon) { - case ShaderAst::AttributeType::Unroll: - HandleUniqueStringAttribute("unroll", s_unrollModes, forEach->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always)); - break; - - default: - throw AttributeError{ "unhandled attribute for for-each" }; + Consume(); + stepExpr = ParseExpression(); } - } - return forEach; + ShaderAst::StatementPtr statement = ParseStatement(); + + auto forNode = ShaderBuilder::For(std::move(varName), std::move(expr), std::move(toExpr), std::move(stepExpr), std::move(statement)); + + // TODO: Deduplicate code + for (auto&& [attributeType, arg] : attributes) + { + switch (attributeType) + { + case ShaderAst::AttributeType::Unroll: + HandleUniqueStringAttribute("unroll", s_unrollModes, forNode->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always)); + break; + + default: + throw AttributeError{ "unhandled attribute for numerical for" }; + } + } + + return forNode; + } + else + { + // For each + ShaderAst::StatementPtr statement = ParseStatement(); + + auto forEachNode = ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement)); + + // TODO: Deduplicate code + for (auto&& [attributeType, arg] : attributes) + { + switch (attributeType) + { + case ShaderAst::AttributeType::Unroll: + HandleUniqueStringAttribute("unroll", s_unrollModes, forEachNode->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always)); + break; + + default: + throw AttributeError{ "unhandled attribute for for-each" }; + } + } + + return forEachNode; + } } std::vector Parser::ParseFunctionBody() @@ -668,7 +707,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::ClosingParenthesis); ShaderAst::ExpressionType returnType; - if (Peek().type == TokenType::FunctionReturn) + if (Peek().type == TokenType::Arrow) { Consume(); returnType = ParseType(); diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index ce56a2b01..e16ecfa88 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -491,6 +491,7 @@ namespace Nz options.removeCompoundAssignments = true; options.removeOptionDeclaration = true; options.splitMultipleBranches = true; + options.useIdentifierAccessesForStructs = false; sanitizedAst = ShaderAst::Sanitize(shader, options); targetAst = sanitizedAst.get(); diff --git a/tests/Engine/Shader/Const.cpp b/tests/Engine/Shader/Const.cpp index 81a74b77e..7bb3634cf 100644 --- a/tests/Engine/Shader/Const.cpp +++ b/tests/Engine/Shader/Const.cpp @@ -109,6 +109,63 @@ fn main() )"); } } + + WHEN("using [unroll] attribute on numerical for") + { + std::string_view sourceCode = R"( +const LightCount = 3; + +[layout(std140)] +struct Light +{ + color: vec4 +} + +[layout(std140)] +struct LightData +{ + lights: [Light; LightCount] +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + let color = (0.0).xxxx; + + [unroll] + for i in 0 -> 10 : 2 + { + color += data.lights[i].color; + } +} +)"; + + Nz::ShaderAst::StatementPtr shader; + REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); + + ExpectOutput(*shader, {}, R"( +[entry(frag)] +fn main() +{ + let color: vec4 = (0.000000).xxxx; + let i: i32 = 0; + color += data.lights[i].color; + let i: i32 = 2; + color += data.lights[i].color; + let i: i32 = 4; + color += data.lights[i].color; + let i: i32 = 6; + color += data.lights[i].color; + let i: i32 = 8; + color += data.lights[i].color; +} +)"); + } WHEN("using [unroll] attribute on for-each") { diff --git a/tests/Engine/Shader/Loops.cpp b/tests/Engine/Shader/Loops.cpp index 6ba55d163..182e21873 100644 --- a/tests/Engine/Shader/Loops.cpp +++ b/tests/Engine/Shader/Loops.cpp @@ -88,6 +88,160 @@ OpStore OpBranch OpLabel OpReturn +OpFunctionEnd)"); + } + + WHEN("using a for range") + { + std::string_view nzslSource = R"( +[entry(frag)] +fn main() +{ + let x = 0; + for v in 0 -> 10 + { + x += v; + } +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + + ExpectGLSL(*shader, R"( +void main() +{ + int x = 0; + int v = 0; + int to = 10; + while (v < to) + { + x += v; + v += 1; + } + +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let x: i32 = 0; + for v in 0 -> 10 + { + x += v; + } + +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpVariable +OpVariable +OpStore +OpStore +OpStore +OpBranch +OpLabel +OpLoad +OpLoad +OpSLessThan +OpLoopMerge +OpBranchConditional +OpLabel +OpLoad +OpLoad +OpIAdd +OpStore +OpLoad +OpIAdd +OpStore +OpBranch +OpLabel +OpReturn +OpFunctionEnd)"); + } + + WHEN("using a for range with step") + { + std::string_view nzslSource = R"( +[entry(frag)] +fn main() +{ + let x = 0; + for v in 0 -> 10 : 2 + { + x += v; + } +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + + ExpectGLSL(*shader, R"( +void main() +{ + int x = 0; + int v = 0; + int to = 10; + int step = 2; + while (v < to) + { + x += v; + v += step; + } + +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let x: i32 = 0; + for v in 0 -> 10 : 2 + { + x += v; + } + +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpVariable +OpVariable +OpVariable +OpStore +OpStore +OpStore +OpStore +OpBranch +OpLabel +OpLoad +OpLoad +OpSLessThan +OpLoopMerge +OpBranchConditional +OpLabel +OpLoad +OpLoad +OpIAdd +OpStore +OpLoad +OpLoad +OpIAdd +OpStore +OpBranch +OpLabel +OpReturn OpFunctionEnd)"); }