Shader: Add initial support for arrays

This commit is contained in:
Jérôme Leclercq
2022-01-01 23:01:31 +01:00
parent 89c7bbf197
commit 1f15328fdd
22 changed files with 781 additions and 57 deletions

View File

@@ -24,6 +24,7 @@ namespace Nz::ShaderAst
AstCloner(AstCloner&&) = delete;
~AstCloner() = default;
template<typename T> AttributeValue<T> Clone(const AttributeValue<T>& attribute);
ExpressionPtr Clone(Expression& statement);
StatementPtr Clone(Statement& statement);
@@ -31,7 +32,6 @@ namespace Nz::ShaderAst
AstCloner& operator=(AstCloner&&) = delete;
protected:
template<typename T> AttributeValue<T> CloneAttribute(const AttributeValue<T>& attribute);
inline ExpressionPtr CloneExpression(const ExpressionPtr& expr);
inline StatementPtr CloneStatement(const StatementPtr& statement);
@@ -83,6 +83,7 @@ namespace Nz::ShaderAst
std::vector<StatementPtr> m_statementStack;
};
template<typename T> AttributeValue<T> Clone(const AttributeValue<T>& attribute);
inline ExpressionPtr Clone(Expression& node);
inline StatementPtr Clone(Statement& node);
}

View File

@@ -8,7 +8,7 @@
namespace Nz::ShaderAst
{
template<typename T>
AttributeValue<T> AstCloner::CloneAttribute(const AttributeValue<T>& attribute)
AttributeValue<T> AstCloner::Clone(const AttributeValue<T>& attribute)
{
if (!attribute.HasValue())
return {};
@@ -38,6 +38,14 @@ namespace Nz::ShaderAst
return CloneStatement(*statement);
}
template<typename T>
AttributeValue<T> Clone(const AttributeValue<T>& attribute)
{
AstCloner cloner;
return cloner.Clone(attribute);
}
inline ExpressionPtr Clone(Expression& node)
{
AstCloner cloner;

View File

@@ -0,0 +1,65 @@
// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Engine - Shader module"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADER_AST_ASTCOMPARE_HPP
#define NAZARA_SHADER_AST_ASTCOMPARE_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/Ast/Attribute.hpp>
#include <Nazara/Shader/Ast/Nodes.hpp>
#include <vector>
namespace Nz::ShaderAst
{
inline bool Compare(const Expression& lhs, const Expression& rhs);
inline bool Compare(const Statement& lhs, const Statement& rhs);
template<typename T> bool Compare(const T& lhs, const T& rhs);
template<typename T, std::size_t S> bool Compare(const std::array<T, S>& lhs, const std::array<T, S>& rhs);
template<typename T> bool Compare(const std::vector<T>& lhs, const std::vector<T>& rhs);
template<typename T> bool Compare(const AttributeValue<T>& lhs, const AttributeValue<T>& rhs);
inline bool Compare(const BranchStatement::ConditionalStatement& lhs, const BranchStatement::ConditionalStatement& rhs);
inline bool Compare(const DeclareExternalStatement::ExternalVar& lhs, const DeclareExternalStatement::ExternalVar& rhs);
inline bool Compare(const DeclareFunctionStatement::Parameter& lhs, const DeclareFunctionStatement::Parameter& rhs);
inline bool Compare(const StructDescription& lhs, const StructDescription& rhs);
inline bool Compare(const StructDescription::StructMember& lhs, const StructDescription::StructMember& rhs);
inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs);
inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs);
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs);
inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs);
inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs);
inline bool Compare(const CallMethodExpression& lhs, const CallMethodExpression& rhs);
inline bool Compare(const CastExpression& lhs, const CastExpression& rhs);
inline bool Compare(const ConditionalExpression& lhs, const ConditionalExpression& rhs);
inline bool Compare(const ConstantExpression& lhs, const ConstantExpression& rhs);
inline bool Compare(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs);
inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs);
inline bool Compare(const IntrinsicExpression& lhs, const IntrinsicExpression& rhs);
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs);
inline bool Compare(const VariableExpression& lhs, const VariableExpression& rhs);
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs);
inline bool Compare(const BranchStatement& lhs, const BranchStatement& rhs);
inline bool Compare(const ConditionalStatement& lhs, const ConditionalStatement& rhs);
inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs);
inline bool Compare(const DeclareExternalStatement& lhs, const DeclareExternalStatement& rhs);
inline bool Compare(const DeclareFunctionStatement& lhs, const DeclareFunctionStatement& rhs);
inline bool Compare(const DeclareOptionStatement& lhs, const DeclareOptionStatement& rhs);
inline bool Compare(const DeclareStructStatement& lhs, const DeclareStructStatement& rhs);
inline bool Compare(const DeclareVariableStatement& lhs, const DeclareVariableStatement& rhs);
inline bool Compare(const DiscardStatement& lhs, const DiscardStatement& rhs);
inline bool Compare(const ExpressionStatement& lhs, const ExpressionStatement& rhs);
inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs);
inline bool Compare(const NoOpStatement& lhs, const NoOpStatement& rhs);
inline bool Compare(const ReturnStatement& lhs, const ReturnStatement& rhs);
inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs);
}
#include <Nazara/Shader/Ast/AstCompare.inl>
#endif // NAZARA_SHADER_AST_ASTCOMPARE_HPP

View File

@@ -0,0 +1,494 @@
// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Engine - Shader module"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/AstCompare.hpp>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline bool Compare(const Expression& lhs, const Expression& rhs)
{
if (lhs.GetType() != rhs.GetType())
return false;
switch (lhs.GetType())
{
case NodeType::None: break;
#define NAZARA_SHADERAST_EXPRESSION(Node) case NodeType::Node: return Compare(static_cast<const Node&>(lhs), static_cast<const Node&>(lhs));
#include <Nazara/Shader/Ast/AstNodeList.hpp>
default: throw std::runtime_error("unexpected node type");
}
return true;
}
inline bool Compare(const Statement& lhs, const Statement& rhs)
{
if (lhs.GetType() != rhs.GetType())
return false;
switch (lhs.GetType())
{
case NodeType::None: break;
#define NAZARA_SHADERAST_STATEMENT(Node) case NodeType::Node: return Compare(static_cast<const Node&>(lhs), static_cast<const Node&>(lhs));
#include <Nazara/Shader/Ast/AstNodeList.hpp>
default: throw std::runtime_error("unexpected node type");
}
return false;
}
template<typename T>
bool Compare(const T& lhs, const T& rhs)
{
return lhs == rhs;
}
template<typename T, std::size_t S>
bool Compare(const std::array<T, S>& lhs, const std::array<T, S>& rhs)
{
for (std::size_t i = 0; i < S; ++i)
{
if (!Compare(lhs[i], rhs[i]))
return false;
}
return true;
}
template<typename T>
bool Compare(const std::vector<T>& lhs, const std::vector<T>& rhs)
{
if (lhs.size() != rhs.size())
return false;
for (std::size_t i = 0; i < lhs.size(); ++i)
{
if (!Compare(lhs[i], rhs[i]))
return false;
}
return true;
}
template<typename T>
bool Compare(const AttributeValue<T>& lhs, const AttributeValue<T>& rhs)
{
if (!Compare(lhs.HasValue(), rhs.HasValue()))
return false;
if (!Compare(lhs.IsResultingValue(), rhs.IsResultingValue()))
return false;
if (!Compare(lhs.IsExpression(), rhs.IsExpression()))
return false;
if (lhs.IsExpression())
{
if (!Compare(lhs.GetExpression(), rhs.GetExpression()))
return false;
}
else if (lhs.IsResultingValue())
{
if (!Compare(lhs.GetResultingValue(), rhs.GetResultingValue()))
return false;
}
return true;
}
inline bool Compare(const BranchStatement::ConditionalStatement& lhs, const BranchStatement::ConditionalStatement& rhs)
{
if (!Compare(lhs.condition, rhs.condition))
return false;
if (!Compare(lhs.statement, rhs.statement))
return false;
return true;
}
inline bool Compare(const DeclareExternalStatement::ExternalVar& lhs, const DeclareExternalStatement::ExternalVar& rhs)
{
if (!Compare(lhs.bindingIndex, rhs.bindingIndex))
return false;
if (!Compare(lhs.bindingSet, rhs.bindingSet))
return false;
if (!Compare(lhs.name, rhs.name))
return false;
if (!Compare(lhs.type, rhs.type))
return false;
return true;
}
inline bool Compare(const DeclareFunctionStatement::Parameter& lhs, const DeclareFunctionStatement::Parameter& rhs)
{
if (!Compare(lhs.name, rhs.name))
return false;
if (!Compare(lhs.type, rhs.type))
return false;
return true;
}
inline bool Compare(const StructDescription& lhs, const StructDescription& rhs)
{
if (!Compare(lhs.layout, rhs.layout))
return false;
if (!Compare(lhs.name, rhs.name))
return false;
if (!Compare(lhs.members, rhs.members))
return false;
return true;
}
inline bool Compare(const StructDescription::StructMember& lhs, const StructDescription::StructMember& rhs)
{
if (!Compare(lhs.builtin, rhs.builtin))
return false;
if (!Compare(lhs.cond, rhs.cond))
return false;
if (!Compare(lhs.locationIndex, rhs.locationIndex))
return false;
if (!Compare(lhs.name, rhs.name))
return false;
if (!Compare(lhs.type, rhs.type))
return false;
return true;
}
inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs)
{
if (!Compare(*lhs.expr, *rhs.expr))
return false;
if (!Compare(lhs.identifiers, rhs.identifiers))
return false;
return true;
}
inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs)
{
if (!Compare(*lhs.expr, *rhs.expr))
return false;
if (!Compare(lhs.indices, rhs.indices))
return false;
return true;
}
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs)
{
if (!Compare(lhs.op, rhs.op))
return false;
if (!Compare(lhs.left, rhs.left))
return false;
if (!Compare(lhs.right, rhs.right))
return false;
return true;
}
inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs)
{
if (!Compare(lhs.op, rhs.op))
return false;
if (!Compare(lhs.left, rhs.left))
return false;
if (!Compare(lhs.right, rhs.right))
return false;
return true;
}
inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs)
{
if (!Compare(lhs.targetFunction, rhs.targetFunction))
return false;
if (!Compare(lhs.parameters, rhs.parameters))
return false;
return true;
}
inline bool Compare(const CallMethodExpression& lhs, const CallMethodExpression& rhs)
{
if (!Compare(lhs.methodName, rhs.methodName))
return false;
if (!Compare(lhs.object, rhs.object))
return false;
if (!Compare(lhs.parameters, rhs.parameters))
return false;
return true;
}
inline bool Compare(const CastExpression& lhs, const CastExpression& rhs)
{
if (!Compare(lhs.targetType, rhs.targetType))
return false;
if (!Compare(lhs.expressions, rhs.expressions))
return false;
return true;
}
inline bool Compare(const ConditionalExpression& lhs, const ConditionalExpression& rhs)
{
if (!Compare(lhs.condition, rhs.condition))
return false;
if (!Compare(lhs.truePath, rhs.truePath))
return false;
if (!Compare(lhs.falsePath, rhs.falsePath))
return false;
return true;
}
inline bool Compare(const ConstantExpression& lhs, const ConstantExpression& rhs)
{
if (!Compare(lhs.constantId, rhs.constantId))
return false;
return true;
}
inline bool Compare(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs)
{
if (!Compare(lhs.value, rhs.value))
return false;
return true;
}
inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs)
{
if (!Compare(lhs.identifier, rhs.identifier))
return false;
return true;
}
inline bool Compare(const IntrinsicExpression& lhs, const IntrinsicExpression& rhs)
{
if (!Compare(lhs.intrinsic, rhs.intrinsic))
return false;
if (!Compare(lhs.parameters, rhs.parameters))
return false;
return true;
}
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs)
{
if (!Compare(lhs.componentCount, rhs.componentCount))
return false;
if (!Compare(lhs.expression, rhs.expression))
return false;
if (!Compare(lhs.components, rhs.components))
return false;
return true;
}
inline bool Compare(const VariableExpression& lhs, const VariableExpression& rhs)
{
if (!Compare(lhs.variableId, rhs.variableId))
return false;
return true;
}
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs)
{
if (!Compare(lhs.op, rhs.op))
return false;
if (!Compare(lhs.expression, rhs.expression))
return false;
return true;
}
inline bool Compare(const BranchStatement& lhs, const BranchStatement& rhs)
{
if (!Compare(lhs.isConst, rhs.isConst))
return false;
if (!Compare(lhs.elseStatement, rhs.elseStatement))
return false;
if (!Compare(lhs.condStatements, rhs.condStatements))
return false;
return true;
}
inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs)
{
if (!Compare(lhs.name, rhs.name))
return false;
if (!Compare(lhs.type, rhs.type))
return false;
if (!Compare(lhs.expression, rhs.expression))
return false;
return true;
}
inline bool Compare(const DeclareExternalStatement& lhs, const DeclareExternalStatement& rhs)
{
if (!Compare(lhs.bindingSet, rhs.bindingSet))
return false;
if (!Compare(lhs.externalVars, rhs.externalVars))
return false;
return true;
}
inline bool Compare(const DeclareFunctionStatement& lhs, const DeclareFunctionStatement& rhs)
{
if (!Compare(lhs.depthWrite, rhs.depthWrite))
return false;
if (!Compare(lhs.earlyFragmentTests, rhs.earlyFragmentTests))
return false;
if (!Compare(lhs.entryStage, rhs.entryStage))
return false;
if (!Compare(lhs.name, rhs.name))
return false;
if (!Compare(lhs.parameters, rhs.parameters))
return false;
if (!Compare(lhs.returnType, rhs.returnType))
return false;
if (!Compare(lhs.statements, rhs.statements))
return false;
return true;
}
inline bool Compare(const DeclareOptionStatement& lhs, const DeclareOptionStatement& rhs)
{
if (!Compare(lhs.optName, rhs.optName))
return false;
if (!Compare(lhs.optType, rhs.optType))
return false;
if (!Compare(lhs.defaultValue, rhs.defaultValue))
return false;
return true;
}
inline bool Compare(const DeclareStructStatement& lhs, const DeclareStructStatement& rhs)
{
if (!Compare(lhs.description, rhs.description))
return false;
return true;
}
inline bool Compare(const DeclareVariableStatement& lhs, const DeclareVariableStatement& rhs)
{
if (!Compare(lhs.varName, rhs.varName))
return false;
if (!Compare(lhs.varType, rhs.varType))
return false;
if (!Compare(lhs.initialExpression, rhs.initialExpression))
return false;
return true;
}
inline bool Compare(const DiscardStatement& /*lhs*/, const DiscardStatement& /*rhs*/)
{
return true;
}
inline bool Compare(const ExpressionStatement& lhs, const ExpressionStatement& rhs)
{
if (!Compare(lhs.expression, rhs.expression))
return false;
return true;
}
inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs)
{
if (!Compare(lhs.statements, rhs.statements))
return false;
return true;
}
inline bool Compare(const NoOpStatement& /*lhs*/, const NoOpStatement& /*rhs*/)
{
return true;
}
inline bool Compare(const ReturnStatement& lhs, const ReturnStatement& rhs)
{
if (!Compare(lhs.returnExpr, rhs.returnExpr))
return false;
return true;
}
inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs)
{
if (!Compare(lhs.condition, rhs.condition))
return false;
if (!Compare(lhs.body, rhs.body))
return false;
return true;
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@@ -18,6 +18,24 @@
namespace Nz::ShaderAst
{
struct ContainedType;
struct NAZARA_SHADER_API ArrayType
{
ArrayType() = default;
ArrayType(const ArrayType& array);
ArrayType(ArrayType&&) noexcept = default;
ArrayType& operator=(const ArrayType& array);
ArrayType& operator=(ArrayType&&) noexcept = default;
AttributeValue<UInt32> length;
std::unique_ptr<ContainedType> containedType;
bool operator==(const ArrayType& rhs) const;
inline bool operator!=(const ArrayType& rhs) const;
};
struct IdentifierType //< Alias or struct
{
std::string name;
@@ -76,7 +94,12 @@ namespace Nz::ShaderAst
inline bool operator!=(const VectorType& rhs) const;
};
using ExpressionType = std::variant<NoType, IdentifierType, PrimitiveType, MatrixType, SamplerType, StructType, UniformType, VectorType>;
using ExpressionType = std::variant<NoType, ArrayType, IdentifierType, PrimitiveType, MatrixType, SamplerType, StructType, UniformType, VectorType>;
struct ContainedType
{
ExpressionType type;
};
struct StructDescription
{
@@ -94,6 +117,7 @@ namespace Nz::ShaderAst
std::vector<StructMember> members;
};
inline bool IsArrayType(const ExpressionType& type);
inline bool IsIdentifierType(const ExpressionType& type);
inline bool IsMatrixType(const ExpressionType& type);
inline bool IsNoType(const ExpressionType& type);

View File

@@ -8,6 +8,12 @@
namespace Nz::ShaderAst
{
inline bool ArrayType::operator!=(const ArrayType& rhs) const
{
return !operator==(rhs);
}
inline bool IdentifierType::operator==(const IdentifierType& rhs) const
{
return name == rhs.name;
@@ -84,6 +90,11 @@ namespace Nz::ShaderAst
}
bool IsArrayType(const ExpressionType& type)
{
return std::holds_alternative<ArrayType>(type);
}
inline bool IsIdentifierType(const ExpressionType& type)
{
return std::holds_alternative<IdentifierType>(type);

View File

@@ -50,6 +50,7 @@ namespace Nz
static ShaderAst::StatementPtr Sanitize(ShaderAst::Statement& ast, std::unordered_map<std::size_t, ShaderAst::ConstantValue> optionValues, std::string* error = nullptr);
private:
void Append(const ShaderAst::ArrayType& type);
void Append(const ShaderAst::ExpressionType& type);
void Append(ShaderAst::BuiltinEntry builtin);
void Append(const ShaderAst::IdentifierType& identifierType);

View File

@@ -46,6 +46,7 @@ namespace Nz
struct LocationAttribute;
struct SetAttribute;
void Append(const ShaderAst::ArrayType& type);
void Append(const ShaderAst::ExpressionType& type);
void Append(const ShaderAst::IdentifierType& identifierType);
void Append(const ShaderAst::MatrixType& matrixType);

View File

@@ -112,6 +112,7 @@ namespace Nz::ShaderLang
ShaderAst::ExpressionPtr ParsePrimaryExpression();
ShaderAst::ExpressionPtr ParseVariableAssignation();
ShaderAst::ExpressionType ParseArrayType();
ShaderAst::AttributeType ParseIdentifierAsAttributeType();
const std::string& ParseIdentifierAsName();
ShaderAst::PrimitiveType ParsePrimitiveType();

View File

@@ -39,6 +39,12 @@ namespace Nz
using ConstantPtr = std::shared_ptr<Constant>;
using TypePtr = std::shared_ptr<Type>;
struct Array
{
TypePtr elementType;
UInt32 length;
};
struct Bool {};
struct Float
@@ -108,7 +114,7 @@ namespace Nz
std::vector<SpirvDecoration> decorations;
};
using AnyType = std::variant<Bool, Float, Function, Image, Integer, Matrix, Pointer, SampledImage, Structure, Vector, Void>;
using AnyType = std::variant<Array, Bool, Float, Function, Image, Integer, Matrix, Pointer, SampledImage, Structure, Vector, Void>;
struct ConstantBool
{