Shader: Add proper support for alias

This commit is contained in:
Jérôme Leclercq
2022-03-09 12:35:00 +01:00
parent ce93b61c91
commit 05cf98477e
31 changed files with 472 additions and 98 deletions

View File

@@ -300,6 +300,16 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(AliasValueExpression& node)
{
auto clone = std::make_unique<AliasValueExpression>();
clone->aliasId = node.aliasId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(AssignExpression& node)
{
auto clone = std::make_unique<AssignExpression>();

View File

@@ -19,6 +19,11 @@ namespace Nz::ShaderAst
index->Visit(*this);
}
void AstRecursiveVisitor::Visit(AliasValueExpression& /*node*/)
{
/* nothing to do */
}
void AstRecursiveVisitor::Visit(AssignExpression& node)
{
node.left->Visit(*this);

View File

@@ -58,6 +58,11 @@ namespace Nz::ShaderAst
Node(identifier);
}
void AstSerializerBase::Serialize(AliasValueExpression& node)
{
SizeT(node.aliasId);
}
void AstSerializerBase::Serialize(AssignExpression& node)
{
Enum(node.op);
@@ -485,6 +490,12 @@ namespace Nz::ShaderAst
Type(arg.objectType->type);
SizeT(arg.methodIndex);
}
else if constexpr (std::is_same_v<T, ShaderAst::AliasType>)
{
m_stream << UInt8(13);
SizeT(arg.aliasIndex);
Type(arg.targetType->type);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
@@ -800,6 +811,22 @@ namespace Nz::ShaderAst
type = std::move(methodType);
}
case 13: //< AliasType
{
std::size_t aliasIndex;
ExpressionType containedType;
SizeT(aliasIndex);
Type(containedType);
AliasType aliasType;
aliasType.aliasIndex = aliasIndex;
aliasType.targetType = std::make_unique<ContainedType>();
aliasType.targetType->type = std::move(containedType);
type = std::move(aliasType);
break;
}
default:
break;
}

View File

@@ -23,6 +23,11 @@ namespace Nz::ShaderAst
node.expr->Visit(*this);
}
void ShaderAstValueCategory::Visit(AliasValueExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(AssignExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::RValue;

View File

@@ -9,6 +9,37 @@
namespace Nz::ShaderAst
{
AliasType::AliasType(const AliasType& alias) :
aliasIndex(alias.aliasIndex)
{
assert(alias.targetType);
targetType = std::make_unique<ContainedType>(*alias.targetType);
}
AliasType& AliasType::operator=(const AliasType& alias)
{
aliasIndex = alias.aliasIndex;
assert(alias.targetType);
targetType = std::make_unique<ContainedType>(*alias.targetType);
return *this;
}
bool AliasType::operator==(const AliasType& rhs) const
{
assert(targetType);
assert(rhs.targetType);
if (aliasIndex != rhs.aliasIndex)
return false;
if (targetType->type != rhs.targetType->type)
return false;
return true;
}
ArrayType::ArrayType(const ArrayType& array) :
length(array.length)
{
@@ -31,10 +62,10 @@ namespace Nz::ShaderAst
assert(containedType);
assert(rhs.containedType);
if (containedType->type != rhs.containedType->type)
if (length != rhs.length)
return false;
if (length != rhs.length)
if (containedType->type != rhs.containedType->type)
return false;
return true;

View File

@@ -282,7 +282,7 @@ namespace Nz::ShaderAst
if (identifier.empty())
throw AstError{ "empty identifier" };
const ExpressionType& exprType = GetExpressionType(*indexedExpr);
const ExpressionType& exprType = ResolveAlias(GetExpressionType(*indexedExpr));
// TODO: Add proper support for methods
if (IsSamplerType(exprType))
{
@@ -429,6 +429,25 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(AliasValueExpression& node)
{
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId));
ExpressionPtr targetExpr = HandleIdentifier(targetIdentifier);
if (m_context->options.removeAliases)
return targetExpr;
AliasType aliasType;
aliasType.aliasIndex = node.aliasId;
aliasType.targetType = std::make_unique<ContainedType>();
aliasType.targetType->type = *targetExpr->cachedExpressionType;
auto clone = static_unique_pointer_cast<AliasValueExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = std::move(aliasType);
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
{
MandatoryExpr(node.left);
@@ -543,7 +562,7 @@ namespace Nz::ShaderAst
{
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
const ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
bool isMatrixCast = IsMatrixType(frontExprType);
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
{
@@ -785,6 +804,9 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<DeclareAliasStatement>(AstCloner::Clone(node));
Validate(*clone);
if (m_context->options.removeAliases)
return ShaderBuilder::NoOp();
return clone;
}
@@ -803,7 +825,7 @@ namespace Nz::ShaderAst
ExpressionType expressionType = ResolveType(GetExpressionType(value));
if (clone->type.HasValue() && ResolveType(clone->type) != expressionType)
if (clone->type.HasValue() && ResolveType(clone->type, true) != ResolveAlias(expressionType))
throw AstError{ "constant expression doesn't match type" };
clone->type = expressionType;
@@ -852,12 +874,13 @@ namespace Nz::ShaderAst
m_context->declaredExternalVar.insert(extVar.name);
ExpressionType resolvedType = ResolveType(extVar.type);
const ExpressionType& targetType = ResolveAlias(resolvedType);
ExpressionType varType;
if (IsUniformType(resolvedType))
varType = std::get<UniformType>(resolvedType).containedType;
else if (IsSamplerType(resolvedType))
varType = resolvedType;
if (IsUniformType(targetType))
varType = std::get<UniformType>(targetType).containedType;
else if (IsSamplerType(targetType))
varType = targetType;
else
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
@@ -954,8 +977,9 @@ namespace Nz::ShaderAst
throw AstError{ "empty option name" };
ExpressionType resolvedType = ResolveType(clone->optType);
const ExpressionType& targetType = ResolveAlias(resolvedType);
if (clone->defaultValue && resolvedType != GetExpressionType(*clone->defaultValue))
if (clone->defaultValue && targetType != GetExpressionType(*clone->defaultValue))
throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" };
clone->optType = std::move(resolvedType);
@@ -1009,11 +1033,13 @@ namespace Nz::ShaderAst
ExpressionType resolvedType = ResolveType(member.type);
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
{
if (IsPrimitiveType(resolvedType) && std::get<PrimitiveType>(resolvedType) == PrimitiveType::Boolean)
const ExpressionType& targetType = ResolveAlias(resolvedType);
if (IsPrimitiveType(targetType) && std::get<PrimitiveType>(targetType) == PrimitiveType::Boolean)
throw AstError{ "boolean type is not allowed in std140 layout" };
else if (IsStructType(resolvedType))
else if (IsStructType(targetType))
{
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex;
std::size_t structIndex = std::get<StructType>(targetType).structIndex;
const StructDescription* desc = m_context->structs.Retrieve(structIndex);
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
throw AstError{ "inner struct layout mismatch" };
@@ -1461,7 +1487,7 @@ namespace Nz::ShaderAst
AstExportVisitor exportVisitor;
exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
if (aliasStatements.empty())
if (aliasStatements.empty() || m_context->options.removeAliases)
return ShaderBuilder::NoOp();
// Register module and aliases
@@ -1546,7 +1572,7 @@ namespace Nz::ShaderAst
return nullptr;
}
return ResolveAlias(&it->data);
return &it->data;
}
template<typename F>
@@ -1556,7 +1582,7 @@ namespace Nz::ShaderAst
{
if (identifier.name == identifierName)
{
if (functor(*ResolveAlias(&identifier.data)))
if (functor(identifier.data))
return true;
}
@@ -1570,7 +1596,7 @@ namespace Nz::ShaderAst
return nullptr;
}
return ResolveAlias(&it->data);
return &it->data;
}
TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const
@@ -1626,6 +1652,14 @@ namespace Nz::ShaderAst
{
switch (identifierData->category)
{
case IdentifierCategory::Alias:
{
AliasValueExpression aliasValue;
aliasValue.aliasId = identifierData->index;
return Clone(aliasValue);
}
case IdentifierCategory::Constant:
{
// Replace IdentifierExpression by Constant(Value)Expression
@@ -2124,7 +2158,7 @@ namespace Nz::ShaderAst
return varIndex;
}
auto SanitizeVisitor::ResolveAlias(const IdentifierData* identifier) const -> const IdentifierData*
auto SanitizeVisitor::ResolveAliasIdentifier(const IdentifierData* identifier) const -> const IdentifierData*
{
while (identifier->category == IdentifierCategory::Alias)
identifier = &m_context->aliases.Retrieve(identifier->index);
@@ -2181,7 +2215,7 @@ namespace Nz::ShaderAst
for (const auto& [funcIndex, funcData] : m_context->functions.values)
{
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
if (funcData.flags.Test(FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
throw AstError{ "discard can only be used in the fragment stage" };
}
}
@@ -2203,13 +2237,18 @@ namespace Nz::ShaderAst
}
std::size_t SanitizeVisitor::ResolveStruct(const AliasType& aliasType)
{
return ResolveStruct(aliasType.targetType->type);
}
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
{
return std::visit([&](auto&& arg) -> std::size_t
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType>)
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType> || std::is_same_v<T, AliasType>)
return ResolveStruct(arg);
else if constexpr (std::is_same_v<T, NoType> ||
std::is_same_v<T, ArrayType> ||
@@ -2251,10 +2290,15 @@ namespace Nz::ShaderAst
return uniformType.containedType.structIndex;
}
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType)
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias)
{
if (!IsTypeExpression(exprType))
return exprType;
{
if (resolveAlias || m_context->options.removeAliases)
return ResolveAlias(exprType);
else
return exprType;
}
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
@@ -2265,13 +2309,13 @@ namespace Nz::ShaderAst
return std::get<ExpressionType>(type);
}
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue)
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias)
{
if (!exprTypeValue.HasValue())
return {};
if (exprTypeValue.IsResultingValue())
return ResolveType(exprTypeValue.GetResultingValue());
return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias);
assert(exprTypeValue.IsExpression());
ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression());
@@ -2281,7 +2325,7 @@ namespace Nz::ShaderAst
//if (!IsTypeType(exprType))
// throw AstError{ "type expected" };
return ResolveType(exprType);
return ResolveType(exprType, resolveAlias);
}
void SanitizeVisitor::SanitizeIdentifier(std::string& identifier)
@@ -2334,7 +2378,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const
{
if (left != right)
if (ResolveAlias(left) != ResolveAlias(right))
throw AstError{ "Left expression type must match right expression type" };
}
@@ -2359,6 +2403,11 @@ namespace Nz::ShaderAst
std::size_t structIndex = ResolveStruct(exprType);
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
}
else if (IsAliasType(exprType))
{
const AliasType& alias = std::get<AliasType>(exprType);
node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex);
}
else
throw AstError{ "for now, only structs can be aliased" };
}
@@ -2400,7 +2449,7 @@ namespace Nz::ShaderAst
case TypeParameterCategory::PrimitiveType:
case TypeParameterCategory::StructType:
{
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr));
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr), true);
switch (partialType.parameters[i])
{
@@ -2440,7 +2489,7 @@ namespace Nz::ShaderAst
for (auto& index : node.indices)
{
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
const ExpressionType& indexType = GetExpressionType(*index);
if (!IsPrimitiveType(indexType))
throw AstError{ "AccessIndex expects integer indices" };
@@ -2459,7 +2508,7 @@ namespace Nz::ShaderAst
}
else if (IsStructType(exprType))
{
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
const ExpressionType& indexType = GetExpressionType(*indexExpr);
if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
throw AstError{ "struct can only be accessed with constant i32 indices" };
@@ -2470,7 +2519,7 @@ namespace Nz::ShaderAst
std::size_t structIndex = ResolveStruct(exprType);
const StructDescription* s = m_context->structs.Retrieve(structIndex);
exprType = ResolveType(s->members[index].type);
exprType = ResolveType(s->members[index].type, true);
}
else if (IsMatrixType(exprType))
{
@@ -2538,7 +2587,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(CallFunctionExpression& node)
{
const ShaderAst::ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
assert(std::holds_alternative<FunctionType>(targetFuncType));
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
@@ -2564,14 +2613,15 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(CastExpression& node)
{
ExpressionType resolvedType = ResolveType(node.targetType);
const ExpressionType& targetType = ResolveAlias(resolvedType);
const auto& firstExprPtr = node.expressions.front();
if (!firstExprPtr)
throw AstError{ "expected at least one expression" };
if (IsMatrixType(resolvedType))
if (IsMatrixType(targetType))
{
const MatrixType& targetMatrixType = std::get<MatrixType>(resolvedType);
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
if (IsMatrixType(firstExprType))
@@ -2614,7 +2664,7 @@ namespace Nz::ShaderAst
};
std::size_t componentCount = 0;
std::size_t requiredComponents = GetComponentCount(resolvedType);
std::size_t requiredComponents = GetComponentCount(targetType);
for (auto& exprPtr : node.expressions)
{
@@ -2885,11 +2935,11 @@ namespace Nz::ShaderAst
case UnaryType::Minus:
case UnaryType::Plus:
{
ShaderAst::PrimitiveType basicType;
PrimitiveType basicType;
if (IsPrimitiveType(exprType))
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
basicType = std::get<PrimitiveType>(exprType);
else if (IsVectorType(exprType))
basicType = std::get<ShaderAst::VectorType>(exprType).type;
basicType = std::get<VectorType>(exprType).type;
else
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };