Shader: Add proper support for alias

This commit is contained in:
Jérôme Leclercq 2022-03-09 12:35:00 +01:00
parent ce93b61c91
commit 05cf98477e
31 changed files with 472 additions and 98 deletions

View File

@ -42,6 +42,7 @@ namespace Nz::ShaderAst
virtual ExpressionPtr Clone(AccessIdentifierExpression& node);
virtual ExpressionPtr Clone(AccessIndexExpression& node);
virtual ExpressionPtr Clone(AliasValueExpression& node);
virtual ExpressionPtr Clone(AssignExpression& node);
virtual ExpressionPtr Clone(BinaryExpression& node);
virtual ExpressionPtr Clone(CallFunctionExpression& node);

View File

@ -34,6 +34,7 @@ namespace Nz::ShaderAst
inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs);
inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs);
inline bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs);
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs);
inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs);
inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs);

View File

@ -248,6 +248,14 @@ namespace Nz::ShaderAst
return true;
}
bool Compare(const AliasValueExpression& lhs, const AliasValueExpression& rhs)
{
if (!Compare(lhs.aliasId, rhs.aliasId))
return false;
return true;
}
inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs)
{
if (!Compare(lhs.op, rhs.op))

View File

@ -30,6 +30,7 @@
NAZARA_SHADERAST_EXPRESSION(AccessIdentifierExpression)
NAZARA_SHADERAST_EXPRESSION(AccessIndexExpression)
NAZARA_SHADERAST_EXPRESSION(AliasValueExpression)
NAZARA_SHADERAST_EXPRESSION(AssignExpression)
NAZARA_SHADERAST_EXPRESSION(BinaryExpression)
NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression)

View File

@ -22,6 +22,7 @@ namespace Nz::ShaderAst
void Visit(AccessIdentifierExpression& node) override;
void Visit(AccessIndexExpression& node) override;
void Visit(AliasValueExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CallFunctionExpression& node) override;

View File

@ -25,6 +25,7 @@ namespace Nz::ShaderAst
void Serialize(AccessIdentifierExpression& node);
void Serialize(AccessIndexExpression& node);
void Serialize(AliasValueExpression& node);
void Serialize(AssignExpression& node);
void Serialize(BinaryExpression& node);
void Serialize(CallFunctionExpression& node);

View File

@ -33,6 +33,7 @@ namespace Nz::ShaderAst
void Visit(AccessIdentifierExpression& node) override;
void Visit(AccessIndexExpression& node) override;
void Visit(AliasValueExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CallFunctionExpression& node) override;

View File

@ -20,6 +20,22 @@ namespace Nz::ShaderAst
{
struct ContainedType;
struct NAZARA_SHADER_API AliasType
{
AliasType() = default;
AliasType(const AliasType& alias);
AliasType(AliasType&&) noexcept = default;
AliasType& operator=(const AliasType& alias);
AliasType& operator=(AliasType&&) noexcept = default;
std::size_t aliasIndex;
std::unique_ptr<ContainedType> targetType;
bool operator==(const AliasType& rhs) const;
inline bool operator!=(const AliasType& rhs) const;
};
struct NAZARA_SHADER_API ArrayType
{
ArrayType() = default;
@ -134,7 +150,7 @@ namespace Nz::ShaderAst
inline bool operator!=(const VectorType& rhs) const;
};
using ExpressionType = std::variant<NoType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
using ExpressionType = std::variant<NoType, AliasType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
struct ContainedType
{
@ -157,6 +173,7 @@ namespace Nz::ShaderAst
std::vector<StructMember> members;
};
inline bool IsAliasType(const ExpressionType& type);
inline bool IsArrayType(const ExpressionType& type);
inline bool IsFunctionType(const ExpressionType& type);
inline bool IsIdentifierType(const ExpressionType& type);

View File

@ -8,6 +8,11 @@
namespace Nz::ShaderAst
{
inline bool AliasType::operator!=(const AliasType& rhs) const
{
return !operator==(rhs);
}
inline bool ArrayType::operator!=(const ArrayType& rhs) const
{
return !operator==(rhs);
@ -130,6 +135,11 @@ namespace Nz::ShaderAst
}
inline bool IsAliasType(const ExpressionType& type)
{
return std::holds_alternative<AliasType>(type);
}
inline bool IsArrayType(const ExpressionType& type)
{
return std::holds_alternative<ArrayType>(type);

View File

@ -82,6 +82,14 @@ namespace Nz::ShaderAst
ExpressionPtr expr;
};
struct NAZARA_SHADER_API AliasValueExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t aliasId;
};
struct NAZARA_SHADER_API AssignExpression : Expression
{
NodeType GetType() const override;
@ -153,7 +161,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
ShaderAst::ConstantValue value;
ConstantValue value;
};
struct NAZARA_SHADER_API FunctionExpression : Expression
@ -207,7 +215,6 @@ namespace Nz::ShaderAst
ExpressionPtr expression;
};
struct NAZARA_SHADER_API VariableExpression : Expression
struct NAZARA_SHADER_API VariableValueExpression : Expression
{
NodeType GetType() const override;
@ -455,10 +462,12 @@ namespace Nz::ShaderAst
#include <Nazara/Shader/Ast/AstNodeList.hpp>
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr);
inline ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr);
inline const ExpressionType& GetExpressionType(Expression& expr);
inline ExpressionType& GetExpressionTypeMut(Expression& expr);
inline bool IsExpression(NodeType nodeType);
inline bool IsStatement(NodeType nodeType);
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType);
}
#include <Nazara/Shader/Ast/Nodes.inl>

View File

@ -7,18 +7,29 @@
namespace Nz::ShaderAst
{
const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr)
inline const ExpressionType& GetExpressionType(Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
}
ShaderAst::ExpressionType& GetExpressionTypeMut(ShaderAst::Expression& expr)
inline ExpressionType& GetExpressionTypeMut(Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
}
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)
{
if (IsAliasType(exprType))
{
const AliasType& alias = std::get<AliasType>(exprType);
return alias.targetType->type;
}
else
return exprType;
}
inline bool IsExpression(NodeType nodeType)
{
switch (nodeType)

View File

@ -47,6 +47,7 @@ namespace Nz::ShaderAst
std::unordered_map<UInt32, ConstantValue> optionValues;
bool makeVariableNameUnique = false;
bool reduceLoopsToWhile = false;
bool removeAliases = false;
bool removeConstDeclaration = false;
bool removeCompoundAssignments = false;
bool removeMatrixCast = false;
@ -71,6 +72,7 @@ namespace Nz::ShaderAst
ExpressionPtr Clone(AccessIdentifierExpression& node) override;
ExpressionPtr Clone(AccessIndexExpression& node) override;
ExpressionPtr Clone(AliasValueExpression& node) override;
ExpressionPtr Clone(AssignExpression& node) override;
ExpressionPtr Clone(BinaryExpression& node) override;
ExpressionPtr Clone(CallFunctionExpression& node) override;
@ -136,15 +138,16 @@ namespace Nz::ShaderAst
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
const IdentifierData* ResolveAlias(const IdentifierData* identifier) const;
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
std::size_t ResolveStruct(const AliasType& aliasType);
std::size_t ResolveStruct(const ExpressionType& exprType);
std::size_t ResolveStruct(const IdentifierType& identifierType);
std::size_t ResolveStruct(const StructType& structType);
std::size_t ResolveStruct(const UniformType& uniformType);
ExpressionType ResolveType(const ExpressionType& exprType);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue);
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
void SanitizeIdentifier(std::string& identifier);
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);

View File

@ -51,6 +51,7 @@ namespace Nz
static ShaderAst::ModulePtr Sanitize(const ShaderAst::Module& module, std::unordered_map<UInt32, ShaderAst::ConstantValue> optionValues, std::string* error = nullptr);
private:
void Append(const ShaderAst::AliasType& aliasType);
void Append(const ShaderAst::ArrayType& type);
void Append(ShaderAst::BuiltinEntry builtin);
void Append(const ShaderAst::ExpressionType& type);
@ -94,6 +95,7 @@ namespace Nz
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
void Visit(ShaderAst::AccessIndexExpression& node) override;
void Visit(ShaderAst::AliasValueExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CallFunctionExpression& node) override;

View File

@ -50,6 +50,7 @@ namespace Nz
struct UnrollAttribute;
struct UuidAttribute;
void Append(const ShaderAst::AliasType& type);
void Append(const ShaderAst::ArrayType& type);
void Append(const ShaderAst::ExpressionType& type);
void Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type);
@ -92,6 +93,7 @@ namespace Nz
void EnterScope();
void LeaveScope(bool skipLine = true);
void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
void RegisterConstant(std::size_t constantIndex, std::string constantName);
void RegisterStruct(std::size_t structIndex, std::string structName);
void RegisterVariable(std::size_t varIndex, std::string varName);
@ -102,6 +104,7 @@ namespace Nz
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
void Visit(ShaderAst::AccessIndexExpression& node) override;
void Visit(ShaderAst::AliasValueExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override;

View File

@ -48,8 +48,6 @@ namespace Nz
void Visit(ShaderAst::CallFunctionExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConstantValueExpression& node) override;
void Visit(ShaderAst::DeclareAliasStatement& node) override;
void Visit(ShaderAst::DeclareConstStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
void Visit(ShaderAst::DeclareOptionStatement& node) override;
@ -64,7 +62,7 @@ namespace Nz
void Visit(ShaderAst::ScopedStatement& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;
void Visit(ShaderAst::UnaryExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override;
void Visit(ShaderAst::VariableValueExpression& node) override;
void Visit(ShaderAst::WhileStatement& node) override;
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;

View File

@ -178,6 +178,7 @@ namespace Nz
TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const;
TypePtr BuildType(const ShaderAst::AliasType& type) const;
TypePtr BuildType(const ShaderAst::ArrayType& type) const;
TypePtr BuildType(const ShaderAst::ExpressionType& type) const;
TypePtr BuildType(const ShaderAst::IdentifierType& type) const;

View File

@ -300,6 +300,16 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(AliasValueExpression& node)
{
auto clone = std::make_unique<AliasValueExpression>();
clone->aliasId = node.aliasId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(AssignExpression& node)
{
auto clone = std::make_unique<AssignExpression>();

View File

@ -19,6 +19,11 @@ namespace Nz::ShaderAst
index->Visit(*this);
}
void AstRecursiveVisitor::Visit(AliasValueExpression& /*node*/)
{
/* nothing to do */
}
void AstRecursiveVisitor::Visit(AssignExpression& node)
{
node.left->Visit(*this);

View File

@ -58,6 +58,11 @@ namespace Nz::ShaderAst
Node(identifier);
}
void AstSerializerBase::Serialize(AliasValueExpression& node)
{
SizeT(node.aliasId);
}
void AstSerializerBase::Serialize(AssignExpression& node)
{
Enum(node.op);
@ -485,6 +490,12 @@ namespace Nz::ShaderAst
Type(arg.objectType->type);
SizeT(arg.methodIndex);
}
else if constexpr (std::is_same_v<T, ShaderAst::AliasType>)
{
m_stream << UInt8(13);
SizeT(arg.aliasIndex);
Type(arg.targetType->type);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
@ -800,6 +811,22 @@ namespace Nz::ShaderAst
type = std::move(methodType);
}
case 13: //< AliasType
{
std::size_t aliasIndex;
ExpressionType containedType;
SizeT(aliasIndex);
Type(containedType);
AliasType aliasType;
aliasType.aliasIndex = aliasIndex;
aliasType.targetType = std::make_unique<ContainedType>();
aliasType.targetType->type = std::move(containedType);
type = std::move(aliasType);
break;
}
default:
break;
}

View File

@ -23,6 +23,11 @@ namespace Nz::ShaderAst
node.expr->Visit(*this);
}
void ShaderAstValueCategory::Visit(AliasValueExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(AssignExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::RValue;

View File

@ -9,6 +9,37 @@
namespace Nz::ShaderAst
{
AliasType::AliasType(const AliasType& alias) :
aliasIndex(alias.aliasIndex)
{
assert(alias.targetType);
targetType = std::make_unique<ContainedType>(*alias.targetType);
}
AliasType& AliasType::operator=(const AliasType& alias)
{
aliasIndex = alias.aliasIndex;
assert(alias.targetType);
targetType = std::make_unique<ContainedType>(*alias.targetType);
return *this;
}
bool AliasType::operator==(const AliasType& rhs) const
{
assert(targetType);
assert(rhs.targetType);
if (aliasIndex != rhs.aliasIndex)
return false;
if (targetType->type != rhs.targetType->type)
return false;
return true;
}
ArrayType::ArrayType(const ArrayType& array) :
length(array.length)
{
@ -31,10 +62,10 @@ namespace Nz::ShaderAst
assert(containedType);
assert(rhs.containedType);
if (containedType->type != rhs.containedType->type)
if (length != rhs.length)
return false;
if (length != rhs.length)
if (containedType->type != rhs.containedType->type)
return false;
return true;

View File

@ -282,7 +282,7 @@ namespace Nz::ShaderAst
if (identifier.empty())
throw AstError{ "empty identifier" };
const ExpressionType& exprType = GetExpressionType(*indexedExpr);
const ExpressionType& exprType = ResolveAlias(GetExpressionType(*indexedExpr));
// TODO: Add proper support for methods
if (IsSamplerType(exprType))
{
@ -429,6 +429,25 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(AliasValueExpression& node)
{
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId));
ExpressionPtr targetExpr = HandleIdentifier(targetIdentifier);
if (m_context->options.removeAliases)
return targetExpr;
AliasType aliasType;
aliasType.aliasIndex = node.aliasId;
aliasType.targetType = std::make_unique<ContainedType>();
aliasType.targetType->type = *targetExpr->cachedExpressionType;
auto clone = static_unique_pointer_cast<AliasValueExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = std::move(aliasType);
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
{
MandatoryExpr(node.left);
@ -543,7 +562,7 @@ namespace Nz::ShaderAst
{
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
const ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
bool isMatrixCast = IsMatrixType(frontExprType);
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
{
@ -785,6 +804,9 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<DeclareAliasStatement>(AstCloner::Clone(node));
Validate(*clone);
if (m_context->options.removeAliases)
return ShaderBuilder::NoOp();
return clone;
}
@ -803,7 +825,7 @@ namespace Nz::ShaderAst
ExpressionType expressionType = ResolveType(GetExpressionType(value));
if (clone->type.HasValue() && ResolveType(clone->type) != expressionType)
if (clone->type.HasValue() && ResolveType(clone->type, true) != ResolveAlias(expressionType))
throw AstError{ "constant expression doesn't match type" };
clone->type = expressionType;
@ -852,12 +874,13 @@ namespace Nz::ShaderAst
m_context->declaredExternalVar.insert(extVar.name);
ExpressionType resolvedType = ResolveType(extVar.type);
const ExpressionType& targetType = ResolveAlias(resolvedType);
ExpressionType varType;
if (IsUniformType(resolvedType))
varType = std::get<UniformType>(resolvedType).containedType;
else if (IsSamplerType(resolvedType))
varType = resolvedType;
if (IsUniformType(targetType))
varType = std::get<UniformType>(targetType).containedType;
else if (IsSamplerType(targetType))
varType = targetType;
else
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
@ -954,8 +977,9 @@ namespace Nz::ShaderAst
throw AstError{ "empty option name" };
ExpressionType resolvedType = ResolveType(clone->optType);
const ExpressionType& targetType = ResolveAlias(resolvedType);
if (clone->defaultValue && resolvedType != GetExpressionType(*clone->defaultValue))
if (clone->defaultValue && targetType != GetExpressionType(*clone->defaultValue))
throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" };
clone->optType = std::move(resolvedType);
@ -1009,11 +1033,13 @@ namespace Nz::ShaderAst
ExpressionType resolvedType = ResolveType(member.type);
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
{
if (IsPrimitiveType(resolvedType) && std::get<PrimitiveType>(resolvedType) == PrimitiveType::Boolean)
const ExpressionType& targetType = ResolveAlias(resolvedType);
if (IsPrimitiveType(targetType) && std::get<PrimitiveType>(targetType) == PrimitiveType::Boolean)
throw AstError{ "boolean type is not allowed in std140 layout" };
else if (IsStructType(resolvedType))
else if (IsStructType(targetType))
{
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex;
std::size_t structIndex = std::get<StructType>(targetType).structIndex;
const StructDescription* desc = m_context->structs.Retrieve(structIndex);
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
throw AstError{ "inner struct layout mismatch" };
@ -1461,7 +1487,7 @@ namespace Nz::ShaderAst
AstExportVisitor exportVisitor;
exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
if (aliasStatements.empty())
if (aliasStatements.empty() || m_context->options.removeAliases)
return ShaderBuilder::NoOp();
// Register module and aliases
@ -1546,7 +1572,7 @@ namespace Nz::ShaderAst
return nullptr;
}
return ResolveAlias(&it->data);
return &it->data;
}
template<typename F>
@ -1556,7 +1582,7 @@ namespace Nz::ShaderAst
{
if (identifier.name == identifierName)
{
if (functor(*ResolveAlias(&identifier.data)))
if (functor(identifier.data))
return true;
}
@ -1570,7 +1596,7 @@ namespace Nz::ShaderAst
return nullptr;
}
return ResolveAlias(&it->data);
return &it->data;
}
TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const
@ -1626,6 +1652,14 @@ namespace Nz::ShaderAst
{
switch (identifierData->category)
{
case IdentifierCategory::Alias:
{
AliasValueExpression aliasValue;
aliasValue.aliasId = identifierData->index;
return Clone(aliasValue);
}
case IdentifierCategory::Constant:
{
// Replace IdentifierExpression by Constant(Value)Expression
@ -2124,7 +2158,7 @@ namespace Nz::ShaderAst
return varIndex;
}
auto SanitizeVisitor::ResolveAlias(const IdentifierData* identifier) const -> const IdentifierData*
auto SanitizeVisitor::ResolveAliasIdentifier(const IdentifierData* identifier) const -> const IdentifierData*
{
while (identifier->category == IdentifierCategory::Alias)
identifier = &m_context->aliases.Retrieve(identifier->index);
@ -2181,7 +2215,7 @@ namespace Nz::ShaderAst
for (const auto& [funcIndex, funcData] : m_context->functions.values)
{
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
if (funcData.flags.Test(FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
throw AstError{ "discard can only be used in the fragment stage" };
}
}
@ -2203,13 +2237,18 @@ namespace Nz::ShaderAst
}
std::size_t SanitizeVisitor::ResolveStruct(const AliasType& aliasType)
{
return ResolveStruct(aliasType.targetType->type);
}
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
{
return std::visit([&](auto&& arg) -> std::size_t
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType>)
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType> || std::is_same_v<T, AliasType>)
return ResolveStruct(arg);
else if constexpr (std::is_same_v<T, NoType> ||
std::is_same_v<T, ArrayType> ||
@ -2251,10 +2290,15 @@ namespace Nz::ShaderAst
return uniformType.containedType.structIndex;
}
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType)
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias)
{
if (!IsTypeExpression(exprType))
return exprType;
{
if (resolveAlias || m_context->options.removeAliases)
return ResolveAlias(exprType);
else
return exprType;
}
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
@ -2265,13 +2309,13 @@ namespace Nz::ShaderAst
return std::get<ExpressionType>(type);
}
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue)
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias)
{
if (!exprTypeValue.HasValue())
return {};
if (exprTypeValue.IsResultingValue())
return ResolveType(exprTypeValue.GetResultingValue());
return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias);
assert(exprTypeValue.IsExpression());
ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression());
@ -2281,7 +2325,7 @@ namespace Nz::ShaderAst
//if (!IsTypeType(exprType))
// throw AstError{ "type expected" };
return ResolveType(exprType);
return ResolveType(exprType, resolveAlias);
}
void SanitizeVisitor::SanitizeIdentifier(std::string& identifier)
@ -2334,7 +2378,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const
{
if (left != right)
if (ResolveAlias(left) != ResolveAlias(right))
throw AstError{ "Left expression type must match right expression type" };
}
@ -2359,6 +2403,11 @@ namespace Nz::ShaderAst
std::size_t structIndex = ResolveStruct(exprType);
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
}
else if (IsAliasType(exprType))
{
const AliasType& alias = std::get<AliasType>(exprType);
node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex);
}
else
throw AstError{ "for now, only structs can be aliased" };
}
@ -2400,7 +2449,7 @@ namespace Nz::ShaderAst
case TypeParameterCategory::PrimitiveType:
case TypeParameterCategory::StructType:
{
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr));
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr), true);
switch (partialType.parameters[i])
{
@ -2440,7 +2489,7 @@ namespace Nz::ShaderAst
for (auto& index : node.indices)
{
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
const ExpressionType& indexType = GetExpressionType(*index);
if (!IsPrimitiveType(indexType))
throw AstError{ "AccessIndex expects integer indices" };
@ -2459,7 +2508,7 @@ namespace Nz::ShaderAst
}
else if (IsStructType(exprType))
{
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
const ExpressionType& indexType = GetExpressionType(*indexExpr);
if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
throw AstError{ "struct can only be accessed with constant i32 indices" };
@ -2470,7 +2519,7 @@ namespace Nz::ShaderAst
std::size_t structIndex = ResolveStruct(exprType);
const StructDescription* s = m_context->structs.Retrieve(structIndex);
exprType = ResolveType(s->members[index].type);
exprType = ResolveType(s->members[index].type, true);
}
else if (IsMatrixType(exprType))
{
@ -2538,7 +2587,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(CallFunctionExpression& node)
{
const ShaderAst::ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
assert(std::holds_alternative<FunctionType>(targetFuncType));
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
@ -2564,14 +2613,15 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(CastExpression& node)
{
ExpressionType resolvedType = ResolveType(node.targetType);
const ExpressionType& targetType = ResolveAlias(resolvedType);
const auto& firstExprPtr = node.expressions.front();
if (!firstExprPtr)
throw AstError{ "expected at least one expression" };
if (IsMatrixType(resolvedType))
if (IsMatrixType(targetType))
{
const MatrixType& targetMatrixType = std::get<MatrixType>(resolvedType);
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
if (IsMatrixType(firstExprType))
@ -2614,7 +2664,7 @@ namespace Nz::ShaderAst
};
std::size_t componentCount = 0;
std::size_t requiredComponents = GetComponentCount(resolvedType);
std::size_t requiredComponents = GetComponentCount(targetType);
for (auto& exprPtr : node.expressions)
{
@ -2885,11 +2935,11 @@ namespace Nz::ShaderAst
case UnaryType::Minus:
case UnaryType::Plus:
{
ShaderAst::PrimitiveType basicType;
PrimitiveType basicType;
if (IsPrimitiveType(exprType))
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
basicType = std::get<PrimitiveType>(exprType);
else if (IsVectorType(exprType))
basicType = std::get<ShaderAst::VectorType>(exprType).type;
basicType = std::get<VectorType>(exprType).type;
else
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };

View File

@ -232,6 +232,7 @@ namespace Nz
options.optionValues = std::move(optionValues);
options.makeVariableNameUnique = true;
options.reduceLoopsToWhile = true;
options.removeAliases = true;
options.removeCompoundAssignments = false;
options.removeConstDeclaration = true;
options.removeOptionDeclaration = true;
@ -246,6 +247,11 @@ namespace Nz
return ShaderAst::Sanitize(module, options, error);
}
void GlslWriter::Append(const ShaderAst::AliasType& /*aliasType*/)
{
throw std::runtime_error("unexpected AliasType");
}
void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/)
{
throw std::runtime_error("unexpected ArrayType");
@ -689,11 +695,16 @@ namespace Nz
builtin.identifier
});
}
else if (member.locationIndex.HasValue())
else
{
Append("layout(location = ");
Append(member.locationIndex.GetResultingValue());
Append(") ", keyword, " ");
if (member.locationIndex.HasValue())
{
Append("layout(location = ");
Append(member.locationIndex.GetResultingValue());
Append(") ");
}
Append(keyword, " ");
AppendVariableDeclaration(member.type.GetResultingValue(), targetPrefix + member.name);
AppendLine(";");
@ -805,6 +816,12 @@ namespace Nz
Append("]");
}
void GlslWriter::Visit(ShaderAst::AliasValueExpression& /*node*/)
{
// all aliases should have been handled by sanitizer
throw std::runtime_error("unexpected alias value, is shader sanitized?");
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
{
node.left->Visit(*this);
@ -1038,12 +1055,14 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
{
/* nothing to do */
// all aliases should have been handled by sanitizer
throw std::runtime_error("unexpected alias declaration, is shader sanitized?");
}
void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{
/* nothing to do */
// all consts should have been handled by sanitizer
throw std::runtime_error("unexpected const declaration, is shader sanitized?");
}
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
@ -1184,7 +1203,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareOptionStatement& /*node*/)
{
/* nothing to do */
// all options should have been handled by sanitizer
throw std::runtime_error("unexpected option declaration, is shader sanitized?");
}
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
@ -1247,7 +1267,7 @@ namespace Nz
Append(";");
}
void GlslWriter::Visit(ShaderAst::ImportStatement& node)
void GlslWriter::Visit(ShaderAst::ImportStatement& /*node*/)
{
throw std::runtime_error("unexpected import statement, is the shader sanitized properly?");
}
@ -1280,8 +1300,8 @@ namespace Nz
const auto& structData = Retrieve(m_currentState->structs, structIndex);
std::string outputStructVarName;
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression)
outputStructVarName = Retrieve(m_currentState->variableNames, static_cast<ShaderAst::VariableExpression&>(*node.returnExpr).variableId);
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableValueExpression)
outputStructVarName = Retrieve(m_currentState->variableNames, static_cast<ShaderAst::VariableValueExpression&>(*node.returnExpr).variableId);
else
{
AppendLine();

View File

@ -116,6 +116,7 @@ namespace Nz
ShaderAst::Module* module;
std::size_t currentModuleIndex;
std::stringstream stream;
std::unordered_map<std::size_t, Identifier> aliases;
std::unordered_map<std::size_t, Identifier> constants;
std::unordered_map<std::size_t, Identifier> structs;
std::unordered_map<std::size_t, Identifier> variables;
@ -164,6 +165,11 @@ namespace Nz
m_environment = std::move(environment);
}
void LangWriter::Append(const ShaderAst::AliasType& type)
{
AppendIdentifier(m_currentState->aliases, type.aliasIndex);
}
void LangWriter::Append(const ShaderAst::ArrayType& type)
{
Append("array[", type.containedType->type, ", ", type.length, "]");
@ -655,6 +661,16 @@ namespace Nz
Append("}");
}
void LangWriter::RegisterAlias(std::size_t aliasIndex, std::string aliasName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(aliasName);
assert(m_currentState->aliases.find(aliasIndex) == m_currentState->aliases.end());
m_currentState->aliases.emplace(aliasIndex, std::move(identifier));
}
void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName)
{
State::Identifier identifier;
@ -714,7 +730,7 @@ namespace Nz
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
assert(IsStructType(exprType));
for (const std::string& identifier : node.identifiers)
@ -725,7 +741,7 @@ namespace Nz
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
assert(!IsStructType(exprType));
// Array access
@ -744,6 +760,11 @@ namespace Nz
Append("]");
}
void LangWriter::Visit(ShaderAst::AliasValueExpression& node)
{
AppendIdentifier(m_currentState->aliases, node.aliasId);
}
void LangWriter::Visit(ShaderAst::AssignExpression& node)
{
node.left->Visit(*this);
@ -840,7 +861,13 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ConditionalExpression& node)
{
throw std::runtime_error("fixme");
Append("const_select(");
node.condition->Visit(*this);
Append(", ");
node.truePath->Visit(*this);
Append(", ");
node.falsePath->Visit(*this);
Append(")");
}
void LangWriter::Visit(ShaderAst::ConditionalStatement& node)
@ -853,9 +880,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
{
//throw std::runtime_error("TODO"); //< missing registering
assert(node.aliasIndex);
RegisterAlias(*node.aliasIndex, node.name);
Append("alias ", node.name, " = ");
assert(node.expression);

View File

@ -598,16 +598,6 @@ namespace Nz
}, node.value);
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
{
/* nothing to do */
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{
/* nothing to do */
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
{
for (auto&& extVar : node.externalVars)
@ -729,11 +719,6 @@ namespace Nz
PopResultId();
}
void SpirvAstVisitor::Visit(ShaderAst::ImportStatement& node)
{
/* nothing to do */
}
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
{
switch (node.intrinsic)

View File

@ -655,6 +655,28 @@ namespace Nz
return typePtr;
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
{
bool wasInblockStruct = m_internal->isInBlockStruct;
if (storageClass == SpirvStorageClass::Uniform)
m_internal->isInBlockStruct = true;
auto typePtr = std::make_shared<Type>(Pointer{
BuildType(type),
storageClass
});
m_internal->isInBlockStruct = wasInblockStruct;
return typePtr;
}
auto SpirvConstantCache::BuildType(const ShaderAst::AliasType& /*type*/) const -> TypePtr
{
// No AliasType is expected (as they should have been resolved by now)
throw std::runtime_error("unexpected alias");
}
auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr
{
const auto& containedType = type.containedType->type;
@ -678,22 +700,6 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
{
bool wasInblockStruct = m_internal->isInBlockStruct;
if (storageClass == SpirvStorageClass::Uniform)
m_internal->isInBlockStruct = true;
auto typePtr = std::make_shared<Type>(Pointer{
BuildType(type),
storageClass
});
m_internal->isInBlockStruct = wasInblockStruct;
return typePtr;
}
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr
{
return std::visit([&](auto&& arg) -> TypePtr

View File

@ -505,6 +505,7 @@ namespace Nz
ShaderAst::SanitizeVisitor::Options options;
options.optionValues = states.optionValues;
options.reduceLoopsToWhile = true;
options.removeAliases = true;
options.removeCompoundAssignments = true;
options.removeMatrixCast = true;
options.removeOptionDeclaration = true;

View File

@ -0,0 +1,97 @@
#include <Engine/Shader/ShaderUtils.hpp>
#include <Nazara/Core/File.hpp>
#include <Nazara/Core/StringExt.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <catch2/catch.hpp>
#include <cctype>
TEST_CASE("aliases", "[Shader]")
{
SECTION("Alias of structs")
{
std::string_view nzslSource = R"(
[nzsl_version("1.0")]
module;
struct Data
{
value: f32
}
alias ExtData = Data;
external
{
[binding(0)] extData: uniform[ExtData]
}
struct Input
{
value: f32
}
alias In = Input;
struct Output
{
[location(0)] value: f32
}
alias Out = Output;
alias FragOut = Out;
[entry(frag)]
fn main(input: In) -> FragOut
{
let output: Out;
output.value = extData.value * input.value;
return output;
}
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
ExpectGLSL(*shaderModule, R"(
void main()
{
Input input_;
input_.value = _NzIn_value;
Output output_;
output_.value = extData.value * input_.value;
_NzOut_value = output_.value;
return;
}
)");
ExpectNZSL(*shaderModule, R"(
[entry(frag)]
fn main(input: In) -> FragOut
{
let output: Out;
output.value = extData.value * input.value;
return output;
}
)");
ExpectSPIRV(*shaderModule, R"(
OpFunction
OpLabel
OpVariable
OpVariable
OpAccessChain
OpLoad
OpAccessChain
OpLoad
OpFMul
OpAccessChain
OpStore
OpLoad
OpCompositeExtract
OpStore
OpReturn
OpFunctionEnd)");
}
}

View File

@ -262,6 +262,47 @@ fn testMat4ToMat4(input: mat4[f32]) -> mat4[f32]
{
return input;
}
)");
}
WHEN("removing aliases")
{
std::string_view nzslSource = R"(
[nzsl_version("1.0")]
module;
struct inputStruct
{
value: f32
}
alias Input = inputStruct;
alias In = Input;
external
{
[set(0), binding(0)] data: uniform[In]
}
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
Nz::ShaderAst::SanitizeVisitor::Options options;
options.removeAliases = true;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, options));
ExpectNZSL(*shaderModule, R"(
struct inputStruct
{
value: f32
}
external
{
[set(0), binding(0)] data: uniform[inputStruct]
}
)");
}

View File

@ -142,8 +142,10 @@ void ExpectGLSL(Nz::ShaderAst::Module& shader, std::string_view expectedOutput)
Nz::ShaderAst::AstReflect reflectVisitor;
reflectVisitor.Reflect(*shader.rootNode, callbacks);
INFO("no entry point found");
REQUIRE(entryShaderStage.has_value());
{
INFO("no entry point found");
REQUIRE(entryShaderStage.has_value());
}
Nz::GlslWriter writer;
std::string output = writer.Generate(entryShaderStage, shader);