Shader/IndexRemapperVisitor: Fix unhandled cases

This commit is contained in:
SirLynix 2022-04-20 01:02:42 +02:00
parent 5a7bd44744
commit e11644a81d
3 changed files with 94 additions and 3 deletions

View File

@ -31,6 +31,7 @@ namespace Nz::ShaderAst
struct Callbacks struct Callbacks
{ {
std::function<std::size_t(std::size_t previousIndex)> aliasIndexGenerator;
std::function<std::size_t(std::size_t previousIndex)> constIndexGenerator; std::function<std::size_t(std::size_t previousIndex)> constIndexGenerator;
std::function<std::size_t(std::size_t previousIndex)> funcIndexGenerator; std::function<std::size_t(std::size_t previousIndex)> funcIndexGenerator;
std::function<std::size_t(std::size_t previousIndex) > structIndexGenerator; std::function<std::size_t(std::size_t previousIndex) > structIndexGenerator;
@ -39,17 +40,21 @@ namespace Nz::ShaderAst
}; };
private: private:
StatementPtr Clone(DeclareAliasStatement& node) override;
StatementPtr Clone(DeclareConstStatement& node) override; StatementPtr Clone(DeclareConstStatement& node) override;
StatementPtr Clone(DeclareExternalStatement& node) override; StatementPtr Clone(DeclareExternalStatement& node) override;
StatementPtr Clone(DeclareFunctionStatement& node) override; StatementPtr Clone(DeclareFunctionStatement& node) override;
StatementPtr Clone(DeclareStructStatement& node) override; StatementPtr Clone(DeclareStructStatement& node) override;
StatementPtr Clone(DeclareVariableStatement& node) override; StatementPtr Clone(DeclareVariableStatement& node) override;
ExpressionPtr Clone(AliasValueExpression& node) override;
ExpressionPtr Clone(ConstantExpression& node) override;
ExpressionPtr Clone(FunctionExpression& node) override; ExpressionPtr Clone(FunctionExpression& node) override;
ExpressionPtr Clone(StructTypeExpression& node) override; ExpressionPtr Clone(StructTypeExpression& node) override;
ExpressionPtr Clone(VariableValueExpression& node) override; ExpressionPtr Clone(VariableValueExpression& node) override;
void HandleType(ExpressionValue<ExpressionType>& exprType); void HandleType(ExpressionValue<ExpressionType>& exprType);
ExpressionType RemapType(const ExpressionType& exprType);
struct Context; struct Context;
Context* m_context; Context* m_context;

View File

@ -21,6 +21,7 @@ namespace Nz::ShaderAst
struct IndexRemapperVisitor::Context struct IndexRemapperVisitor::Context
{ {
const IndexRemapperVisitor::Callbacks* callbacks; const IndexRemapperVisitor::Callbacks* callbacks;
std::unordered_map<std::size_t, std::size_t> newAliasIndices;
std::unordered_map<std::size_t, std::size_t> newConstIndices; std::unordered_map<std::size_t, std::size_t> newConstIndices;
std::unordered_map<std::size_t, std::size_t> newFuncIndices; std::unordered_map<std::size_t, std::size_t> newFuncIndices;
std::unordered_map<std::size_t, std::size_t> newStructIndices; std::unordered_map<std::size_t, std::size_t> newStructIndices;
@ -29,6 +30,7 @@ namespace Nz::ShaderAst
StatementPtr IndexRemapperVisitor::Clone(Statement& statement, const Callbacks& callbacks) StatementPtr IndexRemapperVisitor::Clone(Statement& statement, const Callbacks& callbacks)
{ {
assert(callbacks.aliasIndexGenerator);
assert(callbacks.constIndexGenerator); assert(callbacks.constIndexGenerator);
assert(callbacks.funcIndexGenerator); assert(callbacks.funcIndexGenerator);
assert(callbacks.structIndexGenerator); assert(callbacks.structIndexGenerator);
@ -42,6 +44,20 @@ namespace Nz::ShaderAst
return AstCloner::Clone(statement); return AstCloner::Clone(statement);
} }
StatementPtr IndexRemapperVisitor::Clone(DeclareAliasStatement& node)
{
DeclareAliasStatementPtr clone = StaticUniquePointerCast<DeclareAliasStatement>(AstCloner::Clone(node));
if (clone->aliasIndex)
{
std::size_t newAliasIndex = m_context->callbacks->aliasIndexGenerator(*clone->aliasIndex);
UniqueInsert(m_context->newAliasIndices, *clone->aliasIndex, newAliasIndex);
clone->aliasIndex = newAliasIndex;
}
return clone;
}
StatementPtr IndexRemapperVisitor::Clone(DeclareConstStatement& node) StatementPtr IndexRemapperVisitor::Clone(DeclareConstStatement& node)
{ {
DeclareConstStatementPtr clone = StaticUniquePointerCast<DeclareConstStatement>(AstCloner::Clone(node)); DeclareConstStatementPtr clone = StaticUniquePointerCast<DeclareConstStatement>(AstCloner::Clone(node));
@ -134,6 +150,26 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
ExpressionPtr IndexRemapperVisitor::Clone(AliasValueExpression& node)
{
AliasValueExpressionPtr clone = StaticUniquePointerCast<AliasValueExpression>(AstCloner::Clone(node));
if (clone->aliasId)
clone->aliasId = Retrieve(m_context->newAliasIndices, clone->aliasId);
return clone;
}
ExpressionPtr IndexRemapperVisitor::Clone(ConstantExpression& node)
{
ConstantExpressionPtr clone = StaticUniquePointerCast<ConstantExpression>(AstCloner::Clone(node));
if (clone->constantId)
clone->constantId = Retrieve(m_context->newConstIndices, clone->constantId);
return clone;
}
ExpressionPtr IndexRemapperVisitor::Clone(FunctionExpression& node) ExpressionPtr IndexRemapperVisitor::Clone(FunctionExpression& node)
{ {
FunctionExpressionPtr clone = StaticUniquePointerCast<FunctionExpression>(AstCloner::Clone(node)); FunctionExpressionPtr clone = StaticUniquePointerCast<FunctionExpression>(AstCloner::Clone(node));
@ -170,10 +206,59 @@ namespace Nz::ShaderAst
return; return;
const auto& resultingType = exprType.GetResultingValue(); const auto& resultingType = exprType.GetResultingValue();
if (IsStructType(resultingType)) exprType = RemapType(resultingType);
}
ExpressionType IndexRemapperVisitor::RemapType(const ExpressionType& exprType)
{ {
std::size_t newStructIndex = Retrieve(m_context->newStructIndices, std::get<StructType>(resultingType).structIndex); if (IsAliasType(exprType))
exprType = ExpressionType{ StructType{ newStructIndex } }; {
const AliasType& aliasType = std::get<AliasType>(exprType);
AliasType remappedAliasType;
remappedAliasType.aliasIndex = Retrieve(m_context->newAliasIndices, aliasType.aliasIndex);
remappedAliasType.targetType = std::make_unique<ContainedType>();
remappedAliasType.targetType->type = RemapType(aliasType.targetType->type);
return remappedAliasType;
}
else if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
ArrayType remappedArrayType;
remappedArrayType.containedType = std::make_unique<ContainedType>();
remappedArrayType.containedType->type = RemapType(arrayType.containedType->type);
remappedArrayType.length = arrayType.length;
return remappedArrayType;
}
else if (IsFunctionType(exprType))
{
std::size_t newFuncIndex = Retrieve(m_context->newFuncIndices, std::get<FunctionType>(exprType).funcIndex);
return FunctionType{ newFuncIndex };
}
else if (IsMethodType(exprType))
{
const MethodType& methodType = std::get<MethodType>(exprType);
MethodType remappedMethodType;
remappedMethodType.methodIndex = methodType.methodIndex;
remappedMethodType.objectType = std::make_unique<ContainedType>();
remappedMethodType.objectType->type = RemapType(methodType.objectType->type);
return remappedMethodType;
}
else if (IsStructType(exprType))
{
std::size_t newStructIndex = Retrieve(m_context->newStructIndices, std::get<StructType>(exprType).structIndex);
return StructType{ newStructIndex };
}
else if (IsUniformType(exprType))
{
UniformType uniformType;
uniformType.containedType.structIndex = Retrieve(m_context->newStructIndices, std::get<UniformType>(exprType).containedType.structIndex);
return uniformType;
} }
} }
} }

View File

@ -1754,6 +1754,7 @@ namespace Nz::ShaderAst
// Remap already used indices // Remap already used indices
IndexRemapperVisitor::Callbacks indexCallbacks; IndexRemapperVisitor::Callbacks indexCallbacks;
indexCallbacks.aliasIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->aliases.RegisterNewIndex(true); };
indexCallbacks.constIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->constantValues.RegisterNewIndex(true); }; indexCallbacks.constIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->constantValues.RegisterNewIndex(true); };
indexCallbacks.funcIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->functions.RegisterNewIndex(true); }; indexCallbacks.funcIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->functions.RegisterNewIndex(true); };
indexCallbacks.structIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->structs.RegisterNewIndex(true); }; indexCallbacks.structIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->structs.RegisterNewIndex(true); };