|
|
|
|
@@ -282,7 +282,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (identifier.empty())
|
|
|
|
|
throw AstError{ "empty identifier" };
|
|
|
|
|
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*indexedExpr);
|
|
|
|
|
const ExpressionType& exprType = ResolveAlias(GetExpressionType(*indexedExpr));
|
|
|
|
|
// TODO: Add proper support for methods
|
|
|
|
|
if (IsSamplerType(exprType))
|
|
|
|
|
{
|
|
|
|
|
@@ -429,6 +429,25 @@ namespace Nz::ShaderAst
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(AliasValueExpression& node)
|
|
|
|
|
{
|
|
|
|
|
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId));
|
|
|
|
|
ExpressionPtr targetExpr = HandleIdentifier(targetIdentifier);
|
|
|
|
|
|
|
|
|
|
if (m_context->options.removeAliases)
|
|
|
|
|
return targetExpr;
|
|
|
|
|
|
|
|
|
|
AliasType aliasType;
|
|
|
|
|
aliasType.aliasIndex = node.aliasId;
|
|
|
|
|
aliasType.targetType = std::make_unique<ContainedType>();
|
|
|
|
|
aliasType.targetType->type = *targetExpr->cachedExpressionType;
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<AliasValueExpression>(AstCloner::Clone(node));
|
|
|
|
|
clone->cachedExpressionType = std::move(aliasType);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
|
|
|
|
|
{
|
|
|
|
|
MandatoryExpr(node.left);
|
|
|
|
|
@@ -543,7 +562,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
{
|
|
|
|
|
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
|
|
|
|
|
|
|
|
|
|
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
|
|
|
|
|
const ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
|
|
|
|
|
bool isMatrixCast = IsMatrixType(frontExprType);
|
|
|
|
|
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
|
|
|
|
|
{
|
|
|
|
|
@@ -785,6 +804,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareAliasStatement>(AstCloner::Clone(node));
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
if (m_context->options.removeAliases)
|
|
|
|
|
return ShaderBuilder::NoOp();
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -803,7 +825,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
ExpressionType expressionType = ResolveType(GetExpressionType(value));
|
|
|
|
|
|
|
|
|
|
if (clone->type.HasValue() && ResolveType(clone->type) != expressionType)
|
|
|
|
|
if (clone->type.HasValue() && ResolveType(clone->type, true) != ResolveAlias(expressionType))
|
|
|
|
|
throw AstError{ "constant expression doesn't match type" };
|
|
|
|
|
|
|
|
|
|
clone->type = expressionType;
|
|
|
|
|
@@ -852,12 +874,13 @@ namespace Nz::ShaderAst
|
|
|
|
|
m_context->declaredExternalVar.insert(extVar.name);
|
|
|
|
|
|
|
|
|
|
ExpressionType resolvedType = ResolveType(extVar.type);
|
|
|
|
|
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
|
|
|
|
|
|
|
|
|
ExpressionType varType;
|
|
|
|
|
if (IsUniformType(resolvedType))
|
|
|
|
|
varType = std::get<UniformType>(resolvedType).containedType;
|
|
|
|
|
else if (IsSamplerType(resolvedType))
|
|
|
|
|
varType = resolvedType;
|
|
|
|
|
if (IsUniformType(targetType))
|
|
|
|
|
varType = std::get<UniformType>(targetType).containedType;
|
|
|
|
|
else if (IsSamplerType(targetType))
|
|
|
|
|
varType = targetType;
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
|
|
|
|
|
|
|
|
|
@@ -954,8 +977,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
throw AstError{ "empty option name" };
|
|
|
|
|
|
|
|
|
|
ExpressionType resolvedType = ResolveType(clone->optType);
|
|
|
|
|
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
|
|
|
|
|
|
|
|
|
if (clone->defaultValue && resolvedType != GetExpressionType(*clone->defaultValue))
|
|
|
|
|
if (clone->defaultValue && targetType != GetExpressionType(*clone->defaultValue))
|
|
|
|
|
throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" };
|
|
|
|
|
|
|
|
|
|
clone->optType = std::move(resolvedType);
|
|
|
|
|
@@ -1009,11 +1033,13 @@ namespace Nz::ShaderAst
|
|
|
|
|
ExpressionType resolvedType = ResolveType(member.type);
|
|
|
|
|
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
|
|
|
|
|
{
|
|
|
|
|
if (IsPrimitiveType(resolvedType) && std::get<PrimitiveType>(resolvedType) == PrimitiveType::Boolean)
|
|
|
|
|
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
|
|
|
|
|
|
|
|
|
if (IsPrimitiveType(targetType) && std::get<PrimitiveType>(targetType) == PrimitiveType::Boolean)
|
|
|
|
|
throw AstError{ "boolean type is not allowed in std140 layout" };
|
|
|
|
|
else if (IsStructType(resolvedType))
|
|
|
|
|
else if (IsStructType(targetType))
|
|
|
|
|
{
|
|
|
|
|
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex;
|
|
|
|
|
std::size_t structIndex = std::get<StructType>(targetType).structIndex;
|
|
|
|
|
const StructDescription* desc = m_context->structs.Retrieve(structIndex);
|
|
|
|
|
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
|
|
|
|
|
throw AstError{ "inner struct layout mismatch" };
|
|
|
|
|
@@ -1461,7 +1487,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
AstExportVisitor exportVisitor;
|
|
|
|
|
exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
|
|
|
|
|
|
|
|
|
|
if (aliasStatements.empty())
|
|
|
|
|
if (aliasStatements.empty() || m_context->options.removeAliases)
|
|
|
|
|
return ShaderBuilder::NoOp();
|
|
|
|
|
|
|
|
|
|
// Register module and aliases
|
|
|
|
|
@@ -1546,7 +1572,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ResolveAlias(&it->data);
|
|
|
|
|
return &it->data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename F>
|
|
|
|
|
@@ -1556,7 +1582,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
{
|
|
|
|
|
if (identifier.name == identifierName)
|
|
|
|
|
{
|
|
|
|
|
if (functor(*ResolveAlias(&identifier.data)))
|
|
|
|
|
if (functor(identifier.data))
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1570,7 +1596,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ResolveAlias(&it->data);
|
|
|
|
|
return &it->data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const
|
|
|
|
|
@@ -1626,6 +1652,14 @@ namespace Nz::ShaderAst
|
|
|
|
|
{
|
|
|
|
|
switch (identifierData->category)
|
|
|
|
|
{
|
|
|
|
|
case IdentifierCategory::Alias:
|
|
|
|
|
{
|
|
|
|
|
AliasValueExpression aliasValue;
|
|
|
|
|
aliasValue.aliasId = identifierData->index;
|
|
|
|
|
|
|
|
|
|
return Clone(aliasValue);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IdentifierCategory::Constant:
|
|
|
|
|
{
|
|
|
|
|
// Replace IdentifierExpression by Constant(Value)Expression
|
|
|
|
|
@@ -2124,7 +2158,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
return varIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto SanitizeVisitor::ResolveAlias(const IdentifierData* identifier) const -> const IdentifierData*
|
|
|
|
|
auto SanitizeVisitor::ResolveAliasIdentifier(const IdentifierData* identifier) const -> const IdentifierData*
|
|
|
|
|
{
|
|
|
|
|
while (identifier->category == IdentifierCategory::Alias)
|
|
|
|
|
identifier = &m_context->aliases.Retrieve(identifier->index);
|
|
|
|
|
@@ -2181,7 +2215,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
for (const auto& [funcIndex, funcData] : m_context->functions.values)
|
|
|
|
|
{
|
|
|
|
|
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
|
|
|
|
|
if (funcData.flags.Test(FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
|
|
|
|
|
throw AstError{ "discard can only be used in the fragment stage" };
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -2203,13 +2237,18 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::ResolveStruct(const AliasType& aliasType)
|
|
|
|
|
{
|
|
|
|
|
return ResolveStruct(aliasType.targetType->type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
|
|
|
|
|
{
|
|
|
|
|
return std::visit([&](auto&& arg) -> std::size_t
|
|
|
|
|
{
|
|
|
|
|
using T = std::decay_t<decltype(arg)>;
|
|
|
|
|
|
|
|
|
|
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType>)
|
|
|
|
|
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType> || std::is_same_v<T, AliasType>)
|
|
|
|
|
return ResolveStruct(arg);
|
|
|
|
|
else if constexpr (std::is_same_v<T, NoType> ||
|
|
|
|
|
std::is_same_v<T, ArrayType> ||
|
|
|
|
|
@@ -2251,10 +2290,15 @@ namespace Nz::ShaderAst
|
|
|
|
|
return uniformType.containedType.structIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType)
|
|
|
|
|
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias)
|
|
|
|
|
{
|
|
|
|
|
if (!IsTypeExpression(exprType))
|
|
|
|
|
return exprType;
|
|
|
|
|
{
|
|
|
|
|
if (resolveAlias || m_context->options.removeAliases)
|
|
|
|
|
return ResolveAlias(exprType);
|
|
|
|
|
else
|
|
|
|
|
return exprType;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
|
|
|
|
|
|
|
|
|
|
@@ -2265,13 +2309,13 @@ namespace Nz::ShaderAst
|
|
|
|
|
return std::get<ExpressionType>(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue)
|
|
|
|
|
ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias)
|
|
|
|
|
{
|
|
|
|
|
if (!exprTypeValue.HasValue())
|
|
|
|
|
return {};
|
|
|
|
|
|
|
|
|
|
if (exprTypeValue.IsResultingValue())
|
|
|
|
|
return ResolveType(exprTypeValue.GetResultingValue());
|
|
|
|
|
return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias);
|
|
|
|
|
|
|
|
|
|
assert(exprTypeValue.IsExpression());
|
|
|
|
|
ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression());
|
|
|
|
|
@@ -2281,7 +2325,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
//if (!IsTypeType(exprType))
|
|
|
|
|
// throw AstError{ "type expected" };
|
|
|
|
|
|
|
|
|
|
return ResolveType(exprType);
|
|
|
|
|
return ResolveType(exprType, resolveAlias);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::SanitizeIdentifier(std::string& identifier)
|
|
|
|
|
@@ -2334,7 +2378,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const
|
|
|
|
|
{
|
|
|
|
|
if (left != right)
|
|
|
|
|
if (ResolveAlias(left) != ResolveAlias(right))
|
|
|
|
|
throw AstError{ "Left expression type must match right expression type" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -2359,6 +2403,11 @@ namespace Nz::ShaderAst
|
|
|
|
|
std::size_t structIndex = ResolveStruct(exprType);
|
|
|
|
|
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
|
|
|
|
|
}
|
|
|
|
|
else if (IsAliasType(exprType))
|
|
|
|
|
{
|
|
|
|
|
const AliasType& alias = std::get<AliasType>(exprType);
|
|
|
|
|
node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "for now, only structs can be aliased" };
|
|
|
|
|
}
|
|
|
|
|
@@ -2400,7 +2449,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
case TypeParameterCategory::PrimitiveType:
|
|
|
|
|
case TypeParameterCategory::StructType:
|
|
|
|
|
{
|
|
|
|
|
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr));
|
|
|
|
|
ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr), true);
|
|
|
|
|
|
|
|
|
|
switch (partialType.parameters[i])
|
|
|
|
|
{
|
|
|
|
|
@@ -2440,7 +2489,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
for (auto& index : node.indices)
|
|
|
|
|
{
|
|
|
|
|
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
|
|
|
|
|
const ExpressionType& indexType = GetExpressionType(*index);
|
|
|
|
|
if (!IsPrimitiveType(indexType))
|
|
|
|
|
throw AstError{ "AccessIndex expects integer indices" };
|
|
|
|
|
|
|
|
|
|
@@ -2459,7 +2508,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
}
|
|
|
|
|
else if (IsStructType(exprType))
|
|
|
|
|
{
|
|
|
|
|
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
|
|
|
|
|
const ExpressionType& indexType = GetExpressionType(*indexExpr);
|
|
|
|
|
if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
|
|
|
|
|
throw AstError{ "struct can only be accessed with constant i32 indices" };
|
|
|
|
|
|
|
|
|
|
@@ -2470,7 +2519,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
std::size_t structIndex = ResolveStruct(exprType);
|
|
|
|
|
const StructDescription* s = m_context->structs.Retrieve(structIndex);
|
|
|
|
|
|
|
|
|
|
exprType = ResolveType(s->members[index].type);
|
|
|
|
|
exprType = ResolveType(s->members[index].type, true);
|
|
|
|
|
}
|
|
|
|
|
else if (IsMatrixType(exprType))
|
|
|
|
|
{
|
|
|
|
|
@@ -2538,7 +2587,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(CallFunctionExpression& node)
|
|
|
|
|
{
|
|
|
|
|
const ShaderAst::ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
|
|
|
|
|
const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
|
|
|
|
|
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
|
|
|
|
|
|
|
|
|
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
|
|
|
|
|
@@ -2564,14 +2613,15 @@ namespace Nz::ShaderAst
|
|
|
|
|
void SanitizeVisitor::Validate(CastExpression& node)
|
|
|
|
|
{
|
|
|
|
|
ExpressionType resolvedType = ResolveType(node.targetType);
|
|
|
|
|
const ExpressionType& targetType = ResolveAlias(resolvedType);
|
|
|
|
|
|
|
|
|
|
const auto& firstExprPtr = node.expressions.front();
|
|
|
|
|
if (!firstExprPtr)
|
|
|
|
|
throw AstError{ "expected at least one expression" };
|
|
|
|
|
|
|
|
|
|
if (IsMatrixType(resolvedType))
|
|
|
|
|
if (IsMatrixType(targetType))
|
|
|
|
|
{
|
|
|
|
|
const MatrixType& targetMatrixType = std::get<MatrixType>(resolvedType);
|
|
|
|
|
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
|
|
|
|
|
|
|
|
|
|
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
|
|
|
|
|
if (IsMatrixType(firstExprType))
|
|
|
|
|
@@ -2614,7 +2664,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::size_t componentCount = 0;
|
|
|
|
|
std::size_t requiredComponents = GetComponentCount(resolvedType);
|
|
|
|
|
std::size_t requiredComponents = GetComponentCount(targetType);
|
|
|
|
|
|
|
|
|
|
for (auto& exprPtr : node.expressions)
|
|
|
|
|
{
|
|
|
|
|
@@ -2885,11 +2935,11 @@ namespace Nz::ShaderAst
|
|
|
|
|
case UnaryType::Minus:
|
|
|
|
|
case UnaryType::Plus:
|
|
|
|
|
{
|
|
|
|
|
ShaderAst::PrimitiveType basicType;
|
|
|
|
|
PrimitiveType basicType;
|
|
|
|
|
if (IsPrimitiveType(exprType))
|
|
|
|
|
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
|
|
|
|
|
basicType = std::get<PrimitiveType>(exprType);
|
|
|
|
|
else if (IsVectorType(exprType))
|
|
|
|
|
basicType = std::get<ShaderAst::VectorType>(exprType).type;
|
|
|
|
|
basicType = std::get<VectorType>(exprType).type;
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };
|
|
|
|
|
|
|
|
|
|
|