From a54f70fd24dd5e86320b5a13d1c656353156023f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Mon, 21 Mar 2022 23:11:28 +0100 Subject: [PATCH] Shader: Fix parsing of unary/dot/indices/and/or --- include/Nazara/Shader/ShaderLangParser.hpp | 6 +- src/Nazara/Shader/ShaderLangParser.cpp | 162 ++++++++++----------- tests/Engine/Shader/BranchTests.cpp | 91 ++++++++++++ tests/Engine/Shader/FunctionsTests.cpp | 7 +- 4 files changed, 173 insertions(+), 93 deletions(-) diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index 16fd84e53..900e57ca1 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -85,6 +85,10 @@ namespace Nz::ShaderLang void ParseModuleStatement(std::vector attributes); void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue& type, ShaderAst::ExpressionPtr& initialValue); + ShaderAst::ExpressionPtr BuildIdentifierAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs); + ShaderAst::ExpressionPtr BuildIndexAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs); + ShaderAst::ExpressionPtr BuildBinary(ShaderAst::BinaryType binaryType, ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs); + // Statements ShaderAst::StatementPtr ParseAliasDeclaration(); ShaderAst::StatementPtr ParseBranchStatement(); @@ -110,10 +114,10 @@ namespace Nz::ShaderLang ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs); ShaderAst::ExpressionPtr ParseConstSelectExpression(); ShaderAst::ExpressionPtr ParseExpression(); + std::vector ParseExpressionList(TokenType terminationToken); ShaderAst::ExpressionPtr ParseFloatingPointExpression(); ShaderAst::ExpressionPtr ParseIdentifier(); ShaderAst::ExpressionPtr ParseIntegerExpression(); - std::vector ParseParameters(); ShaderAst::ExpressionPtr ParseParenthesisExpression(); ShaderAst::ExpressionPtr ParsePrimaryExpression(); ShaderAst::ExpressionPtr ParseStringExpression(); diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index f3a565d53..1f549d5a1 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -390,6 +390,24 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::Semicolon); } + ShaderAst::ExpressionPtr Parser::BuildIdentifierAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs) + { + if (rhs->GetType() == ShaderAst::NodeType::IdentifierExpression) + return ShaderBuilder::AccessMember(std::move(lhs), { std::move(SafeCast(*rhs).identifier) }); + else + return BuildIndexAccess(std::move(lhs), std::move(rhs)); + } + + ShaderAst::ExpressionPtr Parser::BuildIndexAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs) + { + return ShaderBuilder::AccessIndex(std::move(lhs), std::move(rhs)); + } + + ShaderAst::ExpressionPtr Parser::BuildBinary(ShaderAst::BinaryType binaryType, ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs) + { + return ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs)); + } + ShaderAst::StatementPtr Parser::ParseAliasDeclaration() { Expect(Advance(), TokenType::Alias); @@ -1124,59 +1142,25 @@ namespace Nz::ShaderLang if (tokenPrecedence < exprPrecedence) return lhs; - bool c = false; - while (currentTokenType == TokenType::Dot || currentTokenType == TokenType::OpenSquareBracket) - { - c = true; - - if (currentTokenType == TokenType::Dot) - { - std::unique_ptr accessMemberNode = std::make_unique(); - accessMemberNode->expr = std::move(lhs); - - do - { - Consume(); - - accessMemberNode->identifiers.push_back(ParseIdentifierAsName()); - } while (Peek().type == TokenType::Dot); - - lhs = std::move(accessMemberNode); - } - else - { - assert(currentTokenType == TokenType::OpenSquareBracket); - - std::unique_ptr indexNode = std::make_unique(); - indexNode->expr = std::move(lhs); - - do - { - Consume(); - - indexNode->indices.push_back(ParseExpression()); - } - while (Peek().type == TokenType::Comma); - - Expect(Advance(), TokenType::ClosingSquareBracket); - - lhs = std::move(indexNode); - } - - currentTokenType = Peek().type; - } - if (currentTokenType == TokenType::OpenParenthesis) { - // Function call - auto parameters = ParseParameters(); - lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters)); + Consume(); - c = true; + // Function call + auto parameters = ParseExpressionList(TokenType::ClosingParenthesis); + lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters)); + continue; } - if (c) + if (currentTokenType == TokenType::OpenSquareBracket) + { + Consume(); + + // Indices + auto parameters = ParseExpressionList(TokenType::ClosingSquareBracket); + lhs = ShaderBuilder::AccessIndex(std::move(lhs), std::move(parameters)); continue; + } Consume(); ShaderAst::ExpressionPtr rhs = ParsePrimaryExpression(); @@ -1187,28 +1171,30 @@ namespace Nz::ShaderLang if (tokenPrecedence < nextTokenPrecedence) rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs)); - ShaderAst::BinaryType binaryType; + lhs = [&] { switch (currentTokenType) { - case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break; - case TokenType::Equal: binaryType = ShaderAst::BinaryType::CompEq; break; - case TokenType::LessThan: binaryType = ShaderAst::BinaryType::CompLt; break; - case TokenType::LessThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break; - case TokenType::LogicalAnd: binaryType = ShaderAst::BinaryType::LogicalAnd; break; - case TokenType::LogicalOr: binaryType = ShaderAst::BinaryType::LogicalOr; break; - case TokenType::GreaterThan: binaryType = ShaderAst::BinaryType::CompGt; break; - case TokenType::GreaterThanEqual: binaryType = ShaderAst::BinaryType::CompGe; break; - case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break; - case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break; - case TokenType::NotEqual: binaryType = ShaderAst::BinaryType::CompNe; break; - case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break; + case TokenType::Dot: + return BuildIdentifierAccess(std::move(lhs), std::move(rhs)); + + case TokenType::Divide: return BuildBinary(ShaderAst::BinaryType::Divide, std::move(lhs), std::move(rhs)); + case TokenType::Equal: return BuildBinary(ShaderAst::BinaryType::CompEq, std::move(lhs), std::move(rhs)); + case TokenType::LessThan: return BuildBinary(ShaderAst::BinaryType::CompLt, std::move(lhs), std::move(rhs)); + case TokenType::LessThanEqual: return BuildBinary(ShaderAst::BinaryType::CompLe, std::move(lhs), std::move(rhs)); + case TokenType::LogicalAnd: return BuildBinary(ShaderAst::BinaryType::LogicalAnd, std::move(lhs), std::move(rhs)); + case TokenType::LogicalOr: return BuildBinary(ShaderAst::BinaryType::LogicalOr, std::move(lhs), std::move(rhs)); + case TokenType::GreaterThan: return BuildBinary(ShaderAst::BinaryType::CompGt, std::move(lhs), std::move(rhs)); + case TokenType::GreaterThanEqual: return BuildBinary(ShaderAst::BinaryType::CompGe, std::move(lhs), std::move(rhs)); + case TokenType::Minus: return BuildBinary(ShaderAst::BinaryType::Subtract, std::move(lhs), std::move(rhs)); + case TokenType::Multiply: return BuildBinary(ShaderAst::BinaryType::Multiply, std::move(lhs), std::move(rhs)); + case TokenType::NotEqual: return BuildBinary(ShaderAst::BinaryType::CompNe, std::move(lhs), std::move(rhs)); + case TokenType::Plus: return BuildBinary(ShaderAst::BinaryType::Add, std::move(lhs), std::move(rhs)); default: throw UnexpectedToken{}; - } - } - lhs = ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs)); + } + }(); } } @@ -1237,6 +1223,24 @@ namespace Nz::ShaderLang return ParseBinOpRhs(0, ParsePrimaryExpression()); } + std::vector Parser::ParseExpressionList(TokenType terminationToken) + { + std::vector parameters; + bool first = true; + while (Peek().type != terminationToken) + { + if (!first) + Expect(Advance(), TokenType::Comma); + + first = false; + parameters.push_back(ParseExpression()); + } + + Expect(Advance(), terminationToken); + + return parameters; + } + ShaderAst::ExpressionPtr Parser::ParseFloatingPointExpression() { const Token& floatingPointToken = Expect(Advance(), TokenType::FloatingPointValue); @@ -1257,26 +1261,6 @@ namespace Nz::ShaderLang return ShaderBuilder::Constant(SafeCast(std::get(integerToken.data))); //< FIXME } - std::vector Parser::ParseParameters() - { - Expect(Advance(), TokenType::OpenParenthesis); - - std::vector parameters; - bool first = true; - while (Peek().type != TokenType::ClosingParenthesis) - { - if (!first) - Expect(Advance(), TokenType::Comma); - - first = false; - parameters.push_back(ParseExpression()); - } - - Expect(Advance(), TokenType::ClosingParenthesis); - - return parameters; - } - ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression() { Expect(Advance(), TokenType::OpenParenthesis); @@ -1314,7 +1298,7 @@ namespace Nz::ShaderLang case TokenType::Minus: { Consume(); - ShaderAst::ExpressionPtr expr = ParsePrimaryExpression(); + ShaderAst::ExpressionPtr expr = ParseExpression(); return ShaderBuilder::Unary(ShaderAst::UnaryType::Minus, std::move(expr)); } @@ -1322,7 +1306,7 @@ namespace Nz::ShaderLang case TokenType::Plus: { Consume(); - ShaderAst::ExpressionPtr expr = ParsePrimaryExpression(); + ShaderAst::ExpressionPtr expr = ParseExpression(); return ShaderBuilder::Unary(ShaderAst::UnaryType::Plus, std::move(expr)); } @@ -1330,7 +1314,7 @@ namespace Nz::ShaderLang case TokenType::Not: { Consume(); - ShaderAst::ExpressionPtr expr = ParsePrimaryExpression(); + ShaderAst::ExpressionPtr expr = ParseExpression(); return ShaderBuilder::Unary(ShaderAst::UnaryType::LogicalNot, std::move(expr)); } @@ -1404,12 +1388,12 @@ namespace Nz::ShaderLang switch (token) { case TokenType::Divide: return 80; - case TokenType::Dot: return 100; + case TokenType::Dot: return 150; case TokenType::Equal: return 50; case TokenType::LessThan: return 40; case TokenType::LessThanEqual: return 40; - case TokenType::LogicalAnd: return 120; - case TokenType::LogicalOr: return 140; + case TokenType::LogicalAnd: return 20; + case TokenType::LogicalOr: return 10; case TokenType::GreaterThan: return 40; case TokenType::GreaterThanEqual: return 40; case TokenType::Multiply: return 80; diff --git a/tests/Engine/Shader/BranchTests.cpp b/tests/Engine/Shader/BranchTests.cpp index ee836edeb..e677d6d27 100644 --- a/tests/Engine/Shader/BranchTests.cpp +++ b/tests/Engine/Shader/BranchTests.cpp @@ -88,6 +88,97 @@ OpStore OpBranch OpLabel OpReturn +OpFunctionEnd)"); + } + + WHEN("using a more complex branch") + { + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +struct inputStruct +{ + value: f32 +} + +external +{ + [set(0), binding(0)] data: uniform[inputStruct] +} + +[entry(frag)] +fn main() +{ + let value: f32; + if (data.value > 42.0 || data.value <= 50.0 && data.value < 0.0) + value = 1.0; + else + value = 0.0; +} +)"; + + Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); + shaderModule = SanitizeModule(*shaderModule); + + ExpectGLSL(*shaderModule, R"( +void main() +{ + float value; + if ((data.value > (42.000000)) || ((data.value <= (50.000000)) && (data.value < (0.000000)))) + { + value = 1.000000; + } + else + { + value = 0.000000; + } + +} +)"); + + ExpectNZSL(*shaderModule, R"( +[entry(frag)] +fn main() +{ + let value: f32; + if ((data.value > (42.000000)) || ((data.value <= (50.000000)) && (data.value < (0.000000)))) + { + value = 1.000000; + } + else + { + value = 0.000000; + } + +} +)"); + + ExpectSPIRV(*shaderModule, R"( +OpFunction +OpLabel +OpVariable +OpAccessChain +OpLoad +OpFOrdGreaterThanEqual +OpAccessChain +OpLoad +OpFOrdLessThanEqual +OpAccessChain +OpLoad +OpFOrdLessThan +OpLogicalAnd +OpLogicalOr +OpSelectionMerge +OpBranchConditional +OpLabel +OpStore +OpBranch +OpLabel +OpStore +OpBranch +OpLabel +OpReturn OpFunctionEnd)"); } diff --git a/tests/Engine/Shader/FunctionsTests.cpp b/tests/Engine/Shader/FunctionsTests.cpp index 4bbc250cb..5da086c3e 100644 --- a/tests/Engine/Shader/FunctionsTests.cpp +++ b/tests/Engine/Shader/FunctionsTests.cpp @@ -28,7 +28,7 @@ fn GetValue() -> f32 fn main() -> FragOut { let output: FragOut; - output.value = GetValue(); + output.value = -GetValue(); return output; } @@ -49,7 +49,7 @@ layout(location = 0) out float _NzOut_value; void main() { FragOut output_; - output_.value = GetValue(); + output_.value = -GetValue(); _NzOut_value = output_.value; return; @@ -66,7 +66,7 @@ fn GetValue() -> f32 fn main() -> FragOut { let output: FragOut; - output.value = GetValue(); + output.value = -GetValue(); return output; } )"); @@ -80,6 +80,7 @@ OpFunction OpLabel OpVariable OpFunctionCall +OpFNegate OpAccessChain OpStore OpLoad