Shader: Replace const for with [unroll] attribute
This commit is contained in:
@@ -173,9 +173,9 @@ namespace Nz::ShaderAst
|
||||
StatementPtr AstCloner::Clone(ForEachStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<ForEachStatement>();
|
||||
clone->isConst = node.isConst;
|
||||
clone->expression = CloneExpression(node.expression);
|
||||
clone->statement = CloneStatement(node.statement);
|
||||
clone->unroll = Clone(node.unroll);
|
||||
|
||||
return clone;
|
||||
}
|
||||
@@ -208,6 +208,7 @@ namespace Nz::ShaderAst
|
||||
auto clone = std::make_unique<WhileStatement>();
|
||||
clone->condition = CloneExpression(node.condition);
|
||||
clone->body = CloneStatement(node.body);
|
||||
clone->unroll = Clone(node.unroll);
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
@@ -303,7 +303,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
void AstSerializerBase::Serialize(ForEachStatement& node)
|
||||
{
|
||||
Value(node.isConst);
|
||||
Attribute(node.unroll);
|
||||
Value(node.varName);
|
||||
Node(node.expression);
|
||||
Node(node.statement);
|
||||
@@ -328,6 +328,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
void AstSerializerBase::Serialize(WhileStatement& node)
|
||||
{
|
||||
Attribute(node.unroll);
|
||||
Node(node.condition);
|
||||
Node(node.body);
|
||||
}
|
||||
|
||||
@@ -843,29 +843,34 @@ namespace Nz::ShaderAst
|
||||
else
|
||||
throw AstError{ "for-each is only supported on arrays and range expressions" };
|
||||
|
||||
if (node.isConst)
|
||||
AttributeValue<LoopUnroll> unrollValue;
|
||||
if (node.unroll.HasValue())
|
||||
{
|
||||
// Repeat code
|
||||
auto multi = std::make_unique<MultiStatement>();
|
||||
if (IsArrayType(exprType))
|
||||
unrollValue = ComputeAttributeValue(node.unroll);
|
||||
if (unrollValue.GetResultingValue() == LoopUnroll::Always)
|
||||
{
|
||||
const ArrayType& arrayType = std::get<ArrayType>(exprType);
|
||||
UInt32 length = arrayType.length.GetResultingValue();
|
||||
|
||||
for (UInt32 i = 0; i < length; ++i)
|
||||
// Repeat code
|
||||
auto multi = std::make_unique<MultiStatement>();
|
||||
if (IsArrayType(exprType))
|
||||
{
|
||||
auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i));
|
||||
Validate(*accessIndex);
|
||||
const ArrayType& arrayType = std::get<ArrayType>(exprType);
|
||||
UInt32 length = arrayType.length.GetResultingValue();
|
||||
|
||||
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
|
||||
Validate(*elementVariable);
|
||||
for (UInt32 i = 0; i < length; ++i)
|
||||
{
|
||||
auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i));
|
||||
Validate(*accessIndex);
|
||||
|
||||
multi->statements.emplace_back(std::move(elementVariable));
|
||||
multi->statements.emplace_back(CloneStatement(node.statement));
|
||||
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
|
||||
Validate(*elementVariable);
|
||||
|
||||
multi->statements.emplace_back(std::move(elementVariable));
|
||||
multi->statements.emplace_back(CloneStatement(node.statement));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return multi;
|
||||
return multi;
|
||||
}
|
||||
}
|
||||
|
||||
if (m_context->options.reduceLoopsToWhile)
|
||||
@@ -890,6 +895,7 @@ namespace Nz::ShaderAst
|
||||
multi->statements.emplace_back(std::move(counterVariable));
|
||||
|
||||
auto whileStatement = std::make_unique<WhileStatement>();
|
||||
whileStatement->unroll = std::move(unrollValue);
|
||||
|
||||
// While condition
|
||||
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(length));
|
||||
@@ -928,6 +934,7 @@ namespace Nz::ShaderAst
|
||||
auto clone = std::make_unique<ForEachStatement>();
|
||||
clone->expression = std::move(expr);
|
||||
clone->varName = node.varName;
|
||||
clone->unroll = std::move(unrollValue);
|
||||
|
||||
PushScope();
|
||||
{
|
||||
@@ -968,9 +975,15 @@ namespace Nz::ShaderAst
|
||||
MandatoryStatement(node.body);
|
||||
|
||||
auto clone = static_unique_pointer_cast<WhileStatement>(AstCloner::Clone(node));
|
||||
Validate(*clone);
|
||||
|
||||
if (GetExpressionType(*clone->condition) != ExpressionType{ PrimitiveType::Boolean })
|
||||
throw AstError{ "expected a boolean value" };
|
||||
AttributeValue<LoopUnroll> unrollValue;
|
||||
if (node.unroll.HasValue())
|
||||
{
|
||||
clone->unroll = ComputeAttributeValue(node.unroll);
|
||||
if (clone->unroll.GetResultingValue() == LoopUnroll::Always)
|
||||
throw AstError{ "unroll(always) is not yet supported on while" };
|
||||
}
|
||||
|
||||
return clone;
|
||||
}
|
||||
@@ -1350,6 +1363,12 @@ namespace Nz::ShaderAst
|
||||
}
|
||||
}
|
||||
|
||||
void SanitizeVisitor::Validate(WhileStatement& node)
|
||||
{
|
||||
if (GetExpressionType(*node.condition) != ExpressionType{ PrimitiveType::Boolean })
|
||||
throw AstError{ "expected a boolean value" };
|
||||
}
|
||||
|
||||
void SanitizeVisitor::Validate(AccessIndexExpression& node)
|
||||
{
|
||||
if (node.indices.empty())
|
||||
|
||||
@@ -37,6 +37,7 @@ namespace Nz::ShaderLang
|
||||
{ "layout", ShaderAst::AttributeType::Layout },
|
||||
{ "location", ShaderAst::AttributeType::Location },
|
||||
{ "set", ShaderAst::AttributeType::Set },
|
||||
{ "unroll", ShaderAst::AttributeType::Unroll },
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
|
||||
@@ -54,6 +55,12 @@ namespace Nz::ShaderLang
|
||||
{ "std140", StructLayout::Std140 }
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, ShaderAst::LoopUnroll> s_unrollModes = {
|
||||
{ "always", ShaderAst::LoopUnroll::Always },
|
||||
{ "hint", ShaderAst::LoopUnroll::Hint },
|
||||
{ "never", ShaderAst::LoopUnroll::Never }
|
||||
};
|
||||
|
||||
template<typename T, typename U>
|
||||
std::optional<T> BoundCast(U val)
|
||||
{
|
||||
@@ -76,26 +83,33 @@ namespace Nz::ShaderLang
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param)
|
||||
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param, std::optional<T> defaultValue = {})
|
||||
{
|
||||
if (targetAttribute.HasValue())
|
||||
throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" };
|
||||
|
||||
//FIXME: This should be handled with global values at sanitization stage
|
||||
if (!param)
|
||||
throw AttributeError{ "attribute " + std::string(attributeName) + " requires a value" };
|
||||
if (param)
|
||||
{
|
||||
const ShaderAst::ExpressionPtr& expr = *param;
|
||||
if (expr->GetType() != ShaderAst::NodeType::IdentifierExpression)
|
||||
throw AttributeError{ "attribute " + std::string(attributeName) + " can only be an identifier for now" };
|
||||
|
||||
const ShaderAst::ExpressionPtr& expr = *param;
|
||||
if (expr->GetType() != ShaderAst::NodeType::IdentifierExpression)
|
||||
throw AttributeError{ "attribute " + std::string(attributeName) + " can only be an identifier for now" };
|
||||
const std::string& exprStr = static_cast<ShaderAst::IdentifierExpression&>(*expr).identifier;
|
||||
|
||||
const std::string& exprStr = static_cast<ShaderAst::IdentifierExpression&>(*expr).identifier;
|
||||
auto it = map.find(exprStr);
|
||||
if (it == map.end())
|
||||
throw AttributeError{ ("invalid parameter " + exprStr + " for " + std::string(attributeName) + " attribute").c_str() };
|
||||
|
||||
auto it = map.find(exprStr);
|
||||
if (it == map.end())
|
||||
throw AttributeError{ ("invalid parameter " + exprStr + " for " + std::string(attributeName) + " attribute").c_str() };
|
||||
targetAttribute = it->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!defaultValue)
|
||||
throw AttributeError{ "attribute " + std::string(attributeName) + " requires a value" };
|
||||
|
||||
targetAttribute = it->second;
|
||||
targetAttribute = defaultValue.value();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,14 +487,6 @@ namespace Nz::ShaderLang
|
||||
|
||||
switch (Peek().type)
|
||||
{
|
||||
case TokenType::For:
|
||||
{
|
||||
auto forEach = ParseForDeclaration();
|
||||
SafeCast<ShaderAst::ForEachStatement&>(*forEach).isConst = true;
|
||||
|
||||
return forEach;
|
||||
}
|
||||
|
||||
case TokenType::Identifier:
|
||||
{
|
||||
std::string constName;
|
||||
@@ -598,7 +604,7 @@ namespace Nz::ShaderLang
|
||||
return externalStatement;
|
||||
}
|
||||
|
||||
ShaderAst::StatementPtr Parser::ParseForDeclaration()
|
||||
ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector<ShaderAst::Attribute> attributes)
|
||||
{
|
||||
Expect(Advance(), TokenType::For);
|
||||
|
||||
@@ -610,7 +616,22 @@ namespace Nz::ShaderLang
|
||||
|
||||
ShaderAst::StatementPtr statement = ParseStatement();
|
||||
|
||||
return ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement));
|
||||
auto forEach = ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement));
|
||||
|
||||
for (auto&& [attributeType, arg] : attributes)
|
||||
{
|
||||
switch (attributeType)
|
||||
{
|
||||
case ShaderAst::AttributeType::Unroll:
|
||||
HandleUniqueStringAttribute("unroll", s_unrollModes, forEach->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always));
|
||||
break;
|
||||
|
||||
default:
|
||||
throw AttributeError{ "unhandled attribute for for-each" };
|
||||
}
|
||||
}
|
||||
|
||||
return forEach;
|
||||
}
|
||||
|
||||
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
|
||||
@@ -745,47 +766,74 @@ namespace Nz::ShaderLang
|
||||
|
||||
ShaderAst::StatementPtr Parser::ParseSingleStatement()
|
||||
{
|
||||
const Token& token = Peek();
|
||||
|
||||
std::vector<ShaderAst::Attribute> attributes;
|
||||
ShaderAst::StatementPtr statement;
|
||||
switch (token.type)
|
||||
do
|
||||
{
|
||||
case TokenType::Const:
|
||||
statement = ParseConstStatement();
|
||||
break;
|
||||
const Token& token = Peek();
|
||||
switch (token.type)
|
||||
{
|
||||
case TokenType::Const:
|
||||
if (!attributes.empty())
|
||||
throw UnexpectedToken{};
|
||||
|
||||
case TokenType::Discard:
|
||||
statement = ParseDiscardStatement();
|
||||
break;
|
||||
statement = ParseConstStatement();
|
||||
break;
|
||||
|
||||
case TokenType::For:
|
||||
statement = ParseForDeclaration();
|
||||
break;
|
||||
case TokenType::Discard:
|
||||
if (!attributes.empty())
|
||||
throw UnexpectedToken{};
|
||||
|
||||
case TokenType::Let:
|
||||
statement = ParseVariableDeclaration();
|
||||
break;
|
||||
statement = ParseDiscardStatement();
|
||||
break;
|
||||
|
||||
case TokenType::Identifier:
|
||||
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
|
||||
Expect(Advance(), TokenType::Semicolon);
|
||||
break;
|
||||
case TokenType::For:
|
||||
statement = ParseForDeclaration(std::move(attributes));
|
||||
break;
|
||||
|
||||
case TokenType::If:
|
||||
statement = ParseBranchStatement();
|
||||
break;
|
||||
case TokenType::Let:
|
||||
if (!attributes.empty())
|
||||
throw UnexpectedToken{};
|
||||
|
||||
case TokenType::Return:
|
||||
statement = ParseReturnStatement();
|
||||
break;
|
||||
statement = ParseVariableDeclaration();
|
||||
break;
|
||||
|
||||
case TokenType::While:
|
||||
statement = ParseWhileStatement();
|
||||
break;
|
||||
case TokenType::Identifier:
|
||||
if (!attributes.empty())
|
||||
throw UnexpectedToken{};
|
||||
|
||||
default:
|
||||
throw UnexpectedToken{};
|
||||
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
|
||||
Expect(Advance(), TokenType::Semicolon);
|
||||
break;
|
||||
|
||||
case TokenType::If:
|
||||
if (!attributes.empty())
|
||||
throw UnexpectedToken{};
|
||||
|
||||
statement = ParseBranchStatement();
|
||||
break;
|
||||
|
||||
case TokenType::OpenSquareBracket:
|
||||
assert(attributes.empty());
|
||||
attributes = ParseAttributes();
|
||||
break;
|
||||
|
||||
case TokenType::Return:
|
||||
if (!attributes.empty())
|
||||
throw UnexpectedToken{};
|
||||
|
||||
statement = ParseReturnStatement();
|
||||
break;
|
||||
|
||||
case TokenType::While:
|
||||
statement = ParseWhileStatement(std::move(attributes));
|
||||
break;
|
||||
|
||||
default:
|
||||
throw UnexpectedToken{};
|
||||
}
|
||||
}
|
||||
while (!statement); //< small trick to repeat parsing once we got attributes
|
||||
|
||||
return statement;
|
||||
}
|
||||
@@ -955,7 +1003,7 @@ namespace Nz::ShaderLang
|
||||
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
|
||||
}
|
||||
|
||||
ShaderAst::StatementPtr Parser::ParseWhileStatement()
|
||||
ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector<ShaderAst::Attribute> attributes)
|
||||
{
|
||||
Expect(Advance(), TokenType::While);
|
||||
|
||||
@@ -967,7 +1015,22 @@ namespace Nz::ShaderLang
|
||||
|
||||
ShaderAst::StatementPtr body = ParseStatement();
|
||||
|
||||
return ShaderBuilder::While(std::move(condition), std::move(body));
|
||||
auto whileStatement = ShaderBuilder::While(std::move(condition), std::move(body));
|
||||
|
||||
for (auto&& [attributeType, arg] : attributes)
|
||||
{
|
||||
switch (attributeType)
|
||||
{
|
||||
case ShaderAst::AttributeType::Unroll:
|
||||
HandleUniqueStringAttribute("unroll", s_unrollModes, whileStatement->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always));
|
||||
break;
|
||||
|
||||
default:
|
||||
throw AttributeError{ "unhandled attribute for while" };
|
||||
}
|
||||
}
|
||||
|
||||
return whileStatement;
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)
|
||||
|
||||
@@ -1030,7 +1030,28 @@ namespace Nz
|
||||
|
||||
UInt32 expressionId = EvaluateExpression(node.condition);
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), SpirvLoopControl::None);
|
||||
SpirvLoopControl loopControl;
|
||||
if (node.unroll.HasValue())
|
||||
{
|
||||
switch (node.unroll.GetResultingValue())
|
||||
{
|
||||
case ShaderAst::LoopUnroll::Always:
|
||||
// it shouldn't be possible to have this attribute as the loop gets unrolled in the sanitizer
|
||||
throw std::runtime_error("unexpected unroll attribute");
|
||||
|
||||
case ShaderAst::LoopUnroll::Hint:
|
||||
loopControl = SpirvLoopControl::Unroll;
|
||||
break;
|
||||
|
||||
case ShaderAst::LoopUnroll::Never:
|
||||
loopControl = SpirvLoopControl::DontUnroll;
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
loopControl = SpirvLoopControl::None;
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), loopControl);
|
||||
m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId());
|
||||
|
||||
m_currentBlock = &bodyBlock;
|
||||
|
||||
Reference in New Issue
Block a user