Shader: Add proper support for alias

This commit is contained in:
Jérôme Leclercq 2022-03-09 12:35:00 +01:00
parent ce93b61c91
commit 05cf98477e
31 changed files with 472 additions and 98 deletions

View File

@ -42,6 +42,7 @@ namespace Nz::ShaderAst
virtual ExpressionPtr Clone(AccessIdentifierExpression& node); virtual ExpressionPtr Clone(AccessIdentifierExpression& node);
virtual ExpressionPtr Clone(AccessIndexExpression& node); virtual ExpressionPtr Clone(AccessIndexExpression& node);
virtual ExpressionPtr Clone(AliasValueExpression& node);
virtual ExpressionPtr Clone(AssignExpression& node); virtual ExpressionPtr Clone(AssignExpression& node);
virtual ExpressionPtr Clone(BinaryExpression& node); virtual ExpressionPtr Clone(BinaryExpression& node);
virtual ExpressionPtr Clone(CallFunctionExpression& node); virtual ExpressionPtr Clone(CallFunctionExpression& node);

View File

@ -34,6 +34,7 @@ namespace Nz::ShaderAst
inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs); inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs);
inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs); inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs);
inline bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs);
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs); inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs);
inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs); inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs);
inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs); inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs);

View File

@ -248,6 +248,14 @@ namespace Nz::ShaderAst
return true; return true;
} }
bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs)
{
if (!Compare(lhs.aliasId, rhs.aliasId))
return false;
return true;
}
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs) inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs)
{ {
if (!Compare(lhs.op, rhs.op)) if (!Compare(lhs.op, rhs.op))

View File

@ -30,6 +30,7 @@
NAZARA_SHADERAST_EXPRESSION(AccessIdentifierExpression) NAZARA_SHADERAST_EXPRESSION(AccessIdentifierExpression)
NAZARA_SHADERAST_EXPRESSION(AccessIndexExpression) NAZARA_SHADERAST_EXPRESSION(AccessIndexExpression)
NAZARA_SHADERAST_EXPRESSION(AliasValueExpression)
NAZARA_SHADERAST_EXPRESSION(AssignExpression) NAZARA_SHADERAST_EXPRESSION(AssignExpression)
NAZARA_SHADERAST_EXPRESSION(BinaryExpression) NAZARA_SHADERAST_EXPRESSION(BinaryExpression)
NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression) NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression)

View File

@ -22,6 +22,7 @@ namespace Nz::ShaderAst
void Visit(AccessIdentifierExpression& node) override; void Visit(AccessIdentifierExpression& node) override;
void Visit(AccessIndexExpression& node) override; void Visit(AccessIndexExpression& node) override;
void Visit(AliasValueExpression& node) override;
void Visit(AssignExpression& node) override; void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override; void Visit(BinaryExpression& node) override;
void Visit(CallFunctionExpression& node) override; void Visit(CallFunctionExpression& node) override;

View File

@ -25,6 +25,7 @@ namespace Nz::ShaderAst
void Serialize(AccessIdentifierExpression& node); void Serialize(AccessIdentifierExpression& node);
void Serialize(AccessIndexExpression& node); void Serialize(AccessIndexExpression& node);
void Serialize(AliasValueExpression& node);
void Serialize(AssignExpression& node); void Serialize(AssignExpression& node);
void Serialize(BinaryExpression& node); void Serialize(BinaryExpression& node);
void Serialize(CallFunctionExpression& node); void Serialize(CallFunctionExpression& node);

View File

@ -33,6 +33,7 @@ namespace Nz::ShaderAst
void Visit(AccessIdentifierExpression& node) override; void Visit(AccessIdentifierExpression& node) override;
void Visit(AccessIndexExpression& node) override; void Visit(AccessIndexExpression& node) override;
void Visit(AliasValueExpression& node) override;
void Visit(AssignExpression& node) override; void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override; void Visit(BinaryExpression& node) override;
void Visit(CallFunctionExpression& node) override; void Visit(CallFunctionExpression& node) override;

View File

@ -20,6 +20,22 @@ namespace Nz::ShaderAst
{ {
struct ContainedType; struct ContainedType;
struct NAZARA_SHADER_API AliasType
{
AliasType() = default;
AliasType(const AliasType& alias);
AliasType(AliasType&&) noexcept = default;
AliasType& operator=(const AliasType& alias);
AliasType& operator=(AliasType&&) noexcept = default;
std::size_t aliasIndex;
std::unique_ptr<ContainedType> targetType;
bool operator==(const AliasType& rhs) const;
inline bool operator!=(const AliasType& rhs) const;
};
struct NAZARA_SHADER_API ArrayType struct NAZARA_SHADER_API ArrayType
{ {
ArrayType() = default; ArrayType() = default;
@ -134,7 +150,7 @@ namespace Nz::ShaderAst
inline bool operator!=(const VectorType& rhs) const; inline bool operator!=(const VectorType& rhs) const;
}; };
using ExpressionType = std::variant<NoType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>; using ExpressionType = std::variant<NoType, AliasType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
struct ContainedType struct ContainedType
{ {
@ -157,6 +173,7 @@ namespace Nz::ShaderAst
std::vector<StructMember> members; std::vector<StructMember> members;
}; };
inline bool IsAliasType(const ExpressionType& type);
inline bool IsArrayType(const ExpressionType& type); inline bool IsArrayType(const ExpressionType& type);
inline bool IsFunctionType(const ExpressionType& type); inline bool IsFunctionType(const ExpressionType& type);
inline bool IsIdentifierType(const ExpressionType& type); inline bool IsIdentifierType(const ExpressionType& type);

View File

@ -8,6 +8,11 @@
namespace Nz::ShaderAst namespace Nz::ShaderAst
{ {
inline bool AliasType::operator!=(const AliasType& rhs) const
{
return !operator==(rhs);
}
inline bool ArrayType::operator!=(const ArrayType& rhs) const inline bool ArrayType::operator!=(const ArrayType& rhs) const
{ {
return !operator==(rhs); return !operator==(rhs);
@ -130,6 +135,11 @@ namespace Nz::ShaderAst
} }
inline bool IsAliasType(const ExpressionType& type)
{
return std::holds_alternative<AliasType>(type);
}
inline bool IsArrayType(const ExpressionType& type) inline bool IsArrayType(const ExpressionType& type)
{ {
return std::holds_alternative<ArrayType>(type); return std::holds_alternative<ArrayType>(type);

View File

@ -82,6 +82,14 @@ namespace Nz::ShaderAst
ExpressionPtr expr; ExpressionPtr expr;
}; };
struct NAZARA_SHADER_API AliasValueExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t aliasId;
};
struct NAZARA_SHADER_API AssignExpression : Expression struct NAZARA_SHADER_API AssignExpression : Expression
{ {
NodeType GetType() const override; NodeType GetType() const override;
@ -153,7 +161,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override; NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override; void Visit(AstExpressionVisitor& visitor) override;
ShaderAst::ConstantValue value; ConstantValue value;
}; };
struct NAZARA_SHADER_API FunctionExpression : Expression struct NAZARA_SHADER_API FunctionExpression : Expression
@ -207,7 +215,6 @@ namespace Nz::ShaderAst
ExpressionPtr expression; ExpressionPtr expression;
}; };
struct NAZARA_SHADER_API VariableExpression : Expression
struct NAZARA_SHADER_API VariableValueExpression : Expression struct NAZARA_SHADER_API VariableValueExpression : Expression
{ {
NodeType GetType() const override; NodeType GetType() const override;
@ -455,10 +462,12 @@ namespace Nz::ShaderAst
#include <Nazara/Shader/Ast/AstNodeList.hpp> #include <Nazara/Shader/Ast/AstNodeList.hpp>
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr); inline const ExpressionType& GetExpressionType(Expression& expr);
inline ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr); inline ExpressionType& GetExpressionTypeMut(Expression& expr);
inline bool IsExpression(NodeType nodeType); inline bool IsExpression(NodeType nodeType);
inline bool IsStatement(NodeType nodeType); inline bool IsStatement(NodeType nodeType);
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType);
} }
#include <Nazara/Shader/Ast/Nodes.inl> #include <Nazara/Shader/Ast/Nodes.inl>

View File

@ -7,18 +7,29 @@
namespace Nz::ShaderAst namespace Nz::ShaderAst
{ {
const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr) inline const ExpressionType& GetExpressionType(Expression& expr)
{ {
assert(expr.cachedExpressionType); assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value(); return expr.cachedExpressionType.value();
} }
ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr) inline ExpressionType& GetExpressionTypeMut(Expression& expr)
{ {
assert(expr.cachedExpressionType); assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value(); return expr.cachedExpressionType.value();
} }
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)
{
if (IsAliasType(exprType))
{
const AliasType& alias = std::get<AliasType>(exprType);
return alias.targetType->type;
}
else
return exprType;
}
inline bool IsExpression(NodeType nodeType) inline bool IsExpression(NodeType nodeType)
{ {
switch (nodeType) switch (nodeType)

View File

@ -47,6 +47,7 @@ namespace Nz::ShaderAst
std::unordered_map<UInt32, ConstantValue> optionValues; std::unordered_map<UInt32, ConstantValue> optionValues;
bool makeVariableNameUnique = false; bool makeVariableNameUnique = false;
bool reduceLoopsToWhile = false; bool reduceLoopsToWhile = false;
bool removeAliases = false;
bool removeConstDeclaration = false; bool removeConstDeclaration = false;
bool removeCompoundAssignments = false; bool removeCompoundAssignments = false;
bool removeMatrixCast = false; bool removeMatrixCast = false;
@ -71,6 +72,7 @@ namespace Nz::ShaderAst
ExpressionPtr Clone(AccessIdentifierExpression& node) override; ExpressionPtr Clone(AccessIdentifierExpression& node) override;
ExpressionPtr Clone(AccessIndexExpression& node) override; ExpressionPtr Clone(AccessIndexExpression& node) override;
ExpressionPtr Clone(AliasValueExpression& node) override;
ExpressionPtr Clone(AssignExpression& node) override; ExpressionPtr Clone(AssignExpression& node) override;
ExpressionPtr Clone(BinaryExpression& node) override; ExpressionPtr Clone(BinaryExpression& node) override;
ExpressionPtr Clone(CallFunctionExpression& node) override; ExpressionPtr Clone(CallFunctionExpression& node) override;
@ -136,15 +138,16 @@ namespace Nz::ShaderAst
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {}); std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {}); std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
const IdentifierData* ResolveAlias(const IdentifierData* identifier) const; const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
void ResolveFunctions(); void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node); const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
std::size_t ResolveStruct(const AliasType& aliasType);
std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const ExpressionType& exprType);
std::size_t ResolveStruct(const IdentifierType& identifierType); std::size_t ResolveStruct(const IdentifierType& identifierType);
std::size_t ResolveStruct(const StructType& structType); std::size_t ResolveStruct(const StructType& structType);
std::size_t ResolveStruct(const UniformType& uniformType); std::size_t ResolveStruct(const UniformType& uniformType);
ExpressionType ResolveType(const ExpressionType& exprType); ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue); ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
void SanitizeIdentifier(std::string& identifier); void SanitizeIdentifier(std::string& identifier);
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error); MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);

View File

@ -51,6 +51,7 @@ namespace Nz
static ShaderAst::ModulePtr Sanitize(const ShaderAst::Module& module, std::unordered_map<UInt32, ShaderAst::ConstantValue> optionValues, std::string* error = nullptr); static ShaderAst::ModulePtr Sanitize(const ShaderAst::Module& module, std::unordered_map<UInt32, ShaderAst::ConstantValue> optionValues, std::string* error = nullptr);
private: private:
void Append(const ShaderAst::AliasType& aliasType);
void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ArrayType& type);
void Append(ShaderAst::BuiltinEntry builtin); void Append(ShaderAst::BuiltinEntry builtin);
void Append(const ShaderAst::ExpressionType& type); void Append(const ShaderAst::ExpressionType& type);
@ -94,6 +95,7 @@ namespace Nz
void Visit(ShaderAst::AccessIdentifierExpression& node) override; void Visit(ShaderAst::AccessIdentifierExpression& node) override;
void Visit(ShaderAst::AccessIndexExpression& node) override; void Visit(ShaderAst::AccessIndexExpression& node) override;
void Visit(ShaderAst::AliasValueExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CallFunctionExpression& node) override; void Visit(ShaderAst::CallFunctionExpression& node) override;

View File

@ -50,6 +50,7 @@ namespace Nz
struct UnrollAttribute; struct UnrollAttribute;
struct UuidAttribute; struct UuidAttribute;
void Append(const ShaderAst::AliasType& type);
void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ArrayType& type);
void Append(const ShaderAst::ExpressionType& type); void Append(const ShaderAst::ExpressionType& type);
void Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type); void Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type);
@ -92,6 +93,7 @@ namespace Nz
void EnterScope(); void EnterScope();
void LeaveScope(bool skipLine = true); void LeaveScope(bool skipLine = true);
void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
void RegisterConstant(std::size_t constantIndex, std::string constantName); void RegisterConstant(std::size_t constantIndex, std::string constantName);
void RegisterStruct(std::size_t structIndex, std::string structName); void RegisterStruct(std::size_t structIndex, std::string structName);
void RegisterVariable(std::size_t varIndex, std::string varName); void RegisterVariable(std::size_t varIndex, std::string varName);
@ -102,6 +104,7 @@ namespace Nz
void Visit(ShaderAst::AccessIdentifierExpression& node) override; void Visit(ShaderAst::AccessIdentifierExpression& node) override;
void Visit(ShaderAst::AccessIndexExpression& node) override; void Visit(ShaderAst::AccessIndexExpression& node) override;
void Visit(ShaderAst::AliasValueExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::CastExpression& node) override;

View File

@ -48,8 +48,6 @@ namespace Nz
void Visit(ShaderAst::CallFunctionExpression& node) override; void Visit(ShaderAst::CallFunctionExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConstantValueExpression& node) override; void Visit(ShaderAst::ConstantValueExpression& node) override;
void Visit(ShaderAst::DeclareAliasStatement& node) override;
void Visit(ShaderAst::DeclareConstStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override;
void Visit(ShaderAst::DeclareOptionStatement& node) override; void Visit(ShaderAst::DeclareOptionStatement& node) override;
@ -64,7 +62,7 @@ namespace Nz
void Visit(ShaderAst::ScopedStatement& node) override; void Visit(ShaderAst::ScopedStatement& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override;
void Visit(ShaderAst::UnaryExpression& node) override; void Visit(ShaderAst::UnaryExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override; void Visit(ShaderAst::VariableValueExpression& node) override;
void Visit(ShaderAst::WhileStatement& node) override; void Visit(ShaderAst::WhileStatement& node) override;
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete; SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;

View File

@ -178,6 +178,7 @@ namespace Nz
TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const; TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const; TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const; TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const;
TypePtr BuildType(const ShaderAst::AliasType& type) const;
TypePtr BuildType(const ShaderAst::ArrayType& type) const; TypePtr BuildType(const ShaderAst::ArrayType& type) const;
TypePtr BuildType(const ShaderAst::ExpressionType& type) const; TypePtr BuildType(const ShaderAst::ExpressionType& type) const;
TypePtr BuildType(const ShaderAst::IdentifierType& type) const; TypePtr BuildType(const ShaderAst::IdentifierType& type) const;

View File

@ -300,6 +300,16 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
ExpressionPtr AstCloner::Clone(AliasValueExpression& node)
{
auto clone = std::make_unique<AliasValueExpression>();
clone->aliasId = node.aliasId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(AssignExpression& node) ExpressionPtr AstCloner::Clone(AssignExpression& node)
{ {
auto clone = std::make_unique<AssignExpression>(); auto clone = std::make_unique<AssignExpression>();

View File

@ -19,6 +19,11 @@ namespace Nz::ShaderAst
index->Visit(*this); index->Visit(*this);
} }
void AstRecursiveVisitor::Visit(AliasValueExpression& /*node*/)
{
/* nothing to do */
}
void AstRecursiveVisitor::Visit(AssignExpression& node) void AstRecursiveVisitor::Visit(AssignExpression& node)
{ {
node.left->Visit(*this); node.left->Visit(*this);

View File

@ -58,6 +58,11 @@ namespace Nz::ShaderAst
Node(identifier); Node(identifier);
} }
void AstSerializerBase::Serialize(AliasValueExpression& node)
{
SizeT(node.aliasId);
}
void AstSerializerBase::Serialize(AssignExpression& node) void AstSerializerBase::Serialize(AssignExpression& node)
{ {
Enum(node.op); Enum(node.op);
@ -485,6 +490,12 @@ namespace Nz::ShaderAst
Type(arg.objectType->type); Type(arg.objectType->type);
SizeT(arg.methodIndex); SizeT(arg.methodIndex);
} }
else if constexpr (std::is_same_v<T, ShaderAst::AliasType>)
{
m_stream << UInt8(13);
SizeT(arg.aliasIndex);
Type(arg.targetType->type);
}
else else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor"); static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type); }, type);
@ -800,6 +811,22 @@ namespace Nz::ShaderAst
type = std::move(methodType); type = std::move(methodType);
} }
case 13: //< AliasType
{
std::size_t aliasIndex;
ExpressionType containedType;
SizeT(aliasIndex);
Type(containedType);
AliasType aliasType;
aliasType.aliasIndex = aliasIndex;
aliasType.targetType = std::make_unique<ContainedType>();
aliasType.targetType->type = std::move(containedType);
type = std::move(aliasType);
break;
}
default: default:
break; break;
} }

View File

@ -23,6 +23,11 @@ namespace Nz::ShaderAst
node.expr->Visit(*this); node.expr->Visit(*this);
} }
void ShaderAstValueCategory::Visit(AliasValueExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(AssignExpression& /*node*/) void ShaderAstValueCategory::Visit(AssignExpression& /*node*/)
{ {
m_expressionCategory = ExpressionCategory::RValue; m_expressionCategory = ExpressionCategory::RValue;

View File

@ -9,6 +9,37 @@
namespace Nz::ShaderAst namespace Nz::ShaderAst
{ {
AliasType::AliasType(const AliasType& alias) :
aliasIndex(alias.aliasIndex)
{
assert(alias.targetType);
targetType = std::make_unique<ContainedType>(*alias.targetType);
}
AliasType& AliasType::operator=(const AliasType& alias)
{
aliasIndex = alias.aliasIndex;
assert(alias.targetType);
targetType = std::make_unique<ContainedType>(*alias.targetType);
return *this;
}
bool AliasType::operator==(const AliasType& rhs) const
{
assert(targetType);
assert(rhs.targetType);
if (aliasIndex != rhs.aliasIndex)
return false;
if (targetType->type != rhs.targetType->type)
return false;
return true;
}
ArrayType::ArrayType(const ArrayType& array) : ArrayType::ArrayType(const ArrayType& array) :
length(array.length) length(array.length)
{ {
@ -31,10 +62,10 @@ namespace Nz::ShaderAst
assert(containedType); assert(containedType);
assert(rhs.containedType); assert(rhs.containedType);
if (containedType->type != rhs.containedType->type) if (length != rhs.length)
return false; return false;
if (length != rhs.length) if (containedType->type != rhs.containedType->type)
return false; return false;
return true; return true;

View File

@ -282,7 +282,7 @@ namespace Nz::ShaderAst
if (identifier.empty()) if (identifier.empty())
throw AstError{ "empty identifier" }; throw AstError{ "empty identifier" };
const ExpressionType& exprType = GetExpressionType(*indexedExpr); const ExpressionType& exprType = ResolveAlias(GetExpressionType(*indexedExpr));
// TODO: Add proper support for methods // TODO: Add proper support for methods
if (IsSamplerType(exprType)) if (IsSamplerType(exprType))
{ {
@ -429,6 +429,25 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
ExpressionPtr SanitizeVisitor::Clone(AliasValueExpression& node)
{
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId));
ExpressionPtr targetExpr = HandleIdentifier(targetIdentifier);
if (m_context->options.removeAliases)
return targetExpr;
AliasType aliasType;
aliasType.aliasIndex = node.aliasId;
aliasType.targetType = std::make_unique<ContainedType>();
aliasType.targetType->type = *targetExpr->cachedExpressionType;
auto clone = static_unique_pointer_cast<AliasValueExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = std::move(aliasType);
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node) ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
{ {
MandatoryExpr(node.left); MandatoryExpr(node.left);
@ -543,7 +562,7 @@ namespace Nz::ShaderAst
{ {
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType); const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front()); const ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
bool isMatrixCast = IsMatrixType(frontExprType); bool isMatrixCast = IsMatrixType(frontExprType);
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType) if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
{ {
@ -785,6 +804,9 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<DeclareAliasStatement>(AstCloner::Clone(node)); auto clone = static_unique_pointer_cast<DeclareAliasStatement>(AstCloner::Clone(node));
Validate(*clone); Validate(*clone);
if (m_context->options.removeAliases)
return ShaderBuilder::NoOp();
return clone; return clone;
} }
@ -803,7 +825,7 @@ namespace Nz::ShaderAst
ExpressionType expressionType = ResolveType(GetExpressionType(value)); ExpressionType expressionType = ResolveType(GetExpressionType(value));
if (clone->type.HasValue() && ResolveType(clone->type) != expressionType) if (clone->type.HasValue() && ResolveType(clone->type, true) != ResolveAlias(expressionType))
throw AstError{ "constant expression doesn't match type" }; throw AstError{ "constant expression doesn't match type" };
clone->type = expressionType; clone->type = expressionType;
@ -852,12 +874,13 @@ namespace Nz::ShaderAst
m_context->declaredExternalVar.insert(extVar.name); m_context->declaredExternalVar.insert(extVar.name);
ExpressionType resolvedType = ResolveType(extVar.type); ExpressionType resolvedType = ResolveType(extVar.type);
const ExpressionType& targetType = ResolveAlias(resolvedType);
ExpressionType varType; ExpressionType varType;
if (IsUniformType(resolvedType)) if (IsUniformType(targetType))
varType = std::get<UniformType>(resolvedType).containedType; varType = std::get<UniformType>(targetType).containedType;
else if (IsSamplerType(resolvedType)) else if (IsSamplerType(targetType))
varType = resolvedType; varType = targetType;
else else
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
@ -954,8 +977,9 @@ namespace Nz::ShaderAst
throw AstError{ "empty option name" }; throw AstError{ "empty option name" };
ExpressionType resolvedType = ResolveType(clone->optType); ExpressionType resolvedType = ResolveType(clone->optType);
const ExpressionType& targetType = ResolveAlias(resolvedType);
if (clone->defaultValue && resolvedType != GetExpressionType(*clone->defaultValue)) if (clone->defaultValue && targetType != GetExpressionType(*clone->defaultValue))
throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" }; throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" };
clone->optType = std::move(resolvedType); clone->optType = std::move(resolvedType);
@ -1009,11 +1033,13 @@ namespace Nz::ShaderAst
ExpressionType resolvedType = ResolveType(member.type); ExpressionType resolvedType = ResolveType(member.type);
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140) if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
{ {
if (IsPrimitiveType(resolvedType) && std::get<PrimitiveType>(resolvedType) == PrimitiveType::Boolean) const ExpressionType& targetType = ResolveAlias(resolvedType);
if (IsPrimitiveType(targetType) && std::get<PrimitiveType>(targetType) == PrimitiveType::Boolean)
throw AstError{ "boolean type is not allowed in std140 layout" }; throw AstError{ "boolean type is not allowed in std140 layout" };
else if (IsStructType(resolvedType)) else if (IsStructType(targetType))
{ {
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex; std::size_t structIndex = std::get<StructType>(targetType).structIndex;
const StructDescription* desc = m_context->structs.Retrieve(structIndex); const StructDescription* desc = m_context->structs.Retrieve(structIndex);
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
throw AstError{ "inner struct layout mismatch" }; throw AstError{ "inner struct layout mismatch" };
@ -1461,7 +1487,7 @@ namespace Nz::ShaderAst
AstExportVisitor exportVisitor; AstExportVisitor exportVisitor;
exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks); exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
if (aliasStatements.empty()) if (aliasStatements.empty() || m_context->options.removeAliases)
return ShaderBuilder::NoOp(); return ShaderBuilder::NoOp();
// Register module and aliases // Register module and aliases
@ -1546,7 +1572,7 @@ namespace Nz::ShaderAst
return nullptr; return nullptr;
} }
return ResolveAlias(&it->data); return &it->data;
} }
template<typename F> template<typename F>
@ -1556,7 +1582,7 @@ namespace Nz::ShaderAst
{ {
if (identifier.name == identifierName) if (identifier.name == identifierName)
{ {
if (functor(*ResolveAlias(&identifier.data))) if (functor(identifier.data))
return true; return true;
} }
@ -1570,7 +1596,7 @@ namespace Nz::ShaderAst
return nullptr; return nullptr;
} }
return ResolveAlias(&it->data); return &it->data;
} }
TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const
@ -1626,6 +1652,14 @@ namespace Nz::ShaderAst
{ {
switch (identifierData->category) switch (identifierData->category)
{ {
case IdentifierCategory::Alias:
{
AliasValueExpression aliasValue;
aliasValue.aliasId = identifierData->index;
return Clone(aliasValue);
}
case IdentifierCategory::Constant: case IdentifierCategory::Constant:
{ {
// Replace IdentifierExpression by Constant(Value)Expression // Replace IdentifierExpression by Constant(Value)Expression
@ -2124,7 +2158,7 @@ namespace Nz::ShaderAst
return varIndex; return varIndex;
} }
auto SanitizeVisitor::ResolveAlias(const IdentifierData* identifier) const -> const IdentifierData* auto SanitizeVisitor::ResolveAliasIdentifier(const IdentifierData* identifier) const -> const IdentifierData*
{ {
while (identifier->category == IdentifierCategory::Alias) while (identifier->category == IdentifierCategory::Alias)
identifier = &m_context->aliases.Retrieve(identifier->index); identifier = &m_context->aliases.Retrieve(identifier->index);
@ -2181,7 +2215,7 @@ namespace Nz::ShaderAst
for (const auto& [funcIndex, funcData] : m_context->functions.values) for (const auto& [funcIndex, funcData] : m_context->functions.values)
{ {
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment) if (funcData.flags.Test(FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
throw AstError{ "discard can only be used in the fragment stage" }; throw AstError{ "discard can only be used in the fragment stage" };
} }
} }
@ -2203,13 +2237,18 @@ namespace Nz::ShaderAst
} }
std::size_t SanitizeVisitor::ResolveStruct(const AliasType& aliasType)
{
return ResolveStruct(aliasType.targetType->type);
}
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType) std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
{ {
return std::visit([&](auto&& arg) -> std::size_t return std::visit([&](auto&& arg) -> std::size_t
{ {
using T = std::decay_t<decltype(arg)>; using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType>) if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType> || std::is_same_v<T, AliasType>)
return ResolveStruct(arg); return ResolveStruct(arg);
else if constexpr (std::is_same_v<T, NoType> || else if constexpr (std::is_same_v<T, NoType> ||
std::is_same_v<T, ArrayType> || std::is_same_v<T, ArrayType> ||
@ -2251,10 +2290,15 @@ namespace Nz::ShaderAst
return uniformType.containedType.structIndex; return uniformType.containedType.structIndex;
} }
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType) ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias)
{ {
if (!IsTypeExpression(exprType)) if (!IsTypeExpression(exprType))
return exprType; {
if (resolveAlias || m_context->options.removeAliases)
return ResolveAlias(exprType);
else
return exprType;
}
std::size_t typeIndex = std::get<Type>(exprType).typeIndex; std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
@ -2265,13 +2309,13 @@ namespace Nz::ShaderAst
return std::get<ExpressionType>(type); return std::get<ExpressionType>(type);
} }
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue) ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias)
{ {
if (!exprTypeValue.HasValue()) if (!exprTypeValue.HasValue())
return {}; return {};
if (exprTypeValue.IsResultingValue()) if (exprTypeValue.IsResultingValue())
return ResolveType(exprTypeValue.GetResultingValue()); return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias);
assert(exprTypeValue.IsExpression()); assert(exprTypeValue.IsExpression());
ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression()); ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression());
@ -2281,7 +2325,7 @@ namespace Nz::ShaderAst
//if (!IsTypeType(exprType)) //if (!IsTypeType(exprType))
// throw AstError{ "type expected" }; // throw AstError{ "type expected" };
return ResolveType(exprType); return ResolveType(exprType, resolveAlias);
} }
void SanitizeVisitor::SanitizeIdentifier(std::string& identifier) void SanitizeVisitor::SanitizeIdentifier(std::string& identifier)
@ -2334,7 +2378,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const
{ {
if (left != right) if (ResolveAlias(left) != ResolveAlias(right))
throw AstError{ "Left expression type must match right expression type" }; throw AstError{ "Left expression type must match right expression type" };
} }
@ -2359,6 +2403,11 @@ namespace Nz::ShaderAst
std::size_t structIndex = ResolveStruct(exprType); std::size_t structIndex = ResolveStruct(exprType);
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex); node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
} }
else if (IsAliasType(exprType))
{
const AliasType& alias = std::get<AliasType>(exprType);
node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex);
}
else else
throw AstError{ "for now, only structs can be aliased" }; throw AstError{ "for now, only structs can be aliased" };
} }
@ -2400,7 +2449,7 @@ namespace Nz::ShaderAst
case TypeParameterCategory::PrimitiveType: case TypeParameterCategory::PrimitiveType:
case TypeParameterCategory::StructType: case TypeParameterCategory::StructType:
{ {
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr)); ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr), true);
switch (partialType.parameters[i]) switch (partialType.parameters[i])
{ {
@ -2440,7 +2489,7 @@ namespace Nz::ShaderAst
for (auto& index : node.indices) for (auto& index : node.indices)
{ {
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index); const ExpressionType& indexType = GetExpressionType(*index);
if (!IsPrimitiveType(indexType)) if (!IsPrimitiveType(indexType))
throw AstError{ "AccessIndex expects integer indices" }; throw AstError{ "AccessIndex expects integer indices" };
@ -2459,7 +2508,7 @@ namespace Nz::ShaderAst
} }
else if (IsStructType(exprType)) else if (IsStructType(exprType))
{ {
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr); const ExpressionType& indexType = GetExpressionType(*indexExpr);
if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 }) if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
throw AstError{ "struct can only be accessed with constant i32 indices" }; throw AstError{ "struct can only be accessed with constant i32 indices" };
@ -2470,7 +2519,7 @@ namespace Nz::ShaderAst
std::size_t structIndex = ResolveStruct(exprType); std::size_t structIndex = ResolveStruct(exprType);
const StructDescription* s = m_context->structs.Retrieve(structIndex); const StructDescription* s = m_context->structs.Retrieve(structIndex);
exprType = ResolveType(s->members[index].type); exprType = ResolveType(s->members[index].type, true);
} }
else if (IsMatrixType(exprType)) else if (IsMatrixType(exprType))
{ {
@ -2538,7 +2587,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(CallFunctionExpression& node) void SanitizeVisitor::Validate(CallFunctionExpression& node)
{ {
const ShaderAst::ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction); const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
assert(std::holds_alternative<FunctionType>(targetFuncType)); assert(std::holds_alternative<FunctionType>(targetFuncType));
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex; std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
@ -2564,14 +2613,15 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(CastExpression& node) void SanitizeVisitor::Validate(CastExpression& node)
{ {
ExpressionType resolvedType = ResolveType(node.targetType); ExpressionType resolvedType = ResolveType(node.targetType);
const ExpressionType& targetType = ResolveAlias(resolvedType);
const auto& firstExprPtr = node.expressions.front(); const auto& firstExprPtr = node.expressions.front();
if (!firstExprPtr) if (!firstExprPtr)
throw AstError{ "expected at least one expression" }; throw AstError{ "expected at least one expression" };
if (IsMatrixType(resolvedType)) if (IsMatrixType(targetType))
{ {
const MatrixType& targetMatrixType = std::get<MatrixType>(resolvedType); const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr); const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
if (IsMatrixType(firstExprType)) if (IsMatrixType(firstExprType))
@ -2614,7 +2664,7 @@ namespace Nz::ShaderAst
}; };
std::size_t componentCount = 0; std::size_t componentCount = 0;
std::size_t requiredComponents = GetComponentCount(resolvedType); std::size_t requiredComponents = GetComponentCount(targetType);
for (auto& exprPtr : node.expressions) for (auto& exprPtr : node.expressions)
{ {
@ -2885,11 +2935,11 @@ namespace Nz::ShaderAst
case UnaryType::Minus: case UnaryType::Minus:
case UnaryType::Plus: case UnaryType::Plus:
{ {
ShaderAst::PrimitiveType basicType; PrimitiveType basicType;
if (IsPrimitiveType(exprType)) if (IsPrimitiveType(exprType))
basicType = std::get<ShaderAst::PrimitiveType>(exprType); basicType = std::get<PrimitiveType>(exprType);
else if (IsVectorType(exprType)) else if (IsVectorType(exprType))
basicType = std::get<ShaderAst::VectorType>(exprType).type; basicType = std::get<VectorType>(exprType).type;
else else
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" }; throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };

View File

@ -232,6 +232,7 @@ namespace Nz
options.optionValues = std::move(optionValues); options.optionValues = std::move(optionValues);
options.makeVariableNameUnique = true; options.makeVariableNameUnique = true;
options.reduceLoopsToWhile = true; options.reduceLoopsToWhile = true;
options.removeAliases = true;
options.removeCompoundAssignments = false; options.removeCompoundAssignments = false;
options.removeConstDeclaration = true; options.removeConstDeclaration = true;
options.removeOptionDeclaration = true; options.removeOptionDeclaration = true;
@ -246,6 +247,11 @@ namespace Nz
return ShaderAst::Sanitize(module, options, error); return ShaderAst::Sanitize(module, options, error);
} }
void GlslWriter::Append(const ShaderAst::AliasType& /*aliasType*/)
{
throw std::runtime_error("unexpected AliasType");
}
void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/) void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/)
{ {
throw std::runtime_error("unexpected ArrayType"); throw std::runtime_error("unexpected ArrayType");
@ -689,11 +695,16 @@ namespace Nz
builtin.identifier builtin.identifier
}); });
} }
else if (member.locationIndex.HasValue()) else
{ {
Append("layout(location = "); if (member.locationIndex.HasValue())
Append(member.locationIndex.GetResultingValue()); {
Append(") ", keyword, " "); Append("layout(location = ");
Append(member.locationIndex.GetResultingValue());
Append(") ");
}
Append(keyword, " ");
AppendVariableDeclaration(member.type.GetResultingValue(), targetPrefix + member.name); AppendVariableDeclaration(member.type.GetResultingValue(), targetPrefix + member.name);
AppendLine(";"); AppendLine(";");
@ -805,6 +816,12 @@ namespace Nz
Append("]"); Append("]");
} }
void GlslWriter::Visit(ShaderAst::AliasValueExpression& /*node*/)
{
// all aliases should have been handled by sanitizer
throw std::runtime_error("unexpected alias value, is shader sanitized?");
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node) void GlslWriter::Visit(ShaderAst::AssignExpression& node)
{ {
node.left->Visit(*this); node.left->Visit(*this);
@ -1038,12 +1055,14 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareAliasStatement& /*node*/) void GlslWriter::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
{ {
/* nothing to do */ // all aliases should have been handled by sanitizer
throw std::runtime_error("unexpected alias declaration, is shader sanitized?");
} }
void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/) void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{ {
/* nothing to do */ // all consts should have been handled by sanitizer
throw std::runtime_error("unexpected const declaration, is shader sanitized?");
} }
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node) void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
@ -1184,7 +1203,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareOptionStatement& /*node*/) void GlslWriter::Visit(ShaderAst::DeclareOptionStatement& /*node*/)
{ {
/* nothing to do */ // all options should have been handled by sanitizer
throw std::runtime_error("unexpected option declaration, is shader sanitized?");
} }
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node) void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
@ -1247,7 +1267,7 @@ namespace Nz
Append(";"); Append(";");
} }
void GlslWriter::Visit(ShaderAst::ImportStatement& node) void GlslWriter::Visit(ShaderAst::ImportStatement& /*node*/)
{ {
throw std::runtime_error("unexpected import statement, is the shader sanitized properly?"); throw std::runtime_error("unexpected import statement, is the shader sanitized properly?");
} }
@ -1280,8 +1300,8 @@ namespace Nz
const auto& structData = Retrieve(m_currentState->structs, structIndex); const auto& structData = Retrieve(m_currentState->structs, structIndex);
std::string outputStructVarName; std::string outputStructVarName;
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression) if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableValueExpression)
outputStructVarName = Retrieve(m_currentState->variableNames, static_cast<ShaderAst::VariableExpression&>(*node.returnExpr).variableId); outputStructVarName = Retrieve(m_currentState->variableNames, static_cast<ShaderAst::VariableValueExpression&>(*node.returnExpr).variableId);
else else
{ {
AppendLine(); AppendLine();

View File

@ -116,6 +116,7 @@ namespace Nz
ShaderAst::Module* module; ShaderAst::Module* module;
std::size_t currentModuleIndex; std::size_t currentModuleIndex;
std::stringstream stream; std::stringstream stream;
std::unordered_map<std::size_t, Identifier> aliases;
std::unordered_map<std::size_t, Identifier> constants; std::unordered_map<std::size_t, Identifier> constants;
std::unordered_map<std::size_t, Identifier> structs; std::unordered_map<std::size_t, Identifier> structs;
std::unordered_map<std::size_t, Identifier> variables; std::unordered_map<std::size_t, Identifier> variables;
@ -164,6 +165,11 @@ namespace Nz
m_environment = std::move(environment); m_environment = std::move(environment);
} }
void LangWriter::Append(const ShaderAst::AliasType& type)
{
AppendIdentifier(m_currentState->aliases, type.aliasIndex);
}
void LangWriter::Append(const ShaderAst::ArrayType& type) void LangWriter::Append(const ShaderAst::ArrayType& type)
{ {
Append("array[", type.containedType->type, ", ", type.length, "]"); Append("array[", type.containedType->type, ", ", type.length, "]");
@ -655,6 +661,16 @@ namespace Nz
Append("}"); Append("}");
} }
void LangWriter::RegisterAlias(std::size_t aliasIndex, std::string aliasName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(aliasName);
assert(m_currentState->aliases.find(aliasIndex) == m_currentState->aliases.end());
m_currentState->aliases.emplace(aliasIndex, std::move(identifier));
}
void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName) void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName)
{ {
State::Identifier identifier; State::Identifier identifier;
@ -714,7 +730,7 @@ namespace Nz
{ {
Visit(node.expr, true); Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
assert(IsStructType(exprType)); assert(IsStructType(exprType));
for (const std::string& identifier : node.identifiers) for (const std::string& identifier : node.identifiers)
@ -725,7 +741,7 @@ namespace Nz
{ {
Visit(node.expr, true); Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
assert(!IsStructType(exprType)); assert(!IsStructType(exprType));
// Array access // Array access
@ -744,6 +760,11 @@ namespace Nz
Append("]"); Append("]");
} }
void LangWriter::Visit(ShaderAst::AliasValueExpression& node)
{
AppendIdentifier(m_currentState->aliases, node.aliasId);
}
void LangWriter::Visit(ShaderAst::AssignExpression& node) void LangWriter::Visit(ShaderAst::AssignExpression& node)
{ {
node.left->Visit(*this); node.left->Visit(*this);
@ -840,7 +861,13 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ConditionalExpression& node) void LangWriter::Visit(ShaderAst::ConditionalExpression& node)
{ {
throw std::runtime_error("fixme"); Append("const_select(");
node.condition->Visit(*this);
Append(", ");
node.truePath->Visit(*this);
Append(", ");
node.falsePath->Visit(*this);
Append(")");
} }
void LangWriter::Visit(ShaderAst::ConditionalStatement& node) void LangWriter::Visit(ShaderAst::ConditionalStatement& node)
@ -853,9 +880,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node) void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
{ {
//throw std::runtime_error("TODO"); //< missing registering
assert(node.aliasIndex); assert(node.aliasIndex);
RegisterAlias(*node.aliasIndex, node.name);
Append("alias ", node.name, " = "); Append("alias ", node.name, " = ");
assert(node.expression); assert(node.expression);

View File

@ -598,16 +598,6 @@ namespace Nz
}, node.value); }, node.value);
} }
void SpirvAstVisitor::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
{
/* nothing to do */
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{
/* nothing to do */
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node) void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
{ {
for (auto&& extVar : node.externalVars) for (auto&& extVar : node.externalVars)
@ -729,11 +719,6 @@ namespace Nz
PopResultId(); PopResultId();
} }
void SpirvAstVisitor::Visit(ShaderAst::ImportStatement& node)
{
/* nothing to do */
}
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node) void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
{ {
switch (node.intrinsic) switch (node.intrinsic)

View File

@ -655,6 +655,28 @@ namespace Nz
return typePtr; return typePtr;
} }
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
{
bool wasInblockStruct = m_internal->isInBlockStruct;
if (storageClass == SpirvStorageClass::Uniform)
m_internal->isInBlockStruct = true;
auto typePtr = std::make_shared<Type>(Pointer{
BuildType(type),
storageClass
});
m_internal->isInBlockStruct = wasInblockStruct;
return typePtr;
}
auto SpirvConstantCache::BuildType(const ShaderAst::AliasType& /*type*/) const -> TypePtr
{
// No AliasType is expected (as they should have been resolved by now)
throw std::runtime_error("unexpected alias");
}
auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr
{ {
const auto& containedType = type.containedType->type; const auto& containedType = type.containedType->type;
@ -678,22 +700,6 @@ namespace Nz
}); });
} }
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
{
bool wasInblockStruct = m_internal->isInBlockStruct;
if (storageClass == SpirvStorageClass::Uniform)
m_internal->isInBlockStruct = true;
auto typePtr = std::make_shared<Type>(Pointer{
BuildType(type),
storageClass
});
m_internal->isInBlockStruct = wasInblockStruct;
return typePtr;
}
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr
{ {
return std::visit([&](auto&& arg) -> TypePtr return std::visit([&](auto&& arg) -> TypePtr

View File

@ -505,6 +505,7 @@ namespace Nz
ShaderAst::SanitizeVisitor::Options options; ShaderAst::SanitizeVisitor::Options options;
options.optionValues = states.optionValues; options.optionValues = states.optionValues;
options.reduceLoopsToWhile = true; options.reduceLoopsToWhile = true;
options.removeAliases = true;
options.removeCompoundAssignments = true; options.removeCompoundAssignments = true;
options.removeMatrixCast = true; options.removeMatrixCast = true;
options.removeOptionDeclaration = true; options.removeOptionDeclaration = true;

View File

@ -0,0 +1,97 @@
#include <Engine/Shader/ShaderUtils.hpp>
#include <Nazara/Core/File.hpp>
#include <Nazara/Core/StringExt.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <catch2/catch.hpp>
#include <cctype>
TEST_CASE("aliases", "[Shader]")
{
SECTION("Alias of structs")
{
std::string_view nzslSource = R"(
[nzsl_version("1.0")]
module;
struct Data
{
value: f32
}
alias ExtData = Data;
external
{
[binding(0)] extData: uniform[ExtData]
}
struct Input
{
value: f32
}
alias In = Input;
struct Output
{
[location(0)] value: f32
}
alias Out = Output;
alias FragOut = Out;
[entry(frag)]
fn main(input: In) -> FragOut
{
let output: Out;
output.value = extData.value * input.value;
return output;
}
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
ExpectGLSL(*shaderModule, R"(
void main()
{
Input input_;
input_.value = _NzIn_value;
Output output_;
output_.value = extData.value * input_.value;
_NzOut_value = output_.value;
return;
}
)");
ExpectNZSL(*shaderModule, R"(
[entry(frag)]
fn main(input: In) -> FragOut
{
let output: Out;
output.value = extData.value * input.value;
return output;
}
)");
ExpectSPIRV(*shaderModule, R"(
OpFunction
OpLabel
OpVariable
OpVariable
OpAccessChain
OpLoad
OpAccessChain
OpLoad
OpFMul
OpAccessChain
OpStore
OpLoad
OpCompositeExtract
OpStore
OpReturn
OpFunctionEnd)");
}
}

View File

@ -262,6 +262,47 @@ fn testMat4ToMat4(input: mat4[f32]) -> mat4[f32]
{ {
return input; return input;
} }
)");
}
WHEN("removing aliases")
{
std::string_view nzslSource = R"(
[nzsl_version("1.0")]
module;
struct inputStruct
{
value: f32
}
alias Input = inputStruct;
alias In = Input;
external
{
[set(0), binding(0)] data: uniform[In]
}
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
Nz::ShaderAst::SanitizeVisitor::Options options;
options.removeAliases = true;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, options));
ExpectNZSL(*shaderModule, R"(
struct inputStruct
{
value: f32
}
external
{
[set(0), binding(0)] data: uniform[inputStruct]
}
)"); )");
} }

View File

@ -142,8 +142,10 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
Nz::ShaderAst::AstReflect reflectVisitor; Nz::ShaderAst::AstReflect reflectVisitor;
reflectVisitor.Reflect(*shader.rootNode, callbacks); reflectVisitor.Reflect(*shader.rootNode, callbacks);
INFO("no entry point found"); {
REQUIRE(entryShaderStage.has_value()); INFO("no entry point found");
REQUIRE(entryShaderStage.has_value());
}
Nz::GlslWriter writer; Nz::GlslWriter writer;
std::string output = writer.Generate(entryShaderStage, shader); std::string output = writer.Generate(entryShaderStage, shader);