Shader: Add proper support for alias
This commit is contained in:
parent
ce93b61c91
commit
05cf98477e
|
|
@ -42,6 +42,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
virtual ExpressionPtr Clone(AccessIdentifierExpression& node);
|
virtual ExpressionPtr Clone(AccessIdentifierExpression& node);
|
||||||
virtual ExpressionPtr Clone(AccessIndexExpression& node);
|
virtual ExpressionPtr Clone(AccessIndexExpression& node);
|
||||||
|
virtual ExpressionPtr Clone(AliasValueExpression& node);
|
||||||
virtual ExpressionPtr Clone(AssignExpression& node);
|
virtual ExpressionPtr Clone(AssignExpression& node);
|
||||||
virtual ExpressionPtr Clone(BinaryExpression& node);
|
virtual ExpressionPtr Clone(BinaryExpression& node);
|
||||||
virtual ExpressionPtr Clone(CallFunctionExpression& node);
|
virtual ExpressionPtr Clone(CallFunctionExpression& node);
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs);
|
inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs);
|
||||||
inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs);
|
inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs);
|
||||||
|
inline bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs);
|
||||||
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs);
|
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs);
|
||||||
inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs);
|
inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs);
|
||||||
inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs);
|
inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs);
|
||||||
|
|
|
||||||
|
|
@ -248,6 +248,14 @@ namespace Nz::ShaderAst
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs)
|
||||||
|
{
|
||||||
|
if (!Compare(lhs.aliasId, rhs.aliasId))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs)
|
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs)
|
||||||
{
|
{
|
||||||
if (!Compare(lhs.op, rhs.op))
|
if (!Compare(lhs.op, rhs.op))
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@
|
||||||
|
|
||||||
NAZARA_SHADERAST_EXPRESSION(AccessIdentifierExpression)
|
NAZARA_SHADERAST_EXPRESSION(AccessIdentifierExpression)
|
||||||
NAZARA_SHADERAST_EXPRESSION(AccessIndexExpression)
|
NAZARA_SHADERAST_EXPRESSION(AccessIndexExpression)
|
||||||
|
NAZARA_SHADERAST_EXPRESSION(AliasValueExpression)
|
||||||
NAZARA_SHADERAST_EXPRESSION(AssignExpression)
|
NAZARA_SHADERAST_EXPRESSION(AssignExpression)
|
||||||
NAZARA_SHADERAST_EXPRESSION(BinaryExpression)
|
NAZARA_SHADERAST_EXPRESSION(BinaryExpression)
|
||||||
NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression)
|
NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression)
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void Visit(AccessIdentifierExpression& node) override;
|
void Visit(AccessIdentifierExpression& node) override;
|
||||||
void Visit(AccessIndexExpression& node) override;
|
void Visit(AccessIndexExpression& node) override;
|
||||||
|
void Visit(AliasValueExpression& node) override;
|
||||||
void Visit(AssignExpression& node) override;
|
void Visit(AssignExpression& node) override;
|
||||||
void Visit(BinaryExpression& node) override;
|
void Visit(BinaryExpression& node) override;
|
||||||
void Visit(CallFunctionExpression& node) override;
|
void Visit(CallFunctionExpression& node) override;
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void Serialize(AccessIdentifierExpression& node);
|
void Serialize(AccessIdentifierExpression& node);
|
||||||
void Serialize(AccessIndexExpression& node);
|
void Serialize(AccessIndexExpression& node);
|
||||||
|
void Serialize(AliasValueExpression& node);
|
||||||
void Serialize(AssignExpression& node);
|
void Serialize(AssignExpression& node);
|
||||||
void Serialize(BinaryExpression& node);
|
void Serialize(BinaryExpression& node);
|
||||||
void Serialize(CallFunctionExpression& node);
|
void Serialize(CallFunctionExpression& node);
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void Visit(AccessIdentifierExpression& node) override;
|
void Visit(AccessIdentifierExpression& node) override;
|
||||||
void Visit(AccessIndexExpression& node) override;
|
void Visit(AccessIndexExpression& node) override;
|
||||||
|
void Visit(AliasValueExpression& node) override;
|
||||||
void Visit(AssignExpression& node) override;
|
void Visit(AssignExpression& node) override;
|
||||||
void Visit(BinaryExpression& node) override;
|
void Visit(BinaryExpression& node) override;
|
||||||
void Visit(CallFunctionExpression& node) override;
|
void Visit(CallFunctionExpression& node) override;
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,22 @@ namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
struct ContainedType;
|
struct ContainedType;
|
||||||
|
|
||||||
|
struct NAZARA_SHADER_API AliasType
|
||||||
|
{
|
||||||
|
AliasType() = default;
|
||||||
|
AliasType(const AliasType& alias);
|
||||||
|
AliasType(AliasType&&) noexcept = default;
|
||||||
|
|
||||||
|
AliasType& operator=(const AliasType& alias);
|
||||||
|
AliasType& operator=(AliasType&&) noexcept = default;
|
||||||
|
|
||||||
|
std::size_t aliasIndex;
|
||||||
|
std::unique_ptr<ContainedType> targetType;
|
||||||
|
|
||||||
|
bool operator==(const AliasType& rhs) const;
|
||||||
|
inline bool operator!=(const AliasType& rhs) const;
|
||||||
|
};
|
||||||
|
|
||||||
struct NAZARA_SHADER_API ArrayType
|
struct NAZARA_SHADER_API ArrayType
|
||||||
{
|
{
|
||||||
ArrayType() = default;
|
ArrayType() = default;
|
||||||
|
|
@ -134,7 +150,7 @@ namespace Nz::ShaderAst
|
||||||
inline bool operator!=(const VectorType& rhs) const;
|
inline bool operator!=(const VectorType& rhs) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
using ExpressionType = std::variant<NoType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
|
using ExpressionType = std::variant<NoType, AliasType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
|
||||||
|
|
||||||
struct ContainedType
|
struct ContainedType
|
||||||
{
|
{
|
||||||
|
|
@ -157,6 +173,7 @@ namespace Nz::ShaderAst
|
||||||
std::vector<StructMember> members;
|
std::vector<StructMember> members;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool IsAliasType(const ExpressionType& type);
|
||||||
inline bool IsArrayType(const ExpressionType& type);
|
inline bool IsArrayType(const ExpressionType& type);
|
||||||
inline bool IsFunctionType(const ExpressionType& type);
|
inline bool IsFunctionType(const ExpressionType& type);
|
||||||
inline bool IsIdentifierType(const ExpressionType& type);
|
inline bool IsIdentifierType(const ExpressionType& type);
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,11 @@
|
||||||
|
|
||||||
namespace Nz::ShaderAst
|
namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
|
inline bool AliasType::operator!=(const AliasType& rhs) const
|
||||||
|
{
|
||||||
|
return !operator==(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
inline bool ArrayType::operator!=(const ArrayType& rhs) const
|
inline bool ArrayType::operator!=(const ArrayType& rhs) const
|
||||||
{
|
{
|
||||||
return !operator==(rhs);
|
return !operator==(rhs);
|
||||||
|
|
@ -130,6 +135,11 @@ namespace Nz::ShaderAst
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline bool IsAliasType(const ExpressionType& type)
|
||||||
|
{
|
||||||
|
return std::holds_alternative<AliasType>(type);
|
||||||
|
}
|
||||||
|
|
||||||
inline bool IsArrayType(const ExpressionType& type)
|
inline bool IsArrayType(const ExpressionType& type)
|
||||||
{
|
{
|
||||||
return std::holds_alternative<ArrayType>(type);
|
return std::holds_alternative<ArrayType>(type);
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,14 @@ namespace Nz::ShaderAst
|
||||||
ExpressionPtr expr;
|
ExpressionPtr expr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct NAZARA_SHADER_API AliasValueExpression : Expression
|
||||||
|
{
|
||||||
|
NodeType GetType() const override;
|
||||||
|
void Visit(AstExpressionVisitor& visitor) override;
|
||||||
|
|
||||||
|
std::size_t aliasId;
|
||||||
|
};
|
||||||
|
|
||||||
struct NAZARA_SHADER_API AssignExpression : Expression
|
struct NAZARA_SHADER_API AssignExpression : Expression
|
||||||
{
|
{
|
||||||
NodeType GetType() const override;
|
NodeType GetType() const override;
|
||||||
|
|
@ -153,7 +161,7 @@ namespace Nz::ShaderAst
|
||||||
NodeType GetType() const override;
|
NodeType GetType() const override;
|
||||||
void Visit(AstExpressionVisitor& visitor) override;
|
void Visit(AstExpressionVisitor& visitor) override;
|
||||||
|
|
||||||
ShaderAst::ConstantValue value;
|
ConstantValue value;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct NAZARA_SHADER_API FunctionExpression : Expression
|
struct NAZARA_SHADER_API FunctionExpression : Expression
|
||||||
|
|
@ -207,7 +215,6 @@ namespace Nz::ShaderAst
|
||||||
ExpressionPtr expression;
|
ExpressionPtr expression;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct NAZARA_SHADER_API VariableExpression : Expression
|
|
||||||
struct NAZARA_SHADER_API VariableValueExpression : Expression
|
struct NAZARA_SHADER_API VariableValueExpression : Expression
|
||||||
{
|
{
|
||||||
NodeType GetType() const override;
|
NodeType GetType() const override;
|
||||||
|
|
@ -455,10 +462,12 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
#include <Nazara/Shader/Ast/AstNodeList.hpp>
|
#include <Nazara/Shader/Ast/AstNodeList.hpp>
|
||||||
|
|
||||||
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr);
|
inline const ExpressionType& GetExpressionType(Expression& expr);
|
||||||
inline ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr);
|
inline ExpressionType& GetExpressionTypeMut(Expression& expr);
|
||||||
inline bool IsExpression(NodeType nodeType);
|
inline bool IsExpression(NodeType nodeType);
|
||||||
inline bool IsStatement(NodeType nodeType);
|
inline bool IsStatement(NodeType nodeType);
|
||||||
|
|
||||||
|
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType);
|
||||||
}
|
}
|
||||||
|
|
||||||
#include <Nazara/Shader/Ast/Nodes.inl>
|
#include <Nazara/Shader/Ast/Nodes.inl>
|
||||||
|
|
|
||||||
|
|
@ -7,18 +7,29 @@
|
||||||
|
|
||||||
namespace Nz::ShaderAst
|
namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr)
|
inline const ExpressionType& GetExpressionType(Expression& expr)
|
||||||
{
|
{
|
||||||
assert(expr.cachedExpressionType);
|
assert(expr.cachedExpressionType);
|
||||||
return expr.cachedExpressionType.value();
|
return expr.cachedExpressionType.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr)
|
inline ExpressionType& GetExpressionTypeMut(Expression& expr)
|
||||||
{
|
{
|
||||||
assert(expr.cachedExpressionType);
|
assert(expr.cachedExpressionType);
|
||||||
return expr.cachedExpressionType.value();
|
return expr.cachedExpressionType.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)
|
||||||
|
{
|
||||||
|
if (IsAliasType(exprType))
|
||||||
|
{
|
||||||
|
const AliasType& alias = std::get<AliasType>(exprType);
|
||||||
|
return alias.targetType->type;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
return exprType;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool IsExpression(NodeType nodeType)
|
inline bool IsExpression(NodeType nodeType)
|
||||||
{
|
{
|
||||||
switch (nodeType)
|
switch (nodeType)
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ namespace Nz::ShaderAst
|
||||||
std::unordered_map<UInt32, ConstantValue> optionValues;
|
std::unordered_map<UInt32, ConstantValue> optionValues;
|
||||||
bool makeVariableNameUnique = false;
|
bool makeVariableNameUnique = false;
|
||||||
bool reduceLoopsToWhile = false;
|
bool reduceLoopsToWhile = false;
|
||||||
|
bool removeAliases = false;
|
||||||
bool removeConstDeclaration = false;
|
bool removeConstDeclaration = false;
|
||||||
bool removeCompoundAssignments = false;
|
bool removeCompoundAssignments = false;
|
||||||
bool removeMatrixCast = false;
|
bool removeMatrixCast = false;
|
||||||
|
|
@ -71,6 +72,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
ExpressionPtr Clone(AccessIdentifierExpression& node) override;
|
ExpressionPtr Clone(AccessIdentifierExpression& node) override;
|
||||||
ExpressionPtr Clone(AccessIndexExpression& node) override;
|
ExpressionPtr Clone(AccessIndexExpression& node) override;
|
||||||
|
ExpressionPtr Clone(AliasValueExpression& node) override;
|
||||||
ExpressionPtr Clone(AssignExpression& node) override;
|
ExpressionPtr Clone(AssignExpression& node) override;
|
||||||
ExpressionPtr Clone(BinaryExpression& node) override;
|
ExpressionPtr Clone(BinaryExpression& node) override;
|
||||||
ExpressionPtr Clone(CallFunctionExpression& node) override;
|
ExpressionPtr Clone(CallFunctionExpression& node) override;
|
||||||
|
|
@ -136,15 +138,16 @@ namespace Nz::ShaderAst
|
||||||
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
|
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
|
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
|
||||||
|
|
||||||
const IdentifierData* ResolveAlias(const IdentifierData* identifier) const;
|
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
|
||||||
void ResolveFunctions();
|
void ResolveFunctions();
|
||||||
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
|
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
|
||||||
|
std::size_t ResolveStruct(const AliasType& aliasType);
|
||||||
std::size_t ResolveStruct(const ExpressionType& exprType);
|
std::size_t ResolveStruct(const ExpressionType& exprType);
|
||||||
std::size_t ResolveStruct(const IdentifierType& identifierType);
|
std::size_t ResolveStruct(const IdentifierType& identifierType);
|
||||||
std::size_t ResolveStruct(const StructType& structType);
|
std::size_t ResolveStruct(const StructType& structType);
|
||||||
std::size_t ResolveStruct(const UniformType& uniformType);
|
std::size_t ResolveStruct(const UniformType& uniformType);
|
||||||
ExpressionType ResolveType(const ExpressionType& exprType);
|
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
|
||||||
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue);
|
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
|
||||||
|
|
||||||
void SanitizeIdentifier(std::string& identifier);
|
void SanitizeIdentifier(std::string& identifier);
|
||||||
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
|
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ namespace Nz
|
||||||
static ShaderAst::ModulePtr Sanitize(const ShaderAst::Module& module, std::unordered_map<UInt32, ShaderAst::ConstantValue> optionValues, std::string* error = nullptr);
|
static ShaderAst::ModulePtr Sanitize(const ShaderAst::Module& module, std::unordered_map<UInt32, ShaderAst::ConstantValue> optionValues, std::string* error = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void Append(const ShaderAst::AliasType& aliasType);
|
||||||
void Append(const ShaderAst::ArrayType& type);
|
void Append(const ShaderAst::ArrayType& type);
|
||||||
void Append(ShaderAst::BuiltinEntry builtin);
|
void Append(ShaderAst::BuiltinEntry builtin);
|
||||||
void Append(const ShaderAst::ExpressionType& type);
|
void Append(const ShaderAst::ExpressionType& type);
|
||||||
|
|
@ -94,6 +95,7 @@ namespace Nz
|
||||||
|
|
||||||
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
|
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
|
||||||
void Visit(ShaderAst::AccessIndexExpression& node) override;
|
void Visit(ShaderAst::AccessIndexExpression& node) override;
|
||||||
|
void Visit(ShaderAst::AliasValueExpression& node) override;
|
||||||
void Visit(ShaderAst::AssignExpression& node) override;
|
void Visit(ShaderAst::AssignExpression& node) override;
|
||||||
void Visit(ShaderAst::BinaryExpression& node) override;
|
void Visit(ShaderAst::BinaryExpression& node) override;
|
||||||
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ namespace Nz
|
||||||
struct UnrollAttribute;
|
struct UnrollAttribute;
|
||||||
struct UuidAttribute;
|
struct UuidAttribute;
|
||||||
|
|
||||||
|
void Append(const ShaderAst::AliasType& type);
|
||||||
void Append(const ShaderAst::ArrayType& type);
|
void Append(const ShaderAst::ArrayType& type);
|
||||||
void Append(const ShaderAst::ExpressionType& type);
|
void Append(const ShaderAst::ExpressionType& type);
|
||||||
void Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type);
|
void Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type);
|
||||||
|
|
@ -92,6 +93,7 @@ namespace Nz
|
||||||
void EnterScope();
|
void EnterScope();
|
||||||
void LeaveScope(bool skipLine = true);
|
void LeaveScope(bool skipLine = true);
|
||||||
|
|
||||||
|
void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
|
||||||
void RegisterConstant(std::size_t constantIndex, std::string constantName);
|
void RegisterConstant(std::size_t constantIndex, std::string constantName);
|
||||||
void RegisterStruct(std::size_t structIndex, std::string structName);
|
void RegisterStruct(std::size_t structIndex, std::string structName);
|
||||||
void RegisterVariable(std::size_t varIndex, std::string varName);
|
void RegisterVariable(std::size_t varIndex, std::string varName);
|
||||||
|
|
@ -102,6 +104,7 @@ namespace Nz
|
||||||
|
|
||||||
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
|
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
|
||||||
void Visit(ShaderAst::AccessIndexExpression& node) override;
|
void Visit(ShaderAst::AccessIndexExpression& node) override;
|
||||||
|
void Visit(ShaderAst::AliasValueExpression& node) override;
|
||||||
void Visit(ShaderAst::AssignExpression& node) override;
|
void Visit(ShaderAst::AssignExpression& node) override;
|
||||||
void Visit(ShaderAst::BinaryExpression& node) override;
|
void Visit(ShaderAst::BinaryExpression& node) override;
|
||||||
void Visit(ShaderAst::CastExpression& node) override;
|
void Visit(ShaderAst::CastExpression& node) override;
|
||||||
|
|
|
||||||
|
|
@ -48,8 +48,6 @@ namespace Nz
|
||||||
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
||||||
void Visit(ShaderAst::CastExpression& node) override;
|
void Visit(ShaderAst::CastExpression& node) override;
|
||||||
void Visit(ShaderAst::ConstantValueExpression& node) override;
|
void Visit(ShaderAst::ConstantValueExpression& node) override;
|
||||||
void Visit(ShaderAst::DeclareAliasStatement& node) override;
|
|
||||||
void Visit(ShaderAst::DeclareConstStatement& node) override;
|
|
||||||
void Visit(ShaderAst::DeclareExternalStatement& node) override;
|
void Visit(ShaderAst::DeclareExternalStatement& node) override;
|
||||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
|
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
|
||||||
void Visit(ShaderAst::DeclareOptionStatement& node) override;
|
void Visit(ShaderAst::DeclareOptionStatement& node) override;
|
||||||
|
|
@ -64,7 +62,7 @@ namespace Nz
|
||||||
void Visit(ShaderAst::ScopedStatement& node) override;
|
void Visit(ShaderAst::ScopedStatement& node) override;
|
||||||
void Visit(ShaderAst::SwizzleExpression& node) override;
|
void Visit(ShaderAst::SwizzleExpression& node) override;
|
||||||
void Visit(ShaderAst::UnaryExpression& node) override;
|
void Visit(ShaderAst::UnaryExpression& node) override;
|
||||||
void Visit(ShaderAst::VariableExpression& node) override;
|
void Visit(ShaderAst::VariableValueExpression& node) override;
|
||||||
void Visit(ShaderAst::WhileStatement& node) override;
|
void Visit(ShaderAst::WhileStatement& node) override;
|
||||||
|
|
||||||
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;
|
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;
|
||||||
|
|
|
||||||
|
|
@ -178,6 +178,7 @@ namespace Nz
|
||||||
TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const;
|
TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const;
|
||||||
TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
|
TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
|
||||||
TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const;
|
TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const;
|
||||||
|
TypePtr BuildType(const ShaderAst::AliasType& type) const;
|
||||||
TypePtr BuildType(const ShaderAst::ArrayType& type) const;
|
TypePtr BuildType(const ShaderAst::ArrayType& type) const;
|
||||||
TypePtr BuildType(const ShaderAst::ExpressionType& type) const;
|
TypePtr BuildType(const ShaderAst::ExpressionType& type) const;
|
||||||
TypePtr BuildType(const ShaderAst::IdentifierType& type) const;
|
TypePtr BuildType(const ShaderAst::IdentifierType& type) const;
|
||||||
|
|
|
||||||
|
|
@ -300,6 +300,16 @@ namespace Nz::ShaderAst
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ExpressionPtr AstCloner::Clone(AliasValueExpression& node)
|
||||||
|
{
|
||||||
|
auto clone = std::make_unique<AliasValueExpression>();
|
||||||
|
clone->aliasId = node.aliasId;
|
||||||
|
|
||||||
|
clone->cachedExpressionType = node.cachedExpressionType;
|
||||||
|
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
ExpressionPtr AstCloner::Clone(AssignExpression& node)
|
ExpressionPtr AstCloner::Clone(AssignExpression& node)
|
||||||
{
|
{
|
||||||
auto clone = std::make_unique<AssignExpression>();
|
auto clone = std::make_unique<AssignExpression>();
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,11 @@ namespace Nz::ShaderAst
|
||||||
index->Visit(*this);
|
index->Visit(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AstRecursiveVisitor::Visit(AliasValueExpression& /*node*/)
|
||||||
|
{
|
||||||
|
/* nothing to do */
|
||||||
|
}
|
||||||
|
|
||||||
void AstRecursiveVisitor::Visit(AssignExpression& node)
|
void AstRecursiveVisitor::Visit(AssignExpression& node)
|
||||||
{
|
{
|
||||||
node.left->Visit(*this);
|
node.left->Visit(*this);
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,11 @@ namespace Nz::ShaderAst
|
||||||
Node(identifier);
|
Node(identifier);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AstSerializerBase::Serialize(AliasValueExpression& node)
|
||||||
|
{
|
||||||
|
SizeT(node.aliasId);
|
||||||
|
}
|
||||||
|
|
||||||
void AstSerializerBase::Serialize(AssignExpression& node)
|
void AstSerializerBase::Serialize(AssignExpression& node)
|
||||||
{
|
{
|
||||||
Enum(node.op);
|
Enum(node.op);
|
||||||
|
|
@ -485,6 +490,12 @@ namespace Nz::ShaderAst
|
||||||
Type(arg.objectType->type);
|
Type(arg.objectType->type);
|
||||||
SizeT(arg.methodIndex);
|
SizeT(arg.methodIndex);
|
||||||
}
|
}
|
||||||
|
else if constexpr (std::is_same_v<T, ShaderAst::AliasType>)
|
||||||
|
{
|
||||||
|
m_stream << UInt8(13);
|
||||||
|
SizeT(arg.aliasIndex);
|
||||||
|
Type(arg.targetType->type);
|
||||||
|
}
|
||||||
else
|
else
|
||||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||||
}, type);
|
}, type);
|
||||||
|
|
@ -800,6 +811,22 @@ namespace Nz::ShaderAst
|
||||||
type = std::move(methodType);
|
type = std::move(methodType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case 13: //< AliasType
|
||||||
|
{
|
||||||
|
std::size_t aliasIndex;
|
||||||
|
ExpressionType containedType;
|
||||||
|
SizeT(aliasIndex);
|
||||||
|
Type(containedType);
|
||||||
|
|
||||||
|
AliasType aliasType;
|
||||||
|
aliasType.aliasIndex = aliasIndex;
|
||||||
|
aliasType.targetType = std::make_unique<ContainedType>();
|
||||||
|
aliasType.targetType->type = std::move(containedType);
|
||||||
|
|
||||||
|
type = std::move(aliasType);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,11 @@ namespace Nz::ShaderAst
|
||||||
node.expr->Visit(*this);
|
node.expr->Visit(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ShaderAstValueCategory::Visit(AliasValueExpression& /*node*/)
|
||||||
|
{
|
||||||
|
m_expressionCategory = ExpressionCategory::LValue;
|
||||||
|
}
|
||||||
|
|
||||||
void ShaderAstValueCategory::Visit(AssignExpression& /*node*/)
|
void ShaderAstValueCategory::Visit(AssignExpression& /*node*/)
|
||||||
{
|
{
|
||||||
m_expressionCategory = ExpressionCategory::RValue;
|
m_expressionCategory = ExpressionCategory::RValue;
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,37 @@
|
||||||
|
|
||||||
namespace Nz::ShaderAst
|
namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
|
AliasType::AliasType(const AliasType& alias) :
|
||||||
|
aliasIndex(alias.aliasIndex)
|
||||||
|
{
|
||||||
|
assert(alias.targetType);
|
||||||
|
targetType = std::make_unique<ContainedType>(*alias.targetType);
|
||||||
|
}
|
||||||
|
|
||||||
|
AliasType& AliasType::operator=(const AliasType& alias)
|
||||||
|
{
|
||||||
|
aliasIndex = alias.aliasIndex;
|
||||||
|
|
||||||
|
assert(alias.targetType);
|
||||||
|
targetType = std::make_unique<ContainedType>(*alias.targetType);
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AliasType::operator==(const AliasType& rhs) const
|
||||||
|
{
|
||||||
|
assert(targetType);
|
||||||
|
assert(rhs.targetType);
|
||||||
|
|
||||||
|
if (aliasIndex != rhs.aliasIndex)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (targetType->type != rhs.targetType->type)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
ArrayType::ArrayType(const ArrayType& array) :
|
ArrayType::ArrayType(const ArrayType& array) :
|
||||||
length(array.length)
|
length(array.length)
|
||||||
{
|
{
|
||||||
|
|
@ -31,10 +62,10 @@ namespace Nz::ShaderAst
|
||||||
assert(containedType);
|
assert(containedType);
|
||||||
assert(rhs.containedType);
|
assert(rhs.containedType);
|
||||||
|
|
||||||
if (containedType->type != rhs.containedType->type)
|
if (length != rhs.length)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (length != rhs.length)
|
if (containedType->type != rhs.containedType->type)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
|
||||||
|
|
@ -282,7 +282,7 @@ namespace Nz::ShaderAst
|
||||||
if (identifier.empty())
|
if (identifier.empty())
|
||||||
throw AstError{ "empty identifier" };
|
throw AstError{ "empty identifier" };
|
||||||
|
|
||||||
const ExpressionType& exprType = GetExpressionType(*indexedExpr);
|
const ExpressionType& exprType = ResolveAlias(GetExpressionType(*indexedExpr));
|
||||||
// TODO: Add proper support for methods
|
// TODO: Add proper support for methods
|
||||||
if (IsSamplerType(exprType))
|
if (IsSamplerType(exprType))
|
||||||
{
|
{
|
||||||
|
|
@ -429,6 +429,25 @@ namespace Nz::ShaderAst
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ExpressionPtr SanitizeVisitor::Clone(AliasValueExpression& node)
|
||||||
|
{
|
||||||
|
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId));
|
||||||
|
ExpressionPtr targetExpr = HandleIdentifier(targetIdentifier);
|
||||||
|
|
||||||
|
if (m_context->options.removeAliases)
|
||||||
|
return targetExpr;
|
||||||
|
|
||||||
|
AliasType aliasType;
|
||||||
|
aliasType.aliasIndex = node.aliasId;
|
||||||
|
aliasType.targetType = std::make_unique<ContainedType>();
|
||||||
|
aliasType.targetType->type = *targetExpr->cachedExpressionType;
|
||||||
|
|
||||||
|
auto clone = static_unique_pointer_cast<AliasValueExpression>(AstCloner::Clone(node));
|
||||||
|
clone->cachedExpressionType = std::move(aliasType);
|
||||||
|
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
|
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
|
||||||
{
|
{
|
||||||
MandatoryExpr(node.left);
|
MandatoryExpr(node.left);
|
||||||
|
|
@ -543,7 +562,7 @@ namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
|
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
|
const ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
|
||||||
bool isMatrixCast = IsMatrixType(frontExprType);
|
bool isMatrixCast = IsMatrixType(frontExprType);
|
||||||
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
|
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
|
||||||
{
|
{
|
||||||
|
|
@ -785,6 +804,9 @@ namespace Nz::ShaderAst
|
||||||
auto clone = static_unique_pointer_cast<DeclareAliasStatement>(AstCloner::Clone(node));
|
auto clone = static_unique_pointer_cast<DeclareAliasStatement>(AstCloner::Clone(node));
|
||||||
Validate(*clone);
|
Validate(*clone);
|
||||||
|
|
||||||
|
if (m_context->options.removeAliases)
|
||||||
|
return ShaderBuilder::NoOp();
|
||||||
|
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -803,7 +825,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
ExpressionType expressionType = ResolveType(GetExpressionType(value));
|
ExpressionType expressionType = ResolveType(GetExpressionType(value));
|
||||||
|
|
||||||
if (clone->type.HasValue() && ResolveType(clone->type) != expressionType)
|
if (clone->type.HasValue() && ResolveType(clone->type, true) != ResolveAlias(expressionType))
|
||||||
throw AstError{ "constant expression doesn't match type" };
|
throw AstError{ "constant expression doesn't match type" };
|
||||||
|
|
||||||
clone->type = expressionType;
|
clone->type = expressionType;
|
||||||
|
|
@ -852,12 +874,13 @@ namespace Nz::ShaderAst
|
||||||
m_context->declaredExternalVar.insert(extVar.name);
|
m_context->declaredExternalVar.insert(extVar.name);
|
||||||
|
|
||||||
ExpressionType resolvedType = ResolveType(extVar.type);
|
ExpressionType resolvedType = ResolveType(extVar.type);
|
||||||
|
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
||||||
|
|
||||||
ExpressionType varType;
|
ExpressionType varType;
|
||||||
if (IsUniformType(resolvedType))
|
if (IsUniformType(targetType))
|
||||||
varType = std::get<UniformType>(resolvedType).containedType;
|
varType = std::get<UniformType>(targetType).containedType;
|
||||||
else if (IsSamplerType(resolvedType))
|
else if (IsSamplerType(targetType))
|
||||||
varType = resolvedType;
|
varType = targetType;
|
||||||
else
|
else
|
||||||
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
||||||
|
|
||||||
|
|
@ -954,8 +977,9 @@ namespace Nz::ShaderAst
|
||||||
throw AstError{ "empty option name" };
|
throw AstError{ "empty option name" };
|
||||||
|
|
||||||
ExpressionType resolvedType = ResolveType(clone->optType);
|
ExpressionType resolvedType = ResolveType(clone->optType);
|
||||||
|
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
||||||
|
|
||||||
if (clone->defaultValue && resolvedType != GetExpressionType(*clone->defaultValue))
|
if (clone->defaultValue && targetType != GetExpressionType(*clone->defaultValue))
|
||||||
throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" };
|
throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" };
|
||||||
|
|
||||||
clone->optType = std::move(resolvedType);
|
clone->optType = std::move(resolvedType);
|
||||||
|
|
@ -1009,11 +1033,13 @@ namespace Nz::ShaderAst
|
||||||
ExpressionType resolvedType = ResolveType(member.type);
|
ExpressionType resolvedType = ResolveType(member.type);
|
||||||
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
|
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
|
||||||
{
|
{
|
||||||
if (IsPrimitiveType(resolvedType) && std::get<PrimitiveType>(resolvedType) == PrimitiveType::Boolean)
|
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
||||||
|
|
||||||
|
if (IsPrimitiveType(targetType) && std::get<PrimitiveType>(targetType) == PrimitiveType::Boolean)
|
||||||
throw AstError{ "boolean type is not allowed in std140 layout" };
|
throw AstError{ "boolean type is not allowed in std140 layout" };
|
||||||
else if (IsStructType(resolvedType))
|
else if (IsStructType(targetType))
|
||||||
{
|
{
|
||||||
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex;
|
std::size_t structIndex = std::get<StructType>(targetType).structIndex;
|
||||||
const StructDescription* desc = m_context->structs.Retrieve(structIndex);
|
const StructDescription* desc = m_context->structs.Retrieve(structIndex);
|
||||||
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
|
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
|
||||||
throw AstError{ "inner struct layout mismatch" };
|
throw AstError{ "inner struct layout mismatch" };
|
||||||
|
|
@ -1461,7 +1487,7 @@ namespace Nz::ShaderAst
|
||||||
AstExportVisitor exportVisitor;
|
AstExportVisitor exportVisitor;
|
||||||
exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
|
exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
|
||||||
|
|
||||||
if (aliasStatements.empty())
|
if (aliasStatements.empty() || m_context->options.removeAliases)
|
||||||
return ShaderBuilder::NoOp();
|
return ShaderBuilder::NoOp();
|
||||||
|
|
||||||
// Register module and aliases
|
// Register module and aliases
|
||||||
|
|
@ -1546,7 +1572,7 @@ namespace Nz::ShaderAst
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return ResolveAlias(&it->data);
|
return &it->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename F>
|
template<typename F>
|
||||||
|
|
@ -1556,7 +1582,7 @@ namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
if (identifier.name == identifierName)
|
if (identifier.name == identifierName)
|
||||||
{
|
{
|
||||||
if (functor(*ResolveAlias(&identifier.data)))
|
if (functor(identifier.data))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1570,7 +1596,7 @@ namespace Nz::ShaderAst
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return ResolveAlias(&it->data);
|
return &it->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const
|
TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const
|
||||||
|
|
@ -1626,6 +1652,14 @@ namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
switch (identifierData->category)
|
switch (identifierData->category)
|
||||||
{
|
{
|
||||||
|
case IdentifierCategory::Alias:
|
||||||
|
{
|
||||||
|
AliasValueExpression aliasValue;
|
||||||
|
aliasValue.aliasId = identifierData->index;
|
||||||
|
|
||||||
|
return Clone(aliasValue);
|
||||||
|
}
|
||||||
|
|
||||||
case IdentifierCategory::Constant:
|
case IdentifierCategory::Constant:
|
||||||
{
|
{
|
||||||
// Replace IdentifierExpression by Constant(Value)Expression
|
// Replace IdentifierExpression by Constant(Value)Expression
|
||||||
|
|
@ -2124,7 +2158,7 @@ namespace Nz::ShaderAst
|
||||||
return varIndex;
|
return varIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto SanitizeVisitor::ResolveAlias(const IdentifierData* identifier) const -> const IdentifierData*
|
auto SanitizeVisitor::ResolveAliasIdentifier(const IdentifierData* identifier) const -> const IdentifierData*
|
||||||
{
|
{
|
||||||
while (identifier->category == IdentifierCategory::Alias)
|
while (identifier->category == IdentifierCategory::Alias)
|
||||||
identifier = &m_context->aliases.Retrieve(identifier->index);
|
identifier = &m_context->aliases.Retrieve(identifier->index);
|
||||||
|
|
@ -2181,7 +2215,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
for (const auto& [funcIndex, funcData] : m_context->functions.values)
|
for (const auto& [funcIndex, funcData] : m_context->functions.values)
|
||||||
{
|
{
|
||||||
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
|
if (funcData.flags.Test(FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
|
||||||
throw AstError{ "discard can only be used in the fragment stage" };
|
throw AstError{ "discard can only be used in the fragment stage" };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -2203,13 +2237,18 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::size_t SanitizeVisitor::ResolveStruct(const AliasType& aliasType)
|
||||||
|
{
|
||||||
|
return ResolveStruct(aliasType.targetType->type);
|
||||||
|
}
|
||||||
|
|
||||||
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
|
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
|
||||||
{
|
{
|
||||||
return std::visit([&](auto&& arg) -> std::size_t
|
return std::visit([&](auto&& arg) -> std::size_t
|
||||||
{
|
{
|
||||||
using T = std::decay_t<decltype(arg)>;
|
using T = std::decay_t<decltype(arg)>;
|
||||||
|
|
||||||
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType>)
|
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType> || std::is_same_v<T, AliasType>)
|
||||||
return ResolveStruct(arg);
|
return ResolveStruct(arg);
|
||||||
else if constexpr (std::is_same_v<T, NoType> ||
|
else if constexpr (std::is_same_v<T, NoType> ||
|
||||||
std::is_same_v<T, ArrayType> ||
|
std::is_same_v<T, ArrayType> ||
|
||||||
|
|
@ -2251,10 +2290,15 @@ namespace Nz::ShaderAst
|
||||||
return uniformType.containedType.structIndex;
|
return uniformType.containedType.structIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType)
|
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias)
|
||||||
{
|
{
|
||||||
if (!IsTypeExpression(exprType))
|
if (!IsTypeExpression(exprType))
|
||||||
return exprType;
|
{
|
||||||
|
if (resolveAlias || m_context->options.removeAliases)
|
||||||
|
return ResolveAlias(exprType);
|
||||||
|
else
|
||||||
|
return exprType;
|
||||||
|
}
|
||||||
|
|
||||||
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
|
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
|
||||||
|
|
||||||
|
|
@ -2265,13 +2309,13 @@ namespace Nz::ShaderAst
|
||||||
return std::get<ExpressionType>(type);
|
return std::get<ExpressionType>(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue)
|
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias)
|
||||||
{
|
{
|
||||||
if (!exprTypeValue.HasValue())
|
if (!exprTypeValue.HasValue())
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
if (exprTypeValue.IsResultingValue())
|
if (exprTypeValue.IsResultingValue())
|
||||||
return ResolveType(exprTypeValue.GetResultingValue());
|
return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias);
|
||||||
|
|
||||||
assert(exprTypeValue.IsExpression());
|
assert(exprTypeValue.IsExpression());
|
||||||
ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression());
|
ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression());
|
||||||
|
|
@ -2281,7 +2325,7 @@ namespace Nz::ShaderAst
|
||||||
//if (!IsTypeType(exprType))
|
//if (!IsTypeType(exprType))
|
||||||
// throw AstError{ "type expected" };
|
// throw AstError{ "type expected" };
|
||||||
|
|
||||||
return ResolveType(exprType);
|
return ResolveType(exprType, resolveAlias);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SanitizeVisitor::SanitizeIdentifier(std::string& identifier)
|
void SanitizeVisitor::SanitizeIdentifier(std::string& identifier)
|
||||||
|
|
@ -2334,7 +2378,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const
|
void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const
|
||||||
{
|
{
|
||||||
if (left != right)
|
if (ResolveAlias(left) != ResolveAlias(right))
|
||||||
throw AstError{ "Left expression type must match right expression type" };
|
throw AstError{ "Left expression type must match right expression type" };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2359,6 +2403,11 @@ namespace Nz::ShaderAst
|
||||||
std::size_t structIndex = ResolveStruct(exprType);
|
std::size_t structIndex = ResolveStruct(exprType);
|
||||||
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
|
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
|
||||||
}
|
}
|
||||||
|
else if (IsAliasType(exprType))
|
||||||
|
{
|
||||||
|
const AliasType& alias = std::get<AliasType>(exprType);
|
||||||
|
node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex);
|
||||||
|
}
|
||||||
else
|
else
|
||||||
throw AstError{ "for now, only structs can be aliased" };
|
throw AstError{ "for now, only structs can be aliased" };
|
||||||
}
|
}
|
||||||
|
|
@ -2400,7 +2449,7 @@ namespace Nz::ShaderAst
|
||||||
case TypeParameterCategory::PrimitiveType:
|
case TypeParameterCategory::PrimitiveType:
|
||||||
case TypeParameterCategory::StructType:
|
case TypeParameterCategory::StructType:
|
||||||
{
|
{
|
||||||
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr));
|
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr), true);
|
||||||
|
|
||||||
switch (partialType.parameters[i])
|
switch (partialType.parameters[i])
|
||||||
{
|
{
|
||||||
|
|
@ -2440,7 +2489,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
for (auto& index : node.indices)
|
for (auto& index : node.indices)
|
||||||
{
|
{
|
||||||
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
|
const ExpressionType& indexType = GetExpressionType(*index);
|
||||||
if (!IsPrimitiveType(indexType))
|
if (!IsPrimitiveType(indexType))
|
||||||
throw AstError{ "AccessIndex expects integer indices" };
|
throw AstError{ "AccessIndex expects integer indices" };
|
||||||
|
|
||||||
|
|
@ -2459,7 +2508,7 @@ namespace Nz::ShaderAst
|
||||||
}
|
}
|
||||||
else if (IsStructType(exprType))
|
else if (IsStructType(exprType))
|
||||||
{
|
{
|
||||||
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
|
const ExpressionType& indexType = GetExpressionType(*indexExpr);
|
||||||
if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
|
if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
|
||||||
throw AstError{ "struct can only be accessed with constant i32 indices" };
|
throw AstError{ "struct can only be accessed with constant i32 indices" };
|
||||||
|
|
||||||
|
|
@ -2470,7 +2519,7 @@ namespace Nz::ShaderAst
|
||||||
std::size_t structIndex = ResolveStruct(exprType);
|
std::size_t structIndex = ResolveStruct(exprType);
|
||||||
const StructDescription* s = m_context->structs.Retrieve(structIndex);
|
const StructDescription* s = m_context->structs.Retrieve(structIndex);
|
||||||
|
|
||||||
exprType = ResolveType(s->members[index].type);
|
exprType = ResolveType(s->members[index].type, true);
|
||||||
}
|
}
|
||||||
else if (IsMatrixType(exprType))
|
else if (IsMatrixType(exprType))
|
||||||
{
|
{
|
||||||
|
|
@ -2538,7 +2587,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void SanitizeVisitor::Validate(CallFunctionExpression& node)
|
void SanitizeVisitor::Validate(CallFunctionExpression& node)
|
||||||
{
|
{
|
||||||
const ShaderAst::ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
|
const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
|
||||||
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
||||||
|
|
||||||
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
|
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
|
||||||
|
|
@ -2564,14 +2613,15 @@ namespace Nz::ShaderAst
|
||||||
void SanitizeVisitor::Validate(CastExpression& node)
|
void SanitizeVisitor::Validate(CastExpression& node)
|
||||||
{
|
{
|
||||||
ExpressionType resolvedType = ResolveType(node.targetType);
|
ExpressionType resolvedType = ResolveType(node.targetType);
|
||||||
|
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
||||||
|
|
||||||
const auto& firstExprPtr = node.expressions.front();
|
const auto& firstExprPtr = node.expressions.front();
|
||||||
if (!firstExprPtr)
|
if (!firstExprPtr)
|
||||||
throw AstError{ "expected at least one expression" };
|
throw AstError{ "expected at least one expression" };
|
||||||
|
|
||||||
if (IsMatrixType(resolvedType))
|
if (IsMatrixType(targetType))
|
||||||
{
|
{
|
||||||
const MatrixType& targetMatrixType = std::get<MatrixType>(resolvedType);
|
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
|
||||||
|
|
||||||
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
|
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
|
||||||
if (IsMatrixType(firstExprType))
|
if (IsMatrixType(firstExprType))
|
||||||
|
|
@ -2614,7 +2664,7 @@ namespace Nz::ShaderAst
|
||||||
};
|
};
|
||||||
|
|
||||||
std::size_t componentCount = 0;
|
std::size_t componentCount = 0;
|
||||||
std::size_t requiredComponents = GetComponentCount(resolvedType);
|
std::size_t requiredComponents = GetComponentCount(targetType);
|
||||||
|
|
||||||
for (auto& exprPtr : node.expressions)
|
for (auto& exprPtr : node.expressions)
|
||||||
{
|
{
|
||||||
|
|
@ -2885,11 +2935,11 @@ namespace Nz::ShaderAst
|
||||||
case UnaryType::Minus:
|
case UnaryType::Minus:
|
||||||
case UnaryType::Plus:
|
case UnaryType::Plus:
|
||||||
{
|
{
|
||||||
ShaderAst::PrimitiveType basicType;
|
PrimitiveType basicType;
|
||||||
if (IsPrimitiveType(exprType))
|
if (IsPrimitiveType(exprType))
|
||||||
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
|
basicType = std::get<PrimitiveType>(exprType);
|
||||||
else if (IsVectorType(exprType))
|
else if (IsVectorType(exprType))
|
||||||
basicType = std::get<ShaderAst::VectorType>(exprType).type;
|
basicType = std::get<VectorType>(exprType).type;
|
||||||
else
|
else
|
||||||
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };
|
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -232,6 +232,7 @@ namespace Nz
|
||||||
options.optionValues = std::move(optionValues);
|
options.optionValues = std::move(optionValues);
|
||||||
options.makeVariableNameUnique = true;
|
options.makeVariableNameUnique = true;
|
||||||
options.reduceLoopsToWhile = true;
|
options.reduceLoopsToWhile = true;
|
||||||
|
options.removeAliases = true;
|
||||||
options.removeCompoundAssignments = false;
|
options.removeCompoundAssignments = false;
|
||||||
options.removeConstDeclaration = true;
|
options.removeConstDeclaration = true;
|
||||||
options.removeOptionDeclaration = true;
|
options.removeOptionDeclaration = true;
|
||||||
|
|
@ -246,6 +247,11 @@ namespace Nz
|
||||||
return ShaderAst::Sanitize(module, options, error);
|
return ShaderAst::Sanitize(module, options, error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GlslWriter::Append(const ShaderAst::AliasType& /*aliasType*/)
|
||||||
|
{
|
||||||
|
throw std::runtime_error("unexpected AliasType");
|
||||||
|
}
|
||||||
|
|
||||||
void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/)
|
void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("unexpected ArrayType");
|
throw std::runtime_error("unexpected ArrayType");
|
||||||
|
|
@ -689,11 +695,16 @@ namespace Nz
|
||||||
builtin.identifier
|
builtin.identifier
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
else if (member.locationIndex.HasValue())
|
else
|
||||||
{
|
{
|
||||||
Append("layout(location = ");
|
if (member.locationIndex.HasValue())
|
||||||
Append(member.locationIndex.GetResultingValue());
|
{
|
||||||
Append(") ", keyword, " ");
|
Append("layout(location = ");
|
||||||
|
Append(member.locationIndex.GetResultingValue());
|
||||||
|
Append(") ");
|
||||||
|
}
|
||||||
|
|
||||||
|
Append(keyword, " ");
|
||||||
AppendVariableDeclaration(member.type.GetResultingValue(), targetPrefix + member.name);
|
AppendVariableDeclaration(member.type.GetResultingValue(), targetPrefix + member.name);
|
||||||
AppendLine(";");
|
AppendLine(";");
|
||||||
|
|
||||||
|
|
@ -805,6 +816,12 @@ namespace Nz
|
||||||
Append("]");
|
Append("]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GlslWriter::Visit(ShaderAst::AliasValueExpression& /*node*/)
|
||||||
|
{
|
||||||
|
// all aliases should have been handled by sanitizer
|
||||||
|
throw std::runtime_error("unexpected alias value, is shader sanitized?");
|
||||||
|
}
|
||||||
|
|
||||||
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
|
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
|
||||||
{
|
{
|
||||||
node.left->Visit(*this);
|
node.left->Visit(*this);
|
||||||
|
|
@ -1038,12 +1055,14 @@ namespace Nz
|
||||||
|
|
||||||
void GlslWriter::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
|
void GlslWriter::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
|
||||||
{
|
{
|
||||||
/* nothing to do */
|
// all aliases should have been handled by sanitizer
|
||||||
|
throw std::runtime_error("unexpected alias declaration, is shader sanitized?");
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/)
|
void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/)
|
||||||
{
|
{
|
||||||
/* nothing to do */
|
// all consts should have been handled by sanitizer
|
||||||
|
throw std::runtime_error("unexpected const declaration, is shader sanitized?");
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||||
|
|
@ -1184,7 +1203,8 @@ namespace Nz
|
||||||
|
|
||||||
void GlslWriter::Visit(ShaderAst::DeclareOptionStatement& /*node*/)
|
void GlslWriter::Visit(ShaderAst::DeclareOptionStatement& /*node*/)
|
||||||
{
|
{
|
||||||
/* nothing to do */
|
// all options should have been handled by sanitizer
|
||||||
|
throw std::runtime_error("unexpected option declaration, is shader sanitized?");
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
|
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
|
||||||
|
|
@ -1247,7 +1267,7 @@ namespace Nz
|
||||||
Append(";");
|
Append(";");
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlslWriter::Visit(ShaderAst::ImportStatement& node)
|
void GlslWriter::Visit(ShaderAst::ImportStatement& /*node*/)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("unexpected import statement, is the shader sanitized properly?");
|
throw std::runtime_error("unexpected import statement, is the shader sanitized properly?");
|
||||||
}
|
}
|
||||||
|
|
@ -1280,8 +1300,8 @@ namespace Nz
|
||||||
const auto& structData = Retrieve(m_currentState->structs, structIndex);
|
const auto& structData = Retrieve(m_currentState->structs, structIndex);
|
||||||
|
|
||||||
std::string outputStructVarName;
|
std::string outputStructVarName;
|
||||||
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression)
|
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableValueExpression)
|
||||||
outputStructVarName = Retrieve(m_currentState->variableNames, static_cast<ShaderAst::VariableExpression&>(*node.returnExpr).variableId);
|
outputStructVarName = Retrieve(m_currentState->variableNames, static_cast<ShaderAst::VariableValueExpression&>(*node.returnExpr).variableId);
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
AppendLine();
|
AppendLine();
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,7 @@ namespace Nz
|
||||||
ShaderAst::Module* module;
|
ShaderAst::Module* module;
|
||||||
std::size_t currentModuleIndex;
|
std::size_t currentModuleIndex;
|
||||||
std::stringstream stream;
|
std::stringstream stream;
|
||||||
|
std::unordered_map<std::size_t, Identifier> aliases;
|
||||||
std::unordered_map<std::size_t, Identifier> constants;
|
std::unordered_map<std::size_t, Identifier> constants;
|
||||||
std::unordered_map<std::size_t, Identifier> structs;
|
std::unordered_map<std::size_t, Identifier> structs;
|
||||||
std::unordered_map<std::size_t, Identifier> variables;
|
std::unordered_map<std::size_t, Identifier> variables;
|
||||||
|
|
@ -164,6 +165,11 @@ namespace Nz
|
||||||
m_environment = std::move(environment);
|
m_environment = std::move(environment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LangWriter::Append(const ShaderAst::AliasType& type)
|
||||||
|
{
|
||||||
|
AppendIdentifier(m_currentState->aliases, type.aliasIndex);
|
||||||
|
}
|
||||||
|
|
||||||
void LangWriter::Append(const ShaderAst::ArrayType& type)
|
void LangWriter::Append(const ShaderAst::ArrayType& type)
|
||||||
{
|
{
|
||||||
Append("array[", type.containedType->type, ", ", type.length, "]");
|
Append("array[", type.containedType->type, ", ", type.length, "]");
|
||||||
|
|
@ -655,6 +661,16 @@ namespace Nz
|
||||||
Append("}");
|
Append("}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LangWriter::RegisterAlias(std::size_t aliasIndex, std::string aliasName)
|
||||||
|
{
|
||||||
|
State::Identifier identifier;
|
||||||
|
identifier.moduleIndex = m_currentState->currentModuleIndex;
|
||||||
|
identifier.name = std::move(aliasName);
|
||||||
|
|
||||||
|
assert(m_currentState->aliases.find(aliasIndex) == m_currentState->aliases.end());
|
||||||
|
m_currentState->aliases.emplace(aliasIndex, std::move(identifier));
|
||||||
|
}
|
||||||
|
|
||||||
void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName)
|
void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName)
|
||||||
{
|
{
|
||||||
State::Identifier identifier;
|
State::Identifier identifier;
|
||||||
|
|
@ -714,7 +730,7 @@ namespace Nz
|
||||||
{
|
{
|
||||||
Visit(node.expr, true);
|
Visit(node.expr, true);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
|
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
||||||
assert(IsStructType(exprType));
|
assert(IsStructType(exprType));
|
||||||
|
|
||||||
for (const std::string& identifier : node.identifiers)
|
for (const std::string& identifier : node.identifiers)
|
||||||
|
|
@ -725,7 +741,7 @@ namespace Nz
|
||||||
{
|
{
|
||||||
Visit(node.expr, true);
|
Visit(node.expr, true);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
|
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
||||||
assert(!IsStructType(exprType));
|
assert(!IsStructType(exprType));
|
||||||
|
|
||||||
// Array access
|
// Array access
|
||||||
|
|
@ -744,6 +760,11 @@ namespace Nz
|
||||||
Append("]");
|
Append("]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LangWriter::Visit(ShaderAst::AliasValueExpression& node)
|
||||||
|
{
|
||||||
|
AppendIdentifier(m_currentState->aliases, node.aliasId);
|
||||||
|
}
|
||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::AssignExpression& node)
|
void LangWriter::Visit(ShaderAst::AssignExpression& node)
|
||||||
{
|
{
|
||||||
node.left->Visit(*this);
|
node.left->Visit(*this);
|
||||||
|
|
@ -840,7 +861,13 @@ namespace Nz
|
||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::ConditionalExpression& node)
|
void LangWriter::Visit(ShaderAst::ConditionalExpression& node)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("fixme");
|
Append("const_select(");
|
||||||
|
node.condition->Visit(*this);
|
||||||
|
Append(", ");
|
||||||
|
node.truePath->Visit(*this);
|
||||||
|
Append(", ");
|
||||||
|
node.falsePath->Visit(*this);
|
||||||
|
Append(")");
|
||||||
}
|
}
|
||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::ConditionalStatement& node)
|
void LangWriter::Visit(ShaderAst::ConditionalStatement& node)
|
||||||
|
|
@ -853,9 +880,8 @@ namespace Nz
|
||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
|
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
|
||||||
{
|
{
|
||||||
//throw std::runtime_error("TODO"); //< missing registering
|
|
||||||
|
|
||||||
assert(node.aliasIndex);
|
assert(node.aliasIndex);
|
||||||
|
RegisterAlias(*node.aliasIndex, node.name);
|
||||||
|
|
||||||
Append("alias ", node.name, " = ");
|
Append("alias ", node.name, " = ");
|
||||||
assert(node.expression);
|
assert(node.expression);
|
||||||
|
|
|
||||||
|
|
@ -598,16 +598,6 @@ namespace Nz
|
||||||
}, node.value);
|
}, node.value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
|
|
||||||
{
|
|
||||||
/* nothing to do */
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/)
|
|
||||||
{
|
|
||||||
/* nothing to do */
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
|
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||||
{
|
{
|
||||||
for (auto&& extVar : node.externalVars)
|
for (auto&& extVar : node.externalVars)
|
||||||
|
|
@ -729,11 +719,6 @@ namespace Nz
|
||||||
PopResultId();
|
PopResultId();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpirvAstVisitor::Visit(ShaderAst::ImportStatement& node)
|
|
||||||
{
|
|
||||||
/* nothing to do */
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
|
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
|
||||||
{
|
{
|
||||||
switch (node.intrinsic)
|
switch (node.intrinsic)
|
||||||
|
|
|
||||||
|
|
@ -655,6 +655,28 @@ namespace Nz
|
||||||
return typePtr;
|
return typePtr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
|
||||||
|
{
|
||||||
|
bool wasInblockStruct = m_internal->isInBlockStruct;
|
||||||
|
if (storageClass == SpirvStorageClass::Uniform)
|
||||||
|
m_internal->isInBlockStruct = true;
|
||||||
|
|
||||||
|
auto typePtr = std::make_shared<Type>(Pointer{
|
||||||
|
BuildType(type),
|
||||||
|
storageClass
|
||||||
|
});
|
||||||
|
|
||||||
|
m_internal->isInBlockStruct = wasInblockStruct;
|
||||||
|
|
||||||
|
return typePtr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto SpirvConstantCache::BuildType(const ShaderAst::AliasType& /*type*/) const -> TypePtr
|
||||||
|
{
|
||||||
|
// No AliasType is expected (as they should have been resolved by now)
|
||||||
|
throw std::runtime_error("unexpected alias");
|
||||||
|
}
|
||||||
|
|
||||||
auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr
|
auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr
|
||||||
{
|
{
|
||||||
const auto& containedType = type.containedType->type;
|
const auto& containedType = type.containedType->type;
|
||||||
|
|
@ -678,22 +700,6 @@ namespace Nz
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
|
|
||||||
{
|
|
||||||
bool wasInblockStruct = m_internal->isInBlockStruct;
|
|
||||||
if (storageClass == SpirvStorageClass::Uniform)
|
|
||||||
m_internal->isInBlockStruct = true;
|
|
||||||
|
|
||||||
auto typePtr = std::make_shared<Type>(Pointer{
|
|
||||||
BuildType(type),
|
|
||||||
storageClass
|
|
||||||
});
|
|
||||||
|
|
||||||
m_internal->isInBlockStruct = wasInblockStruct;
|
|
||||||
|
|
||||||
return typePtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr
|
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr
|
||||||
{
|
{
|
||||||
return std::visit([&](auto&& arg) -> TypePtr
|
return std::visit([&](auto&& arg) -> TypePtr
|
||||||
|
|
|
||||||
|
|
@ -505,6 +505,7 @@ namespace Nz
|
||||||
ShaderAst::SanitizeVisitor::Options options;
|
ShaderAst::SanitizeVisitor::Options options;
|
||||||
options.optionValues = states.optionValues;
|
options.optionValues = states.optionValues;
|
||||||
options.reduceLoopsToWhile = true;
|
options.reduceLoopsToWhile = true;
|
||||||
|
options.removeAliases = true;
|
||||||
options.removeCompoundAssignments = true;
|
options.removeCompoundAssignments = true;
|
||||||
options.removeMatrixCast = true;
|
options.removeMatrixCast = true;
|
||||||
options.removeOptionDeclaration = true;
|
options.removeOptionDeclaration = true;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
#include <Engine/Shader/ShaderUtils.hpp>
|
||||||
|
#include <Nazara/Core/File.hpp>
|
||||||
|
#include <Nazara/Core/StringExt.hpp>
|
||||||
|
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||||
|
#include <Nazara/Shader/ShaderLangParser.hpp>
|
||||||
|
#include <catch2/catch.hpp>
|
||||||
|
#include <cctype>
|
||||||
|
|
||||||
|
TEST_CASE("aliases", "[Shader]")
|
||||||
|
{
|
||||||
|
SECTION("Alias of structs")
|
||||||
|
{
|
||||||
|
std::string_view nzslSource = R"(
|
||||||
|
[nzsl_version("1.0")]
|
||||||
|
module;
|
||||||
|
|
||||||
|
struct Data
|
||||||
|
{
|
||||||
|
value: f32
|
||||||
|
}
|
||||||
|
|
||||||
|
alias ExtData = Data;
|
||||||
|
|
||||||
|
external
|
||||||
|
{
|
||||||
|
[binding(0)] extData: uniform[ExtData]
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Input
|
||||||
|
{
|
||||||
|
value: f32
|
||||||
|
}
|
||||||
|
|
||||||
|
alias In = Input;
|
||||||
|
|
||||||
|
struct Output
|
||||||
|
{
|
||||||
|
[location(0)] value: f32
|
||||||
|
}
|
||||||
|
|
||||||
|
alias Out = Output;
|
||||||
|
alias FragOut = Out;
|
||||||
|
|
||||||
|
[entry(frag)]
|
||||||
|
fn main(input: In) -> FragOut
|
||||||
|
{
|
||||||
|
let output: Out;
|
||||||
|
output.value = extData.value * input.value;
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
|
||||||
|
ExpectGLSL(*shaderModule, R"(
|
||||||
|
void main()
|
||||||
|
{
|
||||||
|
Input input_;
|
||||||
|
input_.value = _NzIn_value;
|
||||||
|
|
||||||
|
Output output_;
|
||||||
|
output_.value = extData.value * input_.value;
|
||||||
|
|
||||||
|
_NzOut_value = output_.value;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
|
||||||
|
ExpectNZSL(*shaderModule, R"(
|
||||||
|
[entry(frag)]
|
||||||
|
fn main(input: In) -> FragOut
|
||||||
|
{
|
||||||
|
let output: Out;
|
||||||
|
output.value = extData.value * input.value;
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
|
||||||
|
ExpectSPIRV(*shaderModule, R"(
|
||||||
|
OpFunction
|
||||||
|
OpLabel
|
||||||
|
OpVariable
|
||||||
|
OpVariable
|
||||||
|
OpAccessChain
|
||||||
|
OpLoad
|
||||||
|
OpAccessChain
|
||||||
|
OpLoad
|
||||||
|
OpFMul
|
||||||
|
OpAccessChain
|
||||||
|
OpStore
|
||||||
|
OpLoad
|
||||||
|
OpCompositeExtract
|
||||||
|
OpStore
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -262,6 +262,47 @@ fn testMat4ToMat4(input: mat4[f32]) -> mat4[f32]
|
||||||
{
|
{
|
||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
)");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
WHEN("removing aliases")
|
||||||
|
{
|
||||||
|
std::string_view nzslSource = R"(
|
||||||
|
[nzsl_version("1.0")]
|
||||||
|
module;
|
||||||
|
|
||||||
|
struct inputStruct
|
||||||
|
{
|
||||||
|
value: f32
|
||||||
|
}
|
||||||
|
|
||||||
|
alias Input = inputStruct;
|
||||||
|
alias In = Input;
|
||||||
|
|
||||||
|
external
|
||||||
|
{
|
||||||
|
[set(0), binding(0)] data: uniform[In]
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
|
||||||
|
Nz::ShaderAst::SanitizeVisitor::Options options;
|
||||||
|
options.removeAliases = true;
|
||||||
|
|
||||||
|
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, options));
|
||||||
|
|
||||||
|
ExpectNZSL(*shaderModule, R"(
|
||||||
|
struct inputStruct
|
||||||
|
{
|
||||||
|
value: f32
|
||||||
|
}
|
||||||
|
|
||||||
|
external
|
||||||
|
{
|
||||||
|
[set(0), binding(0)] data: uniform[inputStruct]
|
||||||
|
}
|
||||||
)");
|
)");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -142,8 +142,10 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
||||||
Nz::ShaderAst::AstReflect reflectVisitor;
|
Nz::ShaderAst::AstReflect reflectVisitor;
|
||||||
reflectVisitor.Reflect(*shader.rootNode, callbacks);
|
reflectVisitor.Reflect(*shader.rootNode, callbacks);
|
||||||
|
|
||||||
INFO("no entry point found");
|
{
|
||||||
REQUIRE(entryShaderStage.has_value());
|
INFO("no entry point found");
|
||||||
|
REQUIRE(entryShaderStage.has_value());
|
||||||
|
}
|
||||||
|
|
||||||
Nz::GlslWriter writer;
|
Nz::GlslWriter writer;
|
||||||
std::string output = writer.Generate(entryShaderStage, shader);
|
std::string output = writer.Generate(entryShaderStage, shader);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue