Shader: Add proper support for alias

This commit is contained in:
Jérôme Leclercq
2022-03-09 12:35:00 +01:00
parent ce93b61c91
commit 05cf98477e
31 changed files with 472 additions and 98 deletions

View File

@@ -42,6 +42,7 @@ namespace Nz::ShaderAst
virtual ExpressionPtr Clone(AccessIdentifierExpression& node);
virtual ExpressionPtr Clone(AccessIndexExpression& node);
virtual ExpressionPtr Clone(AliasValueExpression& node);
virtual ExpressionPtr Clone(AssignExpression& node);
virtual ExpressionPtr Clone(BinaryExpression& node);
virtual ExpressionPtr Clone(CallFunctionExpression& node);

View File

@@ -34,6 +34,7 @@ namespace Nz::ShaderAst
inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs);
inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs);
inline bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs);
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs);
inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs);
inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs);

View File

@@ -248,6 +248,14 @@ namespace Nz::ShaderAst
return true;
}
bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs)
{
if (!Compare(lhs.aliasId, rhs.aliasId))
return false;
return true;
}
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs)
{
if (!Compare(lhs.op, rhs.op))

View File

@@ -30,6 +30,7 @@
NAZARA_SHADERAST_EXPRESSION(AccessIdentifierExpression)
NAZARA_SHADERAST_EXPRESSION(AccessIndexExpression)
NAZARA_SHADERAST_EXPRESSION(AliasValueExpression)
NAZARA_SHADERAST_EXPRESSION(AssignExpression)
NAZARA_SHADERAST_EXPRESSION(BinaryExpression)
NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression)

View File

@@ -22,6 +22,7 @@ namespace Nz::ShaderAst
void Visit(AccessIdentifierExpression& node) override;
void Visit(AccessIndexExpression& node) override;
void Visit(AliasValueExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CallFunctionExpression& node) override;

View File

@@ -25,6 +25,7 @@ namespace Nz::ShaderAst
void Serialize(AccessIdentifierExpression& node);
void Serialize(AccessIndexExpression& node);
void Serialize(AliasValueExpression& node);
void Serialize(AssignExpression& node);
void Serialize(BinaryExpression& node);
void Serialize(CallFunctionExpression& node);

View File

@@ -33,6 +33,7 @@ namespace Nz::ShaderAst
void Visit(AccessIdentifierExpression& node) override;
void Visit(AccessIndexExpression& node) override;
void Visit(AliasValueExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CallFunctionExpression& node) override;

View File

@@ -20,6 +20,22 @@ namespace Nz::ShaderAst
{
struct ContainedType;
struct NAZARA_SHADER_API AliasType
{
AliasType() = default;
AliasType(const AliasType& alias);
AliasType(AliasType&&) noexcept = default;
AliasType& operator=(const AliasType& alias);
AliasType& operator=(AliasType&&) noexcept = default;
std::size_t aliasIndex;
std::unique_ptr<ContainedType> targetType;
bool operator==(const AliasType& rhs) const;
inline bool operator!=(const AliasType& rhs) const;
};
struct NAZARA_SHADER_API ArrayType
{
ArrayType() = default;
@@ -134,7 +150,7 @@ namespace Nz::ShaderAst
inline bool operator!=(const VectorType& rhs) const;
};
using ExpressionType = std::variant<NoType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
using ExpressionType = std::variant<NoType, AliasType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
struct ContainedType
{
@@ -157,6 +173,7 @@ namespace Nz::ShaderAst
std::vector<StructMember> members;
};
inline bool IsAliasType(const ExpressionType& type);
inline bool IsArrayType(const ExpressionType& type);
inline bool IsFunctionType(const ExpressionType& type);
inline bool IsIdentifierType(const ExpressionType& type);

View File

@@ -8,6 +8,11 @@
namespace Nz::ShaderAst
{
inline bool AliasType::operator!=(const AliasType& rhs) const
{
return !operator==(rhs);
}
inline bool ArrayType::operator!=(const ArrayType& rhs) const
{
return !operator==(rhs);
@@ -130,6 +135,11 @@ namespace Nz::ShaderAst
}
inline bool IsAliasType(const ExpressionType& type)
{
return std::holds_alternative<AliasType>(type);
}
inline bool IsArrayType(const ExpressionType& type)
{
return std::holds_alternative<ArrayType>(type);

View File

@@ -82,6 +82,14 @@ namespace Nz::ShaderAst
ExpressionPtr expr;
};
struct NAZARA_SHADER_API AliasValueExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t aliasId;
};
struct NAZARA_SHADER_API AssignExpression : Expression
{
NodeType GetType() const override;
@@ -153,7 +161,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
ShaderAst::ConstantValue value;
ConstantValue value;
};
struct NAZARA_SHADER_API FunctionExpression : Expression
@@ -207,7 +215,6 @@ namespace Nz::ShaderAst
ExpressionPtr expression;
};
struct NAZARA_SHADER_API VariableExpression : Expression
struct NAZARA_SHADER_API VariableValueExpression : Expression
{
NodeType GetType() const override;
@@ -455,10 +462,12 @@ namespace Nz::ShaderAst
#include <Nazara/Shader/Ast/AstNodeList.hpp>
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr);
inline ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr);
inline const ExpressionType& GetExpressionType(Expression& expr);
inline ExpressionType& GetExpressionTypeMut(Expression& expr);
inline bool IsExpression(NodeType nodeType);
inline bool IsStatement(NodeType nodeType);
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType);
}
#include <Nazara/Shader/Ast/Nodes.inl>

View File

@@ -7,18 +7,29 @@
namespace Nz::ShaderAst
{
const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr)
inline const ExpressionType& GetExpressionType(Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
}
ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr)
inline ExpressionType& GetExpressionTypeMut(Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
}
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)
{
if (IsAliasType(exprType))
{
const AliasType& alias = std::get<AliasType>(exprType);
return alias.targetType->type;
}
else
return exprType;
}
inline bool IsExpression(NodeType nodeType)
{
switch (nodeType)

View File

@@ -47,6 +47,7 @@ namespace Nz::ShaderAst
std::unordered_map<UInt32, ConstantValue> optionValues;
bool makeVariableNameUnique = false;
bool reduceLoopsToWhile = false;
bool removeAliases = false;
bool removeConstDeclaration = false;
bool removeCompoundAssignments = false;
bool removeMatrixCast = false;
@@ -71,6 +72,7 @@ namespace Nz::ShaderAst
ExpressionPtr Clone(AccessIdentifierExpression& node) override;
ExpressionPtr Clone(AccessIndexExpression& node) override;
ExpressionPtr Clone(AliasValueExpression& node) override;
ExpressionPtr Clone(AssignExpression& node) override;
ExpressionPtr Clone(BinaryExpression& node) override;
ExpressionPtr Clone(CallFunctionExpression& node) override;
@@ -136,15 +138,16 @@ namespace Nz::ShaderAst
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
const IdentifierData* ResolveAlias(const IdentifierData* identifier) const;
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
std::size_t ResolveStruct(const AliasType& aliasType);
std::size_t ResolveStruct(const ExpressionType& exprType);
std::size_t ResolveStruct(const IdentifierType& identifierType);
std::size_t ResolveStruct(const StructType& structType);
std::size_t ResolveStruct(const UniformType& uniformType);
ExpressionType ResolveType(const ExpressionType& exprType);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue);
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
void SanitizeIdentifier(std::string& identifier);
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);

View File

@@ -51,6 +51,7 @@ namespace Nz
static ShaderAst::ModulePtr Sanitize(const ShaderAst::Module& module, std::unordered_map<UInt32, ShaderAst::ConstantValue> optionValues, std::string* error = nullptr);
private:
void Append(const ShaderAst::AliasType& aliasType);
void Append(const ShaderAst::ArrayType& type);
void Append(ShaderAst::BuiltinEntry builtin);
void Append(const ShaderAst::ExpressionType& type);
@@ -94,6 +95,7 @@ namespace Nz
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
void Visit(ShaderAst::AccessIndexExpression& node) override;
void Visit(ShaderAst::AliasValueExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CallFunctionExpression& node) override;

View File

@@ -50,6 +50,7 @@ namespace Nz
struct UnrollAttribute;
struct UuidAttribute;
void Append(const ShaderAst::AliasType& type);
void Append(const ShaderAst::ArrayType& type);
void Append(const ShaderAst::ExpressionType& type);
void Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type);
@@ -92,6 +93,7 @@ namespace Nz
void EnterScope();
void LeaveScope(bool skipLine = true);
void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
void RegisterConstant(std::size_t constantIndex, std::string constantName);
void RegisterStruct(std::size_t structIndex, std::string structName);
void RegisterVariable(std::size_t varIndex, std::string varName);
@@ -102,6 +104,7 @@ namespace Nz
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
void Visit(ShaderAst::AccessIndexExpression& node) override;
void Visit(ShaderAst::AliasValueExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override;

View File

@@ -48,8 +48,6 @@ namespace Nz
void Visit(ShaderAst::CallFunctionExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConstantValueExpression& node) override;
void Visit(ShaderAst::DeclareAliasStatement& node) override;
void Visit(ShaderAst::DeclareConstStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
void Visit(ShaderAst::DeclareOptionStatement& node) override;
@@ -64,7 +62,7 @@ namespace Nz
void Visit(ShaderAst::ScopedStatement& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;
void Visit(ShaderAst::UnaryExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override;
void Visit(ShaderAst::VariableValueExpression& node) override;
void Visit(ShaderAst::WhileStatement& node) override;
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;

View File

@@ -178,6 +178,7 @@ namespace Nz
TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const;
TypePtr BuildType(const ShaderAst::AliasType& type) const;
TypePtr BuildType(const ShaderAst::ArrayType& type) const;
TypePtr BuildType(const ShaderAst::ExpressionType& type) const;
TypePtr BuildType(const ShaderAst::IdentifierType& type) const;