Shader: Handle type as expressions

This commit is contained in:
Jérôme Leclercq
2022-02-08 17:03:34 +01:00
parent 5ce8120a0c
commit 402e16bd2b
53 changed files with 1746 additions and 1141 deletions

View File

@@ -20,13 +20,6 @@ namespace Nz::ShaderLang
{ "unchanged", ShaderAst::DepthWriteMode::Unchanged },
};
std::unordered_map<std::string, ShaderAst::PrimitiveType> s_identifierToBasicType = {
{ "bool", ShaderAst::PrimitiveType::Boolean },
{ "i32", ShaderAst::PrimitiveType::Int32 },
{ "f32", ShaderAst::PrimitiveType::Float32 },
{ "u32", ShaderAst::PrimitiveType::UInt32 }
};
std::unordered_map<std::string, ShaderAst::AttributeType> s_identifierToAttributeType = {
{ "binding", ShaderAst::AttributeType::Binding },
{ "builtin", ShaderAst::AttributeType::Builtin },
@@ -71,7 +64,7 @@ namespace Nz::ShaderLang
}
template<typename T>
void HandleUniqueAttribute(const std::string_view& attributeName, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param, bool requireValue = true)
void HandleUniqueAttribute(const std::string_view& attributeName, ShaderAst::ExpressionValue<T>& targetAttribute, ShaderAst::ExprValue::Param&& param, bool requireValue = true)
{
if (targetAttribute.HasValue())
throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" };
@@ -83,7 +76,7 @@ namespace Nz::ShaderLang
}
template<typename T>
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param, std::optional<T> defaultValue = {})
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::ExpressionValue<T>& targetAttribute, ShaderAst::ExprValue::Param&& param, std::optional<T> defaultValue = {})
{
if (targetAttribute.HasValue())
throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" };
@@ -123,9 +116,7 @@ namespace Nz::ShaderLang
m_context = &context;
std::vector<ShaderAst::Attribute> attributes;
EnterScope();
std::vector<ShaderAst::ExprValue> attributes;
bool reachedEndOfStream = false;
while (!reachedEndOfStream)
@@ -179,8 +170,6 @@ namespace Nz::ShaderLang
}
}
LeaveScope();
return std::move(context.root);
}
@@ -198,161 +187,6 @@ namespace Nz::ShaderLang
m_context->tokenIndex += count;
}
std::optional<ShaderAst::ExpressionType> Parser::DecodeType(const std::string& identifier)
{
if (auto it = s_identifierToBasicType.find(identifier); it != s_identifierToBasicType.end())
{
Consume();
return it->second;
}
//FIXME: Handle this better
if (identifier == "array")
{
Consume();
Expect(Advance(), TokenType::OpenSquareBracket); //< [
ShaderAst::ArrayType arrayType;
arrayType.containedType = std::make_unique<ShaderAst::ContainedType>();
arrayType.containedType->type = ParseType();
Expect(Advance(), TokenType::Comma); //< ,
arrayType.length = ParseExpression();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return arrayType;
}
else if (identifier == "mat4")
{
Consume();
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 4;
matrixType.rowCount = 4;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return matrixType;
}
else if (identifier == "mat3")
{
Consume();
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 3;
matrixType.rowCount = 3;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return matrixType;
}
else if (identifier == "mat2")
{
Consume();
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 2;
matrixType.rowCount = 2;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return matrixType;
}
else if (identifier == "sampler2D")
{
Consume();
ShaderAst::SamplerType samplerType;
samplerType.dim = ImageType::E2D;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return samplerType;
}
else if (identifier == "samplerCube")
{
Consume();
ShaderAst::SamplerType samplerType;
samplerType.dim = ImageType::Cubemap;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return samplerType;
}
else if (identifier == "uniform")
{
Consume();
ShaderAst::UniformType uniformType;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() };
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return uniformType;
}
else if (identifier == "vec2")
{
Consume();
ShaderAst::VectorType vectorType;
vectorType.componentCount = 2;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return vectorType;
}
else if (identifier == "vec3")
{
Consume();
ShaderAst::VectorType vectorType;
vectorType.componentCount = 3;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return vectorType;
}
else if (identifier == "vec4")
{
Consume();
ShaderAst::VectorType vectorType;
vectorType.componentCount = 4;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return vectorType;
}
else
return std::nullopt;
}
void Parser::EnterScope()
{
m_context->scopeSizes.push_back(m_context->identifiersInScope.size());
}
const Token& Parser::Expect(const Token& token, TokenType type)
{
if (token.type != type)
@@ -377,33 +211,15 @@ namespace Nz::ShaderLang
return token;
}
void Parser::LeaveScope()
{
assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.resize(m_context->scopeSizes.back());
m_context->scopeSizes.pop_back();
}
bool Parser::IsVariableInScope(const std::string_view& identifier) const
{
return std::find(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), identifier) != m_context->identifiersInScope.rend();
}
void Parser::RegisterVariable(std::string identifier)
{
assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.push_back(std::move(identifier));
}
const Token& Parser::Peek(std::size_t advance)
{
assert(m_context->tokenIndex + advance < m_context->tokenCount);
return m_context->tokens[m_context->tokenIndex + advance];
}
std::vector<ShaderAst::Attribute> Parser::ParseAttributes()
std::vector<ShaderAst::ExprValue> Parser::ParseAttributes()
{
std::vector<ShaderAst::Attribute> attributes;
std::vector<ShaderAst::ExprValue> attributes;
Expect(Advance(), TokenType::OpenSquareBracket);
@@ -431,7 +247,7 @@ namespace Nz::ShaderLang
ShaderAst::AttributeType attributeType = ParseIdentifierAsAttributeType();
ShaderAst::Attribute::Param arg;
ShaderAst::ExprValue::Param arg;
if (Peek().type == TokenType::OpenParenthesis)
{
Consume();
@@ -454,7 +270,7 @@ namespace Nz::ShaderLang
return attributes;
}
void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionType& type, ShaderAst::ExpressionPtr& initialValue)
void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type, ShaderAst::ExpressionPtr& initialValue)
{
name = ParseIdentifierAsName();
@@ -464,10 +280,8 @@ namespace Nz::ShaderLang
type = ParseType();
}
else
type = ShaderAst::NoType{};
if (IsNoType(type) || Peek().type == TokenType::Assign)
if (!type.HasValue() || Peek().type == TokenType::Assign)
{
Expect(Advance(), TokenType::Assign);
initialValue = ParseExpression();
@@ -522,11 +336,10 @@ namespace Nz::ShaderLang
case TokenType::Identifier:
{
std::string constName;
ShaderAst::ExpressionType constType;
ShaderAst::ExpressionValue<ShaderAst::ExpressionType> constType;
ShaderAst::ExpressionPtr initialValue;
ParseVariableDeclaration(constName, constType, initialValue);
RegisterVariable(constName);
return ShaderBuilder::DeclareConst(std::move(constName), std::move(constType), std::move(initialValue));
}
@@ -552,14 +365,14 @@ namespace Nz::ShaderLang
return ShaderBuilder::Discard();
}
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::External);
Expect(Advance(), TokenType::OpenCurlyBracket);
std::unique_ptr<ShaderAst::DeclareExternalStatement> externalStatement = std::make_unique<ShaderAst::DeclareExternalStatement>();
ShaderAst::AttributeValue<bool> condition;
ShaderAst::ExpressionValue<bool> condition;
for (auto&& [attributeType, arg] : attributes)
{
@@ -624,8 +437,6 @@ namespace Nz::ShaderLang
extVar.name = ParseIdentifierAsName();
Expect(Advance(), TokenType::Colon);
extVar.type = ParseType();
RegisterVariable(extVar.name);
}
Expect(Advance(), TokenType::ClosingCurlyBracket);
@@ -636,7 +447,7 @@ namespace Nz::ShaderLang
return externalStatement;
}
ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::For);
@@ -710,7 +521,7 @@ namespace Nz::ShaderLang
return ParseStatementList();
}
ShaderAst::StatementPtr Parser::ParseFunctionDeclaration(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseFunctionDeclaration(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::FunctionDeclaration);
@@ -738,24 +549,18 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingParenthesis);
ShaderAst::ExpressionType returnType;
ShaderAst::ExpressionValue<ShaderAst::ExpressionType> returnType;
if (Peek().type == TokenType::Arrow)
{
Consume();
returnType = ParseType();
}
EnterScope();
for (const auto& parameter : parameters)
RegisterVariable(parameter.name);
std::vector<ShaderAst::StatementPtr> functionBody = ParseFunctionBody();
LeaveScope();
auto func = ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
ShaderAst::AttributeValue<bool> condition;
ShaderAst::ExpressionValue<bool> condition;
for (auto&& [attributeType, arg] : attributes)
{
@@ -794,7 +599,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Colon);
ShaderAst::ExpressionType parameterType = ParseType();
ShaderAst::ExpressionPtr parameterType = ParseType();
return { parameterName, std::move(parameterType) };
}
@@ -807,7 +612,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Colon);
ShaderAst::ExpressionType optionType = ParseType();
ShaderAst::ExpressionPtr optionType = ParseType();
ShaderAst::ExpressionPtr initialValue;
if (Peek().type == TokenType::Assign)
@@ -837,7 +642,7 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr Parser::ParseSingleStatement()
{
std::vector<ShaderAst::Attribute> attributes;
std::vector<ShaderAst::ExprValue> attributes;
ShaderAst::StatementPtr statement;
do
{
@@ -912,15 +717,13 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr Parser::ParseStatement()
{
if (Peek().type == TokenType::OpenCurlyBracket)
return ShaderBuilder::MultiStatement(ParseStatementList());
return ShaderBuilder::Scoped(ShaderBuilder::MultiStatement(ParseStatementList()));
else
return ParseSingleStatement();
}
std::vector<ShaderAst::StatementPtr> Parser::ParseStatementList()
{
EnterScope();
Expect(Advance(), TokenType::OpenCurlyBracket);
std::vector<ShaderAst::StatementPtr> statements;
@@ -931,19 +734,17 @@ namespace Nz::ShaderLang
}
Consume(); //< Consume closing curly bracket
LeaveScope();
return statements;
}
ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::Struct);
ShaderAst::StructDescription description;
description.name = ParseIdentifierAsName();
ShaderAst::AttributeValue<bool> condition;
ShaderAst::ExpressionValue<bool> condition;
for (auto&& [attributeType, attributeParam] : attributes)
{
@@ -1065,16 +866,15 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Let);
std::string variableName;
ShaderAst::ExpressionType variableType;
ShaderAst::ExpressionValue<ShaderAst::ExpressionType> variableType;
ShaderAst::ExpressionPtr expression;
ParseVariableDeclaration(variableName, variableType, expression);
RegisterVariable(variableName);
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}
ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::While);
@@ -1133,19 +933,6 @@ namespace Nz::ShaderLang
accessMemberNode->identifiers.push_back(ParseIdentifierAsName());
} while (Peek().type == TokenType::Dot);
// FIXME
if (!accessMemberNode->identifiers.empty() && accessMemberNode->identifiers.front() == "Sample")
{
if (Peek().type == TokenType::OpenParenthesis)
{
auto parameters = ParseParameters();
parameters.insert(parameters.begin(), std::move(accessMemberNode->expr));
lhs = ShaderBuilder::Intrinsic(ShaderAst::IntrinsicType::SampleTexture, std::move(parameters));
break;
}
}
lhs = std::move(accessMemberNode);
}
else
@@ -1160,10 +947,10 @@ namespace Nz::ShaderLang
Consume();
indexNode->indices.push_back(ParseExpression());
Expect(Advance(), TokenType::ClosingSquareBracket);
}
while (Peek().type == TokenType::OpenSquareBracket);
while (Peek().type == TokenType::Comma);
Expect(Advance(), TokenType::ClosingSquareBracket);
lhs = std::move(indexNode);
}
@@ -1171,6 +958,15 @@ namespace Nz::ShaderLang
currentTokenType = Peek().type;
}
if (currentTokenType == TokenType::OpenParenthesis)
{
// Function call
auto parameters = ParseParameters();
lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters));
c = true;
}
if (c)
continue;
@@ -1302,23 +1098,7 @@ namespace Nz::ShaderLang
return ParseFloatingPointExpression();
case TokenType::Identifier:
{
const std::string& identifier = std::get<std::string>(token.data);
// Is it a cast?
std::optional<ShaderAst::ExpressionType> exprType = DecodeType(identifier);
if (exprType)
return ShaderBuilder::Cast(std::move(*exprType), ParseParameters());
if (Peek(1).type == TokenType::OpenParenthesis)
{
// Function call
Consume();
return ShaderBuilder::CallFunction(identifier, ParseParameters());
}
else
return ParseIdentifier();
}
return ParseIdentifier();
case TokenType::IntegerValue:
return ParseIntegerExpression();
@@ -1370,28 +1150,10 @@ namespace Nz::ShaderLang
const std::string& Parser::ParseIdentifierAsName()
{
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
auto it = s_identifierToBasicType.find(identifier);
if (it != s_identifierToBasicType.end())
throw ReservedKeyword{};
return identifier;
return std::get<std::string>(identifierToken.data);
}
ShaderAst::PrimitiveType Parser::ParsePrimitiveType()
{
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
auto it = s_identifierToBasicType.find(identifier);
if (it == s_identifierToBasicType.end())
throw UnknownType{};
return it->second;
}
ShaderAst::ExpressionType Parser::ParseType()
ShaderAst::ExpressionPtr Parser::ParseType()
{
// Handle () as no type
if (Peek().type == TokenType::OpenParenthesis)
@@ -1399,20 +1161,10 @@ namespace Nz::ShaderLang
Consume();
Expect(Advance(), TokenType::ClosingParenthesis);
return ShaderAst::NoType{};
return ShaderBuilder::Constant(ShaderAst::NoValue{});
}
const Token& identifierToken = Expect(Peek(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
auto type = DecodeType(identifier);
if (!type)
{
Consume();
return ShaderAst::IdentifierType{ identifier };
}
return *std::move(type);
return ParseExpression();
}
int Parser::GetTokenPrecedence(TokenType token)
@@ -1433,6 +1185,7 @@ namespace Nz::ShaderLang
case TokenType::NotEqual: return 50;
case TokenType::Plus: return 60;
case TokenType::OpenSquareBracket: return 100;
case TokenType::OpenParenthesis: return 100;
default: return -1;
}
}