Shader/Spirv: Put types and constants in the same section

This commit is contained in:
Jérôme Leclercq 2020-08-21 22:51:11 +02:00
parent cd23c01ace
commit 66a14721cb
3 changed files with 151 additions and 138 deletions

View File

@ -167,7 +167,7 @@ namespace Nz
UInt32 Register(Type t);
UInt32 Register(Variable v);
void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos, SpirvSection& types);
void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);
SpirvConstantCache& operator=(const SpirvConstantCache& cache) = delete;
SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept;
@ -183,7 +183,10 @@ namespace Nz
struct Eq;
struct Internal;
void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& debugInfos, SpirvSection& types);
void Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants);
void Write(const AnyType& type, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);
void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);
std::unique_ptr<Internal> m_internal;
};

View File

@ -12,6 +12,12 @@
namespace Nz
{
namespace
{
template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts> overloaded(Ts...)->overloaded<Ts...>;
}
struct SpirvConstantCache::Eq
{
bool Compare(const ConstantBool& lhs, const ConstantBool& rhs) const
@ -381,8 +387,7 @@ namespace Nz
{
}
tsl::ordered_map<AnyConstant, UInt32 /*id*/, AnyHasher, Eq> constantIds;
tsl::ordered_map<AnyType, UInt32 /*id*/, AnyHasher, Eq> typeIds;
tsl::ordered_map<std::variant<AnyConstant, AnyType>, UInt32 /*id*/, AnyHasher, Eq> ids;
tsl::ordered_map<Variable, UInt32 /*id*/, AnyHasher, Eq> variableIds;
tsl::ordered_map<Structure, FieldOffsets /*fieldOffsets*/, AnyHasher, Eq> structureSizes;
UInt32& nextResultId;
@ -399,8 +404,8 @@ namespace Nz
UInt32 SpirvConstantCache::GetId(const Constant& c)
{
auto it = m_internal->constantIds.find(c.constant);
if (it == m_internal->constantIds.end())
auto it = m_internal->ids.find(c.constant);
if (it == m_internal->ids.end())
throw std::runtime_error("constant is not registered");
return it->second;
@ -408,8 +413,8 @@ namespace Nz
UInt32 SpirvConstantCache::GetId(const Type& t)
{
auto it = m_internal->typeIds.find(t.type);
if (it == m_internal->typeIds.end())
auto it = m_internal->ids.find(t.type);
if (it == m_internal->ids.end())
throw std::runtime_error("constant is not registered");
return it->second;
@ -431,12 +436,12 @@ namespace Nz
DepRegisterer registerer(*this);
registerer.Register(constant);
std::size_t h = m_internal->typeIds.hash_function()(constant);
auto it = m_internal->constantIds.find(constant, h);
if (it == m_internal->constantIds.end())
std::size_t h = m_internal->ids.hash_function()(constant);
auto it = m_internal->ids.find(constant, h);
if (it == m_internal->ids.end())
{
UInt32 resultId = m_internal->nextResultId++;
it = m_internal->constantIds.emplace(std::move(constant), resultId).first;
it = m_internal->ids.emplace(std::move(constant), resultId).first;
}
return it.value();
@ -449,12 +454,12 @@ namespace Nz
DepRegisterer registerer(*this);
registerer.Register(type);
std::size_t h = m_internal->typeIds.hash_function()(type);
auto it = m_internal->typeIds.find(type, h);
if (it == m_internal->typeIds.end())
std::size_t h = m_internal->ids.hash_function()(type);
auto it = m_internal->ids.find(type, h);
if (it == m_internal->ids.end())
{
UInt32 resultId = m_internal->nextResultId++;
it = m_internal->typeIds.emplace(std::move(type), resultId).first;
it = m_internal->ids.emplace(std::move(type), resultId).first;
}
return it.value();
@ -476,131 +481,19 @@ namespace Nz
return it.value();
}
void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos, SpirvSection& types)
void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos)
{
for (auto&& [type, id] : m_internal->typeIds)
for (auto&& [type, id] : m_internal->ids)
{
UInt32 resultId = id;
std::visit([&](auto&& arg)
std::visit(overloaded
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, Bool>)
types.Append(SpirvOp::OpTypeBool, resultId);
else if constexpr (std::is_same_v<T, Float>)
types.Append(SpirvOp::OpTypeFloat, resultId, arg.width);
else if constexpr (std::is_same_v<T, Function>)
{
types.AppendVariadic(SpirvOp::OpTypeFunction, [&](const auto& appender)
{
appender(resultId);
appender(GetId(*arg.returnType));
for (const auto& param : arg.parameters)
appender(GetId(*param));
});
}
else if constexpr (std::is_same_v<T, Image>)
{
UInt32 depth;
if (arg.depth.has_value())
depth = (*arg.depth) ? 1 : 0;
else
depth = 2;
UInt32 sampled;
if (arg.sampled.has_value())
sampled = (*arg.sampled) ? 1 : 0;
else
sampled = 2;
types.AppendVariadic(SpirvOp::OpTypeImage, [&](const auto& appender)
{
appender(resultId);
appender(GetId(*arg.sampledType));
appender(arg.dim);
appender(depth);
appender(arg.arrayed);
appender(arg.multisampled);
appender(sampled);
appender(arg.format);
if (arg.qualifier)
appender(*arg.qualifier);
});
}
else if constexpr (std::is_same_v<T, Integer>)
types.Append(SpirvOp::OpTypeInt, resultId, arg.width, arg.signedness);
else if constexpr (std::is_same_v<T, Matrix>)
types.Append(SpirvOp::OpTypeMatrix, resultId, GetId(*arg.columnType), arg.columnCount);
else if constexpr (std::is_same_v<T, Pointer>)
types.Append(SpirvOp::OpTypePointer, resultId, arg.storageClass, GetId(*arg.type));
else if constexpr (std::is_same_v<T, SampledImage>)
types.Append(SpirvOp::OpTypeSampledImage, resultId, GetId(*arg.image));
else if constexpr (std::is_same_v<T, Structure>)
WriteStruct(arg, resultId, annotations, debugInfos, types);
else if constexpr (std::is_same_v<T, Vector>)
types.Append(SpirvOp::OpTypeVector, resultId, GetId(*arg.componentType), arg.componentCount);
else if constexpr (std::is_same_v<T, Void>)
types.Append(SpirvOp::OpTypeVoid, resultId);
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
[&](const AnyConstant& constant) { Write(constant, resultId, constants); },
[&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); },
}, type);
}
for (auto&& [constant, id] : m_internal->constantIds)
{
UInt32 resultId = id;
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ConstantBool>)
constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, resultId);
else if constexpr (std::is_same_v<T, ConstantComposite>)
{
constants.AppendVariadic(SpirvOp::OpConstantComposite, [&](const auto& appender)
{
appender(GetId(arg.type->type));
appender(resultId);
for (const auto& value : arg.values)
appender(GetId(value->constant));
});
}
else if constexpr (std::is_same_v<T, ConstantScalar>)
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
UInt32 typeId;
if constexpr (std::is_same_v<T, double>)
typeId = GetId({ Float{ 64 } });
else if constexpr (std::is_same_v<T, float>)
typeId = GetId({ Float{ 32 } });
else if constexpr (std::is_same_v<T, Int32>)
typeId = GetId({ Integer{ 32, 1 } });
else if constexpr (std::is_same_v<T, Int64>)
typeId = GetId({ Integer{ 64, 1 } });
else if constexpr (std::is_same_v<T, UInt32>)
typeId = GetId({ Integer{ 32, 0 } });
else if constexpr (std::is_same_v<T, UInt64>)
typeId = GetId({ Integer{ 64, 0 } });
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
constants.Append(SpirvOp::OpConstant, typeId, resultId, SpirvSection::Raw{ &arg, sizeof(arg) });
}, arg.value);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, constant);
}
for (auto&& [variable, id] : m_internal->variableIds)
{
UInt32 resultId = id;
@ -781,9 +674,128 @@ namespace Nz
}, type);
}
void SpirvConstantCache::WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& debugInfos, SpirvSection& types)
void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants)
{
types.AppendVariadic(SpirvOp::OpTypeStruct, [&](const auto& appender)
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ConstantBool>)
constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, resultId);
else if constexpr (std::is_same_v<T, ConstantComposite>)
{
constants.AppendVariadic(SpirvOp::OpConstantComposite, [&](const auto& appender)
{
appender(GetId(arg.type->type));
appender(resultId);
for (const auto& value : arg.values)
appender(GetId(value->constant));
});
}
else if constexpr (std::is_same_v<T, ConstantScalar>)
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
UInt32 typeId;
if constexpr (std::is_same_v<T, double>)
typeId = GetId({ Float{ 64 } });
else if constexpr (std::is_same_v<T, float>)
typeId = GetId({ Float{ 32 } });
else if constexpr (std::is_same_v<T, Int32>)
typeId = GetId({ Integer{ 32, 1 } });
else if constexpr (std::is_same_v<T, Int64>)
typeId = GetId({ Integer{ 64, 1 } });
else if constexpr (std::is_same_v<T, UInt32>)
typeId = GetId({ Integer{ 32, 0 } });
else if constexpr (std::is_same_v<T, UInt64>)
typeId = GetId({ Integer{ 64, 0 } });
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
constants.Append(SpirvOp::OpConstant, typeId, resultId, SpirvSection::Raw{ &arg, sizeof(arg) });
}, arg.value);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, constant);
}
void SpirvConstantCache::Write(const AnyType& type, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos)
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, Bool>)
constants.Append(SpirvOp::OpTypeBool, resultId);
else if constexpr (std::is_same_v<T, Float>)
constants.Append(SpirvOp::OpTypeFloat, resultId, arg.width);
else if constexpr (std::is_same_v<T, Function>)
{
constants.AppendVariadic(SpirvOp::OpTypeFunction, [&](const auto& appender)
{
appender(resultId);
appender(GetId(*arg.returnType));
for (const auto& param : arg.parameters)
appender(GetId(*param));
});
}
else if constexpr (std::is_same_v<T, Image>)
{
UInt32 depth;
if (arg.depth.has_value())
depth = (*arg.depth) ? 1 : 0;
else
depth = 2;
UInt32 sampled;
if (arg.sampled.has_value())
sampled = (*arg.sampled) ? 1 : 0;
else
sampled = 2;
constants.AppendVariadic(SpirvOp::OpTypeImage, [&](const auto& appender)
{
appender(resultId);
appender(GetId(*arg.sampledType));
appender(arg.dim);
appender(depth);
appender(arg.arrayed);
appender(arg.multisampled);
appender(sampled);
appender(arg.format);
if (arg.qualifier)
appender(*arg.qualifier);
});
}
else if constexpr (std::is_same_v<T, Integer>)
constants.Append(SpirvOp::OpTypeInt, resultId, arg.width, arg.signedness);
else if constexpr (std::is_same_v<T, Matrix>)
constants.Append(SpirvOp::OpTypeMatrix, resultId, GetId(*arg.columnType), arg.columnCount);
else if constexpr (std::is_same_v<T, Pointer>)
constants.Append(SpirvOp::OpTypePointer, resultId, arg.storageClass, GetId(*arg.type));
else if constexpr (std::is_same_v<T, SampledImage>)
constants.Append(SpirvOp::OpTypeSampledImage, resultId, GetId(*arg.image));
else if constexpr (std::is_same_v<T, Structure>)
WriteStruct(arg, resultId, annotations, constants, debugInfos);
else if constexpr (std::is_same_v<T, Vector>)
constants.Append(SpirvOp::OpTypeVector, resultId, GetId(*arg.componentType), arg.componentCount);
else if constexpr (std::is_same_v<T, Void>)
constants.Append(SpirvOp::OpTypeVoid, resultId);
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
}
void SpirvConstantCache::WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos)
{
constants.AppendVariadic(SpirvOp::OpTypeStruct, [&](const auto& appender)
{
appender(resultId);

View File

@ -192,7 +192,6 @@ namespace Nz
SpirvSection constants;
SpirvSection debugInfo;
SpirvSection annotations;
SpirvSection types;
SpirvSection instructions;
};
@ -398,7 +397,7 @@ namespace Nz
assert(entryPointIndex != std::numeric_limits<std::size_t>::max());
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo, m_currentState->types);
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
AppendHeader();
@ -446,7 +445,6 @@ namespace Nz
MergeBlocks(ret, state.header);
MergeBlocks(ret, state.debugInfo);
MergeBlocks(ret, state.annotations);
MergeBlocks(ret, state.types);
MergeBlocks(ret, state.constants);
MergeBlocks(ret, state.instructions);