Shader: Replace const for with [unroll] attribute

This commit is contained in:
Jérôme Leclercq 2022-01-03 20:21:09 +01:00
parent b6e4a9470e
commit 2bdcc045cd
13 changed files with 204 additions and 87 deletions

View File

@ -460,10 +460,10 @@ namespace Nz::ShaderAst
bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs)
{
if (!Compare(lhs.isConst, rhs.isConst))
if (!Compare(lhs.varName, rhs.varName))
return false;
if (!Compare(lhs.varName, rhs.varName))
if (!Compare(lhs.unroll, rhs.unroll))
return false;
if (!Compare(lhs.expression, rhs.expression))
@ -498,6 +498,9 @@ namespace Nz::ShaderAst
inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs)
{
if (!Compare(lhs.unroll, rhs.unroll))
return false;
if (!Compare(lhs.condition, rhs.condition))
return false;

View File

@ -36,6 +36,7 @@ namespace Nz
Layout, //< Struct layout (struct only) - has argument style
Location, //< Location (struct member only) - has argument index
Set, //< Binding set (external var only) - has argument index
Unroll, //< Unroll (for/for each only) - has argument mode
};
enum class BinaryType
@ -106,6 +107,13 @@ namespace Nz
SampleTexture = 2,
};
enum class LoopUnroll
{
Always,
Hint,
Never
};
enum class MemoryLayout
{
Std140

View File

@ -345,11 +345,11 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
AttributeValue<LoopUnroll> unroll;
std::optional<std::size_t> varIndex;
std::string varName;
ExpressionPtr expression;
StatementPtr statement;
bool isConst = false;
};
struct NAZARA_SHADER_API MultiStatement : Statement
@ -379,6 +379,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
AttributeValue<LoopUnroll> unroll;
ExpressionPtr condition;
StatementPtr body;
};

View File

@ -118,6 +118,8 @@ namespace Nz::ShaderAst
void SanitizeIdentifier(std::string& identifier);
void Validate(WhileStatement& node);
void Validate(AccessIndexExpression& node);
void Validate(AssignExpression& node);
void Validate(BinaryExpression& node);

View File

@ -108,7 +108,6 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::ExpressionStatement> operator()(ShaderAst::ExpressionPtr expression) const;
};
template<bool Const>
struct ForEach
{
inline std::unique_ptr<ShaderAst::ForEachStatement> operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const;
@ -173,7 +172,6 @@ namespace Nz::ShaderBuilder
constexpr Impl::ConditionalStatement ConditionalStatement;
constexpr Impl::Constant Constant;
constexpr Impl::Branch<true> ConstBranch;
constexpr Impl::ForEach<false> ConstForEach;
constexpr Impl::DeclareConst DeclareConst;
constexpr Impl::DeclareFunction DeclareFunction;
constexpr Impl::DeclareOption DeclareOption;
@ -181,7 +179,7 @@ namespace Nz::ShaderBuilder
constexpr Impl::DeclareVariable DeclareVariable;
constexpr Impl::ExpressionStatement ExpressionStatement;
constexpr Impl::NoParam<ShaderAst::DiscardStatement> Discard;
constexpr Impl::ForEach<false> ForEach;
constexpr Impl::ForEach ForEach;
constexpr Impl::Identifier Identifier;
constexpr Impl::Intrinsic Intrinsic;
constexpr Impl::Multi MultiStatement;

View File

@ -269,11 +269,9 @@ namespace Nz::ShaderBuilder
return expressionStatementNode;
}
template<bool Const>
std::unique_ptr<ShaderAst::ForEachStatement> Impl::ForEach<Const>::operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const
std::unique_ptr<ShaderAst::ForEachStatement> Impl::ForEach::operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const
{
auto forEachNode = std::make_unique<ShaderAst::ForEachStatement>();
forEachNode->isConst = Const;
forEachNode->expression = std::move(expression);
forEachNode->statement = std::move(statement);
forEachNode->varName = std::move(varName);

View File

@ -88,7 +88,7 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr ParseConstStatement();
ShaderAst::StatementPtr ParseDiscardStatement();
ShaderAst::StatementPtr ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes = {});
ShaderAst::StatementPtr ParseForDeclaration();
ShaderAst::StatementPtr ParseForDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
std::vector<ShaderAst::StatementPtr> ParseFunctionBody();
ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter();
@ -99,7 +99,7 @@ namespace Nz::ShaderLang
std::vector<ShaderAst::StatementPtr> ParseStatementList();
ShaderAst::StatementPtr ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
ShaderAst::StatementPtr ParseVariableDeclaration();
ShaderAst::StatementPtr ParseWhileStatement();
ShaderAst::StatementPtr ParseWhileStatement(std::vector<ShaderAst::Attribute> attributes);
// Expressions
ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs);

View File

@ -173,9 +173,9 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(ForEachStatement& node)
{
auto clone = std::make_unique<ForEachStatement>();
clone->isConst = node.isConst;
clone->expression = CloneExpression(node.expression);
clone->statement = CloneStatement(node.statement);
clone->unroll = Clone(node.unroll);
return clone;
}
@ -208,6 +208,7 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<WhileStatement>();
clone->condition = CloneExpression(node.condition);
clone->body = CloneStatement(node.body);
clone->unroll = Clone(node.unroll);
return clone;
}

View File

@ -303,7 +303,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(ForEachStatement& node)
{
Value(node.isConst);
Attribute(node.unroll);
Value(node.varName);
Node(node.expression);
Node(node.statement);
@ -328,6 +328,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(WhileStatement& node)
{
Attribute(node.unroll);
Node(node.condition);
Node(node.body);
}

View File

@ -843,29 +843,34 @@ namespace Nz::ShaderAst
else
throw AstError{ "for-each is only supported on arrays and range expressions" };
if (node.isConst)
AttributeValue<LoopUnroll> unrollValue;
if (node.unroll.HasValue())
{
// Repeat code
auto multi = std::make_unique<MultiStatement>();
if (IsArrayType(exprType))
unrollValue = ComputeAttributeValue(node.unroll);
if (unrollValue.GetResultingValue() == LoopUnroll::Always)
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
UInt32 length = arrayType.length.GetResultingValue();
for (UInt32 i = 0; i < length; ++i)
// Repeat code
auto multi = std::make_unique<MultiStatement>();
if (IsArrayType(exprType))
{
auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i));
Validate(*accessIndex);
const ArrayType& arrayType = std::get<ArrayType>(exprType);
UInt32 length = arrayType.length.GetResultingValue();
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
Validate(*elementVariable);
for (UInt32 i = 0; i < length; ++i)
{
auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i));
Validate(*accessIndex);
multi->statements.emplace_back(std::move(elementVariable));
multi->statements.emplace_back(CloneStatement(node.statement));
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
Validate(*elementVariable);
multi->statements.emplace_back(std::move(elementVariable));
multi->statements.emplace_back(CloneStatement(node.statement));
}
}
}
return multi;
return multi;
}
}
if (m_context->options.reduceLoopsToWhile)
@ -890,6 +895,7 @@ namespace Nz::ShaderAst
multi->statements.emplace_back(std::move(counterVariable));
auto whileStatement = std::make_unique<WhileStatement>();
whileStatement->unroll = std::move(unrollValue);
// While condition
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(length));
@ -928,6 +934,7 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<ForEachStatement>();
clone->expression = std::move(expr);
clone->varName = node.varName;
clone->unroll = std::move(unrollValue);
PushScope();
{
@ -968,9 +975,15 @@ namespace Nz::ShaderAst
MandatoryStatement(node.body);
auto clone = static_unique_pointer_cast<WhileStatement>(AstCloner::Clone(node));
Validate(*clone);
if (GetExpressionType(*clone->condition) != ExpressionType{ PrimitiveType::Boolean })
throw AstError{ "expected a boolean value" };
AttributeValue<LoopUnroll> unrollValue;
if (node.unroll.HasValue())
{
clone->unroll = ComputeAttributeValue(node.unroll);
if (clone->unroll.GetResultingValue() == LoopUnroll::Always)
throw AstError{ "unroll(always) is not yet supported on while" };
}
return clone;
}
@ -1350,6 +1363,12 @@ namespace Nz::ShaderAst
}
}
void SanitizeVisitor::Validate(WhileStatement& node)
{
if (GetExpressionType(*node.condition) != ExpressionType{ PrimitiveType::Boolean })
throw AstError{ "expected a boolean value" };
}
void SanitizeVisitor::Validate(AccessIndexExpression& node)
{
if (node.indices.empty())

View File

@ -37,6 +37,7 @@ namespace Nz::ShaderLang
{ "layout", ShaderAst::AttributeType::Layout },
{ "location", ShaderAst::AttributeType::Location },
{ "set", ShaderAst::AttributeType::Set },
{ "unroll", ShaderAst::AttributeType::Unroll },
};
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
@ -54,6 +55,12 @@ namespace Nz::ShaderLang
{ "std140", StructLayout::Std140 }
};
std::unordered_map<std::string, ShaderAst::LoopUnroll> s_unrollModes = {
{ "always", ShaderAst::LoopUnroll::Always },
{ "hint", ShaderAst::LoopUnroll::Hint },
{ "never", ShaderAst::LoopUnroll::Never }
};
template<typename T, typename U>
std::optional<T> BoundCast(U val)
{
@ -76,26 +83,33 @@ namespace Nz::ShaderLang
}
template<typename T>
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param)
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param, std::optional<T> defaultValue = {})
{
if (targetAttribute.HasValue())
throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" };
//FIXME: This should be handled with global values at sanitization stage
if (!param)
throw AttributeError{ "attribute " + std::string(attributeName) + " requires a value" };
if (param)
{
const ShaderAst::ExpressionPtr& expr = *param;
if (expr->GetType() != ShaderAst::NodeType::IdentifierExpression)
throw AttributeError{ "attribute " + std::string(attributeName) + " can only be an identifier for now" };
const ShaderAst::ExpressionPtr& expr = *param;
if (expr->GetType() != ShaderAst::NodeType::IdentifierExpression)
throw AttributeError{ "attribute " + std::string(attributeName) + " can only be an identifier for now" };
const std::string& exprStr = static_cast<ShaderAst::IdentifierExpression&>(*expr).identifier;
const std::string& exprStr = static_cast<ShaderAst::IdentifierExpression&>(*expr).identifier;
auto it = map.find(exprStr);
if (it == map.end())
throw AttributeError{ ("invalid parameter " + exprStr + " for " + std::string(attributeName) + " attribute").c_str() };
auto it = map.find(exprStr);
if (it == map.end())
throw AttributeError{ ("invalid parameter " + exprStr + " for " + std::string(attributeName) + " attribute").c_str() };
targetAttribute = it->second;
}
else
{
if (!defaultValue)
throw AttributeError{ "attribute " + std::string(attributeName) + " requires a value" };
targetAttribute = it->second;
targetAttribute = defaultValue.value();
}
}
}
@ -473,14 +487,6 @@ namespace Nz::ShaderLang
switch (Peek().type)
{
case TokenType::For:
{
auto forEach = ParseForDeclaration();
SafeCast<ShaderAst::ForEachStatement&>(*forEach).isConst = true;
return forEach;
}
case TokenType::Identifier:
{
std::string constName;
@ -598,7 +604,7 @@ namespace Nz::ShaderLang
return externalStatement;
}
ShaderAst::StatementPtr Parser::ParseForDeclaration()
ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector<ShaderAst::Attribute> attributes)
{
Expect(Advance(), TokenType::For);
@ -610,7 +616,22 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr statement = ParseStatement();
return ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement));
auto forEach = ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement));
for (auto&& [attributeType, arg] : attributes)
{
switch (attributeType)
{
case ShaderAst::AttributeType::Unroll:
HandleUniqueStringAttribute("unroll", s_unrollModes, forEach->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always));
break;
default:
throw AttributeError{ "unhandled attribute for for-each" };
}
}
return forEach;
}
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
@ -745,47 +766,74 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr Parser::ParseSingleStatement()
{
const Token& token = Peek();
std::vector<ShaderAst::Attribute> attributes;
ShaderAst::StatementPtr statement;
switch (token.type)
do
{
case TokenType::Const:
statement = ParseConstStatement();
break;
const Token& token = Peek();
switch (token.type)
{
case TokenType::Const:
if (!attributes.empty())
throw UnexpectedToken{};
case TokenType::Discard:
statement = ParseDiscardStatement();
break;
statement = ParseConstStatement();
break;
case TokenType::For:
statement = ParseForDeclaration();
break;
case TokenType::Discard:
if (!attributes.empty())
throw UnexpectedToken{};
case TokenType::Let:
statement = ParseVariableDeclaration();
break;
statement = ParseDiscardStatement();
break;
case TokenType::Identifier:
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::For:
statement = ParseForDeclaration(std::move(attributes));
break;
case TokenType::If:
statement = ParseBranchStatement();
break;
case TokenType::Let:
if (!attributes.empty())
throw UnexpectedToken{};
case TokenType::Return:
statement = ParseReturnStatement();
break;
statement = ParseVariableDeclaration();
break;
case TokenType::While:
statement = ParseWhileStatement();
break;
case TokenType::Identifier:
if (!attributes.empty())
throw UnexpectedToken{};
default:
throw UnexpectedToken{};
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::If:
if (!attributes.empty())
throw UnexpectedToken{};
statement = ParseBranchStatement();
break;
case TokenType::OpenSquareBracket:
assert(attributes.empty());
attributes = ParseAttributes();
break;
case TokenType::Return:
if (!attributes.empty())
throw UnexpectedToken{};
statement = ParseReturnStatement();
break;
case TokenType::While:
statement = ParseWhileStatement(std::move(attributes));
break;
default:
throw UnexpectedToken{};
}
}
while (!statement); //< small trick to repeat parsing once we got attributes
return statement;
}
@ -955,7 +1003,7 @@ namespace Nz::ShaderLang
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}
ShaderAst::StatementPtr Parser::ParseWhileStatement()
ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector<ShaderAst::Attribute> attributes)
{
Expect(Advance(), TokenType::While);
@ -967,7 +1015,22 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr body = ParseStatement();
return ShaderBuilder::While(std::move(condition), std::move(body));
auto whileStatement = ShaderBuilder::While(std::move(condition), std::move(body));
for (auto&& [attributeType, arg] : attributes)
{
switch (attributeType)
{
case ShaderAst::AttributeType::Unroll:
HandleUniqueStringAttribute("unroll", s_unrollModes, whileStatement->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always));
break;
default:
throw AttributeError{ "unhandled attribute for while" };
}
}
return whileStatement;
}
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)

View File

@ -1030,7 +1030,28 @@ namespace Nz
UInt32 expressionId = EvaluateExpression(node.condition);
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), SpirvLoopControl::None);
SpirvLoopControl loopControl;
if (node.unroll.HasValue())
{
switch (node.unroll.GetResultingValue())
{
case ShaderAst::LoopUnroll::Always:
// it shouldn't be possible to have this attribute as the loop gets unrolled in the sanitizer
throw std::runtime_error("unexpected unroll attribute");
case ShaderAst::LoopUnroll::Hint:
loopControl = SpirvLoopControl::Unroll;
break;
case ShaderAst::LoopUnroll::Never:
loopControl = SpirvLoopControl::DontUnroll;
break;
}
}
else
loopControl = SpirvLoopControl::None;
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), loopControl);
m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId());
m_currentBlock = &bodyBlock;

View File

@ -110,7 +110,7 @@ fn main()
}
}
WHEN("using const for-each")
WHEN("using [unroll] attribute on for-each")
{
std::string_view sourceCode = R"(
const LightCount = 3;
@ -136,7 +136,9 @@ external
fn main()
{
let color = (0.0).xxxx;
const for light in data.lights
[unroll]
for light in data.lights
{
color += light.color;
}