Shader: Add support for for-each statements and improve arrays

This commit is contained in:
Jérôme Leclercq
2022-01-02 22:02:11 +01:00
parent aac6e38da2
commit 4fe44339c5
30 changed files with 712 additions and 93 deletions

View File

@@ -64,6 +64,7 @@ namespace Nz::ShaderAst
virtual StatementPtr Clone(DeclareVariableStatement& node);
virtual StatementPtr Clone(DiscardStatement& node);
virtual StatementPtr Clone(ExpressionStatement& node);
virtual StatementPtr Clone(ForEachStatement& node);
virtual StatementPtr Clone(MultiStatement& node);
virtual StatementPtr Clone(NoOpStatement& node);
virtual StatementPtr Clone(ReturnStatement& node);

View File

@@ -54,6 +54,7 @@ namespace Nz::ShaderAst
inline bool Compare(const DeclareVariableStatement& lhs, const DeclareVariableStatement& rhs);
inline bool Compare(const DiscardStatement& lhs, const DiscardStatement& rhs);
inline bool Compare(const ExpressionStatement& lhs, const ExpressionStatement& rhs);
inline bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs);
inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs);
inline bool Compare(const NoOpStatement& lhs, const NoOpStatement& rhs);
inline bool Compare(const ReturnStatement& lhs, const ReturnStatement& rhs);

View File

@@ -458,6 +458,23 @@ namespace Nz::ShaderAst
return true;
}
bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs)
{
if (!Compare(lhs.isConst, rhs.isConst))
return false;
if (!Compare(lhs.varName, rhs.varName))
return false;
if (!Compare(lhs.expression, rhs.expression))
return false;
if (!Compare(lhs.statement, rhs.statement))
return false;
return true;
}
inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs)
{
if (!Compare(lhs.statements, rhs.statements))

View File

@@ -52,6 +52,7 @@ NAZARA_SHADERAST_STATEMENT(DeclareOptionStatement)
NAZARA_SHADERAST_STATEMENT(DeclareStructStatement)
NAZARA_SHADERAST_STATEMENT(DeclareVariableStatement)
NAZARA_SHADERAST_STATEMENT(DiscardStatement)
NAZARA_SHADERAST_STATEMENT(ForEachStatement)
NAZARA_SHADERAST_STATEMENT(ExpressionStatement)
NAZARA_SHADERAST_STATEMENT(MultiStatement)
NAZARA_SHADERAST_STATEMENT(NoOpStatement)

View File

@@ -46,6 +46,7 @@ namespace Nz::ShaderAst
void Visit(DeclareVariableStatement& node) override;
void Visit(DiscardStatement& node) override;
void Visit(ExpressionStatement& node) override;
void Visit(ForEachStatement& node) override;
void Visit(MultiStatement& node) override;
void Visit(NoOpStatement& node) override;
void Visit(ReturnStatement& node) override;

View File

@@ -49,6 +49,7 @@ namespace Nz::ShaderAst
void Serialize(DeclareVariableStatement& node);
void Serialize(DiscardStatement& node);
void Serialize(ExpressionStatement& node);
void Serialize(ForEachStatement& node);
void Serialize(MultiStatement& node);
void Serialize(NoOpStatement& node);
void Serialize(ReturnStatement& node);

View File

@@ -340,6 +340,18 @@ namespace Nz::ShaderAst
ExpressionPtr expression;
};
struct NAZARA_SHADER_API ForEachStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> varIndex;
std::string varName;
ExpressionPtr expression;
StatementPtr statement;
bool isConst = false;
};
struct NAZARA_SHADER_API MultiStatement : Statement
{
NodeType GetType() const override;
@@ -371,7 +383,12 @@ namespace Nz::ShaderAst
StatementPtr body;
};
#define NAZARA_SHADERAST_NODE(X) using X##Ptr = std::unique_ptr<X>;
#include <Nazara/Shader/Ast/AstNodeList.hpp>
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr);
inline ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr);
inline bool IsExpression(NodeType nodeType);
inline bool IsStatement(NodeType nodeType);
}

View File

@@ -13,6 +13,12 @@ namespace Nz::ShaderAst
return expr.cachedExpressionType.value();
}
ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
}
inline bool IsExpression(NodeType nodeType)
{
switch (nodeType)

View File

@@ -40,8 +40,9 @@ namespace Nz::ShaderAst
std::unordered_set<std::string> reservedIdentifiers;
std::unordered_map<std::size_t, ConstantValue> optionValues;
bool makeVariableNameUnique = false;
bool reduceLoopsToWhile = false;
bool removeCompoundAssignments = false;
bool removeOptionDeclaration = true;
bool removeOptionDeclaration = false;
bool removeScalarSwizzling = false;
bool splitMultipleBranches = false;
};
@@ -77,6 +78,7 @@ namespace Nz::ShaderAst
StatementPtr Clone(DeclareVariableStatement& node) override;
StatementPtr Clone(DiscardStatement& node) override;
StatementPtr Clone(ExpressionStatement& node) override;
StatementPtr Clone(ForEachStatement& node) override;
StatementPtr Clone(MultiStatement& node) override;
StatementPtr Clone(WhileStatement& node) override;
@@ -117,6 +119,8 @@ namespace Nz::ShaderAst
void SanitizeIdentifier(std::string& identifier);
void Validate(AccessIndexExpression& node);
void Validate(AssignExpression& node);
void Validate(BinaryExpression& node);
void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration);
void Validate(CastExpression& node);
void Validate(DeclareVariableStatement& node);

View File

@@ -71,6 +71,7 @@ namespace Nz
void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params);
void AppendStatementList(std::vector<ShaderAst::StatementPtr>& statements);
void AppendVariableDeclaration(const ShaderAst::ExpressionType& varType, const std::string& varName);
void EnterScope();
void LeaveScope(bool skipLine = true);

View File

@@ -99,6 +99,7 @@ namespace Nz
void Visit(ShaderAst::BranchStatement& node) override;
void Visit(ShaderAst::ConditionalStatement& node) override;
void Visit(ShaderAst::DeclareConstStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
void Visit(ShaderAst::DeclareOptionStatement& node) override;
@@ -106,6 +107,7 @@ namespace Nz
void Visit(ShaderAst::DeclareVariableStatement& node) override;
void Visit(ShaderAst::DiscardStatement& node) override;
void Visit(ShaderAst::ExpressionStatement& node) override;
void Visit(ShaderAst::ForEachStatement& node) override;
void Visit(ShaderAst::MultiStatement& node) override;
void Visit(ShaderAst::NoOpStatement& node) override;
void Visit(ShaderAst::ReturnStatement& node) override;

View File

@@ -19,7 +19,9 @@ namespace Nz::ShaderBuilder
{
struct AccessIndex
{
inline std::unique_ptr<ShaderAst::AccessIndexExpression> operator()(ShaderAst::ExpressionPtr expr, Int32 index) const;
inline std::unique_ptr<ShaderAst::AccessIndexExpression> operator()(ShaderAst::ExpressionPtr expr, const std::vector<Int32>& indexConstants) const;
inline std::unique_ptr<ShaderAst::AccessIndexExpression> operator()(ShaderAst::ExpressionPtr expr, ShaderAst::ExpressionPtr indexExpression) const;
inline std::unique_ptr<ShaderAst::AccessIndexExpression> operator()(ShaderAst::ExpressionPtr expr, std::vector<ShaderAst::ExpressionPtr> indexExpressions) const;
};
@@ -106,6 +108,12 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::ExpressionStatement> operator()(ShaderAst::ExpressionPtr expression) const;
};
template<bool Const>
struct ForEach
{
inline std::unique_ptr<ShaderAst::ForEachStatement> operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const;
};
struct Identifier
{
inline std::unique_ptr<ShaderAst::IdentifierExpression> operator()(std::string name) const;
@@ -143,11 +151,16 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::UnaryExpression> operator()(ShaderAst::UnaryType op, ShaderAst::ExpressionPtr expression) const;
};
struct Variable
{
inline std::unique_ptr<ShaderAst::VariableExpression> operator()(std::size_t variableId, ShaderAst::ExpressionType expressionType) const;
};
struct While
{
inline std::unique_ptr<ShaderAst::WhileStatement> operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const;
};
}
}
constexpr Impl::AccessIndex AccessIndex;
constexpr Impl::AccessMember AccessMember;
@@ -160,6 +173,7 @@ namespace Nz::ShaderBuilder
constexpr Impl::ConditionalStatement ConditionalStatement;
constexpr Impl::Constant Constant;
constexpr Impl::Branch<true> ConstBranch;
constexpr Impl::ForEach<false> ConstForEach;
constexpr Impl::DeclareConst DeclareConst;
constexpr Impl::DeclareFunction DeclareFunction;
constexpr Impl::DeclareOption DeclareOption;
@@ -167,6 +181,7 @@ namespace Nz::ShaderBuilder
constexpr Impl::DeclareVariable DeclareVariable;
constexpr Impl::ExpressionStatement ExpressionStatement;
constexpr Impl::NoParam<ShaderAst::DiscardStatement> Discard;
constexpr Impl::ForEach<false> ForEach;
constexpr Impl::Identifier Identifier;
constexpr Impl::Intrinsic Intrinsic;
constexpr Impl::Multi MultiStatement;
@@ -174,6 +189,7 @@ namespace Nz::ShaderBuilder
constexpr Impl::Return Return;
constexpr Impl::Swizzle Swizzle;
constexpr Impl::Unary Unary;
constexpr Impl::Variable Variable;
constexpr Impl::While While;
}

View File

@@ -16,6 +16,15 @@ namespace Nz::ShaderBuilder
return accessMemberNode;
}
inline std::unique_ptr<ShaderAst::AccessIndexExpression> Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, Int32 index) const
{
auto accessMemberNode = std::make_unique<ShaderAst::AccessIndexExpression>();
accessMemberNode->expr = std::move(expr);
accessMemberNode->indices.push_back(ShaderBuilder::Constant(index));
return accessMemberNode;
}
inline std::unique_ptr<ShaderAst::AccessIndexExpression> Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, const std::vector<Int32>& indexConstants) const
{
auto accessMemberNode = std::make_unique<ShaderAst::AccessIndexExpression>();
@@ -28,6 +37,15 @@ namespace Nz::ShaderBuilder
return accessMemberNode;
}
inline std::unique_ptr<ShaderAst::AccessIndexExpression> Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, ShaderAst::ExpressionPtr indexExpression) const
{
auto accessMemberNode = std::make_unique<ShaderAst::AccessIndexExpression>();
accessMemberNode->expr = std::move(expr);
accessMemberNode->indices.push_back(std::move(indexExpression));
return accessMemberNode;
}
inline std::unique_ptr<ShaderAst::AccessIndexExpression> Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, std::vector<ShaderAst::ExpressionPtr> indexExpressions) const
{
auto accessMemberNode = std::make_unique<ShaderAst::AccessIndexExpression>();
@@ -136,6 +154,7 @@ namespace Nz::ShaderBuilder
{
auto constantNode = std::make_unique<ShaderAst::ConstantValueExpression>();
constantNode->value = std::move(value);
constantNode->cachedExpressionType = ShaderAst::GetExpressionType(constantNode->value);
return constantNode;
}
@@ -250,6 +269,18 @@ namespace Nz::ShaderBuilder
return expressionStatementNode;
}
template<bool Const>
std::unique_ptr<ShaderAst::ForEachStatement> Impl::ForEach<Const>::operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const
{
auto forEachNode = std::make_unique<ShaderAst::ForEachStatement>();
forEachNode->isConst = Const;
forEachNode->expression = std::move(expression);
forEachNode->statement = std::move(statement);
forEachNode->varName = std::move(varName);
return forEachNode;
}
inline std::unique_ptr<ShaderAst::IdentifierExpression> Impl::Identifier::operator()(std::string name) const
{
auto identifierNode = std::make_unique<ShaderAst::IdentifierExpression>();
@@ -327,6 +358,15 @@ namespace Nz::ShaderBuilder
return unaryNode;
}
inline std::unique_ptr<ShaderAst::VariableExpression> Impl::Variable::operator()(std::size_t variableId, ShaderAst::ExpressionType expressionType) const
{
auto varNode = std::make_unique<ShaderAst::VariableExpression>();
varNode->variableId = variableId;
varNode->cachedExpressionType = std::move(expressionType);
return varNode;
}
inline std::unique_ptr<ShaderAst::WhileStatement> Impl::While::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr body) const
{
auto whileNode = std::make_unique<ShaderAst::WhileStatement>();

View File

@@ -88,6 +88,7 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr ParseConstStatement();
ShaderAst::StatementPtr ParseDiscardStatement();
ShaderAst::StatementPtr ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes = {});
ShaderAst::StatementPtr ParseForDeclaration();
std::vector<ShaderAst::StatementPtr> ParseFunctionBody();
ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter();

View File

@@ -31,6 +31,7 @@ NAZARA_SHADERLANG_TOKEN(Else)
NAZARA_SHADERLANG_TOKEN(EndOfStream)
NAZARA_SHADERLANG_TOKEN(External)
NAZARA_SHADERLANG_TOKEN(FloatingPointValue)
NAZARA_SHADERLANG_TOKEN(For)
NAZARA_SHADERLANG_TOKEN(FunctionDeclaration)
NAZARA_SHADERLANG_TOKEN(FunctionReturn)
NAZARA_SHADERLANG_TOKEN(GreaterThan)
@@ -38,6 +39,7 @@ NAZARA_SHADERLANG_TOKEN(GreaterThanEqual)
NAZARA_SHADERLANG_TOKEN(IntegerValue)
NAZARA_SHADERLANG_TOKEN(Identifier)
NAZARA_SHADERLANG_TOKEN(If)
NAZARA_SHADERLANG_TOKEN(In)
NAZARA_SHADERLANG_TOKEN(LessThan)
NAZARA_SHADERLANG_TOKEN(LessThanEqual)
NAZARA_SHADERLANG_TOKEN(Let)

View File

@@ -42,7 +42,8 @@ namespace Nz
struct Array
{
TypePtr elementType;
UInt32 length;
ConstantPtr length;
std::optional<UInt32> stride;
};
struct Bool {};
@@ -129,7 +130,7 @@ namespace Nz
struct ConstantScalar
{
std::variant<float, double, Nz::Int32, Nz::Int64, Nz::UInt32, Nz::UInt64> value;
std::variant<float, double, Int32, Int64, UInt32, UInt64> value;
};
using AnyConstant = std::variant<ConstantBool, ConstantComposite, ConstantScalar>;
@@ -174,6 +175,7 @@ namespace Nz
TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const;
TypePtr BuildType(const ShaderAst::ArrayType& type) const;
TypePtr BuildType(const ShaderAst::ExpressionType& type) const;
TypePtr BuildType(const ShaderAst::IdentifierType& type) const;
TypePtr BuildType(const ShaderAst::MatrixType& type) const;