From b8a52b93e8abdf268af09710377cf22d3fe7150d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Sun, 23 Jan 2022 19:59:26 +0100 Subject: [PATCH] Shader/SPIRV: Handle arrays properly --- include/Nazara/Shader/SpirvConstantCache.hpp | 17 + src/Nazara/Shader/SpirvConstantCache.cpp | 334 +++++++++++++------ 2 files changed, 243 insertions(+), 108 deletions(-) diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index a22da4a9a..ad1924956 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -20,6 +20,7 @@ namespace Nz { + class FieldOffsets; class SpirvSection; class NAZARA_SHADER_API SpirvConstantCache @@ -108,6 +109,7 @@ namespace Nz { std::string name; TypePtr type; + mutable std::optional offset; }; std::string name; @@ -171,6 +173,7 @@ namespace Nz }; ConstantPtr BuildConstant(const ShaderAst::ConstantValue& value) const; + FieldOffsets BuildFieldOffsets(const Structure& structData) const; TypePtr BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters) const; TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const; TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const; @@ -195,6 +198,20 @@ namespace Nz UInt32 Register(Type t); UInt32 Register(Variable v); + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Array& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Bool& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Float& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Function& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Image& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Integer& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Matrix& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Pointer& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const SampledImage& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Structure& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Type& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Vector& type, std::size_t arrayLength) const; + std::size_t RegisterArrayField(FieldOffsets& fieldOffsets, const Void& type, std::size_t arrayLength) const; + void SetStructCallback(StructCallback callback); void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index d3b9eb027..2d65f26e1 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -16,7 +17,27 @@ namespace Nz { template struct overloaded : Ts... { using Ts::operator()...; }; - template overloaded(Ts...)->overloaded; + template overloaded(Ts...) -> overloaded; + + StructFieldType TypeToStructFieldType(const SpirvConstantCache::AnyType& type) + { + if (std::holds_alternative(type)) + return StructFieldType::Bool1; + else if (std::holds_alternative(type)) + { + const auto& floatType = std::get(type); + assert(floatType.width == 32 || floatType.width == 64); + return (floatType.width == 32) ? StructFieldType::Float1 : StructFieldType::Double1; + } + else if (std::holds_alternative(type)) + { + const auto& intType = std::get(type); + assert(intType.width == 32); + return (intType.signedness) ? StructFieldType::Int1 : StructFieldType::UInt1; + } + + throw std::runtime_error("unexpected type"); + } } struct SpirvConstantCache::Eq @@ -278,6 +299,7 @@ namespace Nz void Register(const Structure& s) { Register(s.members); + cache.BuildFieldOffsets(s); } void Register(const SpirvConstantCache::Structure::Member& m) @@ -408,6 +430,12 @@ namespace Nz struct SpirvConstantCache::Internal { + struct StructOffsets + { + FieldOffsets fieldOffsets; + std::vector offsets; + }; + Internal(UInt32& resultId) : nextResultId(resultId) { @@ -415,7 +443,6 @@ namespace Nz tsl::ordered_map, UInt32 /*id*/, AnyHasher, Eq> ids; tsl::ordered_map variableIds; - tsl::ordered_map structureSizes; StructCallback structCallback; UInt32& nextResultId; bool isInBlockStruct = false; @@ -480,6 +507,106 @@ namespace Nz }, value)); } + FieldOffsets SpirvConstantCache::BuildFieldOffsets(const Structure& structData) const + { + FieldOffsets structOffsets(StructLayout::Std140); + + for (const Structure::Member& member : structData.members) + { + member.offset = std::visit([&](auto&& arg) -> std::size_t + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + { + assert(std::holds_alternative(arg.length->constant)); + const auto& scalar = std::get(arg.length->constant); + assert(std::holds_alternative(scalar.value)); + std::size_t length = std::get(scalar.value); + + return RegisterArrayField(structOffsets, arg.elementType->type, length); + } + else if constexpr (std::is_same_v) + return structOffsets.AddField(StructFieldType::Bool1); + else if constexpr (std::is_same_v) + { + switch (arg.width) + { + case 32: return structOffsets.AddField(StructFieldType::Float1); + case 64: return structOffsets.AddField(StructFieldType::Double1); + default: throw std::runtime_error("unexpected float width " + std::to_string(arg.width)); + } + } + else if constexpr (std::is_same_v) + return structOffsets.AddField((arg.signedness) ? StructFieldType::Int1 : StructFieldType::UInt1); + else if constexpr (std::is_same_v) + { + assert(std::holds_alternative(arg.columnType->type)); + Vector& columnVec = std::get(arg.columnType->type); + + if (!std::holds_alternative(columnVec.componentType->type)) + throw std::runtime_error("unexpected vector type"); + + Float& vecType = std::get(columnVec.componentType->type); + + StructFieldType columnType; + switch (vecType.width) + { + case 32: columnType = StructFieldType::Float1; break; + case 64: columnType = StructFieldType::Double1; break; + default: throw std::runtime_error("unexpected float width " + std::to_string(vecType.width)); + } + + return structOffsets.AddMatrix(columnType, arg.columnCount, columnVec.componentCount, true); + } + else if constexpr (std::is_same_v) + throw std::runtime_error("unhandled pointer in struct"); + else if constexpr (std::is_same_v) + return structOffsets.AddStruct(BuildFieldOffsets(arg)); + else if constexpr (std::is_same_v) + { + if (std::holds_alternative(arg.componentType->type)) + return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Bool1) + arg.componentCount - 1)); + else if (std::holds_alternative(arg.componentType->type)) + { + Float& floatData = std::get(arg.componentType->type); + switch (floatData.width) + { + case 32: return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Float1) + arg.componentCount - 1)); + case 64: return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Double1) + arg.componentCount - 1)); + default: throw std::runtime_error("unexpected float width " + std::to_string(floatData.width)); + } + } + else if (std::holds_alternative(arg.componentType->type)) + { + Integer& intData = std::get(arg.componentType->type); + if (intData.width != 32) + throw std::runtime_error("unexpected integer width " + std::to_string(intData.width)); + + if (intData.signedness) + return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Int1) + arg.componentCount - 1)); + else + return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::UInt1) + arg.componentCount - 1)); + } + else + throw std::runtime_error("unexpected type for vector"); + } + else if constexpr (std::is_same_v) + throw std::runtime_error("unexpected function as struct member"); + else if constexpr (std::is_same_v) + throw std::runtime_error("unexpected identifier"); + else if constexpr (std::is_same_v || std::is_same_v) + throw std::runtime_error("unexpected opaque type as struct member"); + else if constexpr (std::is_same_v) + throw std::runtime_error("unexpected void as struct member"); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, member.type->type); + } + + return structOffsets; + } + auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters) const -> TypePtr { std::vector parameterTypes; @@ -528,10 +655,24 @@ namespace Nz auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr { + const auto& containedType = type.containedType->type; + + TypePtr builtContainedType = BuildType(containedType); + + // ArrayStride + std::optional arrayStride; + if (m_internal->isInBlockStruct) + { + FieldOffsets fieldOffset(StructLayout::Std140); + RegisterArrayField(fieldOffset, builtContainedType->type, 1); + + arrayStride = SafeCast(fieldOffset.GetAlignedSize()); + } + return std::make_shared(Array{ - BuildType(type.containedType->type), + builtContainedType, BuildConstant(type.length.GetResultingValue()), - (m_internal->isInBlockStruct) ? std::make_optional(16) : std::nullopt + arrayStride }); } @@ -759,6 +900,83 @@ namespace Nz return it.value(); } + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Array& type, std::size_t arrayLength) const + { + FieldOffsets dummyStruct(fieldOffsets.GetLayout()); + RegisterArrayField(dummyStruct, type.elementType->type, std::get(std::get(type.length->constant).value)); + + return fieldOffsets.AddStructArray(dummyStruct, arrayLength); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Bool& type, std::size_t arrayLength) const + { + return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Float& type, std::size_t arrayLength) const + { + return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Function& type, std::size_t arrayLength) const + { + throw std::runtime_error("unexpected Function"); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Image& type, std::size_t arrayLength) const + { + throw std::runtime_error("unexpected Image"); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Integer& type, std::size_t arrayLength) const + { + return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Matrix& type, std::size_t arrayLength) const + { + if (!std::holds_alternative(type.columnType->type)) + throw std::runtime_error("unexpected column type"); + + const Vector& vecType = std::get(type.columnType->type); + return fieldOffsets.AddMatrixArray(TypeToStructFieldType(vecType.componentType->type), type.columnCount, vecType.componentCount, true, arrayLength); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Pointer& /*type*/, std::size_t /*arrayLength*/) const + { + throw std::runtime_error("unexpected Pointer (not implemented)"); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const SampledImage& /*type*/, std::size_t /*arrayLength*/) const + { + throw std::runtime_error("unexpected SampledImage"); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Structure& type, std::size_t arrayLength) const + { + auto innerFieldOffset = BuildFieldOffsets(type); + return fieldOffsets.AddStructArray(innerFieldOffset, arrayLength); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Type& type, std::size_t arrayLength) const + { + return std::visit([&](auto&& arg) -> std::size_t + { + return RegisterArrayField(fieldOffsets, arg, arrayLength); + }, type.type); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Vector& type, std::size_t arrayLength) const + { + assert(type.componentCount > 0 && type.componentCount <= 4); + return fieldOffsets.AddFieldArray(static_cast(UnderlyingCast(TypeToStructFieldType(type.componentType->type)) + type.componentCount), arrayLength); + } + + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Void& type, std::size_t arrayLength) const + { + throw std::runtime_error("unexpected Void"); + } + void SpirvConstantCache::SetStructCallback(StructCallback callback) { m_internal->structCallback = std::move(callback); @@ -941,125 +1159,25 @@ namespace Nz for (SpirvDecoration decoration : structData.decorations) annotations.Append(SpirvOp::OpDecorate, resultId, decoration); - FieldOffsets structOffsets(StructLayout::Std140); - for (std::size_t memberIndex = 0; memberIndex < structData.members.size(); ++memberIndex) { const auto& member = structData.members[memberIndex]; debugInfos.Append(SpirvOp::OpMemberName, resultId, memberIndex, member.name); - std::size_t offset = std::visit([&](auto&& arg) -> std::size_t + UInt32 offset = member.offset.value(); + + std::visit([&](auto&& arg) { using T = std::decay_t; - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) { - assert(std::holds_alternative(arg.length->constant)); - const auto& scalar = std::get(arg.length->constant); - assert(std::holds_alternative(scalar.value)); - std::size_t length = std::get(scalar.value); - - if (!std::holds_alternative(arg.elementType->type)) - throw std::runtime_error("todo"); - - // FIXME: Virer cette implémentation du ghetto - - const Float& fData = std::get(arg.elementType->type); - switch (fData.width) - { - case 32: return structOffsets.AddFieldArray(StructFieldType::Float1, length); - case 64: return structOffsets.AddFieldArray(StructFieldType::Double1, length); - default: throw std::runtime_error("unexpected float width " + std::to_string(fData.width)); - } - } - else if constexpr (std::is_same_v) - return structOffsets.AddField(StructFieldType::Bool1); - else if constexpr (std::is_same_v) - { - switch (arg.width) - { - case 32: return structOffsets.AddField(StructFieldType::Float1); - case 64: return structOffsets.AddField(StructFieldType::Double1); - default: throw std::runtime_error("unexpected float width " + std::to_string(arg.width)); - } - } - else if constexpr (std::is_same_v) - return structOffsets.AddField((arg.signedness) ? StructFieldType::Int1 : StructFieldType::UInt1); - else if constexpr (std::is_same_v) - { - assert(std::holds_alternative(arg.columnType->type)); - Vector& columnVec = std::get(arg.columnType->type); - - if (!std::holds_alternative(columnVec.componentType->type)) - throw std::runtime_error("unexpected vector type"); - - Float& vecType = std::get(columnVec.componentType->type); - - StructFieldType columnType; - switch (vecType.width) - { - case 32: columnType = StructFieldType::Float1; break; - case 64: columnType = StructFieldType::Double1; break; - default: throw std::runtime_error("unexpected float width " + std::to_string(vecType.width)); - } - annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::ColMajor); annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::MatrixStride, 16); - - return structOffsets.AddMatrix(columnType, arg.columnCount, columnVec.componentCount, true); } - else if constexpr (std::is_same_v) - throw std::runtime_error("unhandled pointer in struct"); - else if constexpr (std::is_same_v) - { - auto it = m_internal->structureSizes.find(arg); - assert(it != m_internal->structureSizes.end()); - - return structOffsets.AddStruct(it->second); - } - else if constexpr (std::is_same_v) - { - if (std::holds_alternative(arg.componentType->type)) - return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Bool1) + arg.componentCount - 1)); - else if (std::holds_alternative(arg.componentType->type)) - { - Float& floatData = std::get(arg.componentType->type); - switch (floatData.width) - { - case 32: return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Float1) + arg.componentCount - 1)); - case 64: return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Double1) + arg.componentCount - 1)); - default: throw std::runtime_error("unexpected float width " + std::to_string(floatData.width)); - } - } - else if (std::holds_alternative(arg.componentType->type)) - { - Integer& intData = std::get(arg.componentType->type); - if (intData.width != 32) - throw std::runtime_error("unexpected integer width " + std::to_string(intData.width)); - - if (intData.signedness) - return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Int1) + arg.componentCount - 1)); - else - return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::UInt1) + arg.componentCount - 1)); - } - else - throw std::runtime_error("unexpected type for vector"); - } - else if constexpr (std::is_same_v) - throw std::runtime_error("unexpected function as struct member"); - else if constexpr (std::is_same_v) - throw std::runtime_error("unexpected identifier"); - else if constexpr (std::is_same_v || std::is_same_v) - throw std::runtime_error("unexpected opaque type as struct member"); - else if constexpr (std::is_same_v) - throw std::runtime_error("unexpected void as struct member"); - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, member.type->type); annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::Offset, offset); } - - m_internal->structureSizes.emplace(structData, std::move(structOffsets)); } }