Shader: Add proper support for alias
This commit is contained in:
parent
ce93b61c91
commit
05cf98477e
|
|
@ -42,6 +42,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
virtual ExpressionPtr Clone(AccessIdentifierExpression& node);
|
||||
virtual ExpressionPtr Clone(AccessIndexExpression& node);
|
||||
virtual ExpressionPtr Clone(AliasValueExpression& node);
|
||||
virtual ExpressionPtr Clone(AssignExpression& node);
|
||||
virtual ExpressionPtr Clone(BinaryExpression& node);
|
||||
virtual ExpressionPtr Clone(CallFunctionExpression& node);
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs);
|
||||
inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs);
|
||||
inline bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs);
|
||||
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs);
|
||||
inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs);
|
||||
inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs);
|
||||
|
|
|
|||
|
|
@ -248,6 +248,14 @@ namespace Nz::ShaderAst
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs)
|
||||
{
|
||||
if (!Compare(lhs.aliasId, rhs.aliasId))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs)
|
||||
{
|
||||
if (!Compare(lhs.op, rhs.op))
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@
|
|||
|
||||
NAZARA_SHADERAST_EXPRESSION(AccessIdentifierExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(AccessIndexExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(AliasValueExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(AssignExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(BinaryExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
void Visit(AccessIdentifierExpression& node) override;
|
||||
void Visit(AccessIndexExpression& node) override;
|
||||
void Visit(AliasValueExpression& node) override;
|
||||
void Visit(AssignExpression& node) override;
|
||||
void Visit(BinaryExpression& node) override;
|
||||
void Visit(CallFunctionExpression& node) override;
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
void Serialize(AccessIdentifierExpression& node);
|
||||
void Serialize(AccessIndexExpression& node);
|
||||
void Serialize(AliasValueExpression& node);
|
||||
void Serialize(AssignExpression& node);
|
||||
void Serialize(BinaryExpression& node);
|
||||
void Serialize(CallFunctionExpression& node);
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
void Visit(AccessIdentifierExpression& node) override;
|
||||
void Visit(AccessIndexExpression& node) override;
|
||||
void Visit(AliasValueExpression& node) override;
|
||||
void Visit(AssignExpression& node) override;
|
||||
void Visit(BinaryExpression& node) override;
|
||||
void Visit(CallFunctionExpression& node) override;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,22 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
struct ContainedType;
|
||||
|
||||
struct NAZARA_SHADER_API AliasType
|
||||
{
|
||||
AliasType() = default;
|
||||
AliasType(const AliasType& alias);
|
||||
AliasType(AliasType&&) noexcept = default;
|
||||
|
||||
AliasType& operator=(const AliasType& alias);
|
||||
AliasType& operator=(AliasType&&) noexcept = default;
|
||||
|
||||
std::size_t aliasIndex;
|
||||
std::unique_ptr<ContainedType> targetType;
|
||||
|
||||
bool operator==(const AliasType& rhs) const;
|
||||
inline bool operator!=(const AliasType& rhs) const;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API ArrayType
|
||||
{
|
||||
ArrayType() = default;
|
||||
|
|
@ -134,7 +150,7 @@ namespace Nz::ShaderAst
|
|||
inline bool operator!=(const VectorType& rhs) const;
|
||||
};
|
||||
|
||||
using ExpressionType = std::variant<NoType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
|
||||
using ExpressionType = std::variant<NoType, AliasType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
|
||||
|
||||
struct ContainedType
|
||||
{
|
||||
|
|
@ -157,6 +173,7 @@ namespace Nz::ShaderAst
|
|||
std::vector<StructMember> members;
|
||||
};
|
||||
|
||||
inline bool IsAliasType(const ExpressionType& type);
|
||||
inline bool IsArrayType(const ExpressionType& type);
|
||||
inline bool IsFunctionType(const ExpressionType& type);
|
||||
inline bool IsIdentifierType(const ExpressionType& type);
|
||||
|
|
|
|||
|
|
@ -8,6 +8,11 @@
|
|||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
inline bool AliasType::operator!=(const AliasType& rhs) const
|
||||
{
|
||||
return !operator==(rhs);
|
||||
}
|
||||
|
||||
inline bool ArrayType::operator!=(const ArrayType& rhs) const
|
||||
{
|
||||
return !operator==(rhs);
|
||||
|
|
@ -130,6 +135,11 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
|
||||
|
||||
inline bool IsAliasType(const ExpressionType& type)
|
||||
{
|
||||
return std::holds_alternative<AliasType>(type);
|
||||
}
|
||||
|
||||
inline bool IsArrayType(const ExpressionType& type)
|
||||
{
|
||||
return std::holds_alternative<ArrayType>(type);
|
||||
|
|
|
|||
|
|
@ -82,6 +82,14 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr expr;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API AliasValueExpression : Expression
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
void Visit(AstExpressionVisitor& visitor) override;
|
||||
|
||||
std::size_t aliasId;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API AssignExpression : Expression
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
|
|
@ -153,7 +161,7 @@ namespace Nz::ShaderAst
|
|||
NodeType GetType() const override;
|
||||
void Visit(AstExpressionVisitor& visitor) override;
|
||||
|
||||
ShaderAst::ConstantValue value;
|
||||
ConstantValue value;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API FunctionExpression : Expression
|
||||
|
|
@ -207,7 +215,6 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr expression;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API VariableExpression : Expression
|
||||
struct NAZARA_SHADER_API VariableValueExpression : Expression
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
|
|
@ -455,10 +462,12 @@ namespace Nz::ShaderAst
|
|||
|
||||
#include <Nazara/Shader/Ast/AstNodeList.hpp>
|
||||
|
||||
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr);
|
||||
inline ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr);
|
||||
inline const ExpressionType& GetExpressionType(Expression& expr);
|
||||
inline ExpressionType& GetExpressionTypeMut(Expression& expr);
|
||||
inline bool IsExpression(NodeType nodeType);
|
||||
inline bool IsStatement(NodeType nodeType);
|
||||
|
||||
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType);
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/Ast/Nodes.inl>
|
||||
|
|
|
|||
|
|
@ -7,18 +7,29 @@
|
|||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr)
|
||||
inline const ExpressionType& GetExpressionType(Expression& expr)
|
||||
{
|
||||
assert(expr.cachedExpressionType);
|
||||
return expr.cachedExpressionType.value();
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr)
|
||||
inline ExpressionType& GetExpressionTypeMut(Expression& expr)
|
||||
{
|
||||
assert(expr.cachedExpressionType);
|
||||
return expr.cachedExpressionType.value();
|
||||
}
|
||||
|
||||
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)
|
||||
{
|
||||
if (IsAliasType(exprType))
|
||||
{
|
||||
const AliasType& alias = std::get<AliasType>(exprType);
|
||||
return alias.targetType->type;
|
||||
}
|
||||
else
|
||||
return exprType;
|
||||
}
|
||||
|
||||
inline bool IsExpression(NodeType nodeType)
|
||||
{
|
||||
switch (nodeType)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ namespace Nz::ShaderAst
|
|||
std::unordered_map<UInt32, ConstantValue> optionValues;
|
||||
bool makeVariableNameUnique = false;
|
||||
bool reduceLoopsToWhile = false;
|
||||
bool removeAliases = false;
|
||||
bool removeConstDeclaration = false;
|
||||
bool removeCompoundAssignments = false;
|
||||
bool removeMatrixCast = false;
|
||||
|
|
@ -71,6 +72,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
ExpressionPtr Clone(AccessIdentifierExpression& node) override;
|
||||
ExpressionPtr Clone(AccessIndexExpression& node) override;
|
||||
ExpressionPtr Clone(AliasValueExpression& node) override;
|
||||
ExpressionPtr Clone(AssignExpression& node) override;
|
||||
ExpressionPtr Clone(BinaryExpression& node) override;
|
||||
ExpressionPtr Clone(CallFunctionExpression& node) override;
|
||||
|
|
@ -136,15 +138,16 @@ namespace Nz::ShaderAst
|
|||
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
|
||||
|
||||
const IdentifierData* ResolveAlias(const IdentifierData* identifier) const;
|
||||
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
|
||||
void ResolveFunctions();
|
||||
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
|
||||
std::size_t ResolveStruct(const AliasType& aliasType);
|
||||
std::size_t ResolveStruct(const ExpressionType& exprType);
|
||||
std::size_t ResolveStruct(const IdentifierType& identifierType);
|
||||
std::size_t ResolveStruct(const StructType& structType);
|
||||
std::size_t ResolveStruct(const UniformType& uniformType);
|
||||
ExpressionType ResolveType(const ExpressionType& exprType);
|
||||
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue);
|
||||
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
|
||||
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
|
||||
|
||||
void SanitizeIdentifier(std::string& identifier);
|
||||
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ namespace Nz
|
|||
static ShaderAst::ModulePtr Sanitize(const ShaderAst::Module& module, std::unordered_map<UInt32, ShaderAst::ConstantValue> optionValues, std::string* error = nullptr);
|
||||
|
||||
private:
|
||||
void Append(const ShaderAst::AliasType& aliasType);
|
||||
void Append(const ShaderAst::ArrayType& type);
|
||||
void Append(ShaderAst::BuiltinEntry builtin);
|
||||
void Append(const ShaderAst::ExpressionType& type);
|
||||
|
|
@ -94,6 +95,7 @@ namespace Nz
|
|||
|
||||
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
|
||||
void Visit(ShaderAst::AccessIndexExpression& node) override;
|
||||
void Visit(ShaderAst::AliasValueExpression& node) override;
|
||||
void Visit(ShaderAst::AssignExpression& node) override;
|
||||
void Visit(ShaderAst::BinaryExpression& node) override;
|
||||
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ namespace Nz
|
|||
struct UnrollAttribute;
|
||||
struct UuidAttribute;
|
||||
|
||||
void Append(const ShaderAst::AliasType& type);
|
||||
void Append(const ShaderAst::ArrayType& type);
|
||||
void Append(const ShaderAst::ExpressionType& type);
|
||||
void Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type);
|
||||
|
|
@ -92,6 +93,7 @@ namespace Nz
|
|||
void EnterScope();
|
||||
void LeaveScope(bool skipLine = true);
|
||||
|
||||
void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
|
||||
void RegisterConstant(std::size_t constantIndex, std::string constantName);
|
||||
void RegisterStruct(std::size_t structIndex, std::string structName);
|
||||
void RegisterVariable(std::size_t varIndex, std::string varName);
|
||||
|
|
@ -102,6 +104,7 @@ namespace Nz
|
|||
|
||||
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
|
||||
void Visit(ShaderAst::AccessIndexExpression& node) override;
|
||||
void Visit(ShaderAst::AliasValueExpression& node) override;
|
||||
void Visit(ShaderAst::AssignExpression& node) override;
|
||||
void Visit(ShaderAst::BinaryExpression& node) override;
|
||||
void Visit(ShaderAst::CastExpression& node) override;
|
||||
|
|
|
|||
|
|
@ -48,8 +48,6 @@ namespace Nz
|
|||
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
||||
void Visit(ShaderAst::CastExpression& node) override;
|
||||
void Visit(ShaderAst::ConstantValueExpression& node) override;
|
||||
void Visit(ShaderAst::DeclareAliasStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareConstStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareExternalStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareOptionStatement& node) override;
|
||||
|
|
@ -64,7 +62,7 @@ namespace Nz
|
|||
void Visit(ShaderAst::ScopedStatement& node) override;
|
||||
void Visit(ShaderAst::SwizzleExpression& node) override;
|
||||
void Visit(ShaderAst::UnaryExpression& node) override;
|
||||
void Visit(ShaderAst::VariableExpression& node) override;
|
||||
void Visit(ShaderAst::VariableValueExpression& node) override;
|
||||
void Visit(ShaderAst::WhileStatement& node) override;
|
||||
|
||||
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;
|
||||
|
|
|
|||
|
|
@ -178,6 +178,7 @@ namespace Nz
|
|||
TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const;
|
||||
TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
|
||||
TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const;
|
||||
TypePtr BuildType(const ShaderAst::AliasType& type) const;
|
||||
TypePtr BuildType(const ShaderAst::ArrayType& type) const;
|
||||
TypePtr BuildType(const ShaderAst::ExpressionType& type) const;
|
||||
TypePtr BuildType(const ShaderAst::IdentifierType& type) const;
|
||||
|
|
|
|||
|
|
@ -300,6 +300,16 @@ namespace Nz::ShaderAst
|
|||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(AliasValueExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<AliasValueExpression>();
|
||||
clone->aliasId = node.aliasId;
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(AssignExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<AssignExpression>();
|
||||
|
|
|
|||
|
|
@ -19,6 +19,11 @@ namespace Nz::ShaderAst
|
|||
index->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(AliasValueExpression& /*node*/)
|
||||
{
|
||||
/* nothing to do */
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(AssignExpression& node)
|
||||
{
|
||||
node.left->Visit(*this);
|
||||
|
|
|
|||
|
|
@ -58,6 +58,11 @@ namespace Nz::ShaderAst
|
|||
Node(identifier);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(AliasValueExpression& node)
|
||||
{
|
||||
SizeT(node.aliasId);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(AssignExpression& node)
|
||||
{
|
||||
Enum(node.op);
|
||||
|
|
@ -485,6 +490,12 @@ namespace Nz::ShaderAst
|
|||
Type(arg.objectType->type);
|
||||
SizeT(arg.methodIndex);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, ShaderAst::AliasType>)
|
||||
{
|
||||
m_stream << UInt8(13);
|
||||
SizeT(arg.aliasIndex);
|
||||
Type(arg.targetType->type);
|
||||
}
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, type);
|
||||
|
|
@ -800,6 +811,22 @@ namespace Nz::ShaderAst
|
|||
type = std::move(methodType);
|
||||
}
|
||||
|
||||
case 13: //< AliasType
|
||||
{
|
||||
std::size_t aliasIndex;
|
||||
ExpressionType containedType;
|
||||
SizeT(aliasIndex);
|
||||
Type(containedType);
|
||||
|
||||
AliasType aliasType;
|
||||
aliasType.aliasIndex = aliasIndex;
|
||||
aliasType.targetType = std::make_unique<ContainedType>();
|
||||
aliasType.targetType->type = std::move(containedType);
|
||||
|
||||
type = std::move(aliasType);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,11 @@ namespace Nz::ShaderAst
|
|||
node.expr->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(AliasValueExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ExpressionCategory::LValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(AssignExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
|
|
|
|||
|
|
@ -9,6 +9,37 @@
|
|||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
AliasType::AliasType(const AliasType& alias) :
|
||||
aliasIndex(alias.aliasIndex)
|
||||
{
|
||||
assert(alias.targetType);
|
||||
targetType = std::make_unique<ContainedType>(*alias.targetType);
|
||||
}
|
||||
|
||||
AliasType& AliasType::operator=(const AliasType& alias)
|
||||
{
|
||||
aliasIndex = alias.aliasIndex;
|
||||
|
||||
assert(alias.targetType);
|
||||
targetType = std::make_unique<ContainedType>(*alias.targetType);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool AliasType::operator==(const AliasType& rhs) const
|
||||
{
|
||||
assert(targetType);
|
||||
assert(rhs.targetType);
|
||||
|
||||
if (aliasIndex != rhs.aliasIndex)
|
||||
return false;
|
||||
|
||||
if (targetType->type != rhs.targetType->type)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ArrayType::ArrayType(const ArrayType& array) :
|
||||
length(array.length)
|
||||
{
|
||||
|
|
@ -31,10 +62,10 @@ namespace Nz::ShaderAst
|
|||
assert(containedType);
|
||||
assert(rhs.containedType);
|
||||
|
||||
if (containedType->type != rhs.containedType->type)
|
||||
if (length != rhs.length)
|
||||
return false;
|
||||
|
||||
if (length != rhs.length)
|
||||
if (containedType->type != rhs.containedType->type)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -282,7 +282,7 @@ namespace Nz::ShaderAst
|
|||
if (identifier.empty())
|
||||
throw AstError{ "empty identifier" };
|
||||
|
||||
const ExpressionType& exprType = GetExpressionType(*indexedExpr);
|
||||
const ExpressionType& exprType = ResolveAlias(GetExpressionType(*indexedExpr));
|
||||
// TODO: Add proper support for methods
|
||||
if (IsSamplerType(exprType))
|
||||
{
|
||||
|
|
@ -429,6 +429,25 @@ namespace Nz::ShaderAst
|
|||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(AliasValueExpression& node)
|
||||
{
|
||||
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId));
|
||||
ExpressionPtr targetExpr = HandleIdentifier(targetIdentifier);
|
||||
|
||||
if (m_context->options.removeAliases)
|
||||
return targetExpr;
|
||||
|
||||
AliasType aliasType;
|
||||
aliasType.aliasIndex = node.aliasId;
|
||||
aliasType.targetType = std::make_unique<ContainedType>();
|
||||
aliasType.targetType->type = *targetExpr->cachedExpressionType;
|
||||
|
||||
auto clone = static_unique_pointer_cast<AliasValueExpression>(AstCloner::Clone(node));
|
||||
clone->cachedExpressionType = std::move(aliasType);
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
|
||||
{
|
||||
MandatoryExpr(node.left);
|
||||
|
|
@ -543,7 +562,7 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
|
||||
|
||||
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
|
||||
const ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
|
||||
bool isMatrixCast = IsMatrixType(frontExprType);
|
||||
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
|
||||
{
|
||||
|
|
@ -785,6 +804,9 @@ namespace Nz::ShaderAst
|
|||
auto clone = static_unique_pointer_cast<DeclareAliasStatement>(AstCloner::Clone(node));
|
||||
Validate(*clone);
|
||||
|
||||
if (m_context->options.removeAliases)
|
||||
return ShaderBuilder::NoOp();
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
|
@ -803,7 +825,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
ExpressionType expressionType = ResolveType(GetExpressionType(value));
|
||||
|
||||
if (clone->type.HasValue() && ResolveType(clone->type) != expressionType)
|
||||
if (clone->type.HasValue() && ResolveType(clone->type, true) != ResolveAlias(expressionType))
|
||||
throw AstError{ "constant expression doesn't match type" };
|
||||
|
||||
clone->type = expressionType;
|
||||
|
|
@ -852,12 +874,13 @@ namespace Nz::ShaderAst
|
|||
m_context->declaredExternalVar.insert(extVar.name);
|
||||
|
||||
ExpressionType resolvedType = ResolveType(extVar.type);
|
||||
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
||||
|
||||
ExpressionType varType;
|
||||
if (IsUniformType(resolvedType))
|
||||
varType = std::get<UniformType>(resolvedType).containedType;
|
||||
else if (IsSamplerType(resolvedType))
|
||||
varType = resolvedType;
|
||||
if (IsUniformType(targetType))
|
||||
varType = std::get<UniformType>(targetType).containedType;
|
||||
else if (IsSamplerType(targetType))
|
||||
varType = targetType;
|
||||
else
|
||||
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
||||
|
||||
|
|
@ -954,8 +977,9 @@ namespace Nz::ShaderAst
|
|||
throw AstError{ "empty option name" };
|
||||
|
||||
ExpressionType resolvedType = ResolveType(clone->optType);
|
||||
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
||||
|
||||
if (clone->defaultValue && resolvedType != GetExpressionType(*clone->defaultValue))
|
||||
if (clone->defaultValue && targetType != GetExpressionType(*clone->defaultValue))
|
||||
throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" };
|
||||
|
||||
clone->optType = std::move(resolvedType);
|
||||
|
|
@ -1009,11 +1033,13 @@ namespace Nz::ShaderAst
|
|||
ExpressionType resolvedType = ResolveType(member.type);
|
||||
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
|
||||
{
|
||||
if (IsPrimitiveType(resolvedType) && std::get<PrimitiveType>(resolvedType) == PrimitiveType::Boolean)
|
||||
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
||||
|
||||
if (IsPrimitiveType(targetType) && std::get<PrimitiveType>(targetType) == PrimitiveType::Boolean)
|
||||
throw AstError{ "boolean type is not allowed in std140 layout" };
|
||||
else if (IsStructType(resolvedType))
|
||||
else if (IsStructType(targetType))
|
||||
{
|
||||
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex;
|
||||
std::size_t structIndex = std::get<StructType>(targetType).structIndex;
|
||||
const StructDescription* desc = m_context->structs.Retrieve(structIndex);
|
||||
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
|
||||
throw AstError{ "inner struct layout mismatch" };
|
||||
|
|
@ -1461,7 +1487,7 @@ namespace Nz::ShaderAst
|
|||
AstExportVisitor exportVisitor;
|
||||
exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
|
||||
|
||||
if (aliasStatements.empty())
|
||||
if (aliasStatements.empty() || m_context->options.removeAliases)
|
||||
return ShaderBuilder::NoOp();
|
||||
|
||||
// Register module and aliases
|
||||
|
|
@ -1546,7 +1572,7 @@ namespace Nz::ShaderAst
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
return ResolveAlias(&it->data);
|
||||
return &it->data;
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
|
|
@ -1556,7 +1582,7 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
if (identifier.name == identifierName)
|
||||
{
|
||||
if (functor(*ResolveAlias(&identifier.data)))
|
||||
if (functor(identifier.data))
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -1570,7 +1596,7 @@ namespace Nz::ShaderAst
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
return ResolveAlias(&it->data);
|
||||
return &it->data;
|
||||
}
|
||||
|
||||
TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const
|
||||
|
|
@ -1626,6 +1652,14 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
switch (identifierData->category)
|
||||
{
|
||||
case IdentifierCategory::Alias:
|
||||
{
|
||||
AliasValueExpression aliasValue;
|
||||
aliasValue.aliasId = identifierData->index;
|
||||
|
||||
return Clone(aliasValue);
|
||||
}
|
||||
|
||||
case IdentifierCategory::Constant:
|
||||
{
|
||||
// Replace IdentifierExpression by Constant(Value)Expression
|
||||
|
|
@ -2124,7 +2158,7 @@ namespace Nz::ShaderAst
|
|||
return varIndex;
|
||||
}
|
||||
|
||||
auto SanitizeVisitor::ResolveAlias(const IdentifierData* identifier) const -> const IdentifierData*
|
||||
auto SanitizeVisitor::ResolveAliasIdentifier(const IdentifierData* identifier) const -> const IdentifierData*
|
||||
{
|
||||
while (identifier->category == IdentifierCategory::Alias)
|
||||
identifier = &m_context->aliases.Retrieve(identifier->index);
|
||||
|
|
@ -2181,7 +2215,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
for (const auto& [funcIndex, funcData] : m_context->functions.values)
|
||||
{
|
||||
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
|
||||
if (funcData.flags.Test(FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
|
||||
throw AstError{ "discard can only be used in the fragment stage" };
|
||||
}
|
||||
}
|
||||
|
|
@ -2203,13 +2237,18 @@ namespace Nz::ShaderAst
|
|||
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::ResolveStruct(const AliasType& aliasType)
|
||||
{
|
||||
return ResolveStruct(aliasType.targetType->type);
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
|
||||
{
|
||||
return std::visit([&](auto&& arg) -> std::size_t
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType>)
|
||||
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType> || std::is_same_v<T, AliasType>)
|
||||
return ResolveStruct(arg);
|
||||
else if constexpr (std::is_same_v<T, NoType> ||
|
||||
std::is_same_v<T, ArrayType> ||
|
||||
|
|
@ -2251,10 +2290,15 @@ namespace Nz::ShaderAst
|
|||
return uniformType.containedType.structIndex;
|
||||
}
|
||||
|
||||
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType)
|
||||
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias)
|
||||
{
|
||||
if (!IsTypeExpression(exprType))
|
||||
return exprType;
|
||||
{
|
||||
if (resolveAlias || m_context->options.removeAliases)
|
||||
return ResolveAlias(exprType);
|
||||
else
|
||||
return exprType;
|
||||
}
|
||||
|
||||
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
|
||||
|
||||
|
|
@ -2265,13 +2309,13 @@ namespace Nz::ShaderAst
|
|||
return std::get<ExpressionType>(type);
|
||||
}
|
||||
|
||||
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue)
|
||||
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias)
|
||||
{
|
||||
if (!exprTypeValue.HasValue())
|
||||
return {};
|
||||
|
||||
if (exprTypeValue.IsResultingValue())
|
||||
return ResolveType(exprTypeValue.GetResultingValue());
|
||||
return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias);
|
||||
|
||||
assert(exprTypeValue.IsExpression());
|
||||
ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression());
|
||||
|
|
@ -2281,7 +2325,7 @@ namespace Nz::ShaderAst
|
|||
//if (!IsTypeType(exprType))
|
||||
// throw AstError{ "type expected" };
|
||||
|
||||
return ResolveType(exprType);
|
||||
return ResolveType(exprType, resolveAlias);
|
||||
}
|
||||
|
||||
void SanitizeVisitor::SanitizeIdentifier(std::string& identifier)
|
||||
|
|
@ -2334,7 +2378,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const
|
||||
{
|
||||
if (left != right)
|
||||
if (ResolveAlias(left) != ResolveAlias(right))
|
||||
throw AstError{ "Left expression type must match right expression type" };
|
||||
}
|
||||
|
||||
|
|
@ -2359,6 +2403,11 @@ namespace Nz::ShaderAst
|
|||
std::size_t structIndex = ResolveStruct(exprType);
|
||||
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
|
||||
}
|
||||
else if (IsAliasType(exprType))
|
||||
{
|
||||
const AliasType& alias = std::get<AliasType>(exprType);
|
||||
node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex);
|
||||
}
|
||||
else
|
||||
throw AstError{ "for now, only structs can be aliased" };
|
||||
}
|
||||
|
|
@ -2400,7 +2449,7 @@ namespace Nz::ShaderAst
|
|||
case TypeParameterCategory::PrimitiveType:
|
||||
case TypeParameterCategory::StructType:
|
||||
{
|
||||
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr));
|
||||
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr), true);
|
||||
|
||||
switch (partialType.parameters[i])
|
||||
{
|
||||
|
|
@ -2440,7 +2489,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
for (auto& index : node.indices)
|
||||
{
|
||||
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
|
||||
const ExpressionType& indexType = GetExpressionType(*index);
|
||||
if (!IsPrimitiveType(indexType))
|
||||
throw AstError{ "AccessIndex expects integer indices" };
|
||||
|
||||
|
|
@ -2459,7 +2508,7 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
else if (IsStructType(exprType))
|
||||
{
|
||||
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
|
||||
const ExpressionType& indexType = GetExpressionType(*indexExpr);
|
||||
if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
|
||||
throw AstError{ "struct can only be accessed with constant i32 indices" };
|
||||
|
||||
|
|
@ -2470,7 +2519,7 @@ namespace Nz::ShaderAst
|
|||
std::size_t structIndex = ResolveStruct(exprType);
|
||||
const StructDescription* s = m_context->structs.Retrieve(structIndex);
|
||||
|
||||
exprType = ResolveType(s->members[index].type);
|
||||
exprType = ResolveType(s->members[index].type, true);
|
||||
}
|
||||
else if (IsMatrixType(exprType))
|
||||
{
|
||||
|
|
@ -2538,7 +2587,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
void SanitizeVisitor::Validate(CallFunctionExpression& node)
|
||||
{
|
||||
const ShaderAst::ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
|
||||
const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
|
||||
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
||||
|
||||
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
|
||||
|
|
@ -2564,14 +2613,15 @@ namespace Nz::ShaderAst
|
|||
void SanitizeVisitor::Validate(CastExpression& node)
|
||||
{
|
||||
ExpressionType resolvedType = ResolveType(node.targetType);
|
||||
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
||||
|
||||
const auto& firstExprPtr = node.expressions.front();
|
||||
if (!firstExprPtr)
|
||||
throw AstError{ "expected at least one expression" };
|
||||
|
||||
if (IsMatrixType(resolvedType))
|
||||
if (IsMatrixType(targetType))
|
||||
{
|
||||
const MatrixType& targetMatrixType = std::get<MatrixType>(resolvedType);
|
||||
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
|
||||
|
||||
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
|
||||
if (IsMatrixType(firstExprType))
|
||||
|
|
@ -2614,7 +2664,7 @@ namespace Nz::ShaderAst
|
|||
};
|
||||
|
||||
std::size_t componentCount = 0;
|
||||
std::size_t requiredComponents = GetComponentCount(resolvedType);
|
||||
std::size_t requiredComponents = GetComponentCount(targetType);
|
||||
|
||||
for (auto& exprPtr : node.expressions)
|
||||
{
|
||||
|
|
@ -2885,11 +2935,11 @@ namespace Nz::ShaderAst
|
|||
case UnaryType::Minus:
|
||||
case UnaryType::Plus:
|
||||
{
|
||||
ShaderAst::PrimitiveType basicType;
|
||||
PrimitiveType basicType;
|
||||
if (IsPrimitiveType(exprType))
|
||||
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
|
||||
basicType = std::get<PrimitiveType>(exprType);
|
||||
else if (IsVectorType(exprType))
|
||||
basicType = std::get<ShaderAst::VectorType>(exprType).type;
|
||||
basicType = std::get<VectorType>(exprType).type;
|
||||
else
|
||||
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };
|
||||
|
||||
|
|
|
|||
|
|
@ -232,6 +232,7 @@ namespace Nz
|
|||
options.optionValues = std::move(optionValues);
|
||||
options.makeVariableNameUnique = true;
|
||||
options.reduceLoopsToWhile = true;
|
||||
options.removeAliases = true;
|
||||
options.removeCompoundAssignments = false;
|
||||
options.removeConstDeclaration = true;
|
||||
options.removeOptionDeclaration = true;
|
||||
|
|
@ -246,6 +247,11 @@ namespace Nz
|
|||
return ShaderAst::Sanitize(module, options, error);
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::AliasType& /*aliasType*/)
|
||||
{
|
||||
throw std::runtime_error("unexpected AliasType");
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/)
|
||||
{
|
||||
throw std::runtime_error("unexpected ArrayType");
|
||||
|
|
@ -689,11 +695,16 @@ namespace Nz
|
|||
builtin.identifier
|
||||
});
|
||||
}
|
||||
else if (member.locationIndex.HasValue())
|
||||
else
|
||||
{
|
||||
Append("layout(location = ");
|
||||
Append(member.locationIndex.GetResultingValue());
|
||||
Append(") ", keyword, " ");
|
||||
if (member.locationIndex.HasValue())
|
||||
{
|
||||
Append("layout(location = ");
|
||||
Append(member.locationIndex.GetResultingValue());
|
||||
Append(") ");
|
||||
}
|
||||
|
||||
Append(keyword, " ");
|
||||
AppendVariableDeclaration(member.type.GetResultingValue(), targetPrefix + member.name);
|
||||
AppendLine(";");
|
||||
|
||||
|
|
@ -805,6 +816,12 @@ namespace Nz
|
|||
Append("]");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::AliasValueExpression& /*node*/)
|
||||
{
|
||||
// all aliases should have been handled by sanitizer
|
||||
throw std::runtime_error("unexpected alias value, is shader sanitized?");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
|
||||
{
|
||||
node.left->Visit(*this);
|
||||
|
|
@ -1038,12 +1055,14 @@ namespace Nz
|
|||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
|
||||
{
|
||||
/* nothing to do */
|
||||
// all aliases should have been handled by sanitizer
|
||||
throw std::runtime_error("unexpected alias declaration, is shader sanitized?");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/)
|
||||
{
|
||||
/* nothing to do */
|
||||
// all consts should have been handled by sanitizer
|
||||
throw std::runtime_error("unexpected const declaration, is shader sanitized?");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
|
|
@ -1184,7 +1203,8 @@ namespace Nz
|
|||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareOptionStatement& /*node*/)
|
||||
{
|
||||
/* nothing to do */
|
||||
// all options should have been handled by sanitizer
|
||||
throw std::runtime_error("unexpected option declaration, is shader sanitized?");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
|
||||
|
|
@ -1247,7 +1267,7 @@ namespace Nz
|
|||
Append(";");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::ImportStatement& node)
|
||||
void GlslWriter::Visit(ShaderAst::ImportStatement& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unexpected import statement, is the shader sanitized properly?");
|
||||
}
|
||||
|
|
@ -1280,8 +1300,8 @@ namespace Nz
|
|||
const auto& structData = Retrieve(m_currentState->structs, structIndex);
|
||||
|
||||
std::string outputStructVarName;
|
||||
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression)
|
||||
outputStructVarName = Retrieve(m_currentState->variableNames, static_cast<ShaderAst::VariableExpression&>(*node.returnExpr).variableId);
|
||||
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableValueExpression)
|
||||
outputStructVarName = Retrieve(m_currentState->variableNames, static_cast<ShaderAst::VariableValueExpression&>(*node.returnExpr).variableId);
|
||||
else
|
||||
{
|
||||
AppendLine();
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ namespace Nz
|
|||
ShaderAst::Module* module;
|
||||
std::size_t currentModuleIndex;
|
||||
std::stringstream stream;
|
||||
std::unordered_map<std::size_t, Identifier> aliases;
|
||||
std::unordered_map<std::size_t, Identifier> constants;
|
||||
std::unordered_map<std::size_t, Identifier> structs;
|
||||
std::unordered_map<std::size_t, Identifier> variables;
|
||||
|
|
@ -164,6 +165,11 @@ namespace Nz
|
|||
m_environment = std::move(environment);
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::AliasType& type)
|
||||
{
|
||||
AppendIdentifier(m_currentState->aliases, type.aliasIndex);
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::ArrayType& type)
|
||||
{
|
||||
Append("array[", type.containedType->type, ", ", type.length, "]");
|
||||
|
|
@ -655,6 +661,16 @@ namespace Nz
|
|||
Append("}");
|
||||
}
|
||||
|
||||
void LangWriter::RegisterAlias(std::size_t aliasIndex, std::string aliasName)
|
||||
{
|
||||
State::Identifier identifier;
|
||||
identifier.moduleIndex = m_currentState->currentModuleIndex;
|
||||
identifier.name = std::move(aliasName);
|
||||
|
||||
assert(m_currentState->aliases.find(aliasIndex) == m_currentState->aliases.end());
|
||||
m_currentState->aliases.emplace(aliasIndex, std::move(identifier));
|
||||
}
|
||||
|
||||
void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName)
|
||||
{
|
||||
State::Identifier identifier;
|
||||
|
|
@ -714,7 +730,7 @@ namespace Nz
|
|||
{
|
||||
Visit(node.expr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
|
||||
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
||||
assert(IsStructType(exprType));
|
||||
|
||||
for (const std::string& identifier : node.identifiers)
|
||||
|
|
@ -725,7 +741,7 @@ namespace Nz
|
|||
{
|
||||
Visit(node.expr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
|
||||
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
||||
assert(!IsStructType(exprType));
|
||||
|
||||
// Array access
|
||||
|
|
@ -744,6 +760,11 @@ namespace Nz
|
|||
Append("]");
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::AliasValueExpression& node)
|
||||
{
|
||||
AppendIdentifier(m_currentState->aliases, node.aliasId);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::AssignExpression& node)
|
||||
{
|
||||
node.left->Visit(*this);
|
||||
|
|
@ -840,7 +861,13 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::ConditionalExpression& node)
|
||||
{
|
||||
throw std::runtime_error("fixme");
|
||||
Append("const_select(");
|
||||
node.condition->Visit(*this);
|
||||
Append(", ");
|
||||
node.truePath->Visit(*this);
|
||||
Append(", ");
|
||||
node.falsePath->Visit(*this);
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::ConditionalStatement& node)
|
||||
|
|
@ -853,9 +880,8 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
|
||||
{
|
||||
//throw std::runtime_error("TODO"); //< missing registering
|
||||
|
||||
assert(node.aliasIndex);
|
||||
RegisterAlias(*node.aliasIndex, node.name);
|
||||
|
||||
Append("alias ", node.name, " = ");
|
||||
assert(node.expression);
|
||||
|
|
|
|||
|
|
@ -598,16 +598,6 @@ namespace Nz
|
|||
}, node.value);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
|
||||
{
|
||||
/* nothing to do */
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/)
|
||||
{
|
||||
/* nothing to do */
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
{
|
||||
for (auto&& extVar : node.externalVars)
|
||||
|
|
@ -729,11 +719,6 @@ namespace Nz
|
|||
PopResultId();
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::ImportStatement& node)
|
||||
{
|
||||
/* nothing to do */
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
|
||||
{
|
||||
switch (node.intrinsic)
|
||||
|
|
|
|||
|
|
@ -655,6 +655,28 @@ namespace Nz
|
|||
return typePtr;
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
|
||||
{
|
||||
bool wasInblockStruct = m_internal->isInBlockStruct;
|
||||
if (storageClass == SpirvStorageClass::Uniform)
|
||||
m_internal->isInBlockStruct = true;
|
||||
|
||||
auto typePtr = std::make_shared<Type>(Pointer{
|
||||
BuildType(type),
|
||||
storageClass
|
||||
});
|
||||
|
||||
m_internal->isInBlockStruct = wasInblockStruct;
|
||||
|
||||
return typePtr;
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::AliasType& /*type*/) const -> TypePtr
|
||||
{
|
||||
// No AliasType is expected (as they should have been resolved by now)
|
||||
throw std::runtime_error("unexpected alias");
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr
|
||||
{
|
||||
const auto& containedType = type.containedType->type;
|
||||
|
|
@ -678,22 +700,6 @@ namespace Nz
|
|||
});
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
|
||||
{
|
||||
bool wasInblockStruct = m_internal->isInBlockStruct;
|
||||
if (storageClass == SpirvStorageClass::Uniform)
|
||||
m_internal->isInBlockStruct = true;
|
||||
|
||||
auto typePtr = std::make_shared<Type>(Pointer{
|
||||
BuildType(type),
|
||||
storageClass
|
||||
});
|
||||
|
||||
m_internal->isInBlockStruct = wasInblockStruct;
|
||||
|
||||
return typePtr;
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr
|
||||
{
|
||||
return std::visit([&](auto&& arg) -> TypePtr
|
||||
|
|
|
|||
|
|
@ -505,6 +505,7 @@ namespace Nz
|
|||
ShaderAst::SanitizeVisitor::Options options;
|
||||
options.optionValues = states.optionValues;
|
||||
options.reduceLoopsToWhile = true;
|
||||
options.removeAliases = true;
|
||||
options.removeCompoundAssignments = true;
|
||||
options.removeMatrixCast = true;
|
||||
options.removeOptionDeclaration = true;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
#include <Engine/Shader/ShaderUtils.hpp>
|
||||
#include <Nazara/Core/File.hpp>
|
||||
#include <Nazara/Core/StringExt.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <Nazara/Shader/ShaderLangParser.hpp>
|
||||
#include <catch2/catch.hpp>
|
||||
#include <cctype>
|
||||
|
||||
TEST_CASE("aliases", "[Shader]")
|
||||
{
|
||||
SECTION("Alias of structs")
|
||||
{
|
||||
std::string_view nzslSource = R"(
|
||||
[nzsl_version("1.0")]
|
||||
module;
|
||||
|
||||
struct Data
|
||||
{
|
||||
value: f32
|
||||
}
|
||||
|
||||
alias ExtData = Data;
|
||||
|
||||
external
|
||||
{
|
||||
[binding(0)] extData: uniform[ExtData]
|
||||
}
|
||||
|
||||
struct Input
|
||||
{
|
||||
value: f32
|
||||
}
|
||||
|
||||
alias In = Input;
|
||||
|
||||
struct Output
|
||||
{
|
||||
[location(0)] value: f32
|
||||
}
|
||||
|
||||
alias Out = Output;
|
||||
alias FragOut = Out;
|
||||
|
||||
[entry(frag)]
|
||||
fn main(input: In) -> FragOut
|
||||
{
|
||||
let output: Out;
|
||||
output.value = extData.value * input.value;
|
||||
return output;
|
||||
}
|
||||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
{
|
||||
Input input_;
|
||||
input_.value = _NzIn_value;
|
||||
|
||||
Output output_;
|
||||
output_.value = extData.value * input_.value;
|
||||
|
||||
_NzOut_value = output_.value;
|
||||
return;
|
||||
}
|
||||
)");
|
||||
|
||||
ExpectNZSL(*shaderModule, R"(
|
||||
[entry(frag)]
|
||||
fn main(input: In) -> FragOut
|
||||
{
|
||||
let output: Out;
|
||||
output.value = extData.value * input.value;
|
||||
return output;
|
||||
}
|
||||
)");
|
||||
|
||||
ExpectSPIRV(*shaderModule, R"(
|
||||
OpFunction
|
||||
OpLabel
|
||||
OpVariable
|
||||
OpVariable
|
||||
OpAccessChain
|
||||
OpLoad
|
||||
OpAccessChain
|
||||
OpLoad
|
||||
OpFMul
|
||||
OpAccessChain
|
||||
OpStore
|
||||
OpLoad
|
||||
OpCompositeExtract
|
||||
OpStore
|
||||
OpReturn
|
||||
OpFunctionEnd)");
|
||||
}
|
||||
}
|
||||
|
|
@ -262,6 +262,47 @@ fn testMat4ToMat4(input: mat4[f32]) -> mat4[f32]
|
|||
{
|
||||
return input;
|
||||
}
|
||||
)");
|
||||
|
||||
}
|
||||
|
||||
WHEN("removing aliases")
|
||||
{
|
||||
std::string_view nzslSource = R"(
|
||||
[nzsl_version("1.0")]
|
||||
module;
|
||||
|
||||
struct inputStruct
|
||||
{
|
||||
value: f32
|
||||
}
|
||||
|
||||
alias Input = inputStruct;
|
||||
alias In = Input;
|
||||
|
||||
external
|
||||
{
|
||||
[set(0), binding(0)] data: uniform[In]
|
||||
}
|
||||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
|
||||
Nz::ShaderAst::SanitizeVisitor::Options options;
|
||||
options.removeAliases = true;
|
||||
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, options));
|
||||
|
||||
ExpectNZSL(*shaderModule, R"(
|
||||
struct inputStruct
|
||||
{
|
||||
value: f32
|
||||
}
|
||||
|
||||
external
|
||||
{
|
||||
[set(0), binding(0)] data: uniform[inputStruct]
|
||||
}
|
||||
)");
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -142,8 +142,10 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
|
|||
Nz::ShaderAst::AstReflect reflectVisitor;
|
||||
reflectVisitor.Reflect(*shader.rootNode, callbacks);
|
||||
|
||||
INFO("no entry point found");
|
||||
REQUIRE(entryShaderStage.has_value());
|
||||
{
|
||||
INFO("no entry point found");
|
||||
REQUIRE(entryShaderStage.has_value());
|
||||
}
|
||||
|
||||
Nz::GlslWriter writer;
|
||||
std::string output = writer.Generate(entryShaderStage, shader);
|
||||
|
|
|
|||
Loading…
Reference in New Issue