diff --git a/include/Nazara/Shader/Ast/IndexRemapperVisitor.hpp b/include/Nazara/Shader/Ast/IndexRemapperVisitor.hpp index dc8ebc63f..59954831f 100644 --- a/include/Nazara/Shader/Ast/IndexRemapperVisitor.hpp +++ b/include/Nazara/Shader/Ast/IndexRemapperVisitor.hpp @@ -31,6 +31,7 @@ namespace Nz::ShaderAst struct Callbacks { + std::function aliasIndexGenerator; std::function constIndexGenerator; std::function funcIndexGenerator; std::function structIndexGenerator; @@ -39,17 +40,21 @@ namespace Nz::ShaderAst }; private: + StatementPtr Clone(DeclareAliasStatement& node) override; StatementPtr Clone(DeclareConstStatement& node) override; StatementPtr Clone(DeclareExternalStatement& node) override; StatementPtr Clone(DeclareFunctionStatement& node) override; StatementPtr Clone(DeclareStructStatement& node) override; StatementPtr Clone(DeclareVariableStatement& node) override; + ExpressionPtr Clone(AliasValueExpression& node) override; + ExpressionPtr Clone(ConstantExpression& node) override; ExpressionPtr Clone(FunctionExpression& node) override; ExpressionPtr Clone(StructTypeExpression& node) override; ExpressionPtr Clone(VariableValueExpression& node) override; void HandleType(ExpressionValue& exprType); + ExpressionType RemapType(const ExpressionType& exprType); struct Context; Context* m_context; diff --git a/src/Nazara/Shader/Ast/IndexRemapperVisitor.cpp b/src/Nazara/Shader/Ast/IndexRemapperVisitor.cpp index bcb785665..f8e8c82fb 100644 --- a/src/Nazara/Shader/Ast/IndexRemapperVisitor.cpp +++ b/src/Nazara/Shader/Ast/IndexRemapperVisitor.cpp @@ -21,6 +21,7 @@ namespace Nz::ShaderAst struct IndexRemapperVisitor::Context { const IndexRemapperVisitor::Callbacks* callbacks; + std::unordered_map newAliasIndices; std::unordered_map newConstIndices; std::unordered_map newFuncIndices; std::unordered_map newStructIndices; @@ -29,6 +30,7 @@ namespace Nz::ShaderAst StatementPtr IndexRemapperVisitor::Clone(Statement& statement, const Callbacks& callbacks) { + assert(callbacks.aliasIndexGenerator); assert(callbacks.constIndexGenerator); assert(callbacks.funcIndexGenerator); assert(callbacks.structIndexGenerator); @@ -42,6 +44,20 @@ namespace Nz::ShaderAst return AstCloner::Clone(statement); } + StatementPtr IndexRemapperVisitor::Clone(DeclareAliasStatement& node) + { + DeclareAliasStatementPtr clone = StaticUniquePointerCast(AstCloner::Clone(node)); + + if (clone->aliasIndex) + { + std::size_t newAliasIndex = m_context->callbacks->aliasIndexGenerator(*clone->aliasIndex); + UniqueInsert(m_context->newAliasIndices, *clone->aliasIndex, newAliasIndex); + clone->aliasIndex = newAliasIndex; + } + + return clone; + } + StatementPtr IndexRemapperVisitor::Clone(DeclareConstStatement& node) { DeclareConstStatementPtr clone = StaticUniquePointerCast(AstCloner::Clone(node)); @@ -134,6 +150,26 @@ namespace Nz::ShaderAst return clone; } + ExpressionPtr IndexRemapperVisitor::Clone(AliasValueExpression& node) + { + AliasValueExpressionPtr clone = StaticUniquePointerCast(AstCloner::Clone(node)); + + if (clone->aliasId) + clone->aliasId = Retrieve(m_context->newAliasIndices, clone->aliasId); + + return clone; + } + + ExpressionPtr IndexRemapperVisitor::Clone(ConstantExpression& node) + { + ConstantExpressionPtr clone = StaticUniquePointerCast(AstCloner::Clone(node)); + + if (clone->constantId) + clone->constantId = Retrieve(m_context->newConstIndices, clone->constantId); + + return clone; + } + ExpressionPtr IndexRemapperVisitor::Clone(FunctionExpression& node) { FunctionExpressionPtr clone = StaticUniquePointerCast(AstCloner::Clone(node)); @@ -170,10 +206,59 @@ namespace Nz::ShaderAst return; const auto& resultingType = exprType.GetResultingValue(); - if (IsStructType(resultingType)) + exprType = RemapType(resultingType); + } + + ExpressionType IndexRemapperVisitor::RemapType(const ExpressionType& exprType) + { + if (IsAliasType(exprType)) { - std::size_t newStructIndex = Retrieve(m_context->newStructIndices, std::get(resultingType).structIndex); - exprType = ExpressionType{ StructType{ newStructIndex } }; + const AliasType& aliasType = std::get(exprType); + + AliasType remappedAliasType; + remappedAliasType.aliasIndex = Retrieve(m_context->newAliasIndices, aliasType.aliasIndex); + remappedAliasType.targetType = std::make_unique(); + remappedAliasType.targetType->type = RemapType(aliasType.targetType->type); + + return remappedAliasType; + } + else if (IsArrayType(exprType)) + { + const ArrayType& arrayType = std::get(exprType); + + ArrayType remappedArrayType; + remappedArrayType.containedType = std::make_unique(); + remappedArrayType.containedType->type = RemapType(arrayType.containedType->type); + remappedArrayType.length = arrayType.length; + + return remappedArrayType; + } + else if (IsFunctionType(exprType)) + { + std::size_t newFuncIndex = Retrieve(m_context->newFuncIndices, std::get(exprType).funcIndex); + return FunctionType{ newFuncIndex }; + } + else if (IsMethodType(exprType)) + { + const MethodType& methodType = std::get(exprType); + + MethodType remappedMethodType; + remappedMethodType.methodIndex = methodType.methodIndex; + remappedMethodType.objectType = std::make_unique(); + remappedMethodType.objectType->type = RemapType(methodType.objectType->type); + + return remappedMethodType; + } + else if (IsStructType(exprType)) + { + std::size_t newStructIndex = Retrieve(m_context->newStructIndices, std::get(exprType).structIndex); + return StructType{ newStructIndex }; + } + else if (IsUniformType(exprType)) + { + UniformType uniformType; + uniformType.containedType.structIndex = Retrieve(m_context->newStructIndices, std::get(exprType).containedType.structIndex); + return uniformType; } } } diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 21ca31b85..3e3ba1a0d 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -1754,6 +1754,7 @@ namespace Nz::ShaderAst // Remap already used indices IndexRemapperVisitor::Callbacks indexCallbacks; + indexCallbacks.aliasIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->aliases.RegisterNewIndex(true); }; indexCallbacks.constIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->constantValues.RegisterNewIndex(true); }; indexCallbacks.funcIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->functions.RegisterNewIndex(true); }; indexCallbacks.structIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->structs.RegisterNewIndex(true); };