diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index 95172c757..279f71b07 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -167,7 +167,7 @@ namespace Nz UInt32 Register(Type t); UInt32 Register(Variable v); - void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos, SpirvSection& types); + void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos); SpirvConstantCache& operator=(const SpirvConstantCache& cache) = delete; SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept; @@ -183,7 +183,10 @@ namespace Nz struct Eq; struct Internal; - void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& debugInfos, SpirvSection& types); + void Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants); + void Write(const AnyType& type, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos); + + void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos); std::unique_ptr m_internal; }; diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index 02f7b4f34..db2bdb7ee 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -12,6 +12,12 @@ namespace Nz { + namespace + { + template struct overloaded : Ts... { using Ts::operator()...; }; + + template overloaded(Ts...)->overloaded; + } struct SpirvConstantCache::Eq { bool Compare(const ConstantBool& lhs, const ConstantBool& rhs) const @@ -381,8 +387,7 @@ namespace Nz { } - tsl::ordered_map constantIds; - tsl::ordered_map typeIds; + tsl::ordered_map, UInt32 /*id*/, AnyHasher, Eq> ids; tsl::ordered_map variableIds; tsl::ordered_map structureSizes; UInt32& nextResultId; @@ -399,8 +404,8 @@ namespace Nz UInt32 SpirvConstantCache::GetId(const Constant& c) { - auto it = m_internal->constantIds.find(c.constant); - if (it == m_internal->constantIds.end()) + auto it = m_internal->ids.find(c.constant); + if (it == m_internal->ids.end()) throw std::runtime_error("constant is not registered"); return it->second; @@ -408,8 +413,8 @@ namespace Nz UInt32 SpirvConstantCache::GetId(const Type& t) { - auto it = m_internal->typeIds.find(t.type); - if (it == m_internal->typeIds.end()) + auto it = m_internal->ids.find(t.type); + if (it == m_internal->ids.end()) throw std::runtime_error("constant is not registered"); return it->second; @@ -431,12 +436,12 @@ namespace Nz DepRegisterer registerer(*this); registerer.Register(constant); - std::size_t h = m_internal->typeIds.hash_function()(constant); - auto it = m_internal->constantIds.find(constant, h); - if (it == m_internal->constantIds.end()) + std::size_t h = m_internal->ids.hash_function()(constant); + auto it = m_internal->ids.find(constant, h); + if (it == m_internal->ids.end()) { UInt32 resultId = m_internal->nextResultId++; - it = m_internal->constantIds.emplace(std::move(constant), resultId).first; + it = m_internal->ids.emplace(std::move(constant), resultId).first; } return it.value(); @@ -449,12 +454,12 @@ namespace Nz DepRegisterer registerer(*this); registerer.Register(type); - std::size_t h = m_internal->typeIds.hash_function()(type); - auto it = m_internal->typeIds.find(type, h); - if (it == m_internal->typeIds.end()) + std::size_t h = m_internal->ids.hash_function()(type); + auto it = m_internal->ids.find(type, h); + if (it == m_internal->ids.end()) { UInt32 resultId = m_internal->nextResultId++; - it = m_internal->typeIds.emplace(std::move(type), resultId).first; + it = m_internal->ids.emplace(std::move(type), resultId).first; } return it.value(); @@ -476,131 +481,19 @@ namespace Nz return it.value(); } - void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos, SpirvSection& types) + void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) { - for (auto&& [type, id] : m_internal->typeIds) + for (auto&& [type, id] : m_internal->ids) { UInt32 resultId = id; - std::visit([&](auto&& arg) + std::visit(overloaded { - using T = std::decay_t; - - if constexpr (std::is_same_v) - types.Append(SpirvOp::OpTypeBool, resultId); - else if constexpr (std::is_same_v) - types.Append(SpirvOp::OpTypeFloat, resultId, arg.width); - else if constexpr (std::is_same_v) - { - types.AppendVariadic(SpirvOp::OpTypeFunction, [&](const auto& appender) - { - appender(resultId); - appender(GetId(*arg.returnType)); - - for (const auto& param : arg.parameters) - appender(GetId(*param)); - }); - } - else if constexpr (std::is_same_v) - { - UInt32 depth; - if (arg.depth.has_value()) - depth = (*arg.depth) ? 1 : 0; - else - depth = 2; - - UInt32 sampled; - if (arg.sampled.has_value()) - sampled = (*arg.sampled) ? 1 : 0; - else - sampled = 2; - - types.AppendVariadic(SpirvOp::OpTypeImage, [&](const auto& appender) - { - appender(resultId); - appender(GetId(*arg.sampledType)); - appender(arg.dim); - appender(depth); - appender(arg.arrayed); - appender(arg.multisampled); - appender(sampled); - appender(arg.format); - - if (arg.qualifier) - appender(*arg.qualifier); - }); - } - else if constexpr (std::is_same_v) - types.Append(SpirvOp::OpTypeInt, resultId, arg.width, arg.signedness); - else if constexpr (std::is_same_v) - types.Append(SpirvOp::OpTypeMatrix, resultId, GetId(*arg.columnType), arg.columnCount); - else if constexpr (std::is_same_v) - types.Append(SpirvOp::OpTypePointer, resultId, arg.storageClass, GetId(*arg.type)); - else if constexpr (std::is_same_v) - types.Append(SpirvOp::OpTypeSampledImage, resultId, GetId(*arg.image)); - else if constexpr (std::is_same_v) - WriteStruct(arg, resultId, annotations, debugInfos, types); - else if constexpr (std::is_same_v) - types.Append(SpirvOp::OpTypeVector, resultId, GetId(*arg.componentType), arg.componentCount); - else if constexpr (std::is_same_v) - types.Append(SpirvOp::OpTypeVoid, resultId); - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + [&](const AnyConstant& constant) { Write(constant, resultId, constants); }, + [&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); }, }, type); } - for (auto&& [constant, id] : m_internal->constantIds) - { - UInt32 resultId = id; - - std::visit([&](auto&& arg) - { - using T = std::decay_t; - - if constexpr (std::is_same_v) - constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, resultId); - else if constexpr (std::is_same_v) - { - constants.AppendVariadic(SpirvOp::OpConstantComposite, [&](const auto& appender) - { - appender(GetId(arg.type->type)); - appender(resultId); - - for (const auto& value : arg.values) - appender(GetId(value->constant)); - }); - } - else if constexpr (std::is_same_v) - { - std::visit([&](auto&& arg) - { - using T = std::decay_t; - - UInt32 typeId; - if constexpr (std::is_same_v) - typeId = GetId({ Float{ 64 } }); - else if constexpr (std::is_same_v) - typeId = GetId({ Float{ 32 } }); - else if constexpr (std::is_same_v) - typeId = GetId({ Integer{ 32, 1 } }); - else if constexpr (std::is_same_v) - typeId = GetId({ Integer{ 64, 1 } }); - else if constexpr (std::is_same_v) - typeId = GetId({ Integer{ 32, 0 } }); - else if constexpr (std::is_same_v) - typeId = GetId({ Integer{ 64, 0 } }); - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - - constants.Append(SpirvOp::OpConstant, typeId, resultId, SpirvSection::Raw{ &arg, sizeof(arg) }); - - }, arg.value); - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, constant); - } - for (auto&& [variable, id] : m_internal->variableIds) { UInt32 resultId = id; @@ -781,9 +674,128 @@ namespace Nz }, type); } - void SpirvConstantCache::WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& debugInfos, SpirvSection& types) + void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants) { - types.AppendVariadic(SpirvOp::OpTypeStruct, [&](const auto& appender) + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, resultId); + else if constexpr (std::is_same_v) + { + constants.AppendVariadic(SpirvOp::OpConstantComposite, [&](const auto& appender) + { + appender(GetId(arg.type->type)); + appender(resultId); + + for (const auto& value : arg.values) + appender(GetId(value->constant)); + }); + } + else if constexpr (std::is_same_v) + { + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + UInt32 typeId; + if constexpr (std::is_same_v) + typeId = GetId({ Float{ 64 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Float{ 32 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Integer{ 32, 1 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Integer{ 64, 1 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Integer{ 32, 0 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Integer{ 64, 0 } }); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + + constants.Append(SpirvOp::OpConstant, typeId, resultId, SpirvSection::Raw{ &arg, sizeof(arg) }); + + }, arg.value); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, constant); + } + + void SpirvConstantCache::Write(const AnyType& type, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) + { + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + constants.Append(SpirvOp::OpTypeBool, resultId); + else if constexpr (std::is_same_v) + constants.Append(SpirvOp::OpTypeFloat, resultId, arg.width); + else if constexpr (std::is_same_v) + { + constants.AppendVariadic(SpirvOp::OpTypeFunction, [&](const auto& appender) + { + appender(resultId); + appender(GetId(*arg.returnType)); + + for (const auto& param : arg.parameters) + appender(GetId(*param)); + }); + } + else if constexpr (std::is_same_v) + { + UInt32 depth; + if (arg.depth.has_value()) + depth = (*arg.depth) ? 1 : 0; + else + depth = 2; + + UInt32 sampled; + if (arg.sampled.has_value()) + sampled = (*arg.sampled) ? 1 : 0; + else + sampled = 2; + + constants.AppendVariadic(SpirvOp::OpTypeImage, [&](const auto& appender) + { + appender(resultId); + appender(GetId(*arg.sampledType)); + appender(arg.dim); + appender(depth); + appender(arg.arrayed); + appender(arg.multisampled); + appender(sampled); + appender(arg.format); + + if (arg.qualifier) + appender(*arg.qualifier); + }); + } + else if constexpr (std::is_same_v) + constants.Append(SpirvOp::OpTypeInt, resultId, arg.width, arg.signedness); + else if constexpr (std::is_same_v) + constants.Append(SpirvOp::OpTypeMatrix, resultId, GetId(*arg.columnType), arg.columnCount); + else if constexpr (std::is_same_v) + constants.Append(SpirvOp::OpTypePointer, resultId, arg.storageClass, GetId(*arg.type)); + else if constexpr (std::is_same_v) + constants.Append(SpirvOp::OpTypeSampledImage, resultId, GetId(*arg.image)); + else if constexpr (std::is_same_v) + WriteStruct(arg, resultId, annotations, constants, debugInfos); + else if constexpr (std::is_same_v) + constants.Append(SpirvOp::OpTypeVector, resultId, GetId(*arg.componentType), arg.componentCount); + else if constexpr (std::is_same_v) + constants.Append(SpirvOp::OpTypeVoid, resultId); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, type); + } + + void SpirvConstantCache::WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) + { + constants.AppendVariadic(SpirvOp::OpTypeStruct, [&](const auto& appender) { appender(resultId); diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 92630d506..7a1f41bff 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -192,7 +192,6 @@ namespace Nz SpirvSection constants; SpirvSection debugInfo; SpirvSection annotations; - SpirvSection types; SpirvSection instructions; }; @@ -398,7 +397,7 @@ namespace Nz assert(entryPointIndex != std::numeric_limits::max()); - m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo, m_currentState->types); + m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo); AppendHeader(); @@ -446,7 +445,6 @@ namespace Nz MergeBlocks(ret, state.header); MergeBlocks(ret, state.debugInfo); MergeBlocks(ret, state.annotations); - MergeBlocks(ret, state.types); MergeBlocks(ret, state.constants); MergeBlocks(ret, state.instructions);