Shader/NZSL: Add support for array indexing

This commit is contained in:
Jérôme Leclercq
2021-06-01 16:22:41 +02:00
parent 0f3c0abb96
commit 4465e230af
17 changed files with 1139 additions and 976 deletions

View File

@@ -161,7 +161,7 @@ namespace Nz::ShaderAst
ExpressionPtr AstCloner::Clone(AccessIdentifierExpression& node)
{
auto clone = std::make_unique<AccessIdentifierExpression>();
clone->memberIdentifiers = node.memberIdentifiers;
clone->identifiers = node.identifiers;
clone->expr = CloneExpression(node.expr);
clone->cachedExpressionType = node.cachedExpressionType;
@@ -172,9 +172,12 @@ namespace Nz::ShaderAst
ExpressionPtr AstCloner::Clone(AccessIndexExpression& node)
{
auto clone = std::make_unique<AccessIndexExpression>();
clone->memberIndices = node.memberIndices;
clone->expr = CloneExpression(node.expr);
clone->indices.reserve(node.indices.size());
for (auto& parameter : node.indices)
clone->indices.push_back(CloneExpression(parameter));
clone->cachedExpressionType = node.cachedExpressionType;
return clone;

View File

@@ -15,6 +15,8 @@ namespace Nz::ShaderAst
void AstRecursiveVisitor::Visit(AccessIndexExpression& node)
{
node.expr->Visit(*this);
for (auto& index : node.indices)
index->Visit(*this);
}
void AstRecursiveVisitor::Visit(AssignExpression& node)

View File

@@ -37,8 +37,8 @@ namespace Nz::ShaderAst
{
Node(node.expr);
Container(node.memberIdentifiers);
for (std::string& identifier : node.memberIdentifiers)
Container(node.identifiers);
for (std::string& identifier : node.identifiers)
Value(identifier);
}
@@ -46,9 +46,9 @@ namespace Nz::ShaderAst
{
Node(node.expr);
Container(node.memberIndices);
for (std::size_t& identifier : node.memberIndices)
SizeT(identifier);
Container(node.indices);
for (auto& identifier : node.indices)
Node(identifier);
}
void AstSerializerBase::Serialize(AssignExpression& node)

View File

@@ -95,110 +95,114 @@ namespace Nz::ShaderAst
return clone;
}
const ExpressionType& SanitizeVisitor::CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices)
{
std::size_t structIndex = ResolveStruct(structType);
*structIndices++ = structIndex;
assert(structIndex < m_structs.size());
const StructDescription& s = m_structs[structIndex];
auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
if (memberIt == s.members.end())
throw AstError{ "unknown field " + memberIdentifier[0] };
const auto& member = *memberIt;
if (remainingMembers > 1)
return CheckField(member.type, memberIdentifier + 1, remainingMembers - 1, structIndices);
else
return member.type;
}
ExpressionPtr SanitizeVisitor::Clone(AccessIdentifierExpression& node)
{
auto structExpr = CloneExpression(MandatoryExpr(node.expr));
if (node.identifiers.empty())
throw AstError{ "AccessIdentifierExpression must have at least one identifier" };
const ExpressionType& exprType = GetExpressionType(*structExpr);
if (IsVectorType(exprType))
ExpressionPtr indexedExpr = CloneExpression(MandatoryExpr(node.expr));
for (std::size_t i = 0; i < node.identifiers.size(); ++i)
{
const VectorType& swizzledVec = std::get<VectorType>(exprType);
const std::string& identifier = node.identifiers[i];
// Swizzle expression
auto swizzle = std::make_unique<SwizzleExpression>();
swizzle->expression = std::move(structExpr);
// FIXME: Handle properly multiple identifiers (treat recursively)
if (node.memberIdentifiers.size() != 1)
throw AstError{ "invalid swizzle" };
const std::string& swizzleStr = node.memberIdentifiers.front();
if (swizzleStr.empty() || swizzleStr.size() > swizzle->components.size())
throw AstError{ "invalid swizzle" };
swizzle->componentCount = swizzleStr.size();
if (swizzle->componentCount > 1)
swizzle->cachedExpressionType = VectorType{ swizzle->componentCount, swizzledVec.type };
else
swizzle->cachedExpressionType = swizzledVec.type;
for (std::size_t i = 0; i < swizzle->componentCount; ++i)
const ExpressionType& exprType = GetExpressionType(*indexedExpr);
if (IsStructType(exprType))
{
switch (swizzleStr[i])
// Transform to AccessIndexExpression
AccessIndexExpression* accessIndexPtr;
if (indexedExpr->GetType() != NodeType::AccessIndexExpression)
{
case 'r':
case 'x':
case 's':
swizzle->components[i] = SwizzleComponent::First;
break;
std::unique_ptr<AccessIndexExpression> accessIndex = std::make_unique<AccessIndexExpression>();
accessIndex->expr = std::move(indexedExpr);
case 'g':
case 'y':
case 't':
swizzle->components[i] = SwizzleComponent::Second;
break;
case 'b':
case 'z':
case 'p':
swizzle->components[i] = SwizzleComponent::Third;
break;
case 'a':
case 'w':
case 'q':
swizzle->components[i] = SwizzleComponent::Fourth;
break;
accessIndexPtr = accessIndex.get();
indexedExpr = std::move(accessIndex);
}
else
accessIndexPtr = static_cast<AccessIndexExpression*>(indexedExpr.get());
std::size_t structIndex = ResolveStruct(exprType);
assert(structIndex < m_structs.size());
const StructDescription& s = m_structs[structIndex];
auto it = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == identifier; });
if (it == s.members.end())
throw AstError{ "unknown field " + identifier };
accessIndexPtr->indices.push_back(ShaderBuilder::Constant(Int32(std::distance(s.members.begin(), it))));
accessIndexPtr->cachedExpressionType = ResolveType(it->type);
}
else if (IsVectorType(exprType))
{
// Swizzle expression
const VectorType& swizzledVec = std::get<VectorType>(exprType);
return swizzle;
auto swizzle = std::make_unique<SwizzleExpression>();
swizzle->expression = std::move(indexedExpr);
if (node.identifiers.size() - i != 1)
throw AstError{ "invalid swizzle" };
const std::string& swizzleStr = node.identifiers[i];
if (swizzleStr.empty() || swizzleStr.size() > swizzle->components.size())
throw AstError{ "invalid swizzle" };
swizzle->componentCount = swizzleStr.size();
if (swizzle->componentCount > 1)
swizzle->cachedExpressionType = VectorType{ swizzle->componentCount, swizzledVec.type };
else
swizzle->cachedExpressionType = swizzledVec.type;
for (std::size_t j = 0; j < swizzle->componentCount; ++j)
{
switch (swizzleStr[j])
{
case 'r':
case 'x':
case 's':
swizzle->components[j] = SwizzleComponent::First;
break;
case 'g':
case 'y':
case 't':
swizzle->components[j] = SwizzleComponent::Second;
break;
case 'b':
case 'z':
case 'p':
swizzle->components[j] = SwizzleComponent::Third;
break;
case 'a':
case 'w':
case 'q':
swizzle->components[j] = SwizzleComponent::Fourth;
break;
}
}
indexedExpr = std::move(swizzle);
}
else
throw AstError{ "unexpected type (only struct and vectors can be indexed with identifiers)" }; //< TODO: Add support for arrays
}
// Transform to AccessIndexExpression
auto accessMemberIndex = std::make_unique<AccessIndexExpression>();
accessMemberIndex->expr = std::move(structExpr);
return indexedExpr;
}
StackArray<std::size_t> structIndices = NazaraStackArrayNoInit(std::size_t, node.memberIdentifiers.size());
ExpressionPtr SanitizeVisitor::Clone(AccessIndexExpression& node)
{
MandatoryExpr(node.expr);
for (auto& index : node.indices)
MandatoryExpr(index);
accessMemberIndex->cachedExpressionType = ResolveType(CheckField(exprType, node.memberIdentifiers.data(), node.memberIdentifiers.size(), structIndices.data()));
auto clone = static_unique_pointer_cast<AccessIndexExpression>(AstCloner::Clone(node));
Validate(*clone);
accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size());
for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i)
{
std::size_t structIndex = structIndices[i];
assert(structIndex < m_structs.size());
const StructDescription& structDesc = m_structs[structIndex];
auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; });
assert(it != structDesc.members.end());
accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it);
}
return accessMemberIndex;
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
@@ -1204,6 +1208,61 @@ namespace Nz::ShaderAst
}
}
void SanitizeVisitor::Validate(AccessIndexExpression& node)
{
if (node.indices.empty())
throw AstError{ "AccessIndexExpression must have at least one index" };
for (auto& index : node.indices)
{
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
if (!IsPrimitiveType(indexType) || std::get<PrimitiveType>(indexType) != PrimitiveType::Int32)
throw AstError{ "AccessIndex expects Int32 indices" };
}
ExpressionType exprType = GetExpressionType(*node.expr);
for (std::size_t i = 0; i < node.indices.size(); ++i)
{
if (IsStructType(exprType))
{
auto& indexExpr = node.indices[i];
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
if (indexExpr->GetType() != NodeType::ConstantExpression)
throw AstError{ "struct can only be accessed with constant indices" };
ConstantExpression& constantExpr = static_cast<ConstantExpression&>(*indexExpr);
Int32 index = std::get<Int32>(constantExpr.value);
std::size_t structIndex = ResolveStruct(exprType);
assert(structIndex < m_structs.size());
const StructDescription& s = m_structs[structIndex];
exprType = ResolveType(s.members[index].type);
}
else if (IsMatrixType(exprType))
{
// Matrix index (ex: mat[2])
const MatrixType& matrixType = std::get<MatrixType>(exprType);
//TODO: Handle row-major matrices
exprType = VectorType{ matrixType.rowCount, matrixType.type };
}
else if (IsVectorType(exprType))
{
// Swizzle expression with one component (ex: vec[2])
const VectorType& swizzledVec = std::get<VectorType>(exprType);
exprType = swizzledVec.type;
}
else
throw AstError{ "unexpected type (only struct, vectors and matrices can be indexed)" }; //< TODO: Add support for arrays
}
node.cachedExpressionType = std::move(exprType);
}
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
{
if (referenceDeclaration->entryStage)