// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) // This file is part of the "Nazara Engine - Shader module" // For conditions of distribution and use, see copyright notice in Config.hpp #include #include #include #include #include #include #include #include namespace Nz { namespace { template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...) -> overloaded; StructFieldType TypeToStructFieldType(const SpirvConstantCache::AnyType& type) { if (std::holds_alternative(type)) return StructFieldType::Bool1; else if (std::holds_alternative(type)) { const auto& floatType = std::get(type); assert(floatType.width == 32 || floatType.width == 64); return (floatType.width == 32) ? StructFieldType::Float1 : StructFieldType::Double1; } else if (std::holds_alternative(type)) { const auto& intType = std::get(type); assert(intType.width == 32); return (intType.signedness) ? StructFieldType::Int1 : StructFieldType::UInt1; } throw std::runtime_error("unexpected type"); } } struct SpirvConstantCache::Eq { bool Compare(const ConstantBool& lhs, const ConstantBool& rhs) const { return lhs.value == rhs.value; } bool Compare(const ConstantComposite& lhs, const ConstantComposite& rhs) const { return Compare(lhs.type, rhs.type) && Compare(lhs.values, rhs.values); } bool Compare(const ConstantScalar& lhs, const ConstantScalar& rhs) const { return lhs.value == rhs.value; } bool Compare(const Array& lhs, const Array& rhs) const { return Compare(lhs.length, rhs.length) && Compare(lhs.elementType, rhs.elementType) && lhs.stride == rhs.stride; } bool Compare(const Bool& /*lhs*/, const Bool& /*rhs*/) const { return true; } bool Compare(const Float& lhs, const Float& rhs) const { return lhs.width == rhs.width; } bool Compare(const Function& lhs, const Function& rhs) const { return Compare(lhs.parameters, rhs.parameters) && Compare(lhs.returnType, rhs.returnType); } bool Compare(const Image& lhs, const Image& rhs) const { return lhs.arrayed == rhs.arrayed && lhs.dim == rhs.dim && lhs.format == rhs.format && lhs.multisampled == rhs.multisampled && lhs.qualifier == rhs.qualifier && Compare(lhs.sampledType, rhs.sampledType) && lhs.depth == rhs.depth && lhs.sampled == rhs.sampled; } bool Compare(const Integer& lhs, const Integer& rhs) const { return lhs.width == rhs.width && lhs.signedness == rhs.signedness; } bool Compare(const Matrix& lhs, const Matrix& rhs) const { return lhs.columnCount == rhs.columnCount && Compare(lhs.columnType, rhs.columnType); } bool Compare(const Pointer& lhs, const Pointer& rhs) const { return lhs.storageClass == rhs.storageClass && Compare(lhs.type, rhs.type); } bool Compare(const SampledImage& lhs, const SampledImage& rhs) const { return Compare(lhs.image, rhs.image); } bool Compare(const Structure& lhs, const Structure& rhs) const { if (lhs.name != rhs.name) return false; if (lhs.decorations != rhs.decorations) return false; if (!Compare(lhs.members, rhs.members)) return false; return true; } bool Compare(const Structure::Member& lhs, const Structure::Member& rhs) const { if (!Compare(lhs.type, rhs.type)) return false; if (lhs.name != rhs.name) return false; return true; } bool Compare(const Variable& lhs, const Variable& rhs) const { if (lhs.debugName != rhs.debugName) return false; if (lhs.funcId != rhs.funcId) return false; if (!Compare(lhs.initializer, rhs.initializer)) return false; if (lhs.storageClass != rhs.storageClass) return false; if (!Compare(lhs.type, rhs.type)) return false; return true; } bool Compare(const Vector& lhs, const Vector& rhs) const { return Compare(lhs.componentType, rhs.componentType) && lhs.componentCount == rhs.componentCount; } bool Compare(const Void& /*lhs*/, const Void& /*rhs*/) const { return true; } bool Compare(const Constant& lhs, const Constant& rhs) const { return Compare(lhs.constant, rhs.constant); } bool Compare(const Type& lhs, const Type& rhs) const { return Compare(lhs.type, rhs.type); } template bool Compare(const std::optional& lhs, const std::optional& rhs) const { if (lhs.has_value() != rhs.has_value()) return false; if (!lhs.has_value()) return true; return Compare(*lhs, *rhs); } template bool Compare(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const { if (bool(lhs) != bool(rhs)) return false; if (!lhs) return true; return Compare(*lhs, *rhs); } template bool Compare(const std::variant& lhs, const std::variant& rhs) const { if (lhs.index() != rhs.index()) return false; return std::visit([&](auto&& arg) { using U = std::decay_t; return Compare(arg, std::get(rhs)); }, lhs); } template bool Compare(const std::vector& lhs, const std::vector& rhs) const { if (lhs.size() != rhs.size()) return false; for (std::size_t i = 0; i < lhs.size(); ++i) { if (!Compare(lhs[i], rhs[i])) return false; } return true; } template bool Compare(const std::unique_ptr& lhs, const std::unique_ptr& rhs) const { if (bool(lhs) != bool(rhs)) return false; if (!lhs) return true; return Compare(*lhs, *rhs); } template bool operator()(const T& lhs, const T& rhs) const { return Compare(lhs, rhs); } }; struct SpirvConstantCache::DepRegisterer { DepRegisterer(SpirvConstantCache& c) : cache(c) { } void Register(const Array& array) { assert(array.elementType); cache.Register(*array.elementType); assert(array.length); cache.Register(*array.length); } void Register(const Bool&) {} void Register(const Float&) {} void Register(const Integer&) {} void Register(const Void&) {} void Register(const Image& image) { cache.Register(*image.sampledType); } void Register(const Function& func) { cache.Register(*func.returnType); Register(func.parameters); } void Register(const Matrix& vec) { assert(vec.columnType); cache.Register(*vec.columnType); } void Register(const Pointer& ptr) { assert(ptr.type); cache.Register(*ptr.type); } void Register(const SampledImage& sampledImage) { assert(sampledImage.image); cache.Register(*sampledImage.image); } void Register(const Structure& s) { Register(s.members); cache.BuildFieldOffsets(s); } void Register(const SpirvConstantCache::Structure::Member& m) { cache.Register(*m.type); } void Register(const Variable& variable) { assert(variable.type); cache.Register(*variable.type); } void Register(const Vector& vec) { assert(vec.componentType); cache.Register(*vec.componentType); } void Register(const ConstantBool&) { cache.Register({ Bool{} }); } void Register(const ConstantScalar& scalar) { std::visit([&](auto&& arg) { using T = std::decay_t; if constexpr (std::is_same_v) cache.Register({ Float{ 64 } }); else if constexpr (std::is_same_v) cache.Register({ Float{ 32 } }); else if constexpr (std::is_same_v) cache.Register({ Integer{ 32, true } }); else if constexpr (std::is_same_v) cache.Register({ Integer{ 64, true } }); else if constexpr (std::is_same_v) cache.Register({ Integer{ 32, false } }); else if constexpr (std::is_same_v) cache.Register({ Integer{ 64, false } }); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, scalar.value); } void Register(const ConstantComposite& composite) { assert(composite.type); cache.Register(*composite.type); for (auto&& value : composite.values) { assert(value); cache.Register(*value); } } void Register(const Constant& c) { return Register(c.constant); } void Register(const Type& t) { return Register(t.type); } template void Register(const std::shared_ptr& ptr) { assert(ptr); return Register(*ptr); } template void Register(const std::optional& opt) { if (opt) Register(*opt); } template void Register(const std::variant& v) { return std::visit([&](auto&& arg) { return Register(arg); }, v); } void Register(const std::vector& lhs) { for (std::size_t i = 0; i < lhs.size(); ++i) cache.Register(*lhs[i]); } template void Register(const std::vector& lhs) { for (std::size_t i = 0; i < lhs.size(); ++i) Register(lhs[i]); } template void Register(const std::unique_ptr& lhs) { assert(lhs); return Register(*lhs); } SpirvConstantCache& cache; }; //< FIXME PLZ struct AnyHasher { template std::size_t operator()(const U&) const { return 42; } }; struct SpirvConstantCache::Internal { struct StructOffsets { FieldOffsets fieldOffsets; std::vector offsets; }; Internal(UInt32& resultId) : nextResultId(resultId) { } tsl::ordered_map, UInt32 /*id*/, AnyHasher, Eq> ids; tsl::ordered_map variableIds; StructCallback structCallback; UInt32& nextResultId; bool isInBlockStruct = false; }; SpirvConstantCache::SpirvConstantCache(UInt32& resultId) { m_internal = std::make_unique(resultId); } SpirvConstantCache::SpirvConstantCache(SpirvConstantCache&& cache) noexcept = default; SpirvConstantCache::~SpirvConstantCache() = default; auto SpirvConstantCache::BuildConstant(const ShaderAst::ConstantValue& value) const -> ConstantPtr { return std::make_shared(std::visit([&](auto&& arg) -> SpirvConstantCache::AnyConstant { using T = std::decay_t; if constexpr (std::is_same_v) throw std::runtime_error("invalid type (value expected)"); else if constexpr (std::is_same_v) throw std::runtime_error("unexpected string litteral"); else if constexpr (std::is_same_v) return ConstantBool{ arg }; else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) return ConstantScalar{ arg }; else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ BuildType(ShaderAst::VectorType{ 2, (std::is_same_v) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }), { BuildConstant(arg.x), BuildConstant(arg.y) } }; } else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ BuildType(ShaderAst::VectorType{ 3, (std::is_same_v) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }), { BuildConstant(arg.x), BuildConstant(arg.y), BuildConstant(arg.z) } }; } else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ BuildType(ShaderAst::VectorType{ 4, (std::is_same_v) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }), { BuildConstant(arg.x), BuildConstant(arg.y), BuildConstant(arg.z), BuildConstant(arg.w) } }; } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, value)); } FieldOffsets SpirvConstantCache::BuildFieldOffsets(const Structure& structData) const { FieldOffsets structOffsets(StructLayout::Std140); for (const Structure::Member& member : structData.members) { member.offset = SafeCast(std::visit([&](auto&& arg) -> std::size_t { using T = std::decay_t; if constexpr (std::is_same_v) { assert(std::holds_alternative(arg.length->constant)); const auto& scalar = std::get(arg.length->constant); assert(std::holds_alternative(scalar.value)); std::size_t length = std::get(scalar.value); return RegisterArrayField(structOffsets, arg.elementType->type, length); } else if constexpr (std::is_same_v) return structOffsets.AddField(StructFieldType::Bool1); else if constexpr (std::is_same_v) { switch (arg.width) { case 32: return structOffsets.AddField(StructFieldType::Float1); case 64: return structOffsets.AddField(StructFieldType::Double1); default: throw std::runtime_error("unexpected float width " + std::to_string(arg.width)); } } else if constexpr (std::is_same_v) return structOffsets.AddField((arg.signedness) ? StructFieldType::Int1 : StructFieldType::UInt1); else if constexpr (std::is_same_v) { assert(std::holds_alternative(arg.columnType->type)); Vector& columnVec = std::get(arg.columnType->type); if (!std::holds_alternative(columnVec.componentType->type)) throw std::runtime_error("unexpected vector type"); Float& vecType = std::get(columnVec.componentType->type); StructFieldType columnType; switch (vecType.width) { case 32: columnType = StructFieldType::Float1; break; case 64: columnType = StructFieldType::Double1; break; default: throw std::runtime_error("unexpected float width " + std::to_string(vecType.width)); } return structOffsets.AddMatrix(columnType, arg.columnCount, columnVec.componentCount, true); } else if constexpr (std::is_same_v) throw std::runtime_error("unhandled pointer in struct"); else if constexpr (std::is_same_v) return structOffsets.AddStruct(BuildFieldOffsets(arg)); else if constexpr (std::is_same_v) { if (std::holds_alternative(arg.componentType->type)) return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Bool1) + arg.componentCount - 1)); else if (std::holds_alternative(arg.componentType->type)) { Float& floatData = std::get(arg.componentType->type); switch (floatData.width) { case 32: return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Float1) + arg.componentCount - 1)); case 64: return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Double1) + arg.componentCount - 1)); default: throw std::runtime_error("unexpected float width " + std::to_string(floatData.width)); } } else if (std::holds_alternative(arg.componentType->type)) { Integer& intData = std::get(arg.componentType->type); if (intData.width != 32) throw std::runtime_error("unexpected integer width " + std::to_string(intData.width)); if (intData.signedness) return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Int1) + arg.componentCount - 1)); else return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::UInt1) + arg.componentCount - 1)); } else throw std::runtime_error("unexpected type for vector"); } else if constexpr (std::is_same_v) throw std::runtime_error("unexpected function as struct member"); else if constexpr (std::is_same_v) throw std::runtime_error("unexpected identifier"); else if constexpr (std::is_same_v || std::is_same_v) throw std::runtime_error("unexpected opaque type as struct member"); else if constexpr (std::is_same_v) throw std::runtime_error("unexpected void as struct member"); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, member.type->type)); } return structOffsets; } auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters) const -> TypePtr { std::vector parameterTypes; parameterTypes.reserve(parameters.size()); for (const auto& parameterType : parameters) parameterTypes.push_back(BuildPointerType(parameterType, SpirvStorageClass::Function)); return std::make_shared(Function{ BuildType(retType), std::move(parameterTypes) }); } auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const -> TypePtr { bool wasInblockStruct = m_internal->isInBlockStruct; if (storageClass == SpirvStorageClass::Uniform) m_internal->isInBlockStruct = true; auto typePtr = std::make_shared(Pointer{ BuildType(type), storageClass }); m_internal->isInBlockStruct = wasInblockStruct; return typePtr; } auto SpirvConstantCache::BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const -> TypePtr { bool wasInblockStruct = m_internal->isInBlockStruct; if (storageClass == SpirvStorageClass::Uniform) m_internal->isInBlockStruct = true; auto typePtr = std::make_shared(Pointer{ type, storageClass }); m_internal->isInBlockStruct = wasInblockStruct; return typePtr; } auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr { const auto& containedType = type.containedType->type; TypePtr builtContainedType = BuildType(containedType); // ArrayStride std::optional arrayStride; if (m_internal->isInBlockStruct) { FieldOffsets fieldOffset(StructLayout::Std140); RegisterArrayField(fieldOffset, builtContainedType->type, 1); arrayStride = SafeCast(fieldOffset.GetAlignedSize()); } return std::make_shared(Array{ builtContainedType, BuildConstant(type.length), arrayStride }); } auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr { bool wasInblockStruct = m_internal->isInBlockStruct; if (storageClass == SpirvStorageClass::Uniform) m_internal->isInBlockStruct = true; auto typePtr = std::make_shared(Pointer{ BuildType(type), storageClass }); m_internal->isInBlockStruct = wasInblockStruct; return typePtr; } auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr { return std::visit([&](auto&& arg) -> TypePtr { return BuildType(arg); }, type); } auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& /*type*/) const -> TypePtr { // No IdentifierType is expected (as they should have been resolved by now) throw std::runtime_error("unexpected identifier"); } auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) const -> TypePtr { return std::make_shared([&]() -> AnyType { switch (type) { case ShaderAst::PrimitiveType::Boolean: return Bool{}; case ShaderAst::PrimitiveType::Float32: return Float{ 32 }; case ShaderAst::PrimitiveType::Int32: return Integer{ 32, true }; case ShaderAst::PrimitiveType::UInt32: return Integer{ 32, false }; case ShaderAst::PrimitiveType::String: break; } throw std::runtime_error("unexpected type"); }()); } auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) const -> TypePtr { return std::make_shared( Matrix{ BuildType(ShaderAst::VectorType { UInt32(type.rowCount), type.type }), UInt32(type.columnCount) }); } auto SpirvConstantCache::BuildType(const ShaderAst::NoType& /*type*/) const -> TypePtr { return std::make_shared(Void{}); } auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) const -> TypePtr { Image imageType; imageType.sampled = true; imageType.sampledType = BuildType(type.sampledType); switch (type.dim) { case ImageType::Cubemap: imageType.dim = SpirvDim::Cube; break; case ImageType::E1D_Array: imageType.arrayed = true; [[fallthrough]]; case ImageType::E1D: imageType.dim = SpirvDim::Dim1D; break; case ImageType::E2D_Array: imageType.arrayed = true; [[fallthrough]]; case ImageType::E2D: imageType.dim = SpirvDim::Dim2D; break; case ImageType::E3D: imageType.dim = SpirvDim::Dim3D; break; } return std::make_shared(SampledImage{ std::make_shared(imageType) }); } auto SpirvConstantCache::BuildType(const ShaderAst::StructType& type) const -> TypePtr { assert(m_internal->structCallback); return BuildType(m_internal->structCallback(type.structIndex)); } auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc, std::vector decorations) const -> TypePtr { Structure sType; sType.name = structDesc.name; sType.decorations = std::move(decorations); bool wasInBlock = m_internal->isInBlockStruct; if (!wasInBlock) m_internal->isInBlockStruct = std::find(sType.decorations.begin(), sType.decorations.end(), SpirvDecoration::Block) != sType.decorations.end(); for (const auto& member : structDesc.members) { if (member.cond.HasValue() && !member.cond.GetResultingValue()) continue; auto& sMembers = sType.members.emplace_back(); sMembers.name = member.name; sMembers.type = BuildType(member.type.GetResultingValue()); } m_internal->isInBlockStruct = wasInBlock; return std::make_shared(std::move(sType)); } auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) const -> TypePtr { return std::make_shared(Vector{ BuildType(type.type), UInt32(type.componentCount) }); } auto SpirvConstantCache::BuildType(const ShaderAst::UniformType& type) const -> TypePtr { return BuildType(type.containedType); } UInt32 SpirvConstantCache::GetId(const Constant& c) { auto it = m_internal->ids.find(c.constant); if (it == m_internal->ids.end()) throw std::runtime_error("constant is not registered"); return it->second; } UInt32 SpirvConstantCache::GetId(const Type& t) { auto it = m_internal->ids.find(t.type); if (it == m_internal->ids.end()) throw std::runtime_error("type is not registered"); return it->second; } UInt32 SpirvConstantCache::GetId(const Variable& v) { auto it = m_internal->variableIds.find(v); if (it == m_internal->variableIds.end()) throw std::runtime_error("variable is not registered"); return it->second; } UInt32 SpirvConstantCache::Register(Constant c) { AnyConstant& constant = c.constant; DepRegisterer registerer(*this); registerer.Register(constant); std::size_t h = m_internal->ids.hash_function()(constant); auto it = m_internal->ids.find(constant, h); if (it == m_internal->ids.end()) { UInt32 resultId = m_internal->nextResultId++; it = m_internal->ids.emplace(std::move(constant), resultId).first; } return it.value(); } UInt32 SpirvConstantCache::Register(Type t) { AnyType& type = t.type; DepRegisterer registerer(*this); registerer.Register(type); std::size_t h = m_internal->ids.hash_function()(type); auto it = m_internal->ids.find(type, h); if (it == m_internal->ids.end()) { UInt32 resultId = m_internal->nextResultId++; it = m_internal->ids.emplace(std::move(type), resultId).first; } return it.value(); } UInt32 SpirvConstantCache::Register(Variable v) { DepRegisterer registerer(*this); registerer.Register(v); std::size_t h = m_internal->variableIds.hash_function()(v); auto it = m_internal->variableIds.find(v, h); if (it == m_internal->variableIds.end()) { UInt32 resultId = m_internal->nextResultId++; it = m_internal->variableIds.emplace(std::move(v), resultId).first; } return it.value(); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Array& type, std::size_t arrayLength) const { FieldOffsets dummyStruct(fieldOffsets.GetLayout()); RegisterArrayField(dummyStruct, type.elementType->type, std::get(std::get(type.length->constant).value)); return fieldOffsets.AddStructArray(dummyStruct, arrayLength); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Bool& type, std::size_t arrayLength) const { return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Float& type, std::size_t arrayLength) const { return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Function& /*type*/, std::size_t /*arrayLength*/) const { throw std::runtime_error("unexpected Function"); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Image& /*type*/, std::size_t /*arrayLength*/) const { throw std::runtime_error("unexpected Image"); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Integer& type, std::size_t arrayLength) const { return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Matrix& type, std::size_t arrayLength) const { if (!std::holds_alternative(type.columnType->type)) throw std::runtime_error("unexpected column type"); const Vector& vecType = std::get(type.columnType->type); return fieldOffsets.AddMatrixArray(TypeToStructFieldType(vecType.componentType->type), type.columnCount, vecType.componentCount, true, arrayLength); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Pointer& /*type*/, std::size_t /*arrayLength*/) const { throw std::runtime_error("unexpected Pointer (not implemented)"); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const SampledImage& /*type*/, std::size_t /*arrayLength*/) const { throw std::runtime_error("unexpected SampledImage"); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Structure& type, std::size_t arrayLength) const { auto innerFieldOffset = BuildFieldOffsets(type); return fieldOffsets.AddStructArray(innerFieldOffset, arrayLength); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Type& type, std::size_t arrayLength) const { return std::visit([&](auto&& arg) -> std::size_t { return RegisterArrayField(fieldOffsets, arg, arrayLength); }, type.type); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Vector& type, std::size_t arrayLength) const { assert(type.componentCount > 0 && type.componentCount <= 4); return fieldOffsets.AddFieldArray(static_cast(UnderlyingCast(TypeToStructFieldType(type.componentType->type)) + type.componentCount), arrayLength); } std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Void& type, std::size_t arrayLength) const { throw std::runtime_error("unexpected Void"); } void SpirvConstantCache::SetStructCallback(StructCallback callback) { m_internal->structCallback = std::move(callback); } void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) { for (auto&& [object, id] : m_internal->ids) { UInt32 resultId = id; std::visit(overloaded { [&](const AnyConstant& constant) { Write(constant, resultId, constants); }, [&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); }, }, object); } for (auto&& [variable, id] : m_internal->variableIds) { const auto& var = variable; UInt32 resultId = id; if (!variable.debugName.empty()) debugInfos.Append(SpirvOp::OpName, resultId, variable.debugName); constants.AppendVariadic(SpirvOp::OpVariable, [&](const auto& appender) { appender(GetId(*var.type)); appender(resultId); appender(var.storageClass); if (var.initializer) appender(GetId((*var.initializer)->constant)); }); } } SpirvConstantCache& SpirvConstantCache::operator=(SpirvConstantCache&& cache) noexcept = default; void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants) { std::visit([&](auto&& arg) { using ConstantType = std::decay_t; if constexpr (std::is_same_v) constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, GetId({ Bool{} }), resultId); else if constexpr (std::is_same_v) { constants.AppendVariadic(SpirvOp::OpConstantComposite, [&](const auto& appender) { appender(GetId(arg.type->type)); appender(resultId); for (const auto& value : arg.values) appender(GetId(value->constant)); }); } else if constexpr (std::is_same_v) { std::visit([&](auto&& value) { using ValueType = std::decay_t; UInt32 typeId; if constexpr (std::is_same_v) typeId = GetId({ Float{ 64 } }); else if constexpr (std::is_same_v) typeId = GetId({ Float{ 32 } }); else if constexpr (std::is_same_v) typeId = GetId({ Integer{ 32, true } }); else if constexpr (std::is_same_v) typeId = GetId({ Integer{ 64, true } }); else if constexpr (std::is_same_v) typeId = GetId({ Integer{ 32, false } }); else if constexpr (std::is_same_v) typeId = GetId({ Integer{ 64, false } }); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); constants.Append(SpirvOp::OpConstant, typeId, resultId, SpirvSection::Raw{ &value, sizeof(value) }); }, arg.value); } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, constant); } void SpirvConstantCache::Write(const AnyType& type, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) { std::visit([&](auto&& arg) { using T = std::decay_t; if constexpr (std::is_same_v) { constants.Append(SpirvOp::OpTypeArray, resultId, GetId(*arg.elementType), GetId(*arg.length)); if (arg.stride) annotations.Append(SpirvOp::OpDecorate, resultId, SpirvDecoration::ArrayStride, *arg.stride); } else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeBool, resultId); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeFloat, resultId, arg.width); else if constexpr (std::is_same_v) { constants.AppendVariadic(SpirvOp::OpTypeFunction, [&](const auto& appender) { appender(resultId); appender(GetId(*arg.returnType)); for (const auto& param : arg.parameters) appender(GetId(*param)); }); } else if constexpr (std::is_same_v) throw std::runtime_error("unexpected identifier"); else if constexpr (std::is_same_v) { UInt32 depth; if (arg.depth.has_value()) depth = (*arg.depth) ? 1 : 0; else depth = 2; UInt32 sampled; if (arg.sampled.has_value()) sampled = (*arg.sampled) ? 1 : 2; //< Yes/No else sampled = 0; //< Dunno constants.AppendVariadic(SpirvOp::OpTypeImage, [&](const auto& appender) { appender(resultId); appender(GetId(*arg.sampledType)); appender(arg.dim); appender(depth); appender(arg.arrayed); appender(arg.multisampled); appender(sampled); appender(arg.format); if (arg.qualifier) appender(*arg.qualifier); }); } else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeInt, resultId, arg.width, arg.signedness); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeMatrix, resultId, GetId(*arg.columnType), arg.columnCount); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypePointer, resultId, arg.storageClass, GetId(*arg.type)); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeSampledImage, resultId, GetId(*arg.image)); else if constexpr (std::is_same_v) WriteStruct(arg, resultId, annotations, constants, debugInfos); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeVector, resultId, GetId(*arg.componentType), arg.componentCount); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeVoid, resultId); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, type); } void SpirvConstantCache::WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) { constants.AppendVariadic(SpirvOp::OpTypeStruct, [&](const auto& appender) { appender(resultId); for (const auto& member : structData.members) appender(GetId(*member.type)); }); debugInfos.Append(SpirvOp::OpName, resultId, structData.name); for (SpirvDecoration decoration : structData.decorations) annotations.Append(SpirvOp::OpDecorate, resultId, decoration); for (std::size_t memberIndex = 0; memberIndex < structData.members.size(); ++memberIndex) { const auto& member = structData.members[memberIndex]; debugInfos.Append(SpirvOp::OpMemberName, resultId, memberIndex, member.name); UInt32 offset = member.offset.value(); std::visit([&](auto&& arg) { using T = std::decay_t; if constexpr (std::is_same_v) { annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::ColMajor); annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::MatrixStride, 16); } }, member.type->type); annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::Offset, offset); } } }