Shader: Add support for while loops

This commit is contained in:
Jérôme Leclercq 2021-12-16 23:10:58 +01:00
parent 07199301df
commit 0f9060c45b
22 changed files with 154 additions and 14 deletions

View File

@ -67,6 +67,7 @@ namespace Nz::ShaderAst
virtual StatementPtr Clone(MultiStatement& node); virtual StatementPtr Clone(MultiStatement& node);
virtual StatementPtr Clone(NoOpStatement& node); virtual StatementPtr Clone(NoOpStatement& node);
virtual StatementPtr Clone(ReturnStatement& node); virtual StatementPtr Clone(ReturnStatement& node);
virtual StatementPtr Clone(WhileStatement& node);
#define NAZARA_SHADERAST_NODE(NodeType) void Visit(NodeType& node) override; #define NAZARA_SHADERAST_NODE(NodeType) void Visit(NodeType& node) override;
#include <Nazara/Shader/Ast/AstNodeList.hpp> #include <Nazara/Shader/Ast/AstNodeList.hpp>

View File

@ -55,7 +55,8 @@ NAZARA_SHADERAST_STATEMENT(DiscardStatement)
NAZARA_SHADERAST_STATEMENT(ExpressionStatement) NAZARA_SHADERAST_STATEMENT(ExpressionStatement)
NAZARA_SHADERAST_STATEMENT(MultiStatement) NAZARA_SHADERAST_STATEMENT(MultiStatement)
NAZARA_SHADERAST_STATEMENT(NoOpStatement) NAZARA_SHADERAST_STATEMENT(NoOpStatement)
NAZARA_SHADERAST_STATEMENT_LAST(ReturnStatement) NAZARA_SHADERAST_STATEMENT(ReturnStatement)
NAZARA_SHADERAST_STATEMENT_LAST(WhileStatement)
#undef NAZARA_SHADERAST_EXPRESSION #undef NAZARA_SHADERAST_EXPRESSION
#undef NAZARA_SHADERAST_NODE #undef NAZARA_SHADERAST_NODE

View File

@ -49,6 +49,7 @@ namespace Nz::ShaderAst
void Visit(MultiStatement& node) override; void Visit(MultiStatement& node) override;
void Visit(NoOpStatement& node) override; void Visit(NoOpStatement& node) override;
void Visit(ReturnStatement& node) override; void Visit(ReturnStatement& node) override;
void Visit(WhileStatement& node) override;
}; };
} }

View File

@ -52,6 +52,7 @@ namespace Nz::ShaderAst
void Serialize(MultiStatement& node); void Serialize(MultiStatement& node);
void Serialize(NoOpStatement& node); void Serialize(NoOpStatement& node);
void Serialize(ReturnStatement& node); void Serialize(ReturnStatement& node);
void Serialize(WhileStatement& node);
protected: protected:
template<typename T> void Attribute(AttributeValue<T>& attribute); template<typename T> void Attribute(AttributeValue<T>& attribute);

View File

@ -362,6 +362,15 @@ namespace Nz::ShaderAst
ExpressionPtr returnExpr; ExpressionPtr returnExpr;
}; };
struct NAZARA_SHADER_API WhileStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
ExpressionPtr condition;
StatementPtr body;
};
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr); inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr);
inline bool IsExpression(NodeType nodeType); inline bool IsExpression(NodeType nodeType);
inline bool IsStatement(NodeType nodeType); inline bool IsStatement(NodeType nodeType);

View File

@ -74,6 +74,7 @@ namespace Nz::ShaderAst
StatementPtr Clone(DiscardStatement& node) override; StatementPtr Clone(DiscardStatement& node) override;
StatementPtr Clone(ExpressionStatement& node) override; StatementPtr Clone(ExpressionStatement& node) override;
StatementPtr Clone(MultiStatement& node) override; StatementPtr Clone(MultiStatement& node) override;
StatementPtr Clone(WhileStatement& node) override;
const Identifier* FindIdentifier(const std::string_view& identifierName) const; const Identifier* FindIdentifier(const std::string_view& identifierName) const;

View File

@ -105,6 +105,7 @@ namespace Nz
void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::MultiStatement& node) override;
void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override;
void Visit(ShaderAst::ReturnStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override;
void Visit(ShaderAst::WhileStatement& node) override;
static bool HasExplicitBinding(ShaderAst::StatementPtr& shader); static bool HasExplicitBinding(ShaderAst::StatementPtr& shader);
static bool HasExplicitLocation(ShaderAst::StatementPtr& shader); static bool HasExplicitLocation(ShaderAst::StatementPtr& shader);

View File

@ -108,6 +108,7 @@ namespace Nz
void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::MultiStatement& node) override;
void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override;
void Visit(ShaderAst::ReturnStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override;
void Visit(ShaderAst::WhileStatement& node) override;
struct State; struct State;

View File

@ -140,6 +140,11 @@ namespace Nz::ShaderBuilder
{ {
inline std::unique_ptr<ShaderAst::UnaryExpression> operator()(ShaderAst::UnaryType op, ShaderAst::ExpressionPtr expression) const; inline std::unique_ptr<ShaderAst::UnaryExpression> operator()(ShaderAst::UnaryType op, ShaderAst::ExpressionPtr expression) const;
}; };
struct While
{
inline std::unique_ptr<ShaderAst::WhileStatement> operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const;
};
} }
constexpr Impl::AccessIndex AccessIndex; constexpr Impl::AccessIndex AccessIndex;
@ -167,6 +172,7 @@ namespace Nz::ShaderBuilder
constexpr Impl::Return Return; constexpr Impl::Return Return;
constexpr Impl::Swizzle Swizzle; constexpr Impl::Swizzle Swizzle;
constexpr Impl::Unary Unary; constexpr Impl::Unary Unary;
constexpr Impl::While While;
} }
#include <Nazara/Shader/ShaderBuilder.inl> #include <Nazara/Shader/ShaderBuilder.inl>

View File

@ -298,6 +298,15 @@ namespace Nz::ShaderBuilder
return unaryNode; return unaryNode;
} }
inline std::unique_ptr<ShaderAst::WhileStatement> Impl::While::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const
{
auto whileNode = std::make_unique<ShaderAst::WhileStatement>();
whileNode->condition = std::move(condition);
whileNode->body = std::move(body);
return whileNode;
}
} }
#include <Nazara/Shader/DebugOff.hpp> #include <Nazara/Shader/DebugOff.hpp>

View File

@ -98,6 +98,7 @@ namespace Nz::ShaderLang
std::vector<ShaderAst::StatementPtr> ParseStatementList(); std::vector<ShaderAst::StatementPtr> ParseStatementList();
ShaderAst::StatementPtr ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes = {}); ShaderAst::StatementPtr ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
ShaderAst::StatementPtr ParseVariableDeclaration(); ShaderAst::StatementPtr ParseVariableDeclaration();
ShaderAst::StatementPtr ParseWhileStatement();
// Expressions // Expressions
ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs); ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs);

View File

@ -60,6 +60,7 @@ NAZARA_SHADERLANG_TOKEN(Option)
NAZARA_SHADERLANG_TOKEN(Semicolon) NAZARA_SHADERLANG_TOKEN(Semicolon)
NAZARA_SHADERLANG_TOKEN(Return) NAZARA_SHADERLANG_TOKEN(Return)
NAZARA_SHADERLANG_TOKEN(Struct) NAZARA_SHADERLANG_TOKEN(Struct)
NAZARA_SHADERLANG_TOKEN(While)
#undef NAZARA_SHADERLANG_TOKEN #undef NAZARA_SHADERLANG_TOKEN
#undef NAZARA_SHADERLANG_TOKEN_LAST #undef NAZARA_SHADERLANG_TOKEN_LAST

View File

@ -61,8 +61,9 @@ namespace Nz
void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override;
void Visit(ShaderAst::ReturnStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override;
void Visit(ShaderAst::UnaryExpression& node) override; void Visit(ShaderAst::UnaryExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override;
void Visit(ShaderAst::WhileStatement& node) override;
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete; SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;
SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete; SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete;

View File

@ -193,6 +193,15 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
StatementPtr AstCloner::Clone(WhileStatement& node)
{
auto clone = std::make_unique<WhileStatement>();
clone->condition = CloneExpression(node.condition);
clone->body = CloneStatement(node.body);
return clone;
}
ExpressionPtr AstCloner::Clone(AccessIdentifierExpression& node) ExpressionPtr AstCloner::Clone(AccessIdentifierExpression& node)
{ {
auto clone = std::make_unique<AccessIdentifierExpression>(); auto clone = std::make_unique<AccessIdentifierExpression>();

View File

@ -175,4 +175,13 @@ namespace Nz::ShaderAst
if (node.returnExpr) if (node.returnExpr)
node.returnExpr->Visit(*this); node.returnExpr->Visit(*this);
} }
void AstRecursiveVisitor::Visit(WhileStatement& node)
{
if (node.condition)
node.condition->Visit(*this);
if (node.body)
node.body->Visit(*this);
}
} }

View File

@ -318,6 +318,12 @@ namespace Nz::ShaderAst
Node(node.returnExpr); Node(node.returnExpr);
} }
void AstSerializerBase::Serialize(WhileStatement& node)
{
Node(node.condition);
Node(node.body);
}
void ShaderAstSerializer::Serialize(StatementPtr& shader) void ShaderAstSerializer::Serialize(StatementPtr& shader)
{ {
m_stream << s_magicNumber << s_currentVersion; m_stream << s_magicNumber << s_currentVersion;

View File

@ -460,12 +460,11 @@ namespace Nz::ShaderAst
{ {
case Identifier::Type::Constant: case Identifier::Type::Constant:
{ {
// Replace IdentifierExpression by ConstantExpression // Replace IdentifierExpression by Constant(Value)Expression
auto constantExpr = std::make_unique<ConstantExpression>(); ConstantExpression constantExpr;
constantExpr->cachedExpressionType = GetExpressionType(m_context->constantValues[identifier->index]); constantExpr.constantId = identifier->index;
constantExpr->constantId = identifier->index;
return constantExpr; return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
} }
case Identifier::Type::Variable: case Identifier::Type::Variable:
@ -951,6 +950,19 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
StatementPtr SanitizeVisitor::Clone(WhileStatement& node)
{
MandatoryExpr(node.condition);
MandatoryStatement(node.body);
auto clone = static_unique_pointer_cast<WhileStatement>(AstCloner::Clone(node));
if (GetExpressionType(*clone->condition) != ExpressionType{ PrimitiveType::Boolean })
throw AstError{ "expected a boolean value" };
return clone;
}
auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier*
{ {
auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });

View File

@ -1179,6 +1179,17 @@ namespace Nz
} }
} }
void GlslWriter::Visit(ShaderAst::WhileStatement& node)
{
Append("while (");
node.condition->Visit(*this);
AppendLine(")");
EnterScope();
node.body->Visit(*this);
LeaveScope();
}
void GlslWriter::Visit(ShaderAst::SwizzleExpression& node) void GlslWriter::Visit(ShaderAst::SwizzleExpression& node)
{ {
Visit(node.expression, true); Visit(node.expression, true);

View File

@ -930,6 +930,17 @@ namespace Nz
node.expression->Visit(*this); node.expression->Visit(*this);
} }
void LangWriter::Visit(ShaderAst::WhileStatement& node)
{
Append("while (");
node.condition->Visit(*this);
AppendLine(")");
EnterScope();
node.body->Visit(*this);
LeaveScope();
}
void LangWriter::AppendHeader() void LangWriter::AppendHeader()
{ {
// Nothing yet // Nothing yet

View File

@ -52,7 +52,8 @@ namespace Nz::ShaderLang
{ "option", TokenType::Option }, { "option", TokenType::Option },
{ "return", TokenType::Return }, { "return", TokenType::Return },
{ "struct", TokenType::Struct }, { "struct", TokenType::Struct },
{ "true", TokenType::BoolTrue } { "true", TokenType::BoolTrue },
{ "while", TokenType::While }
}; };
std::size_t currentPos = 0; std::size_t currentPos = 0;

View File

@ -742,6 +742,10 @@ namespace Nz::ShaderLang
statement = ParseReturnStatement(); statement = ParseReturnStatement();
break; break;
case TokenType::While:
statement = ParseWhileStatement();
break;
default: default:
throw UnexpectedToken{}; throw UnexpectedToken{};
} }
@ -905,6 +909,21 @@ namespace Nz::ShaderLang
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression)); return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
} }
ShaderAst::StatementPtr Parser::ParseWhileStatement()
{
Expect(Advance(), TokenType::While);
Expect(Advance(), TokenType::OpenParenthesis);
ShaderAst::ExpressionPtr condition = ParseExpression();
Expect(Advance(), TokenType::ClosingParenthesis);
ShaderAst::StatementPtr body = ParseStatement();
return ShaderBuilder::While(std::move(condition), std::move(body));
}
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs) ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)
{ {
for (;;) for (;;)

View File

@ -944,12 +944,6 @@ namespace Nz
PushResultId(resultId); PushResultId(resultId);
} }
void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node)
{
SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock);
PushResultId(loadVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node) void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)
{ {
const ShaderAst::ExpressionType& resultType = GetExpressionType(node); const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
@ -1011,6 +1005,40 @@ namespace Nz
PushResultId(resultId); PushResultId(resultId);
} }
void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node)
{
SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock);
PushResultId(loadVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderAst::WhileStatement& node)
{
assert(node.condition);
assert(node.body);
SpirvBlock headerBlock(m_writer);
SpirvBlock bodyBlock(m_writer);
SpirvBlock mergeBlock(m_writer);
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId());
m_currentBlock = &headerBlock;
UInt32 expressionId = EvaluateExpression(node.condition);
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), SpirvLoopControl::None);
m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId());
m_currentBlock = &bodyBlock;
node.body->Visit(*this);
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(headerBlock));
m_functionBlocks.emplace_back(std::move(bodyBlock));
m_functionBlocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_functionBlocks.back();
}
void SpirvAstVisitor::PushResultId(UInt32 value) void SpirvAstVisitor::PushResultId(UInt32 value)
{ {
m_resultIds.push_back(value); m_resultIds.push_back(value);