Shader: Add support for branching and discard statements

This commit is contained in:
Jérôme Leclercq 2021-07-03 19:13:22 +02:00
parent f2bb1a839c
commit f9b453bd2b
6 changed files with 127 additions and 40 deletions

View File

@ -35,16 +35,15 @@ namespace Nz
enum class BinaryType enum class BinaryType
{ {
Add, //< + Add, //< +
Subtract, //< -
Multiply, //< *
Divide, //< /
CompEq, //< == CompEq, //< ==
CompGe, //< >= CompGe, //< >=
CompGt, //< > CompGt, //< >
CompLe, //< <= CompLe, //< <=
CompLt, //< < CompLt, //< <
CompNe //< <= CompNe, //< <=
Divide, //< /
Multiply, //< *
Subtract, //< -
}; };
enum class BuiltinEntry enum class BuiltinEntry

View File

@ -83,6 +83,8 @@ namespace Nz::ShaderLang
std::vector<ShaderAst::Attribute> ParseAttributes(); std::vector<ShaderAst::Attribute> ParseAttributes();
// Statements // Statements
ShaderAst::StatementPtr ParseBranchStatement();
ShaderAst::StatementPtr ParseDiscardStatement();
ShaderAst::StatementPtr ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes = {}); ShaderAst::StatementPtr ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes = {});
std::vector<ShaderAst::StatementPtr> ParseFunctionBody(); std::vector<ShaderAst::StatementPtr> ParseFunctionBody();
ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector<ShaderAst::Attribute> attributes = {}); ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
@ -90,6 +92,7 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr ParseOptionDeclaration(); ShaderAst::StatementPtr ParseOptionDeclaration();
ShaderAst::StatementPtr ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes = {}); ShaderAst::StatementPtr ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
ShaderAst::StatementPtr ParseReturnStatement(); ShaderAst::StatementPtr ParseReturnStatement();
ShaderAst::StatementPtr ParseSingleStatement();
ShaderAst::StatementPtr ParseStatement(); ShaderAst::StatementPtr ParseStatement();
std::vector<ShaderAst::StatementPtr> ParseStatementList(); std::vector<ShaderAst::StatementPtr> ParseStatementList();
ShaderAst::StatementPtr ParseVariableDeclaration(); ShaderAst::StatementPtr ParseVariableDeclaration();
@ -126,6 +129,7 @@ namespace Nz::ShaderLang
Context* m_context; Context* m_context;
}; };
inline ShaderAst::StatementPtr Parse(const std::string_view& source);
inline ShaderAst::StatementPtr Parse(const std::vector<Token>& tokens); inline ShaderAst::StatementPtr Parse(const std::vector<Token>& tokens);
NAZARA_SHADER_API ShaderAst::StatementPtr Parse(const std::filesystem::path& sourcePath); NAZARA_SHADER_API ShaderAst::StatementPtr Parse(const std::filesystem::path& sourcePath);
} }

View File

@ -12,6 +12,11 @@ namespace Nz::ShaderLang
{ {
} }
inline ShaderAst::StatementPtr Parse(const std::string_view& source)
{
return Parse(Tokenize(source));
}
inline ShaderAst::StatementPtr Parse(const std::vector<Token>& tokens) inline ShaderAst::StatementPtr Parse(const std::vector<Token>& tokens)
{ {
Parser parser; Parser parser;

View File

@ -18,18 +18,21 @@ NAZARA_SHADERLANG_TOKEN(ClosingCurlyBracket)
NAZARA_SHADERLANG_TOKEN(ClosingSquareBracket) NAZARA_SHADERLANG_TOKEN(ClosingSquareBracket)
NAZARA_SHADERLANG_TOKEN(Colon) NAZARA_SHADERLANG_TOKEN(Colon)
NAZARA_SHADERLANG_TOKEN(Comma) NAZARA_SHADERLANG_TOKEN(Comma)
NAZARA_SHADERLANG_TOKEN(Discard)
NAZARA_SHADERLANG_TOKEN(Divide) NAZARA_SHADERLANG_TOKEN(Divide)
NAZARA_SHADERLANG_TOKEN(Dot) NAZARA_SHADERLANG_TOKEN(Dot)
NAZARA_SHADERLANG_TOKEN(Equal) NAZARA_SHADERLANG_TOKEN(Equal)
NAZARA_SHADERLANG_TOKEN(Else)
NAZARA_SHADERLANG_TOKEN(EndOfStream)
NAZARA_SHADERLANG_TOKEN(External) NAZARA_SHADERLANG_TOKEN(External)
NAZARA_SHADERLANG_TOKEN(FloatingPointValue) NAZARA_SHADERLANG_TOKEN(FloatingPointValue)
NAZARA_SHADERLANG_TOKEN(EndOfStream)
NAZARA_SHADERLANG_TOKEN(FunctionDeclaration) NAZARA_SHADERLANG_TOKEN(FunctionDeclaration)
NAZARA_SHADERLANG_TOKEN(FunctionReturn) NAZARA_SHADERLANG_TOKEN(FunctionReturn)
NAZARA_SHADERLANG_TOKEN(GreatherThan) NAZARA_SHADERLANG_TOKEN(GreaterThan)
NAZARA_SHADERLANG_TOKEN(GreatherThanEqual) NAZARA_SHADERLANG_TOKEN(GreaterThanEqual)
NAZARA_SHADERLANG_TOKEN(IntegerValue) NAZARA_SHADERLANG_TOKEN(IntegerValue)
NAZARA_SHADERLANG_TOKEN(Identifier) NAZARA_SHADERLANG_TOKEN(Identifier)
NAZARA_SHADERLANG_TOKEN(If)
NAZARA_SHADERLANG_TOKEN(LessThan) NAZARA_SHADERLANG_TOKEN(LessThan)
NAZARA_SHADERLANG_TOKEN(LessThanEqual) NAZARA_SHADERLANG_TOKEN(LessThanEqual)
NAZARA_SHADERLANG_TOKEN(Let) NAZARA_SHADERLANG_TOKEN(Let)

View File

@ -40,9 +40,12 @@ namespace Nz::ShaderLang
ForceCLocale forceCLocale; ForceCLocale forceCLocale;
std::unordered_map<std::string, TokenType> reservedKeywords = { std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "discard", TokenType::Discard },
{ "else", TokenType::Else },
{ "external", TokenType::External }, { "external", TokenType::External },
{ "false", TokenType::BoolFalse }, { "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration }, { "fn", TokenType::FunctionDeclaration },
{ "if", TokenType::If },
{ "let", TokenType::Let }, { "let", TokenType::Let },
{ "option", TokenType::Option }, { "option", TokenType::Option },
{ "return", TokenType::Return }, { "return", TokenType::Return },
@ -263,10 +266,10 @@ namespace Nz::ShaderLang
if (next == '=') if (next == '=')
{ {
currentPos++; currentPos++;
tokenType = TokenType::GreatherThanEqual; tokenType = TokenType::GreaterThanEqual;
} }
else else
tokenType = TokenType::GreatherThan; tokenType = TokenType::GreaterThan;
break; break;
} }

View File

@ -160,7 +160,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<' Expect(Advance(), TokenType::LessThan); //< '<'
matrixType.type = ParsePrimitiveType(); matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>' Expect(Advance(), TokenType::GreaterThan); //< '>'
return matrixType; return matrixType;
} }
@ -174,7 +174,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<' Expect(Advance(), TokenType::LessThan); //< '<'
matrixType.type = ParsePrimitiveType(); matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>' Expect(Advance(), TokenType::GreaterThan); //< '>'
return matrixType; return matrixType;
} }
@ -187,7 +187,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<' Expect(Advance(), TokenType::LessThan); //< '<'
samplerType.sampledType = ParsePrimitiveType(); samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>' Expect(Advance(), TokenType::GreaterThan); //< '>'
return samplerType; return samplerType;
} }
@ -200,7 +200,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<' Expect(Advance(), TokenType::LessThan); //< '<'
samplerType.sampledType = ParsePrimitiveType(); samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>' Expect(Advance(), TokenType::GreaterThan); //< '>'
return samplerType; return samplerType;
} }
@ -212,7 +212,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<' Expect(Advance(), TokenType::LessThan); //< '<'
uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() }; uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() };
Expect(Advance(), TokenType::GreatherThan); //< '>' Expect(Advance(), TokenType::GreaterThan); //< '>'
return uniformType; return uniformType;
} }
@ -225,7 +225,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<' Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType(); vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>' Expect(Advance(), TokenType::GreaterThan); //< '>'
return vectorType; return vectorType;
} }
@ -238,7 +238,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<' Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType(); vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>' Expect(Advance(), TokenType::GreaterThan); //< '>'
return vectorType; return vectorType;
} }
@ -251,7 +251,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<' Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType(); vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>' Expect(Advance(), TokenType::GreaterThan); //< '>'
return vectorType; return vectorType;
} }
@ -302,9 +302,6 @@ namespace Nz::ShaderLang
void Parser::RegisterVariable(std::string identifier) void Parser::RegisterVariable(std::string identifier)
{ {
if (IsVariableInScope(identifier))
throw DuplicateIdentifier{ ("identifier name " + identifier + " is already taken").c_str() };
assert(!m_context->scopeSizes.empty()); assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.push_back(std::move(identifier)); m_context->identifiersInScope.push_back(std::move(identifier));
} }
@ -378,6 +375,47 @@ namespace Nz::ShaderLang
return attributes; return attributes;
} }
ShaderAst::StatementPtr Parser::ParseBranchStatement()
{
std::unique_ptr<ShaderAst::BranchStatement> branch = std::make_unique<ShaderAst::BranchStatement>();
bool first = true;
for (;;)
{
if (!first)
Expect(Advance(), TokenType::Else);
first = false;
Expect(Advance(), TokenType::If);
auto& condStatement = branch->condStatements.emplace_back();
Expect(Advance(), TokenType::OpenParenthesis);
condStatement.condition = ParseExpression();
Expect(Advance(), TokenType::ClosingParenthesis);
condStatement.statement = ParseStatement();
if (Peek().type != TokenType::Else || Peek(1).type != TokenType::If)
break;
}
if (Peek().type == TokenType::Else)
branch->elseStatement = ParseStatement();
return branch;
}
ShaderAst::StatementPtr Parser::ParseDiscardStatement()
{
Expect(Advance(), TokenType::Discard);
return ShaderBuilder::Discard();
}
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes) ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes)
{ {
std::optional<UInt32> blockSetIndex; std::optional<UInt32> blockSetIndex;
@ -533,8 +571,6 @@ namespace Nz::ShaderLang
returnType = ParseType(); returnType = ParseType();
} }
Expect(Advance(), TokenType::OpenCurlyBracket);
EnterScope(); EnterScope();
for (const auto& parameter : parameters) for (const auto& parameter : parameters)
RegisterVariable(parameter.name); RegisterVariable(parameter.name);
@ -616,10 +652,10 @@ namespace Nz::ShaderLang
case ShaderAst::AttributeType::Option: case ShaderAst::AttributeType::Option:
{ {
if (!func->optionName.empty()) if (!func->optionName.empty())
throw AttributeError{ "attribute option must be present once" }; throw AttributeError{ "attribute opt must be present once" };
if (!std::holds_alternative<std::string>(arg)) if (!std::holds_alternative<std::string>(arg))
throw AttributeError{ "attribute option requires a string parameter" }; throw AttributeError{ "attribute opt requires a string parameter" };
func->optionName = std::get<std::string>(arg); func->optionName = std::get<std::string>(arg);
break; break;
@ -786,36 +822,58 @@ namespace Nz::ShaderLang
return ShaderBuilder::Return(std::move(expr)); return ShaderBuilder::Return(std::move(expr));
} }
ShaderAst::StatementPtr Parser::ParseStatement() ShaderAst::StatementPtr Parser::ParseSingleStatement()
{ {
const Token& token = Peek(); const Token& token = Peek();
ShaderAst::StatementPtr statement; ShaderAst::StatementPtr statement;
switch (token.type) switch (token.type)
{ {
case TokenType::Discard:
statement = ParseDiscardStatement();
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::Let: case TokenType::Let:
statement = ParseVariableDeclaration(); statement = ParseVariableDeclaration();
Expect(Advance(), TokenType::Semicolon);
break; break;
case TokenType::Identifier: case TokenType::Identifier:
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation()); statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::If:
statement = ParseBranchStatement();
break; break;
case TokenType::Return: case TokenType::Return:
statement = ParseReturnStatement(); statement = ParseReturnStatement();
Expect(Advance(), TokenType::Semicolon);
break; break;
default: default:
break; break;
} }
Expect(Advance(), TokenType::Semicolon);
return statement; return statement;
} }
ShaderAst::StatementPtr Parser::ParseStatement()
{
if (Peek().type == TokenType::OpenCurlyBracket)
return ShaderBuilder::MultiStatement(ParseStatementList());
else
return ParseSingleStatement();
}
std::vector<ShaderAst::StatementPtr> Parser::ParseStatementList() std::vector<ShaderAst::StatementPtr> Parser::ParseStatementList()
{ {
EnterScope();
Expect(Advance(), TokenType::OpenCurlyBracket);
std::vector<ShaderAst::StatementPtr> statements; std::vector<ShaderAst::StatementPtr> statements;
while (Peek().type != TokenType::ClosingCurlyBracket) while (Peek().type != TokenType::ClosingCurlyBracket)
{ {
@ -823,6 +881,8 @@ namespace Nz::ShaderLang
statements.push_back(ParseStatement()); statements.push_back(ParseStatement());
} }
LeaveScope();
return statements; return statements;
} }
@ -944,11 +1004,18 @@ namespace Nz::ShaderLang
{ {
switch (currentTokenType) switch (currentTokenType)
{ {
case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break; case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break;
case TokenType::Equal: binaryType = ShaderAst::BinaryType::CompEq; break;
case TokenType::LessThan: binaryType = ShaderAst::BinaryType::CompLt; break;
case TokenType::LessThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break;
case TokenType::GreaterThan: binaryType = ShaderAst::BinaryType::CompLt; break;
case TokenType::GreaterThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break;
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break; case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break; case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break; case TokenType::NotEqual: binaryType = ShaderAst::BinaryType::CompNe; break;
default: throw UnexpectedToken{}; case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break;
default:
throw UnexpectedToken{};
} }
} }
@ -1159,12 +1226,18 @@ namespace Nz::ShaderLang
{ {
switch (token) switch (token)
{ {
case TokenType::Plus: return 20; case TokenType::Divide: return 80;
case TokenType::Divide: return 40; case TokenType::Dot: return 100;
case TokenType::Multiply: return 40; case TokenType::Equal: return 50;
case TokenType::Minus: return 20; case TokenType::LessThan: return 40;
case TokenType::Dot: return 50; case TokenType::LessThanEqual: return 40;
case TokenType::OpenSquareBracket: return 50; case TokenType::GreaterThan: return 40;
case TokenType::GreaterThanEqual: return 40;
case TokenType::Multiply: return 80;
case TokenType::Minus: return 60;
case TokenType::NotEqual: return 50;
case TokenType::Plus: return 60;
case TokenType::OpenSquareBracket: return 100;
default: return -1; default: return -1;
} }
} }
@ -1187,6 +1260,6 @@ namespace Nz::ShaderLang
return {}; return {};
} }
return Parse(Tokenize(std::string_view(reinterpret_cast<const char*>(source.data()), source.size()))); return Parse(std::string_view(reinterpret_cast<const char*>(source.data()), source.size()));
} }
} }