Rework shader AST (WIP)

This commit is contained in:
Jérôme Leclercq
2021-03-10 11:18:13 +01:00
parent b320b5b44e
commit fed7370e77
73 changed files with 2721 additions and 4312 deletions

View File

@@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <cassert>
#include <Nazara/Shader/Debug.hpp>
@@ -10,36 +11,38 @@ namespace Nz::ShaderLang
{
namespace
{
std::unordered_map<std::string, ShaderNodes::BasicType> identifierToBasicType = {
{ "bool", ShaderNodes::BasicType::Boolean },
std::unordered_map<std::string, ShaderAst::BasicType> identifierToBasicType = {
{ "bool", ShaderAst::BasicType::Boolean },
{ "i32", ShaderNodes::BasicType::Int1 },
{ "vec2i32", ShaderNodes::BasicType::Int2 },
{ "vec3i32", ShaderNodes::BasicType::Int3 },
{ "vec4i32", ShaderNodes::BasicType::Int4 },
{ "i32", ShaderAst::BasicType::Int1 },
{ "vec2i32", ShaderAst::BasicType::Int2 },
{ "vec3i32", ShaderAst::BasicType::Int3 },
{ "vec4i32", ShaderAst::BasicType::Int4 },
{ "f32", ShaderNodes::BasicType::Float1 },
{ "vec2f32", ShaderNodes::BasicType::Float2 },
{ "vec3f32", ShaderNodes::BasicType::Float3 },
{ "vec4f32", ShaderNodes::BasicType::Float4 },
{ "f32", ShaderAst::BasicType::Float1 },
{ "vec2f32", ShaderAst::BasicType::Float2 },
{ "vec3f32", ShaderAst::BasicType::Float3 },
{ "vec4f32", ShaderAst::BasicType::Float4 },
{ "mat4x4f32", ShaderNodes::BasicType::Mat4x4 },
{ "sampler2D", ShaderNodes::BasicType::Sampler2D },
{ "void", ShaderNodes::BasicType::Void },
{ "mat4x4f32", ShaderAst::BasicType::Mat4x4 },
{ "sampler2D", ShaderAst::BasicType::Sampler2D },
{ "void", ShaderAst::BasicType::Void },
{ "u32", ShaderNodes::BasicType::UInt1 },
{ "vec2u32", ShaderNodes::BasicType::UInt3 },
{ "vec3u32", ShaderNodes::BasicType::UInt3 },
{ "vec4u32", ShaderNodes::BasicType::UInt4 },
{ "u32", ShaderAst::BasicType::UInt1 },
{ "vec2u32", ShaderAst::BasicType::UInt3 },
{ "vec3u32", ShaderAst::BasicType::UInt3 },
{ "vec4u32", ShaderAst::BasicType::UInt4 },
};
}
ShaderAst Parser::Parse(const std::vector<Token>& tokens)
ShaderAst::StatementPtr Parser::Parse(const std::vector<Token>& tokens)
{
Context context;
context.tokenCount = tokens.size();
context.tokens = tokens.data();
context.root = std::make_unique<ShaderAst::MultiStatement>();
m_context = &context;
m_context->tokenIndex = -1;
@@ -51,7 +54,7 @@ namespace Nz::ShaderLang
switch (nextToken.type)
{
case TokenType::FunctionDeclaration:
ParseFunctionDeclaration();
context.root->statements.push_back(ParseFunctionDeclaration());
break;
case TokenType::EndOfStream:
@@ -63,7 +66,7 @@ namespace Nz::ShaderLang
}
}
return std::move(context.result);
return std::move(context.root);
}
const Token& Parser::Advance()
@@ -92,12 +95,12 @@ namespace Nz::ShaderLang
return m_context->tokens[m_context->tokenIndex + 1];
}
ShaderNodes::StatementPtr Parser::ParseFunctionBody()
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
{
return ParseStatementList();
}
void Parser::ParseFunctionDeclaration()
ShaderAst::StatementPtr Parser::ParseFunctionDeclaration()
{
ExpectNext(TokenType::FunctionDeclaration);
@@ -105,7 +108,7 @@ namespace Nz::ShaderLang
ExpectNext(TokenType::OpenParenthesis);
std::vector<ShaderAst::FunctionParameter> parameters;
std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters;
bool firstParameter = true;
for (;;)
@@ -126,7 +129,7 @@ namespace Nz::ShaderLang
ExpectNext(TokenType::ClosingParenthesis);
ShaderExpressionType returnType = ShaderNodes::BasicType::Void;
ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void;
if (PeekNext().type == TokenType::FunctionReturn)
{
Advance(); //< Consume ->
@@ -136,42 +139,46 @@ namespace Nz::ShaderLang
ExpectNext(TokenType::OpenCurlyBracket);
ShaderNodes::StatementPtr functionBody = ParseFunctionBody();
std::vector<ShaderAst::StatementPtr> functionBody = ParseFunctionBody();
ExpectNext(TokenType::ClosingCurlyBracket);
m_context->result.AddFunction(functionName, functionBody, std::move(parameters), returnType);
return ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
}
ShaderAst::FunctionParameter Parser::ParseFunctionParameter()
ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter()
{
std::string parameterName = ParseIdentifierAsName();
ExpectNext(TokenType::Colon);
ShaderExpressionType parameterType = ParseIdentifierAsType();
ShaderAst::ShaderExpressionType parameterType = ParseIdentifierAsType();
return { parameterName, parameterType };
}
ShaderNodes::StatementPtr Parser::ParseReturnStatement()
ShaderAst::StatementPtr Parser::ParseReturnStatement()
{
ExpectNext(TokenType::Return);
ShaderNodes::ExpressionPtr expr;
ShaderAst::ExpressionPtr expr;
if (PeekNext().type != TokenType::Semicolon)
expr = ParseExpression();
return ShaderNodes::ReturnStatement::Build(std::move(expr));
return ShaderBuilder::Return(std::move(expr));
}
ShaderNodes::StatementPtr Parser::ParseStatement()
ShaderAst::StatementPtr Parser::ParseStatement()
{
const Token& token = PeekNext();
ShaderNodes::StatementPtr statement;
ShaderAst::StatementPtr statement;
switch (token.type)
{
case TokenType::Let:
statement = ParseVariableDeclaration();
break;
case TokenType::Return:
statement = ParseReturnStatement();
break;
@@ -185,18 +192,38 @@ namespace Nz::ShaderLang
return statement;
}
ShaderNodes::StatementPtr Parser::ParseStatementList()
std::vector<ShaderAst::StatementPtr> Parser::ParseStatementList()
{
std::vector<ShaderNodes::StatementPtr> statements;
std::vector<ShaderAst::StatementPtr> statements;
while (PeekNext().type != TokenType::ClosingCurlyBracket)
{
statements.push_back(ParseStatement());
}
return ShaderNodes::StatementBlock::Build(std::move(statements));
return statements;
}
ShaderNodes::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs)
ShaderAst::StatementPtr Parser::ParseVariableDeclaration()
{
ExpectNext(TokenType::Let);
std::string variableName = ParseIdentifierAsName();
ExpectNext(TokenType::Colon);
ShaderAst::ShaderExpressionType variableType = ParseIdentifierAsType();
ShaderAst::ExpressionPtr expression;
if (PeekNext().type == TokenType::Assign)
{
Advance();
expression = ParseExpression();
}
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)
{
for (;;)
{
@@ -207,7 +234,7 @@ namespace Nz::ShaderLang
return lhs;
Advance();
ShaderNodes::ExpressionPtr rhs = ParsePrimaryExpression();
ShaderAst::ExpressionPtr rhs = ParsePrimaryExpression();
const Token& nextOp = PeekNext();
@@ -215,57 +242,58 @@ namespace Nz::ShaderLang
if (tokenPrecedence < nextTokenPrecedence)
rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs));
ShaderNodes::BinaryType binaryType;
ShaderAst::BinaryType binaryType;
{
switch (currentOp.type)
{
case TokenType::Plus: binaryType = ShaderNodes::BinaryType::Add; break;
case TokenType::Minus: binaryType = ShaderNodes::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderNodes::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderNodes::BinaryType::Divide; break;
case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break;
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break;
default: throw UnexpectedToken{};
}
}
lhs = ShaderNodes::BinaryOp::Build(binaryType, std::move(lhs), std::move(rhs));
lhs = ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs));
}
}
ShaderNodes::ExpressionPtr Parser::ParseExpression()
ShaderAst::ExpressionPtr Parser::ParseExpression()
{
return ParseBinOpRhs(0, ParsePrimaryExpression());
}
ShaderNodes::ExpressionPtr Parser::ParseIdentifier()
ShaderAst::ExpressionPtr Parser::ParseIdentifier()
{
const Token& identifier = ExpectNext(TokenType::Identifier);
return ShaderNodes::Identifier::Build(ShaderNodes::ParameterVariable::Build(std::get<std::string>(identifier.data), ShaderNodes::BasicType::Float3));
return ShaderBuilder::Identifier(std::get<std::string>(identifier.data));
}
ShaderNodes::ExpressionPtr Parser::ParseIntegerExpression()
ShaderAst::ExpressionPtr Parser::ParseIntegerExpression()
{
const Token& integer = ExpectNext(TokenType::IntegerValue);
return ShaderNodes::Constant::Build(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
return ShaderBuilder::Constant(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
}
ShaderNodes::ExpressionPtr Parser::ParseParenthesisExpression()
ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression()
{
ExpectNext(TokenType::OpenParenthesis);
ShaderNodes::ExpressionPtr expression = ParseExpression();
ShaderAst::ExpressionPtr expression = ParseExpression();
ExpectNext(TokenType::ClosingParenthesis);
return expression;
}
ShaderNodes::ExpressionPtr Parser::ParsePrimaryExpression()
ShaderAst::ExpressionPtr Parser::ParsePrimaryExpression()
{
const Token& token = PeekNext();
switch (token.type)
{
case TokenType::BoolFalse: return ShaderNodes::Constant::Build(false);
case TokenType::BoolTrue: return ShaderNodes::Constant::Build(true);
case TokenType::BoolFalse: return ShaderBuilder::Constant(false);
case TokenType::BoolTrue: return ShaderBuilder::Constant(true);
case TokenType::FloatingPointValue: return ShaderBuilder::Constant(float(std::get<double>(Advance().data))); //< FIXME
case TokenType::Identifier: return ParseIdentifier();
case TokenType::IntegerValue: return ParseIntegerExpression();
case TokenType::OpenParenthesis: return ParseParenthesisExpression();
@@ -286,7 +314,7 @@ namespace Nz::ShaderLang
return identifier;
}
ShaderExpressionType Parser::ParseIdentifierAsType()
ShaderAst::ShaderExpressionType Parser::ParseIdentifierAsType()
{
const Token& identifier = ExpectNext(TokenType::Identifier);