|
|
|
|
@@ -95,110 +95,114 @@ namespace Nz::ShaderAst
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const ExpressionType& SanitizeVisitor::CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices)
|
|
|
|
|
{
|
|
|
|
|
std::size_t structIndex = ResolveStruct(structType);
|
|
|
|
|
|
|
|
|
|
*structIndices++ = structIndex;
|
|
|
|
|
|
|
|
|
|
assert(structIndex < m_structs.size());
|
|
|
|
|
const StructDescription& s = m_structs[structIndex];
|
|
|
|
|
|
|
|
|
|
auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
|
|
|
|
|
if (memberIt == s.members.end())
|
|
|
|
|
throw AstError{ "unknown field " + memberIdentifier[0] };
|
|
|
|
|
|
|
|
|
|
const auto& member = *memberIt;
|
|
|
|
|
|
|
|
|
|
if (remainingMembers > 1)
|
|
|
|
|
return CheckField(member.type, memberIdentifier + 1, remainingMembers - 1, structIndices);
|
|
|
|
|
else
|
|
|
|
|
return member.type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(AccessIdentifierExpression& node)
|
|
|
|
|
{
|
|
|
|
|
auto structExpr = CloneExpression(MandatoryExpr(node.expr));
|
|
|
|
|
if (node.identifiers.empty())
|
|
|
|
|
throw AstError{ "AccessIdentifierExpression must have at least one identifier" };
|
|
|
|
|
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*structExpr);
|
|
|
|
|
if (IsVectorType(exprType))
|
|
|
|
|
ExpressionPtr indexedExpr = CloneExpression(MandatoryExpr(node.expr));
|
|
|
|
|
for (std::size_t i = 0; i < node.identifiers.size(); ++i)
|
|
|
|
|
{
|
|
|
|
|
const VectorType& swizzledVec = std::get<VectorType>(exprType);
|
|
|
|
|
const std::string& identifier = node.identifiers[i];
|
|
|
|
|
|
|
|
|
|
// Swizzle expression
|
|
|
|
|
auto swizzle = std::make_unique<SwizzleExpression>();
|
|
|
|
|
swizzle->expression = std::move(structExpr);
|
|
|
|
|
|
|
|
|
|
// FIXME: Handle properly multiple identifiers (treat recursively)
|
|
|
|
|
if (node.memberIdentifiers.size() != 1)
|
|
|
|
|
throw AstError{ "invalid swizzle" };
|
|
|
|
|
|
|
|
|
|
const std::string& swizzleStr = node.memberIdentifiers.front();
|
|
|
|
|
if (swizzleStr.empty() || swizzleStr.size() > swizzle->components.size())
|
|
|
|
|
throw AstError{ "invalid swizzle" };
|
|
|
|
|
|
|
|
|
|
swizzle->componentCount = swizzleStr.size();
|
|
|
|
|
|
|
|
|
|
if (swizzle->componentCount > 1)
|
|
|
|
|
swizzle->cachedExpressionType = VectorType{ swizzle->componentCount, swizzledVec.type };
|
|
|
|
|
else
|
|
|
|
|
swizzle->cachedExpressionType = swizzledVec.type;
|
|
|
|
|
|
|
|
|
|
for (std::size_t i = 0; i < swizzle->componentCount; ++i)
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*indexedExpr);
|
|
|
|
|
if (IsStructType(exprType))
|
|
|
|
|
{
|
|
|
|
|
switch (swizzleStr[i])
|
|
|
|
|
// Transform to AccessIndexExpression
|
|
|
|
|
AccessIndexExpression* accessIndexPtr;
|
|
|
|
|
if (indexedExpr->GetType() != NodeType::AccessIndexExpression)
|
|
|
|
|
{
|
|
|
|
|
case 'r':
|
|
|
|
|
case 'x':
|
|
|
|
|
case 's':
|
|
|
|
|
swizzle->components[i] = SwizzleComponent::First;
|
|
|
|
|
break;
|
|
|
|
|
std::unique_ptr<AccessIndexExpression> accessIndex = std::make_unique<AccessIndexExpression>();
|
|
|
|
|
accessIndex->expr = std::move(indexedExpr);
|
|
|
|
|
|
|
|
|
|
case 'g':
|
|
|
|
|
case 'y':
|
|
|
|
|
case 't':
|
|
|
|
|
swizzle->components[i] = SwizzleComponent::Second;
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case 'b':
|
|
|
|
|
case 'z':
|
|
|
|
|
case 'p':
|
|
|
|
|
swizzle->components[i] = SwizzleComponent::Third;
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case 'a':
|
|
|
|
|
case 'w':
|
|
|
|
|
case 'q':
|
|
|
|
|
swizzle->components[i] = SwizzleComponent::Fourth;
|
|
|
|
|
break;
|
|
|
|
|
accessIndexPtr = accessIndex.get();
|
|
|
|
|
indexedExpr = std::move(accessIndex);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
accessIndexPtr = static_cast<AccessIndexExpression*>(indexedExpr.get());
|
|
|
|
|
|
|
|
|
|
std::size_t structIndex = ResolveStruct(exprType);
|
|
|
|
|
assert(structIndex < m_structs.size());
|
|
|
|
|
const StructDescription& s = m_structs[structIndex];
|
|
|
|
|
|
|
|
|
|
auto it = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == identifier; });
|
|
|
|
|
if (it == s.members.end())
|
|
|
|
|
throw AstError{ "unknown field " + identifier };
|
|
|
|
|
|
|
|
|
|
accessIndexPtr->indices.push_back(ShaderBuilder::Constant(Int32(std::distance(s.members.begin(), it))));
|
|
|
|
|
accessIndexPtr->cachedExpressionType = ResolveType(it->type);
|
|
|
|
|
}
|
|
|
|
|
else if (IsVectorType(exprType))
|
|
|
|
|
{
|
|
|
|
|
// Swizzle expression
|
|
|
|
|
const VectorType& swizzledVec = std::get<VectorType>(exprType);
|
|
|
|
|
|
|
|
|
|
return swizzle;
|
|
|
|
|
auto swizzle = std::make_unique<SwizzleExpression>();
|
|
|
|
|
swizzle->expression = std::move(indexedExpr);
|
|
|
|
|
|
|
|
|
|
if (node.identifiers.size() - i != 1)
|
|
|
|
|
throw AstError{ "invalid swizzle" };
|
|
|
|
|
|
|
|
|
|
const std::string& swizzleStr = node.identifiers[i];
|
|
|
|
|
if (swizzleStr.empty() || swizzleStr.size() > swizzle->components.size())
|
|
|
|
|
throw AstError{ "invalid swizzle" };
|
|
|
|
|
|
|
|
|
|
swizzle->componentCount = swizzleStr.size();
|
|
|
|
|
|
|
|
|
|
if (swizzle->componentCount > 1)
|
|
|
|
|
swizzle->cachedExpressionType = VectorType{ swizzle->componentCount, swizzledVec.type };
|
|
|
|
|
else
|
|
|
|
|
swizzle->cachedExpressionType = swizzledVec.type;
|
|
|
|
|
|
|
|
|
|
for (std::size_t j = 0; j < swizzle->componentCount; ++j)
|
|
|
|
|
{
|
|
|
|
|
switch (swizzleStr[j])
|
|
|
|
|
{
|
|
|
|
|
case 'r':
|
|
|
|
|
case 'x':
|
|
|
|
|
case 's':
|
|
|
|
|
swizzle->components[j] = SwizzleComponent::First;
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case 'g':
|
|
|
|
|
case 'y':
|
|
|
|
|
case 't':
|
|
|
|
|
swizzle->components[j] = SwizzleComponent::Second;
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case 'b':
|
|
|
|
|
case 'z':
|
|
|
|
|
case 'p':
|
|
|
|
|
swizzle->components[j] = SwizzleComponent::Third;
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case 'a':
|
|
|
|
|
case 'w':
|
|
|
|
|
case 'q':
|
|
|
|
|
swizzle->components[j] = SwizzleComponent::Fourth;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
indexedExpr = std::move(swizzle);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "unexpected type (only struct and vectors can be indexed with identifiers)" }; //< TODO: Add support for arrays
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Transform to AccessIndexExpression
|
|
|
|
|
auto accessMemberIndex = std::make_unique<AccessIndexExpression>();
|
|
|
|
|
accessMemberIndex->expr = std::move(structExpr);
|
|
|
|
|
return indexedExpr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StackArray<std::size_t> structIndices = NazaraStackArrayNoInit(std::size_t, node.memberIdentifiers.size());
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(AccessIndexExpression& node)
|
|
|
|
|
{
|
|
|
|
|
MandatoryExpr(node.expr);
|
|
|
|
|
for (auto& index : node.indices)
|
|
|
|
|
MandatoryExpr(index);
|
|
|
|
|
|
|
|
|
|
accessMemberIndex->cachedExpressionType = ResolveType(CheckField(exprType, node.memberIdentifiers.data(), node.memberIdentifiers.size(), structIndices.data()));
|
|
|
|
|
auto clone = static_unique_pointer_cast<AccessIndexExpression>(AstCloner::Clone(node));
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size());
|
|
|
|
|
for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i)
|
|
|
|
|
{
|
|
|
|
|
std::size_t structIndex = structIndices[i];
|
|
|
|
|
assert(structIndex < m_structs.size());
|
|
|
|
|
const StructDescription& structDesc = m_structs[structIndex];
|
|
|
|
|
|
|
|
|
|
auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; });
|
|
|
|
|
assert(it != structDesc.members.end());
|
|
|
|
|
|
|
|
|
|
accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return accessMemberIndex;
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
|
|
|
|
|
@@ -1204,6 +1208,61 @@ namespace Nz::ShaderAst
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(AccessIndexExpression& node)
|
|
|
|
|
{
|
|
|
|
|
if (node.indices.empty())
|
|
|
|
|
throw AstError{ "AccessIndexExpression must have at least one index" };
|
|
|
|
|
|
|
|
|
|
for (auto& index : node.indices)
|
|
|
|
|
{
|
|
|
|
|
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
|
|
|
|
|
if (!IsPrimitiveType(indexType) || std::get<PrimitiveType>(indexType) != PrimitiveType::Int32)
|
|
|
|
|
throw AstError{ "AccessIndex expects Int32 indices" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionType exprType = GetExpressionType(*node.expr);
|
|
|
|
|
for (std::size_t i = 0; i < node.indices.size(); ++i)
|
|
|
|
|
{
|
|
|
|
|
if (IsStructType(exprType))
|
|
|
|
|
{
|
|
|
|
|
auto& indexExpr = node.indices[i];
|
|
|
|
|
|
|
|
|
|
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
|
|
|
|
|
if (indexExpr->GetType() != NodeType::ConstantExpression)
|
|
|
|
|
throw AstError{ "struct can only be accessed with constant indices" };
|
|
|
|
|
|
|
|
|
|
ConstantExpression& constantExpr = static_cast<ConstantExpression&>(*indexExpr);
|
|
|
|
|
|
|
|
|
|
Int32 index = std::get<Int32>(constantExpr.value);
|
|
|
|
|
|
|
|
|
|
std::size_t structIndex = ResolveStruct(exprType);
|
|
|
|
|
assert(structIndex < m_structs.size());
|
|
|
|
|
const StructDescription& s = m_structs[structIndex];
|
|
|
|
|
|
|
|
|
|
exprType = ResolveType(s.members[index].type);
|
|
|
|
|
}
|
|
|
|
|
else if (IsMatrixType(exprType))
|
|
|
|
|
{
|
|
|
|
|
// Matrix index (ex: mat[2])
|
|
|
|
|
const MatrixType& matrixType = std::get<MatrixType>(exprType);
|
|
|
|
|
|
|
|
|
|
//TODO: Handle row-major matrices
|
|
|
|
|
exprType = VectorType{ matrixType.rowCount, matrixType.type };
|
|
|
|
|
}
|
|
|
|
|
else if (IsVectorType(exprType))
|
|
|
|
|
{
|
|
|
|
|
// Swizzle expression with one component (ex: vec[2])
|
|
|
|
|
const VectorType& swizzledVec = std::get<VectorType>(exprType);
|
|
|
|
|
|
|
|
|
|
exprType = swizzledVec.type;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "unexpected type (only struct, vectors and matrices can be indexed)" }; //< TODO: Add support for arrays
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = std::move(exprType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
|
|
|
|
|
{
|
|
|
|
|
if (referenceDeclaration->entryStage)
|
|
|
|
|
|