From f9b453bd2b18f9b3a86a68abc60cfa2cc5e43d6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Sat, 3 Jul 2021 19:13:22 +0200 Subject: [PATCH] Shader: Add support for branching and discard statements --- include/Nazara/Shader/Ast/Enums.hpp | 9 +- include/Nazara/Shader/ShaderLangParser.hpp | 4 + include/Nazara/Shader/ShaderLangParser.inl | 5 + include/Nazara/Shader/ShaderLangTokenList.hpp | 9 +- src/Nazara/Shader/ShaderLangLexer.cpp | 7 +- src/Nazara/Shader/ShaderLangParser.cpp | 133 ++++++++++++++---- 6 files changed, 127 insertions(+), 40 deletions(-) diff --git a/include/Nazara/Shader/Ast/Enums.hpp b/include/Nazara/Shader/Ast/Enums.hpp index a567aaa4a..027b6c500 100644 --- a/include/Nazara/Shader/Ast/Enums.hpp +++ b/include/Nazara/Shader/Ast/Enums.hpp @@ -35,16 +35,15 @@ namespace Nz enum class BinaryType { Add, //< + - Subtract, //< - - Multiply, //< * - Divide, //< / - CompEq, //< == CompGe, //< >= CompGt, //< > CompLe, //< <= CompLt, //< < - CompNe //< <= + CompNe, //< <= + Divide, //< / + Multiply, //< * + Subtract, //< - }; enum class BuiltinEntry diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index 525906b0c..b8b43257e 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -83,6 +83,8 @@ namespace Nz::ShaderLang std::vector ParseAttributes(); // Statements + ShaderAst::StatementPtr ParseBranchStatement(); + ShaderAst::StatementPtr ParseDiscardStatement(); ShaderAst::StatementPtr ParseExternalBlock(std::vector attributes = {}); std::vector ParseFunctionBody(); ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector attributes = {}); @@ -90,6 +92,7 @@ namespace Nz::ShaderLang ShaderAst::StatementPtr ParseOptionDeclaration(); ShaderAst::StatementPtr ParseStructDeclaration(std::vector attributes = {}); ShaderAst::StatementPtr ParseReturnStatement(); + ShaderAst::StatementPtr ParseSingleStatement(); ShaderAst::StatementPtr ParseStatement(); std::vector ParseStatementList(); ShaderAst::StatementPtr ParseVariableDeclaration(); @@ -126,6 +129,7 @@ namespace Nz::ShaderLang Context* m_context; }; + inline ShaderAst::StatementPtr Parse(const std::string_view& source); inline ShaderAst::StatementPtr Parse(const std::vector& tokens); NAZARA_SHADER_API ShaderAst::StatementPtr Parse(const std::filesystem::path& sourcePath); } diff --git a/include/Nazara/Shader/ShaderLangParser.inl b/include/Nazara/Shader/ShaderLangParser.inl index a8f0ca69f..e9ca1e8fd 100644 --- a/include/Nazara/Shader/ShaderLangParser.inl +++ b/include/Nazara/Shader/ShaderLangParser.inl @@ -12,6 +12,11 @@ namespace Nz::ShaderLang { } + inline ShaderAst::StatementPtr Parse(const std::string_view& source) + { + return Parse(Tokenize(source)); + } + inline ShaderAst::StatementPtr Parse(const std::vector& tokens) { Parser parser; diff --git a/include/Nazara/Shader/ShaderLangTokenList.hpp b/include/Nazara/Shader/ShaderLangTokenList.hpp index a7b8c50f7..20b0f1d7d 100644 --- a/include/Nazara/Shader/ShaderLangTokenList.hpp +++ b/include/Nazara/Shader/ShaderLangTokenList.hpp @@ -18,18 +18,21 @@ NAZARA_SHADERLANG_TOKEN(ClosingCurlyBracket) NAZARA_SHADERLANG_TOKEN(ClosingSquareBracket) NAZARA_SHADERLANG_TOKEN(Colon) NAZARA_SHADERLANG_TOKEN(Comma) +NAZARA_SHADERLANG_TOKEN(Discard) NAZARA_SHADERLANG_TOKEN(Divide) NAZARA_SHADERLANG_TOKEN(Dot) NAZARA_SHADERLANG_TOKEN(Equal) +NAZARA_SHADERLANG_TOKEN(Else) +NAZARA_SHADERLANG_TOKEN(EndOfStream) NAZARA_SHADERLANG_TOKEN(External) NAZARA_SHADERLANG_TOKEN(FloatingPointValue) -NAZARA_SHADERLANG_TOKEN(EndOfStream) NAZARA_SHADERLANG_TOKEN(FunctionDeclaration) NAZARA_SHADERLANG_TOKEN(FunctionReturn) -NAZARA_SHADERLANG_TOKEN(GreatherThan) -NAZARA_SHADERLANG_TOKEN(GreatherThanEqual) +NAZARA_SHADERLANG_TOKEN(GreaterThan) +NAZARA_SHADERLANG_TOKEN(GreaterThanEqual) NAZARA_SHADERLANG_TOKEN(IntegerValue) NAZARA_SHADERLANG_TOKEN(Identifier) +NAZARA_SHADERLANG_TOKEN(If) NAZARA_SHADERLANG_TOKEN(LessThan) NAZARA_SHADERLANG_TOKEN(LessThanEqual) NAZARA_SHADERLANG_TOKEN(Let) diff --git a/src/Nazara/Shader/ShaderLangLexer.cpp b/src/Nazara/Shader/ShaderLangLexer.cpp index 91343bc04..6213c95da 100644 --- a/src/Nazara/Shader/ShaderLangLexer.cpp +++ b/src/Nazara/Shader/ShaderLangLexer.cpp @@ -40,9 +40,12 @@ namespace Nz::ShaderLang ForceCLocale forceCLocale; std::unordered_map reservedKeywords = { + { "discard", TokenType::Discard }, + { "else", TokenType::Else }, { "external", TokenType::External }, { "false", TokenType::BoolFalse }, { "fn", TokenType::FunctionDeclaration }, + { "if", TokenType::If }, { "let", TokenType::Let }, { "option", TokenType::Option }, { "return", TokenType::Return }, @@ -263,10 +266,10 @@ namespace Nz::ShaderLang if (next == '=') { currentPos++; - tokenType = TokenType::GreatherThanEqual; + tokenType = TokenType::GreaterThanEqual; } else - tokenType = TokenType::GreatherThan; + tokenType = TokenType::GreaterThan; break; } diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 7e495a646..738787f9c 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -160,7 +160,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::LessThan); //< '<' matrixType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::GreatherThan); //< '>' + Expect(Advance(), TokenType::GreaterThan); //< '>' return matrixType; } @@ -174,7 +174,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::LessThan); //< '<' matrixType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::GreatherThan); //< '>' + Expect(Advance(), TokenType::GreaterThan); //< '>' return matrixType; } @@ -187,7 +187,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::LessThan); //< '<' samplerType.sampledType = ParsePrimitiveType(); - Expect(Advance(), TokenType::GreatherThan); //< '>' + Expect(Advance(), TokenType::GreaterThan); //< '>' return samplerType; } @@ -200,7 +200,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::LessThan); //< '<' samplerType.sampledType = ParsePrimitiveType(); - Expect(Advance(), TokenType::GreatherThan); //< '>' + Expect(Advance(), TokenType::GreaterThan); //< '>' return samplerType; } @@ -212,7 +212,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::LessThan); //< '<' uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() }; - Expect(Advance(), TokenType::GreatherThan); //< '>' + Expect(Advance(), TokenType::GreaterThan); //< '>' return uniformType; } @@ -225,7 +225,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::LessThan); //< '<' vectorType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::GreatherThan); //< '>' + Expect(Advance(), TokenType::GreaterThan); //< '>' return vectorType; } @@ -238,7 +238,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::LessThan); //< '<' vectorType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::GreatherThan); //< '>' + Expect(Advance(), TokenType::GreaterThan); //< '>' return vectorType; } @@ -251,7 +251,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::LessThan); //< '<' vectorType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::GreatherThan); //< '>' + Expect(Advance(), TokenType::GreaterThan); //< '>' return vectorType; } @@ -302,9 +302,6 @@ namespace Nz::ShaderLang void Parser::RegisterVariable(std::string identifier) { - if (IsVariableInScope(identifier)) - throw DuplicateIdentifier{ ("identifier name " + identifier + " is already taken").c_str() }; - assert(!m_context->scopeSizes.empty()); m_context->identifiersInScope.push_back(std::move(identifier)); } @@ -378,6 +375,47 @@ namespace Nz::ShaderLang return attributes; } + ShaderAst::StatementPtr Parser::ParseBranchStatement() + { + std::unique_ptr branch = std::make_unique(); + + bool first = true; + for (;;) + { + if (!first) + Expect(Advance(), TokenType::Else); + + first = false; + + Expect(Advance(), TokenType::If); + + auto& condStatement = branch->condStatements.emplace_back(); + + Expect(Advance(), TokenType::OpenParenthesis); + + condStatement.condition = ParseExpression(); + + Expect(Advance(), TokenType::ClosingParenthesis); + + condStatement.statement = ParseStatement(); + + if (Peek().type != TokenType::Else || Peek(1).type != TokenType::If) + break; + } + + if (Peek().type == TokenType::Else) + branch->elseStatement = ParseStatement(); + + return branch; + } + + ShaderAst::StatementPtr Parser::ParseDiscardStatement() + { + Expect(Advance(), TokenType::Discard); + + return ShaderBuilder::Discard(); + } + ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector attributes) { std::optional blockSetIndex; @@ -533,8 +571,6 @@ namespace Nz::ShaderLang returnType = ParseType(); } - Expect(Advance(), TokenType::OpenCurlyBracket); - EnterScope(); for (const auto& parameter : parameters) RegisterVariable(parameter.name); @@ -616,10 +652,10 @@ namespace Nz::ShaderLang case ShaderAst::AttributeType::Option: { if (!func->optionName.empty()) - throw AttributeError{ "attribute option must be present once" }; + throw AttributeError{ "attribute opt must be present once" }; if (!std::holds_alternative(arg)) - throw AttributeError{ "attribute option requires a string parameter" }; + throw AttributeError{ "attribute opt requires a string parameter" }; func->optionName = std::get(arg); break; @@ -786,36 +822,58 @@ namespace Nz::ShaderLang return ShaderBuilder::Return(std::move(expr)); } - ShaderAst::StatementPtr Parser::ParseStatement() + ShaderAst::StatementPtr Parser::ParseSingleStatement() { const Token& token = Peek(); ShaderAst::StatementPtr statement; switch (token.type) { + case TokenType::Discard: + statement = ParseDiscardStatement(); + Expect(Advance(), TokenType::Semicolon); + break; + case TokenType::Let: statement = ParseVariableDeclaration(); + Expect(Advance(), TokenType::Semicolon); break; case TokenType::Identifier: statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation()); + Expect(Advance(), TokenType::Semicolon); + break; + + case TokenType::If: + statement = ParseBranchStatement(); break; case TokenType::Return: statement = ParseReturnStatement(); + Expect(Advance(), TokenType::Semicolon); break; default: break; } - Expect(Advance(), TokenType::Semicolon); - return statement; } + ShaderAst::StatementPtr Parser::ParseStatement() + { + if (Peek().type == TokenType::OpenCurlyBracket) + return ShaderBuilder::MultiStatement(ParseStatementList()); + else + return ParseSingleStatement(); + } + std::vector Parser::ParseStatementList() { + EnterScope(); + + Expect(Advance(), TokenType::OpenCurlyBracket); + std::vector statements; while (Peek().type != TokenType::ClosingCurlyBracket) { @@ -823,6 +881,8 @@ namespace Nz::ShaderLang statements.push_back(ParseStatement()); } + LeaveScope(); + return statements; } @@ -944,11 +1004,18 @@ namespace Nz::ShaderLang { switch (currentTokenType) { - case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break; - case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break; - case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break; - case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break; - default: throw UnexpectedToken{}; + case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break; + case TokenType::Equal: binaryType = ShaderAst::BinaryType::CompEq; break; + case TokenType::LessThan: binaryType = ShaderAst::BinaryType::CompLt; break; + case TokenType::LessThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break; + case TokenType::GreaterThan: binaryType = ShaderAst::BinaryType::CompLt; break; + case TokenType::GreaterThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break; + case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break; + case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break; + case TokenType::NotEqual: binaryType = ShaderAst::BinaryType::CompNe; break; + case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break; + default: + throw UnexpectedToken{}; } } @@ -1159,12 +1226,18 @@ namespace Nz::ShaderLang { switch (token) { - case TokenType::Plus: return 20; - case TokenType::Divide: return 40; - case TokenType::Multiply: return 40; - case TokenType::Minus: return 20; - case TokenType::Dot: return 50; - case TokenType::OpenSquareBracket: return 50; + case TokenType::Divide: return 80; + case TokenType::Dot: return 100; + case TokenType::Equal: return 50; + case TokenType::LessThan: return 40; + case TokenType::LessThanEqual: return 40; + case TokenType::GreaterThan: return 40; + case TokenType::GreaterThanEqual: return 40; + case TokenType::Multiply: return 80; + case TokenType::Minus: return 60; + case TokenType::NotEqual: return 50; + case TokenType::Plus: return 60; + case TokenType::OpenSquareBracket: return 100; default: return -1; } } @@ -1187,6 +1260,6 @@ namespace Nz::ShaderLang return {}; } - return Parse(Tokenize(std::string_view(reinterpret_cast(source.data()), source.size()))); + return Parse(std::string_view(reinterpret_cast(source.data()), source.size())); } }