Shader: Add function parameters and return handling

This commit is contained in:
Jérôme Leclercq
2021-02-28 17:50:32 +01:00
parent 9a0f201433
commit b320b5b44e
39 changed files with 818 additions and 327 deletions

View File

@@ -8,7 +8,33 @@
namespace Nz::ShaderLang
{
void Parser::Parse(const std::vector<Token>& tokens)
namespace
{
std::unordered_map<std::string, ShaderNodes::BasicType> identifierToBasicType = {
{ "bool", ShaderNodes::BasicType::Boolean },
{ "i32", ShaderNodes::BasicType::Int1 },
{ "vec2i32", ShaderNodes::BasicType::Int2 },
{ "vec3i32", ShaderNodes::BasicType::Int3 },
{ "vec4i32", ShaderNodes::BasicType::Int4 },
{ "f32", ShaderNodes::BasicType::Float1 },
{ "vec2f32", ShaderNodes::BasicType::Float2 },
{ "vec3f32", ShaderNodes::BasicType::Float3 },
{ "vec4f32", ShaderNodes::BasicType::Float4 },
{ "mat4x4f32", ShaderNodes::BasicType::Mat4x4 },
{ "sampler2D", ShaderNodes::BasicType::Sampler2D },
{ "void", ShaderNodes::BasicType::Void },
{ "u32", ShaderNodes::BasicType::UInt1 },
{ "vec2u32", ShaderNodes::BasicType::UInt3 },
{ "vec3u32", ShaderNodes::BasicType::UInt3 },
{ "vec4u32", ShaderNodes::BasicType::UInt4 },
};
}
ShaderAst Parser::Parse(const std::vector<Token>& tokens)
{
Context context;
context.tokenCount = tokens.size();
@@ -16,18 +42,28 @@ namespace Nz::ShaderLang
m_context = &context;
for (const Token& token : tokens)
m_context->tokenIndex = -1;
bool reachedEndOfStream = false;
while (!reachedEndOfStream)
{
switch (token.type)
const Token& nextToken = PeekNext();
switch (nextToken.type)
{
case TokenType::FunctionDeclaration:
ParseFunctionDeclaration();
break;
case TokenType::EndOfStream:
reachedEndOfStream = true;
break;
default:
throw UnexpectedToken{};
}
}
return std::move(context.result);
}
const Token& Parser::Advance()
@@ -42,24 +78,34 @@ namespace Nz::ShaderLang
throw ExpectedToken{};
}
void Parser::ExpectNext(TokenType type)
const Token& Parser::ExpectNext(TokenType type)
{
Expect(m_context->tokens[m_context->tokenIndex + 1], type);
const Token& token = Advance();
Expect(token, type);
return token;
}
void Parser::ParseFunctionBody()
const Token& Parser::PeekNext()
{
assert(m_context->tokenIndex + 1 < m_context->tokenCount);
return m_context->tokens[m_context->tokenIndex + 1];
}
ShaderNodes::StatementPtr Parser::ParseFunctionBody()
{
return ParseStatementList();
}
void Parser::ParseFunctionDeclaration()
{
ExpectNext(TokenType::Identifier);
ExpectNext(TokenType::FunctionDeclaration);
std::string functionName = std::get<std::string>(Advance().data);
std::string functionName = ParseIdentifierAsName();
ExpectNext(TokenType::OpenParenthesis);
Advance();
std::vector<ShaderAst::FunctionParameter> parameters;
bool firstParameter = true;
for (;;)
@@ -74,45 +120,192 @@ namespace Nz::ShaderLang
Advance();
}
ParseFunctionParameter();
parameters.push_back(ParseFunctionParameter());
firstParameter = false;
}
ExpectNext(TokenType::ClosingParenthesis);
Advance();
ShaderExpressionType returnType = ShaderNodes::BasicType::Void;
if (PeekNext().type == TokenType::FunctionReturn)
{
Advance();
Advance(); //< Consume ->
std::string returnType = std::get<std::string>(Advance().data);
returnType = ParseIdentifierAsType();
}
ExpectNext(TokenType::OpenCurlyBracket);
Advance();
ParseFunctionBody();
ShaderNodes::StatementPtr functionBody = ParseFunctionBody();
ExpectNext(TokenType::ClosingCurlyBracket);
Advance();
m_context->result.AddFunction(functionName, functionBody, std::move(parameters), returnType);
}
void Parser::ParseFunctionParameter()
ShaderAst::FunctionParameter Parser::ParseFunctionParameter()
{
ExpectNext(TokenType::Identifier);
std::string parameterName = std::get<std::string>(Advance().data);
std::string parameterName = ParseIdentifierAsName();
ExpectNext(TokenType::Colon);
Advance();
ExpectNext(TokenType::Identifier);
std::string parameterType = std::get<std::string>(Advance().data);
ShaderExpressionType parameterType = ParseIdentifierAsType();
return { parameterName, parameterType };
}
const Token& Parser::PeekNext()
ShaderNodes::StatementPtr Parser::ParseReturnStatement()
{
assert(m_context->tokenIndex + 1 < m_context->tokenCount);
return m_context->tokens[m_context->tokenIndex + 1];
ExpectNext(TokenType::Return);
ShaderNodes::ExpressionPtr expr;
if (PeekNext().type != TokenType::Semicolon)
expr = ParseExpression();
return ShaderNodes::ReturnStatement::Build(std::move(expr));
}
ShaderNodes::StatementPtr Parser::ParseStatement()
{
const Token& token = PeekNext();
ShaderNodes::StatementPtr statement;
switch (token.type)
{
case TokenType::Return:
statement = ParseReturnStatement();
break;
default:
break;
}
ExpectNext(TokenType::Semicolon);
return statement;
}
ShaderNodes::StatementPtr Parser::ParseStatementList()
{
std::vector<ShaderNodes::StatementPtr> statements;
while (PeekNext().type != TokenType::ClosingCurlyBracket)
{
statements.push_back(ParseStatement());
}
return ShaderNodes::StatementBlock::Build(std::move(statements));
}
ShaderNodes::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs)
{
for (;;)
{
const Token& currentOp = PeekNext();
int tokenPrecedence = GetTokenPrecedence(currentOp.type);
if (tokenPrecedence < exprPrecedence)
return lhs;
Advance();
ShaderNodes::ExpressionPtr rhs = ParsePrimaryExpression();
const Token& nextOp = PeekNext();
int nextTokenPrecedence = GetTokenPrecedence(nextOp.type);
if (tokenPrecedence < nextTokenPrecedence)
rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs));
ShaderNodes::BinaryType binaryType;
{
switch (currentOp.type)
{
case TokenType::Plus: binaryType = ShaderNodes::BinaryType::Add; break;
case TokenType::Minus: binaryType = ShaderNodes::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderNodes::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderNodes::BinaryType::Divide; break;
default: throw UnexpectedToken{};
}
}
lhs = ShaderNodes::BinaryOp::Build(binaryType, std::move(lhs), std::move(rhs));
}
}
ShaderNodes::ExpressionPtr Parser::ParseExpression()
{
return ParseBinOpRhs(0, ParsePrimaryExpression());
}
ShaderNodes::ExpressionPtr Parser::ParseIdentifier()
{
const Token& identifier = ExpectNext(TokenType::Identifier);
return ShaderNodes::Identifier::Build(ShaderNodes::ParameterVariable::Build(std::get<std::string>(identifier.data), ShaderNodes::BasicType::Float3));
}
ShaderNodes::ExpressionPtr Parser::ParseIntegerExpression()
{
const Token& integer = ExpectNext(TokenType::IntegerValue);
return ShaderNodes::Constant::Build(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
}
ShaderNodes::ExpressionPtr Parser::ParseParenthesisExpression()
{
ExpectNext(TokenType::OpenParenthesis);
ShaderNodes::ExpressionPtr expression = ParseExpression();
ExpectNext(TokenType::ClosingParenthesis);
return expression;
}
ShaderNodes::ExpressionPtr Parser::ParsePrimaryExpression()
{
const Token& token = PeekNext();
switch (token.type)
{
case TokenType::BoolFalse: return ShaderNodes::Constant::Build(false);
case TokenType::BoolTrue: return ShaderNodes::Constant::Build(true);
case TokenType::Identifier: return ParseIdentifier();
case TokenType::IntegerValue: return ParseIntegerExpression();
case TokenType::OpenParenthesis: return ParseParenthesisExpression();
default: throw UnexpectedToken{};
}
}
std::string Parser::ParseIdentifierAsName()
{
const Token& identifierToken = ExpectNext(TokenType::Identifier);
std::string identifier = std::get<std::string>(identifierToken.data);
auto it = identifierToBasicType.find(identifier);
if (it != identifierToBasicType.end())
throw ReservedKeyword{};
return identifier;
}
ShaderExpressionType Parser::ParseIdentifierAsType()
{
const Token& identifier = ExpectNext(TokenType::Identifier);
auto it = identifierToBasicType.find(std::get<std::string>(identifier.data));
if (it == identifierToBasicType.end())
throw UnknownType{};
return it->second;
}
int Parser::GetTokenPrecedence(TokenType token)
{
switch (token)
{
case TokenType::Plus: return 20;
case TokenType::Divide: return 40;
case TokenType::Multiply: return 40;
case TokenType::Minus: return 20;
default: return -1;
}
}
}