Shader/SPIRV: Handle arrays properly

This commit is contained in:
Jérôme Leclercq 2022-01-23 19:59:26 +01:00
parent 2463e471cc
commit b8a52b93e8
2 changed files with 243 additions and 108 deletions

View File

@ -20,6 +20,7 @@
namespace Nz
{
class FieldOffsets;
class SpirvSection;
class NAZARA_SHADER_API SpirvConstantCache
@ -108,6 +109,7 @@ namespace Nz
{
std::string name;
TypePtr type;
mutable std::optional<UInt32> offset;
};
std::string name;
@ -171,6 +173,7 @@ namespace Nz
};
ConstantPtr BuildConstant(const ShaderAst::ConstantValue& value) const;
FieldOffsets BuildFieldOffsets(const Structure& structData) const;
TypePtr BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters) const;
TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
@ -195,6 +198,20 @@ namespace Nz
UInt32 Register(Type t);
UInt32 Register(Variable v);
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Array& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Bool& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Float& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Function& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Image& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Integer& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Matrix& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Pointer& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const SampledImage& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Structure& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Type& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Vector& type, std::size_t arrayLength) const;
std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Void& type, std::size_t arrayLength) const;
void SetStructCallback(StructCallback callback);
void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);

View File

@ -7,6 +7,7 @@
#include <Nazara/Shader/Ast/Nodes.hpp>
#include <Nazara/Utility/FieldOffsets.hpp>
#include <tsl/ordered_map.h>
#include <cassert>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
@ -16,7 +17,27 @@ namespace Nz
{
template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts> overloaded(Ts...)->overloaded<Ts...>;
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
StructFieldType TypeToStructFieldType(const SpirvConstantCache::AnyType& type)
{
if (std::holds_alternative<SpirvConstantCache::Bool>(type))
return StructFieldType::Bool1;
else if (std::holds_alternative<SpirvConstantCache::Float>(type))
{
const auto& floatType = std::get<SpirvConstantCache::Float>(type);
assert(floatType.width == 32 || floatType.width == 64);
return (floatType.width == 32) ? StructFieldType::Float1 : StructFieldType::Double1;
}
else if (std::holds_alternative<SpirvConstantCache::Integer>(type))
{
const auto& intType = std::get<SpirvConstantCache::Integer>(type);
assert(intType.width == 32);
return (intType.signedness) ? StructFieldType::Int1 : StructFieldType::UInt1;
}
throw std::runtime_error("unexpected type");
}
}
struct SpirvConstantCache::Eq
@ -278,6 +299,7 @@ namespace Nz
void Register(const Structure& s)
{
Register(s.members);
cache.BuildFieldOffsets(s);
}
void Register(const SpirvConstantCache::Structure::Member& m)
@ -408,6 +430,12 @@ namespace Nz
struct SpirvConstantCache::Internal
{
struct StructOffsets
{
FieldOffsets fieldOffsets;
std::vector<UInt32> offsets;
};
Internal(UInt32& resultId) :
nextResultId(resultId)
{
@ -415,7 +443,6 @@ namespace Nz
tsl::ordered_map<std::variant<AnyConstant, AnyType>, UInt32 /*id*/, AnyHasher, Eq> ids;
tsl::ordered_map<Variable, UInt32 /*id*/, AnyHasher, Eq> variableIds;
tsl::ordered_map<Structure, FieldOffsets /*fieldOffsets*/, AnyHasher, Eq> structureSizes;
StructCallback structCallback;
UInt32& nextResultId;
bool isInBlockStruct = false;
@ -480,6 +507,106 @@ namespace Nz
}, value));
}
FieldOffsets SpirvConstantCache::BuildFieldOffsets(const Structure& structData) const
{
FieldOffsets structOffsets(StructLayout::Std140);
for (const Structure::Member& member : structData.members)
{
member.offset = std::visit([&](auto&& arg) -> std::size_t
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, Array>)
{
assert(std::holds_alternative<ConstantScalar>(arg.length->constant));
const auto& scalar = std::get<ConstantScalar>(arg.length->constant);
assert(std::holds_alternative<UInt32>(scalar.value));
std::size_t length = std::get<UInt32>(scalar.value);
return RegisterArrayField(structOffsets, arg.elementType->type, length);
}
else if constexpr (std::is_same_v<T, Bool>)
return structOffsets.AddField(StructFieldType::Bool1);
else if constexpr (std::is_same_v<T, Float>)
{
switch (arg.width)
{
case 32: return structOffsets.AddField(StructFieldType::Float1);
case 64: return structOffsets.AddField(StructFieldType::Double1);
default: throw std::runtime_error("unexpected float width " + std::to_string(arg.width));
}
}
else if constexpr (std::is_same_v<T, Integer>)
return structOffsets.AddField((arg.signedness) ? StructFieldType::Int1 : StructFieldType::UInt1);
else if constexpr (std::is_same_v<T, Matrix>)
{
assert(std::holds_alternative<Vector>(arg.columnType->type));
Vector& columnVec = std::get<Vector>(arg.columnType->type);
if (!std::holds_alternative<Float>(columnVec.componentType->type))
throw std::runtime_error("unexpected vector type");
Float& vecType = std::get<Float>(columnVec.componentType->type);
StructFieldType columnType;
switch (vecType.width)
{
case 32: columnType = StructFieldType::Float1; break;
case 64: columnType = StructFieldType::Double1; break;
default: throw std::runtime_error("unexpected float width " + std::to_string(vecType.width));
}
return structOffsets.AddMatrix(columnType, arg.columnCount, columnVec.componentCount, true);
}
else if constexpr (std::is_same_v<T, Pointer>)
throw std::runtime_error("unhandled pointer in struct");
else if constexpr (std::is_same_v<T, Structure>)
return structOffsets.AddStruct(BuildFieldOffsets(arg));
else if constexpr (std::is_same_v<T, Vector>)
{
if (std::holds_alternative<Bool>(arg.componentType->type))
return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::Bool1) + arg.componentCount - 1));
else if (std::holds_alternative<Float>(arg.componentType->type))
{
Float& floatData = std::get<Float>(arg.componentType->type);
switch (floatData.width)
{
case 32: return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::Float1) + arg.componentCount - 1));
case 64: return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::Double1) + arg.componentCount - 1));
default: throw std::runtime_error("unexpected float width " + std::to_string(floatData.width));
}
}
else if (std::holds_alternative<Integer>(arg.componentType->type))
{
Integer& intData = std::get<Integer>(arg.componentType->type);
if (intData.width != 32)
throw std::runtime_error("unexpected integer width " + std::to_string(intData.width));
if (intData.signedness)
return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::Int1) + arg.componentCount - 1));
else
return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::UInt1) + arg.componentCount - 1));
}
else
throw std::runtime_error("unexpected type for vector");
}
else if constexpr (std::is_same_v<T, Function>)
throw std::runtime_error("unexpected function as struct member");
else if constexpr (std::is_same_v<T, Identifier>)
throw std::runtime_error("unexpected identifier");
else if constexpr (std::is_same_v<T, Image> || std::is_same_v<T, SampledImage>)
throw std::runtime_error("unexpected opaque type as struct member");
else if constexpr (std::is_same_v<T, Void>)
throw std::runtime_error("unexpected void as struct member");
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, member.type->type);
}
return structOffsets;
}
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters) const -> TypePtr
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
@ -528,10 +655,24 @@ namespace Nz
auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr
{
const auto& containedType = type.containedType->type;
TypePtr builtContainedType = BuildType(containedType);
// ArrayStride
std::optional<UInt32> arrayStride;
if (m_internal->isInBlockStruct)
{
FieldOffsets fieldOffset(StructLayout::Std140);
RegisterArrayField(fieldOffset, builtContainedType->type, 1);
arrayStride = SafeCast<UInt32>(fieldOffset.GetAlignedSize());
}
return std::make_shared<Type>(Array{
BuildType(type.containedType->type),
builtContainedType,
BuildConstant(type.length.GetResultingValue()),
(m_internal->isInBlockStruct) ? std::make_optional<UInt32>(16) : std::nullopt
arrayStride
});
}
@ -759,6 +900,83 @@ namespace Nz
return it.value();
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Array& type, std::size_t arrayLength) const
{
FieldOffsets dummyStruct(fieldOffsets.GetLayout());
RegisterArrayField(dummyStruct, type.elementType->type, std::get<UInt32>(std::get<ConstantScalar>(type.length->constant).value));
return fieldOffsets.AddStructArray(dummyStruct, arrayLength);
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Bool& type, std::size_t arrayLength) const
{
return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength);
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Float& type, std::size_t arrayLength) const
{
return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength);
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Function& type, std::size_t arrayLength) const
{
throw std::runtime_error("unexpected Function");
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Image& type, std::size_t arrayLength) const
{
throw std::runtime_error("unexpected Image");
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Integer& type, std::size_t arrayLength) const
{
return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength);
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Matrix& type, std::size_t arrayLength) const
{
if (!std::holds_alternative<Vector>(type.columnType->type))
throw std::runtime_error("unexpected column type");
const Vector& vecType = std::get<Vector>(type.columnType->type);
return fieldOffsets.AddMatrixArray(TypeToStructFieldType(vecType.componentType->type), type.columnCount, vecType.componentCount, true, arrayLength);
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Pointer& /*type*/, std::size_t /*arrayLength*/) const
{
throw std::runtime_error("unexpected Pointer (not implemented)");
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const SampledImage& /*type*/, std::size_t /*arrayLength*/) const
{
throw std::runtime_error("unexpected SampledImage");
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Structure& type, std::size_t arrayLength) const
{
auto innerFieldOffset = BuildFieldOffsets(type);
return fieldOffsets.AddStructArray(innerFieldOffset, arrayLength);
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Type& type, std::size_t arrayLength) const
{
return std::visit([&](auto&& arg) -> std::size_t
{
return RegisterArrayField(fieldOffsets, arg, arrayLength);
}, type.type);
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Vector& type, std::size_t arrayLength) const
{
assert(type.componentCount > 0 && type.componentCount <= 4);
return fieldOffsets.AddFieldArray(static_cast<StructFieldType>(UnderlyingCast(TypeToStructFieldType(type.componentType->type)) + type.componentCount), arrayLength);
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Void& type, std::size_t arrayLength) const
{
throw std::runtime_error("unexpected Void");
}
void SpirvConstantCache::SetStructCallback(StructCallback callback)
{
m_internal->structCallback = std::move(callback);
@ -941,125 +1159,25 @@ namespace Nz
for (SpirvDecoration decoration : structData.decorations)
annotations.Append(SpirvOp::OpDecorate, resultId, decoration);
FieldOffsets structOffsets(StructLayout::Std140);
for (std::size_t memberIndex = 0; memberIndex < structData.members.size(); ++memberIndex)
{
const auto& member = structData.members[memberIndex];
debugInfos.Append(SpirvOp::OpMemberName, resultId, memberIndex, member.name);
std::size_t offset = std::visit([&](auto&& arg) -> std::size_t
UInt32 offset = member.offset.value();
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, Array>)
if constexpr (std::is_same_v<T, Matrix>)
{
assert(std::holds_alternative<ConstantScalar>(arg.length->constant));
const auto& scalar = std::get<ConstantScalar>(arg.length->constant);
assert(std::holds_alternative<UInt32>(scalar.value));
std::size_t length = std::get<UInt32>(scalar.value);
if (!std::holds_alternative<Float>(arg.elementType->type))
throw std::runtime_error("todo");
// FIXME: Virer cette implémentation du ghetto
const Float& fData = std::get<Float>(arg.elementType->type);
switch (fData.width)
{
case 32: return structOffsets.AddFieldArray(StructFieldType::Float1, length);
case 64: return structOffsets.AddFieldArray(StructFieldType::Double1, length);
default: throw std::runtime_error("unexpected float width " + std::to_string(fData.width));
}
}
else if constexpr (std::is_same_v<T, Bool>)
return structOffsets.AddField(StructFieldType::Bool1);
else if constexpr (std::is_same_v<T, Float>)
{
switch (arg.width)
{
case 32: return structOffsets.AddField(StructFieldType::Float1);
case 64: return structOffsets.AddField(StructFieldType::Double1);
default: throw std::runtime_error("unexpected float width " + std::to_string(arg.width));
}
}
else if constexpr (std::is_same_v<T, Integer>)
return structOffsets.AddField((arg.signedness) ? StructFieldType::Int1 : StructFieldType::UInt1);
else if constexpr (std::is_same_v<T, Matrix>)
{
assert(std::holds_alternative<Vector>(arg.columnType->type));
Vector& columnVec = std::get<Vector>(arg.columnType->type);
if (!std::holds_alternative<Float>(columnVec.componentType->type))
throw std::runtime_error("unexpected vector type");
Float& vecType = std::get<Float>(columnVec.componentType->type);
StructFieldType columnType;
switch (vecType.width)
{
case 32: columnType = StructFieldType::Float1; break;
case 64: columnType = StructFieldType::Double1; break;
default: throw std::runtime_error("unexpected float width " + std::to_string(vecType.width));
}
annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::ColMajor);
annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::MatrixStride, 16);
return structOffsets.AddMatrix(columnType, arg.columnCount, columnVec.componentCount, true);
}
else if constexpr (std::is_same_v<T, Pointer>)
throw std::runtime_error("unhandled pointer in struct");
else if constexpr (std::is_same_v<T, Structure>)
{
auto it = m_internal->structureSizes.find(arg);
assert(it != m_internal->structureSizes.end());
return structOffsets.AddStruct(it->second);
}
else if constexpr (std::is_same_v<T, Vector>)
{
if (std::holds_alternative<Bool>(arg.componentType->type))
return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::Bool1) + arg.componentCount - 1));
else if (std::holds_alternative<Float>(arg.componentType->type))
{
Float& floatData = std::get<Float>(arg.componentType->type);
switch (floatData.width)
{
case 32: return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::Float1) + arg.componentCount - 1));
case 64: return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::Double1) + arg.componentCount - 1));
default: throw std::runtime_error("unexpected float width " + std::to_string(floatData.width));
}
}
else if (std::holds_alternative<Integer>(arg.componentType->type))
{
Integer& intData = std::get<Integer>(arg.componentType->type);
if (intData.width != 32)
throw std::runtime_error("unexpected integer width " + std::to_string(intData.width));
if (intData.signedness)
return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::Int1) + arg.componentCount - 1));
else
return structOffsets.AddField(static_cast<StructFieldType>(UnderlyingCast(StructFieldType::UInt1) + arg.componentCount - 1));
}
else
throw std::runtime_error("unexpected type for vector");
}
else if constexpr (std::is_same_v<T, Function>)
throw std::runtime_error("unexpected function as struct member");
else if constexpr (std::is_same_v<T, Identifier>)
throw std::runtime_error("unexpected identifier");
else if constexpr (std::is_same_v<T, Image> || std::is_same_v<T, SampledImage>)
throw std::runtime_error("unexpected opaque type as struct member");
else if constexpr (std::is_same_v<T, Void>)
throw std::runtime_error("unexpected void as struct member");
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, member.type->type);
annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::Offset, offset);
}
m_internal->structureSizes.emplace(structData, std::move(structOffsets));
}
}