Shader: Add support for while loops

This commit is contained in:
Jérôme Leclercq
2021-12-16 23:10:58 +01:00
parent 07199301df
commit 0f9060c45b
22 changed files with 154 additions and 14 deletions

View File

@@ -193,6 +193,15 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr AstCloner::Clone(WhileStatement& node)
{
auto clone = std::make_unique<WhileStatement>();
clone->condition = CloneExpression(node.condition);
clone->body = CloneStatement(node.body);
return clone;
}
ExpressionPtr AstCloner::Clone(AccessIdentifierExpression& node)
{
auto clone = std::make_unique<AccessIdentifierExpression>();

View File

@@ -175,4 +175,13 @@ namespace Nz::ShaderAst
if (node.returnExpr)
node.returnExpr->Visit(*this);
}
void AstRecursiveVisitor::Visit(WhileStatement& node)
{
if (node.condition)
node.condition->Visit(*this);
if (node.body)
node.body->Visit(*this);
}
}

View File

@@ -318,6 +318,12 @@ namespace Nz::ShaderAst
Node(node.returnExpr);
}
void AstSerializerBase::Serialize(WhileStatement& node)
{
Node(node.condition);
Node(node.body);
}
void ShaderAstSerializer::Serialize(StatementPtr& shader)
{
m_stream << s_magicNumber << s_currentVersion;

View File

@@ -460,12 +460,11 @@ namespace Nz::ShaderAst
{
case Identifier::Type::Constant:
{
// Replace IdentifierExpression by ConstantExpression
auto constantExpr = std::make_unique<ConstantExpression>();
constantExpr->cachedExpressionType = GetExpressionType(m_context->constantValues[identifier->index]);
constantExpr->constantId = identifier->index;
// Replace IdentifierExpression by Constant(Value)Expression
ConstantExpression constantExpr;
constantExpr.constantId = identifier->index;
return constantExpr;
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
}
case Identifier::Type::Variable:
@@ -951,6 +950,19 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr SanitizeVisitor::Clone(WhileStatement& node)
{
MandatoryExpr(node.condition);
MandatoryStatement(node.body);
auto clone = static_unique_pointer_cast<WhileStatement>(AstCloner::Clone(node));
if (GetExpressionType(*clone->condition) != ExpressionType{ PrimitiveType::Boolean })
throw AstError{ "expected a boolean value" };
return clone;
}
auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier*
{
auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });

View File

@@ -1179,6 +1179,17 @@ namespace Nz
}
}
void GlslWriter::Visit(ShaderAst::WhileStatement& node)
{
Append("while (");
node.condition->Visit(*this);
AppendLine(")");
EnterScope();
node.body->Visit(*this);
LeaveScope();
}
void GlslWriter::Visit(ShaderAst::SwizzleExpression& node)
{
Visit(node.expression, true);

View File

@@ -930,6 +930,17 @@ namespace Nz
node.expression->Visit(*this);
}
void LangWriter::Visit(ShaderAst::WhileStatement& node)
{
Append("while (");
node.condition->Visit(*this);
AppendLine(")");
EnterScope();
node.body->Visit(*this);
LeaveScope();
}
void LangWriter::AppendHeader()
{
// Nothing yet

View File

@@ -52,7 +52,8 @@ namespace Nz::ShaderLang
{ "option", TokenType::Option },
{ "return", TokenType::Return },
{ "struct", TokenType::Struct },
{ "true", TokenType::BoolTrue }
{ "true", TokenType::BoolTrue },
{ "while", TokenType::While }
};
std::size_t currentPos = 0;

View File

@@ -742,6 +742,10 @@ namespace Nz::ShaderLang
statement = ParseReturnStatement();
break;
case TokenType::While:
statement = ParseWhileStatement();
break;
default:
throw UnexpectedToken{};
}
@@ -905,6 +909,21 @@ namespace Nz::ShaderLang
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}
ShaderAst::StatementPtr Parser::ParseWhileStatement()
{
Expect(Advance(), TokenType::While);
Expect(Advance(), TokenType::OpenParenthesis);
ShaderAst::ExpressionPtr condition = ParseExpression();
Expect(Advance(), TokenType::ClosingParenthesis);
ShaderAst::StatementPtr body = ParseStatement();
return ShaderBuilder::While(std::move(condition), std::move(body));
}
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)
{
for (;;)

View File

@@ -944,12 +944,6 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node)
{
SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock);
PushResultId(loadVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)
{
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
@@ -1011,6 +1005,40 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node)
{
SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock);
PushResultId(loadVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderAst::WhileStatement& node)
{
assert(node.condition);
assert(node.body);
SpirvBlock headerBlock(m_writer);
SpirvBlock bodyBlock(m_writer);
SpirvBlock mergeBlock(m_writer);
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId());
m_currentBlock = &headerBlock;
UInt32 expressionId = EvaluateExpression(node.condition);
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), SpirvLoopControl::None);
m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId());
m_currentBlock = &bodyBlock;
node.body->Visit(*this);
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(headerBlock));
m_functionBlocks.emplace_back(std::move(bodyBlock));
m_functionBlocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_functionBlocks.back();
}
void SpirvAstVisitor::PushResultId(UInt32 value)
{
m_resultIds.push_back(value);