Shader: Add support for while loops
This commit is contained in:
parent
07199301df
commit
0f9060c45b
|
|
@ -67,6 +67,7 @@ namespace Nz::ShaderAst
|
|||
virtual StatementPtr Clone(MultiStatement& node);
|
||||
virtual StatementPtr Clone(NoOpStatement& node);
|
||||
virtual StatementPtr Clone(ReturnStatement& node);
|
||||
virtual StatementPtr Clone(WhileStatement& node);
|
||||
|
||||
#define NAZARA_SHADERAST_NODE(NodeType) void Visit(NodeType& node) override;
|
||||
#include <Nazara/Shader/Ast/AstNodeList.hpp>
|
||||
|
|
|
|||
|
|
@ -55,7 +55,8 @@ NAZARA_SHADERAST_STATEMENT(DiscardStatement)
|
|||
NAZARA_SHADERAST_STATEMENT(ExpressionStatement)
|
||||
NAZARA_SHADERAST_STATEMENT(MultiStatement)
|
||||
NAZARA_SHADERAST_STATEMENT(NoOpStatement)
|
||||
NAZARA_SHADERAST_STATEMENT_LAST(ReturnStatement)
|
||||
NAZARA_SHADERAST_STATEMENT(ReturnStatement)
|
||||
NAZARA_SHADERAST_STATEMENT_LAST(WhileStatement)
|
||||
|
||||
#undef NAZARA_SHADERAST_EXPRESSION
|
||||
#undef NAZARA_SHADERAST_NODE
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ namespace Nz::ShaderAst
|
|||
void Visit(MultiStatement& node) override;
|
||||
void Visit(NoOpStatement& node) override;
|
||||
void Visit(ReturnStatement& node) override;
|
||||
void Visit(WhileStatement& node) override;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ namespace Nz::ShaderAst
|
|||
void Serialize(MultiStatement& node);
|
||||
void Serialize(NoOpStatement& node);
|
||||
void Serialize(ReturnStatement& node);
|
||||
void Serialize(WhileStatement& node);
|
||||
|
||||
protected:
|
||||
template<typename T> void Attribute(AttributeValue<T>& attribute);
|
||||
|
|
|
|||
|
|
@ -362,6 +362,15 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr returnExpr;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API WhileStatement : Statement
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
void Visit(AstStatementVisitor& visitor) override;
|
||||
|
||||
ExpressionPtr condition;
|
||||
StatementPtr body;
|
||||
};
|
||||
|
||||
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr);
|
||||
inline bool IsExpression(NodeType nodeType);
|
||||
inline bool IsStatement(NodeType nodeType);
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ namespace Nz::ShaderAst
|
|||
StatementPtr Clone(DiscardStatement& node) override;
|
||||
StatementPtr Clone(ExpressionStatement& node) override;
|
||||
StatementPtr Clone(MultiStatement& node) override;
|
||||
StatementPtr Clone(WhileStatement& node) override;
|
||||
|
||||
const Identifier* FindIdentifier(const std::string_view& identifierName) const;
|
||||
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ namespace Nz
|
|||
void Visit(ShaderAst::MultiStatement& node) override;
|
||||
void Visit(ShaderAst::NoOpStatement& node) override;
|
||||
void Visit(ShaderAst::ReturnStatement& node) override;
|
||||
void Visit(ShaderAst::WhileStatement& node) override;
|
||||
|
||||
static bool HasExplicitBinding(ShaderAst::StatementPtr& shader);
|
||||
static bool HasExplicitLocation(ShaderAst::StatementPtr& shader);
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ namespace Nz
|
|||
void Visit(ShaderAst::MultiStatement& node) override;
|
||||
void Visit(ShaderAst::NoOpStatement& node) override;
|
||||
void Visit(ShaderAst::ReturnStatement& node) override;
|
||||
void Visit(ShaderAst::WhileStatement& node) override;
|
||||
|
||||
struct State;
|
||||
|
||||
|
|
|
|||
|
|
@ -140,6 +140,11 @@ namespace Nz::ShaderBuilder
|
|||
{
|
||||
inline std::unique_ptr<ShaderAst::UnaryExpression> operator()(ShaderAst::UnaryType op, ShaderAst::ExpressionPtr expression) const;
|
||||
};
|
||||
|
||||
struct While
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::WhileStatement> operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const;
|
||||
};
|
||||
}
|
||||
|
||||
constexpr Impl::AccessIndex AccessIndex;
|
||||
|
|
@ -167,6 +172,7 @@ namespace Nz::ShaderBuilder
|
|||
constexpr Impl::Return Return;
|
||||
constexpr Impl::Swizzle Swizzle;
|
||||
constexpr Impl::Unary Unary;
|
||||
constexpr Impl::While While;
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/ShaderBuilder.inl>
|
||||
|
|
|
|||
|
|
@ -298,6 +298,15 @@ namespace Nz::ShaderBuilder
|
|||
|
||||
return unaryNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::WhileStatement> Impl::While::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const
|
||||
{
|
||||
auto whileNode = std::make_unique<ShaderAst::WhileStatement>();
|
||||
whileNode->condition = std::move(condition);
|
||||
whileNode->body = std::move(body);
|
||||
|
||||
return whileNode;
|
||||
}
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/DebugOff.hpp>
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ namespace Nz::ShaderLang
|
|||
std::vector<ShaderAst::StatementPtr> ParseStatementList();
|
||||
ShaderAst::StatementPtr ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
|
||||
ShaderAst::StatementPtr ParseVariableDeclaration();
|
||||
ShaderAst::StatementPtr ParseWhileStatement();
|
||||
|
||||
// Expressions
|
||||
ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs);
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ NAZARA_SHADERLANG_TOKEN(Option)
|
|||
NAZARA_SHADERLANG_TOKEN(Semicolon)
|
||||
NAZARA_SHADERLANG_TOKEN(Return)
|
||||
NAZARA_SHADERLANG_TOKEN(Struct)
|
||||
NAZARA_SHADERLANG_TOKEN(While)
|
||||
|
||||
#undef NAZARA_SHADERLANG_TOKEN
|
||||
#undef NAZARA_SHADERLANG_TOKEN_LAST
|
||||
|
|
|
|||
|
|
@ -61,8 +61,9 @@ namespace Nz
|
|||
void Visit(ShaderAst::NoOpStatement& node) override;
|
||||
void Visit(ShaderAst::ReturnStatement& node) override;
|
||||
void Visit(ShaderAst::SwizzleExpression& node) override;
|
||||
void Visit(ShaderAst::VariableExpression& node) override;
|
||||
void Visit(ShaderAst::UnaryExpression& node) override;
|
||||
void Visit(ShaderAst::VariableExpression& node) override;
|
||||
void Visit(ShaderAst::WhileStatement& node) override;
|
||||
|
||||
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;
|
||||
SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete;
|
||||
|
|
|
|||
|
|
@ -193,6 +193,15 @@ namespace Nz::ShaderAst
|
|||
return clone;
|
||||
}
|
||||
|
||||
StatementPtr AstCloner::Clone(WhileStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<WhileStatement>();
|
||||
clone->condition = CloneExpression(node.condition);
|
||||
clone->body = CloneStatement(node.body);
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(AccessIdentifierExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<AccessIdentifierExpression>();
|
||||
|
|
|
|||
|
|
@ -175,4 +175,13 @@ namespace Nz::ShaderAst
|
|||
if (node.returnExpr)
|
||||
node.returnExpr->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(WhileStatement& node)
|
||||
{
|
||||
if (node.condition)
|
||||
node.condition->Visit(*this);
|
||||
|
||||
if (node.body)
|
||||
node.body->Visit(*this);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -318,6 +318,12 @@ namespace Nz::ShaderAst
|
|||
Node(node.returnExpr);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(WhileStatement& node)
|
||||
{
|
||||
Node(node.condition);
|
||||
Node(node.body);
|
||||
}
|
||||
|
||||
void ShaderAstSerializer::Serialize(StatementPtr& shader)
|
||||
{
|
||||
m_stream << s_magicNumber << s_currentVersion;
|
||||
|
|
|
|||
|
|
@ -460,12 +460,11 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
case Identifier::Type::Constant:
|
||||
{
|
||||
// Replace IdentifierExpression by ConstantExpression
|
||||
auto constantExpr = std::make_unique<ConstantExpression>();
|
||||
constantExpr->cachedExpressionType = GetExpressionType(m_context->constantValues[identifier->index]);
|
||||
constantExpr->constantId = identifier->index;
|
||||
// Replace IdentifierExpression by Constant(Value)Expression
|
||||
ConstantExpression constantExpr;
|
||||
constantExpr.constantId = identifier->index;
|
||||
|
||||
return constantExpr;
|
||||
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
|
||||
}
|
||||
|
||||
case Identifier::Type::Variable:
|
||||
|
|
@ -951,6 +950,19 @@ namespace Nz::ShaderAst
|
|||
return clone;
|
||||
}
|
||||
|
||||
StatementPtr SanitizeVisitor::Clone(WhileStatement& node)
|
||||
{
|
||||
MandatoryExpr(node.condition);
|
||||
MandatoryStatement(node.body);
|
||||
|
||||
auto clone = static_unique_pointer_cast<WhileStatement>(AstCloner::Clone(node));
|
||||
|
||||
if (GetExpressionType(*clone->condition) != ExpressionType{ PrimitiveType::Boolean })
|
||||
throw AstError{ "expected a boolean value" };
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier*
|
||||
{
|
||||
auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });
|
||||
|
|
|
|||
|
|
@ -1179,6 +1179,17 @@ namespace Nz
|
|||
}
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::WhileStatement& node)
|
||||
{
|
||||
Append("while (");
|
||||
node.condition->Visit(*this);
|
||||
AppendLine(")");
|
||||
|
||||
EnterScope();
|
||||
node.body->Visit(*this);
|
||||
LeaveScope();
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
Visit(node.expression, true);
|
||||
|
|
|
|||
|
|
@ -930,6 +930,17 @@ namespace Nz
|
|||
node.expression->Visit(*this);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::WhileStatement& node)
|
||||
{
|
||||
Append("while (");
|
||||
node.condition->Visit(*this);
|
||||
AppendLine(")");
|
||||
|
||||
EnterScope();
|
||||
node.body->Visit(*this);
|
||||
LeaveScope();
|
||||
}
|
||||
|
||||
void LangWriter::AppendHeader()
|
||||
{
|
||||
// Nothing yet
|
||||
|
|
|
|||
|
|
@ -52,7 +52,8 @@ namespace Nz::ShaderLang
|
|||
{ "option", TokenType::Option },
|
||||
{ "return", TokenType::Return },
|
||||
{ "struct", TokenType::Struct },
|
||||
{ "true", TokenType::BoolTrue }
|
||||
{ "true", TokenType::BoolTrue },
|
||||
{ "while", TokenType::While }
|
||||
};
|
||||
|
||||
std::size_t currentPos = 0;
|
||||
|
|
|
|||
|
|
@ -742,6 +742,10 @@ namespace Nz::ShaderLang
|
|||
statement = ParseReturnStatement();
|
||||
break;
|
||||
|
||||
case TokenType::While:
|
||||
statement = ParseWhileStatement();
|
||||
break;
|
||||
|
||||
default:
|
||||
throw UnexpectedToken{};
|
||||
}
|
||||
|
|
@ -905,6 +909,21 @@ namespace Nz::ShaderLang
|
|||
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
|
||||
}
|
||||
|
||||
ShaderAst::StatementPtr Parser::ParseWhileStatement()
|
||||
{
|
||||
Expect(Advance(), TokenType::While);
|
||||
|
||||
Expect(Advance(), TokenType::OpenParenthesis);
|
||||
|
||||
ShaderAst::ExpressionPtr condition = ParseExpression();
|
||||
|
||||
Expect(Advance(), TokenType::ClosingParenthesis);
|
||||
|
||||
ShaderAst::StatementPtr body = ParseStatement();
|
||||
|
||||
return ShaderBuilder::While(std::move(condition), std::move(body));
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)
|
||||
{
|
||||
for (;;)
|
||||
|
|
|
|||
|
|
@ -944,12 +944,6 @@ namespace Nz
|
|||
PushResultId(resultId);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node)
|
||||
{
|
||||
SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock);
|
||||
PushResultId(loadVisitor.Evaluate(node));
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)
|
||||
{
|
||||
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
|
||||
|
|
@ -1011,6 +1005,40 @@ namespace Nz
|
|||
PushResultId(resultId);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node)
|
||||
{
|
||||
SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock);
|
||||
PushResultId(loadVisitor.Evaluate(node));
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::WhileStatement& node)
|
||||
{
|
||||
assert(node.condition);
|
||||
assert(node.body);
|
||||
|
||||
SpirvBlock headerBlock(m_writer);
|
||||
SpirvBlock bodyBlock(m_writer);
|
||||
SpirvBlock mergeBlock(m_writer);
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId());
|
||||
m_currentBlock = &headerBlock;
|
||||
|
||||
UInt32 expressionId = EvaluateExpression(node.condition);
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpLoopMerge, mergeBlock.GetLabelId(), bodyBlock.GetLabelId(), SpirvLoopControl::None);
|
||||
m_currentBlock->Append(SpirvOp::OpBranchConditional, expressionId, bodyBlock.GetLabelId(), mergeBlock.GetLabelId());
|
||||
|
||||
m_currentBlock = &bodyBlock;
|
||||
node.body->Visit(*this);
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpBranch, headerBlock.GetLabelId());
|
||||
|
||||
m_functionBlocks.emplace_back(std::move(headerBlock));
|
||||
m_functionBlocks.emplace_back(std::move(bodyBlock));
|
||||
m_functionBlocks.emplace_back(std::move(mergeBlock));
|
||||
m_currentBlock = &m_functionBlocks.back();
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::PushResultId(UInt32 value)
|
||||
{
|
||||
m_resultIds.push_back(value);
|
||||
|
|
|
|||
Loading…
Reference in New Issue