Shader/IndexRemapperVisitor: Fix unhandled cases

This commit is contained in:
SirLynix 2022-04-20 01:02:42 +02:00
parent 5a7bd44744
commit e11644a81d
3 changed files with 94 additions and 3 deletions

View File

@ -31,6 +31,7 @@ namespace Nz::ShaderAst
struct Callbacks
{
std::function<std::size_t(std::size_t previousIndex)> aliasIndexGenerator;
std::function<std::size_t(std::size_t previousIndex)> constIndexGenerator;
std::function<std::size_t(std::size_t previousIndex)> funcIndexGenerator;
std::function<std::size_t(std::size_t previousIndex) > structIndexGenerator;
@ -39,17 +40,21 @@ namespace Nz::ShaderAst
};
private:
StatementPtr Clone(DeclareAliasStatement& node) override;
StatementPtr Clone(DeclareConstStatement& node) override;
StatementPtr Clone(DeclareExternalStatement& node) override;
StatementPtr Clone(DeclareFunctionStatement& node) override;
StatementPtr Clone(DeclareStructStatement& node) override;
StatementPtr Clone(DeclareVariableStatement& node) override;
ExpressionPtr Clone(AliasValueExpression& node) override;
ExpressionPtr Clone(ConstantExpression& node) override;
ExpressionPtr Clone(FunctionExpression& node) override;
ExpressionPtr Clone(StructTypeExpression& node) override;
ExpressionPtr Clone(VariableValueExpression& node) override;
void HandleType(ExpressionValue<ExpressionType>& exprType);
ExpressionType RemapType(const ExpressionType& exprType);
struct Context;
Context* m_context;

View File

@ -21,6 +21,7 @@ namespace Nz::ShaderAst
struct IndexRemapperVisitor::Context
{
const IndexRemapperVisitor::Callbacks* callbacks;
std::unordered_map<std::size_t, std::size_t> newAliasIndices;
std::unordered_map<std::size_t, std::size_t> newConstIndices;
std::unordered_map<std::size_t, std::size_t> newFuncIndices;
std::unordered_map<std::size_t, std::size_t> newStructIndices;
@ -29,6 +30,7 @@ namespace Nz::ShaderAst
StatementPtr IndexRemapperVisitor::Clone(Statement& statement, const Callbacks& callbacks)
{
assert(callbacks.aliasIndexGenerator);
assert(callbacks.constIndexGenerator);
assert(callbacks.funcIndexGenerator);
assert(callbacks.structIndexGenerator);
@ -42,6 +44,20 @@ namespace Nz::ShaderAst
return AstCloner::Clone(statement);
}
StatementPtr IndexRemapperVisitor::Clone(DeclareAliasStatement& node)
{
DeclareAliasStatementPtr clone = StaticUniquePointerCast<DeclareAliasStatement>(AstCloner::Clone(node));
if (clone->aliasIndex)
{
std::size_t newAliasIndex = m_context->callbacks->aliasIndexGenerator(*clone->aliasIndex);
UniqueInsert(m_context->newAliasIndices, *clone->aliasIndex, newAliasIndex);
clone->aliasIndex = newAliasIndex;
}
return clone;
}
StatementPtr IndexRemapperVisitor::Clone(DeclareConstStatement& node)
{
DeclareConstStatementPtr clone = StaticUniquePointerCast<DeclareConstStatement>(AstCloner::Clone(node));
@ -134,6 +150,26 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr IndexRemapperVisitor::Clone(AliasValueExpression& node)
{
AliasValueExpressionPtr clone = StaticUniquePointerCast<AliasValueExpression>(AstCloner::Clone(node));
if (clone->aliasId)
clone->aliasId = Retrieve(m_context->newAliasIndices, clone->aliasId);
return clone;
}
ExpressionPtr IndexRemapperVisitor::Clone(ConstantExpression& node)
{
ConstantExpressionPtr clone = StaticUniquePointerCast<ConstantExpression>(AstCloner::Clone(node));
if (clone->constantId)
clone->constantId = Retrieve(m_context->newConstIndices, clone->constantId);
return clone;
}
ExpressionPtr IndexRemapperVisitor::Clone(FunctionExpression& node)
{
FunctionExpressionPtr clone = StaticUniquePointerCast<FunctionExpression>(AstCloner::Clone(node));
@ -170,10 +206,59 @@ namespace Nz::ShaderAst
return;
const auto& resultingType = exprType.GetResultingValue();
if (IsStructType(resultingType))
exprType = RemapType(resultingType);
}
ExpressionType IndexRemapperVisitor::RemapType(const ExpressionType& exprType)
{
if (IsAliasType(exprType))
{
std::size_t newStructIndex = Retrieve(m_context->newStructIndices, std::get<StructType>(resultingType).structIndex);
exprType = ExpressionType{ StructType{ newStructIndex } };
const AliasType& aliasType = std::get<AliasType>(exprType);
AliasType remappedAliasType;
remappedAliasType.aliasIndex = Retrieve(m_context->newAliasIndices, aliasType.aliasIndex);
remappedAliasType.targetType = std::make_unique<ContainedType>();
remappedAliasType.targetType->type = RemapType(aliasType.targetType->type);
return remappedAliasType;
}
else if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
ArrayType remappedArrayType;
remappedArrayType.containedType = std::make_unique<ContainedType>();
remappedArrayType.containedType->type = RemapType(arrayType.containedType->type);
remappedArrayType.length = arrayType.length;
return remappedArrayType;
}
else if (IsFunctionType(exprType))
{
std::size_t newFuncIndex = Retrieve(m_context->newFuncIndices, std::get<FunctionType>(exprType).funcIndex);
return FunctionType{ newFuncIndex };
}
else if (IsMethodType(exprType))
{
const MethodType& methodType = std::get<MethodType>(exprType);
MethodType remappedMethodType;
remappedMethodType.methodIndex = methodType.methodIndex;
remappedMethodType.objectType = std::make_unique<ContainedType>();
remappedMethodType.objectType->type = RemapType(methodType.objectType->type);
return remappedMethodType;
}
else if (IsStructType(exprType))
{
std::size_t newStructIndex = Retrieve(m_context->newStructIndices, std::get<StructType>(exprType).structIndex);
return StructType{ newStructIndex };
}
else if (IsUniformType(exprType))
{
UniformType uniformType;
uniformType.containedType.structIndex = Retrieve(m_context->newStructIndices, std::get<UniformType>(exprType).containedType.structIndex);
return uniformType;
}
}
}

View File

@ -1754,6 +1754,7 @@ namespace Nz::ShaderAst
// Remap already used indices
IndexRemapperVisitor::Callbacks indexCallbacks;
indexCallbacks.aliasIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->aliases.RegisterNewIndex(true); };
indexCallbacks.constIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->constantValues.RegisterNewIndex(true); };
indexCallbacks.funcIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->functions.RegisterNewIndex(true); };
indexCallbacks.structIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->structs.RegisterNewIndex(true); };