Shader: Replace const for with [unroll] attribute

This commit is contained in:
Jérôme Leclercq
2022-01-03 20:21:09 +01:00
parent b6e4a9470e
commit 2bdcc045cd
13 changed files with 204 additions and 87 deletions

View File

@@ -37,6 +37,7 @@ namespace Nz::ShaderLang
{ "layout", ShaderAst::AttributeType::Layout },
{ "location", ShaderAst::AttributeType::Location },
{ "set", ShaderAst::AttributeType::Set },
{ "unroll", ShaderAst::AttributeType::Unroll },
};
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
@@ -54,6 +55,12 @@ namespace Nz::ShaderLang
{ "std140", StructLayout::Std140 }
};
std::unordered_map<std::string, ShaderAst::LoopUnroll> s_unrollModes = {
{ "always", ShaderAst::LoopUnroll::Always },
{ "hint", ShaderAst::LoopUnroll::Hint },
{ "never", ShaderAst::LoopUnroll::Never }
};
template<typename T, typename U>
std::optional<T> BoundCast(U val)
{
@@ -76,26 +83,33 @@ namespace Nz::ShaderLang
}
template<typename T>
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param)
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param, std::optional<T> defaultValue = {})
{
if (targetAttribute.HasValue())
throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" };
//FIXME: This should be handled with global values at sanitization stage
if (!param)
throw AttributeError{ "attribute " + std::string(attributeName) + " requires a value" };
if (param)
{
const ShaderAst::ExpressionPtr& expr = *param;
if (expr->GetType() != ShaderAst::NodeType::IdentifierExpression)
throw AttributeError{ "attribute " + std::string(attributeName) + " can only be an identifier for now" };
const ShaderAst::ExpressionPtr& expr = *param;
if (expr->GetType() != ShaderAst::NodeType::IdentifierExpression)
throw AttributeError{ "attribute " + std::string(attributeName) + " can only be an identifier for now" };
const std::string& exprStr = static_cast<ShaderAst::IdentifierExpression&>(*expr).identifier;
const std::string& exprStr = static_cast<ShaderAst::IdentifierExpression&>(*expr).identifier;
auto it = map.find(exprStr);
if (it == map.end())
throw AttributeError{ ("invalid parameter " + exprStr + " for " + std::string(attributeName) + " attribute").c_str() };
auto it = map.find(exprStr);
if (it == map.end())
throw AttributeError{ ("invalid parameter " + exprStr + " for " + std::string(attributeName) + " attribute").c_str() };
targetAttribute = it->second;
}
else
{
if (!defaultValue)
throw AttributeError{ "attribute " + std::string(attributeName) + " requires a value" };
targetAttribute = it->second;
targetAttribute = defaultValue.value();
}
}
}
@@ -473,14 +487,6 @@ namespace Nz::ShaderLang
switch (Peek().type)
{
case TokenType::For:
{
auto forEach = ParseForDeclaration();
SafeCast<ShaderAst::ForEachStatement&>(*forEach).isConst = true;
return forEach;
}
case TokenType::Identifier:
{
std::string constName;
@@ -598,7 +604,7 @@ namespace Nz::ShaderLang
return externalStatement;
}
ShaderAst::StatementPtr Parser::ParseForDeclaration()
ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector<ShaderAst::Attribute> attributes)
{
Expect(Advance(), TokenType::For);
@@ -610,7 +616,22 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr statement = ParseStatement();
return ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement));
auto forEach = ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement));
for (auto&& [attributeType, arg] : attributes)
{
switch (attributeType)
{
case ShaderAst::AttributeType::Unroll:
HandleUniqueStringAttribute("unroll", s_unrollModes, forEach->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always));
break;
default:
throw AttributeError{ "unhandled attribute for for-each" };
}
}
return forEach;
}
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
@@ -745,47 +766,74 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr Parser::ParseSingleStatement()
{
const Token& token = Peek();
std::vector<ShaderAst::Attribute> attributes;
ShaderAst::StatementPtr statement;
switch (token.type)
do
{
case TokenType::Const:
statement = ParseConstStatement();
break;
const Token& token = Peek();
switch (token.type)
{
case TokenType::Const:
if (!attributes.empty())
throw UnexpectedToken{};
case TokenType::Discard:
statement = ParseDiscardStatement();
break;
statement = ParseConstStatement();
break;
case TokenType::For:
statement = ParseForDeclaration();
break;
case TokenType::Discard:
if (!attributes.empty())
throw UnexpectedToken{};
case TokenType::Let:
statement = ParseVariableDeclaration();
break;
statement = ParseDiscardStatement();
break;
case TokenType::Identifier:
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::For:
statement = ParseForDeclaration(std::move(attributes));
break;
case TokenType::If:
statement = ParseBranchStatement();
break;
case TokenType::Let:
if (!attributes.empty())
throw UnexpectedToken{};
case TokenType::Return:
statement = ParseReturnStatement();
break;
statement = ParseVariableDeclaration();
break;
case TokenType::While:
statement = ParseWhileStatement();
break;
case TokenType::Identifier:
if (!attributes.empty())
throw UnexpectedToken{};
default:
throw UnexpectedToken{};
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::If:
if (!attributes.empty())
throw UnexpectedToken{};
statement = ParseBranchStatement();
break;
case TokenType::OpenSquareBracket:
assert(attributes.empty());
attributes = ParseAttributes();
break;
case TokenType::Return:
if (!attributes.empty())
throw UnexpectedToken{};
statement = ParseReturnStatement();
break;
case TokenType::While:
statement = ParseWhileStatement(std::move(attributes));
break;
default:
throw UnexpectedToken{};
}
}
while (!statement); //< small trick to repeat parsing once we got attributes
return statement;
}
@@ -955,7 +1003,7 @@ namespace Nz::ShaderLang
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}
ShaderAst::StatementPtr Parser::ParseWhileStatement()
ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector<ShaderAst::Attribute> attributes)
{
Expect(Advance(), TokenType::While);
@@ -967,7 +1015,22 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr body = ParseStatement();
return ShaderBuilder::While(std::move(condition), std::move(body));
auto whileStatement = ShaderBuilder::While(std::move(condition), std::move(body));
for (auto&& [attributeType, arg] : attributes)
{
switch (attributeType)
{
case ShaderAst::AttributeType::Unroll:
HandleUniqueStringAttribute("unroll", s_unrollModes, whileStatement->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always));
break;
default:
throw AttributeError{ "unhandled attribute for while" };
}
}
return whileStatement;
}
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)