Shader: Add support for branching and discard statements

This commit is contained in:
Jérôme Leclercq
2021-07-03 19:13:22 +02:00
parent f2bb1a839c
commit f9b453bd2b
6 changed files with 127 additions and 40 deletions

View File

@@ -40,9 +40,12 @@ namespace Nz::ShaderLang
ForceCLocale forceCLocale;
std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "discard", TokenType::Discard },
{ "else", TokenType::Else },
{ "external", TokenType::External },
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "if", TokenType::If },
{ "let", TokenType::Let },
{ "option", TokenType::Option },
{ "return", TokenType::Return },
@@ -263,10 +266,10 @@ namespace Nz::ShaderLang
if (next == '=')
{
currentPos++;
tokenType = TokenType::GreatherThanEqual;
tokenType = TokenType::GreaterThanEqual;
}
else
tokenType = TokenType::GreatherThan;
tokenType = TokenType::GreaterThan;
break;
}

View File

@@ -160,7 +160,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<'
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
Expect(Advance(), TokenType::GreaterThan); //< '>'
return matrixType;
}
@@ -174,7 +174,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<'
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
Expect(Advance(), TokenType::GreaterThan); //< '>'
return matrixType;
}
@@ -187,7 +187,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<'
samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
Expect(Advance(), TokenType::GreaterThan); //< '>'
return samplerType;
}
@@ -200,7 +200,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<'
samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
Expect(Advance(), TokenType::GreaterThan); //< '>'
return samplerType;
}
@@ -212,7 +212,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<'
uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() };
Expect(Advance(), TokenType::GreatherThan); //< '>'
Expect(Advance(), TokenType::GreaterThan); //< '>'
return uniformType;
}
@@ -225,7 +225,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
Expect(Advance(), TokenType::GreaterThan); //< '>'
return vectorType;
}
@@ -238,7 +238,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
Expect(Advance(), TokenType::GreaterThan); //< '>'
return vectorType;
}
@@ -251,7 +251,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
Expect(Advance(), TokenType::GreaterThan); //< '>'
return vectorType;
}
@@ -302,9 +302,6 @@ namespace Nz::ShaderLang
void Parser::RegisterVariable(std::string identifier)
{
if (IsVariableInScope(identifier))
throw DuplicateIdentifier{ ("identifier name " + identifier + " is already taken").c_str() };
assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.push_back(std::move(identifier));
}
@@ -378,6 +375,47 @@ namespace Nz::ShaderLang
return attributes;
}
ShaderAst::StatementPtr Parser::ParseBranchStatement()
{
std::unique_ptr<ShaderAst::BranchStatement> branch = std::make_unique<ShaderAst::BranchStatement>();
bool first = true;
for (;;)
{
if (!first)
Expect(Advance(), TokenType::Else);
first = false;
Expect(Advance(), TokenType::If);
auto& condStatement = branch->condStatements.emplace_back();
Expect(Advance(), TokenType::OpenParenthesis);
condStatement.condition = ParseExpression();
Expect(Advance(), TokenType::ClosingParenthesis);
condStatement.statement = ParseStatement();
if (Peek().type != TokenType::Else || Peek(1).type != TokenType::If)
break;
}
if (Peek().type == TokenType::Else)
branch->elseStatement = ParseStatement();
return branch;
}
ShaderAst::StatementPtr Parser::ParseDiscardStatement()
{
Expect(Advance(), TokenType::Discard);
return ShaderBuilder::Discard();
}
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes)
{
std::optional<UInt32> blockSetIndex;
@@ -533,8 +571,6 @@ namespace Nz::ShaderLang
returnType = ParseType();
}
Expect(Advance(), TokenType::OpenCurlyBracket);
EnterScope();
for (const auto& parameter : parameters)
RegisterVariable(parameter.name);
@@ -616,10 +652,10 @@ namespace Nz::ShaderLang
case ShaderAst::AttributeType::Option:
{
if (!func->optionName.empty())
throw AttributeError{ "attribute option must be present once" };
throw AttributeError{ "attribute opt must be present once" };
if (!std::holds_alternative<std::string>(arg))
throw AttributeError{ "attribute option requires a string parameter" };
throw AttributeError{ "attribute opt requires a string parameter" };
func->optionName = std::get<std::string>(arg);
break;
@@ -786,36 +822,58 @@ namespace Nz::ShaderLang
return ShaderBuilder::Return(std::move(expr));
}
ShaderAst::StatementPtr Parser::ParseStatement()
ShaderAst::StatementPtr Parser::ParseSingleStatement()
{
const Token& token = Peek();
ShaderAst::StatementPtr statement;
switch (token.type)
{
case TokenType::Discard:
statement = ParseDiscardStatement();
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::Let:
statement = ParseVariableDeclaration();
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::Identifier:
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::If:
statement = ParseBranchStatement();
break;
case TokenType::Return:
statement = ParseReturnStatement();
Expect(Advance(), TokenType::Semicolon);
break;
default:
break;
}
Expect(Advance(), TokenType::Semicolon);
return statement;
}
ShaderAst::StatementPtr Parser::ParseStatement()
{
if (Peek().type == TokenType::OpenCurlyBracket)
return ShaderBuilder::MultiStatement(ParseStatementList());
else
return ParseSingleStatement();
}
std::vector<ShaderAst::StatementPtr> Parser::ParseStatementList()
{
EnterScope();
Expect(Advance(), TokenType::OpenCurlyBracket);
std::vector<ShaderAst::StatementPtr> statements;
while (Peek().type != TokenType::ClosingCurlyBracket)
{
@@ -823,6 +881,8 @@ namespace Nz::ShaderLang
statements.push_back(ParseStatement());
}
LeaveScope();
return statements;
}
@@ -944,11 +1004,18 @@ namespace Nz::ShaderLang
{
switch (currentTokenType)
{
case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break;
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break;
default: throw UnexpectedToken{};
case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break;
case TokenType::Equal: binaryType = ShaderAst::BinaryType::CompEq; break;
case TokenType::LessThan: binaryType = ShaderAst::BinaryType::CompLt; break;
case TokenType::LessThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break;
case TokenType::GreaterThan: binaryType = ShaderAst::BinaryType::CompLt; break;
case TokenType::GreaterThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break;
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break;
case TokenType::NotEqual: binaryType = ShaderAst::BinaryType::CompNe; break;
case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break;
default:
throw UnexpectedToken{};
}
}
@@ -1159,12 +1226,18 @@ namespace Nz::ShaderLang
{
switch (token)
{
case TokenType::Plus: return 20;
case TokenType::Divide: return 40;
case TokenType::Multiply: return 40;
case TokenType::Minus: return 20;
case TokenType::Dot: return 50;
case TokenType::OpenSquareBracket: return 50;
case TokenType::Divide: return 80;
case TokenType::Dot: return 100;
case TokenType::Equal: return 50;
case TokenType::LessThan: return 40;
case TokenType::LessThanEqual: return 40;
case TokenType::GreaterThan: return 40;
case TokenType::GreaterThanEqual: return 40;
case TokenType::Multiply: return 80;
case TokenType::Minus: return 60;
case TokenType::NotEqual: return 50;
case TokenType::Plus: return 60;
case TokenType::OpenSquareBracket: return 100;
default: return -1;
}
}
@@ -1187,6 +1260,6 @@ namespace Nz::ShaderLang
return {};
}
return Parse(Tokenize(std::string_view(reinterpret_cast<const char*>(source.data()), source.size())));
return Parse(std::string_view(reinterpret_cast<const char*>(source.data()), source.size()));
}
}