// Copyright (C) 2020 Jérôme Leclercq // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp #include #include #include #include #include #include namespace Nz { namespace { template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...)->overloaded; } struct SpirvConstantCache::Eq { bool Compare(const ConstantBool& lhs, const ConstantBool& rhs) const { return lhs.value == rhs.value; } bool Compare(const ConstantComposite& lhs, const ConstantComposite& rhs) const { return Compare(lhs.type, rhs.type) && Compare(lhs.values, rhs.values); } bool Compare(const ConstantScalar& lhs, const ConstantScalar& rhs) const { return lhs.value == rhs.value; } bool Compare(const Bool& /*lhs*/, const Bool& /*rhs*/) const { return true; } bool Compare(const Float& lhs, const Float& rhs) const { return lhs.width == rhs.width; } bool Compare(const Function& lhs, const Function& rhs) const { return Compare(lhs.parameters, rhs.parameters) && Compare(lhs.returnType, rhs.returnType); } bool Compare(const Image& lhs, const Image& rhs) const { return lhs.arrayed == rhs.arrayed && lhs.dim == rhs.dim && lhs.format == rhs.format && lhs.multisampled == rhs.multisampled && lhs.qualifier == rhs.qualifier && Compare(lhs.sampledType, rhs.sampledType) && lhs.depth == rhs.depth && lhs.sampled == rhs.sampled; } bool Compare(const Integer& lhs, const Integer& rhs) const { return lhs.width == rhs.width && lhs.signedness == rhs.signedness; } bool Compare(const Matrix& lhs, const Matrix& rhs) const { return lhs.columnCount == rhs.columnCount && Compare(lhs.columnType, rhs.columnType); } bool Compare(const Pointer& lhs, const Pointer& rhs) const { return lhs.storageClass == rhs.storageClass && Compare(lhs.type, rhs.type); } bool Compare(const SampledImage& lhs, const SampledImage& rhs) const { return Compare(lhs.image, rhs.image); } bool Compare(const Structure& lhs, const Structure& rhs) const { if (lhs.name != rhs.name) return false; if (!Compare(lhs.members, rhs.members)) return false; return true; } bool Compare(const Structure::Member& lhs, const Structure::Member& rhs) const { if (!Compare(lhs.type, rhs.type)) return false; if (lhs.name != rhs.name) return false; return true; } bool Compare(const Variable& lhs, const Variable& rhs) const { if (lhs.debugName != rhs.debugName) return false; if (lhs.funcId != rhs.funcId) return false; if (!Compare(lhs.initializer, rhs.initializer)) return false; if (lhs.storageClass != rhs.storageClass) return false; if (!Compare(lhs.type, rhs.type)) return false; return true; } bool Compare(const Vector& lhs, const Vector& rhs) const { return Compare(lhs.componentType, rhs.componentType) && lhs.componentCount == rhs.componentCount; } bool Compare(const Void& /*lhs*/, const Void& /*rhs*/) const { return true; } bool Compare(const Constant& lhs, const Constant& rhs) const { return Compare(lhs.constant, rhs.constant); } bool Compare(const Type& lhs, const Type& rhs) const { return Compare(lhs.type, rhs.type); } template bool Compare(const std::optional& lhs, const std::optional& rhs) const { if (lhs.has_value() != rhs.has_value()) return false; if (!lhs.has_value()) return true; return Compare(*lhs, *rhs); } template bool Compare(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const { if (bool(lhs) != bool(rhs)) return false; if (!lhs) return true; return Compare(*lhs, *rhs); } template bool Compare(const std::variant& lhs, const std::variant& rhs) const { if (lhs.index() != rhs.index()) return false; return std::visit([&](auto&& arg) { using U = std::decay_t; return Compare(arg, std::get(rhs)); }, lhs); } template bool Compare(const std::vector& lhs, const std::vector& rhs) const { if (lhs.size() != rhs.size()) return false; for (std::size_t i = 0; i < lhs.size(); ++i) { if (!Compare(lhs[i], rhs[i])) return false; } return true; } template bool Compare(const std::unique_ptr& lhs, const std::unique_ptr& rhs) const { if (bool(lhs) != bool(rhs)) return false; if (!lhs) return true; return Compare(*lhs, *rhs); } template bool operator()(const T& lhs, const T& rhs) const { return Compare(lhs, rhs); } }; struct SpirvConstantCache::DepRegisterer { DepRegisterer(SpirvConstantCache& c) : cache(c) { } void Register(const Bool&) {} void Register(const Float&) {} void Register(const Integer&) {} void Register(const Void&) {} void Register(const Image& image) { cache.Register(*image.sampledType); } void Register(const Function& func) { cache.Register(*func.returnType); Register(func.parameters); } void Register(const Matrix& vec) { assert(vec.columnType); cache.Register(*vec.columnType); } void Register(const Pointer& ptr) { assert(ptr.type); cache.Register(*ptr.type); } void Register(const SampledImage& sampledImage) { assert(sampledImage.image); cache.Register(*sampledImage.image); } void Register(const Structure& s) { Register(s.members); } void Register(const SpirvConstantCache::Structure::Member& m) { cache.Register(*m.type); } void Register(const Variable& variable) { assert(variable.type); cache.Register(*variable.type); } void Register(const Vector& vec) { assert(vec.componentType); cache.Register(*vec.componentType); } void Register(const ConstantBool&) { cache.Register({ Bool{} }); } void Register(const ConstantScalar& scalar) { std::visit([&](auto&& arg) { using T = std::decay_t; if constexpr (std::is_same_v) cache.Register({ Float{ 64 } }); else if constexpr (std::is_same_v) cache.Register({ Float{ 32 } }); else if constexpr (std::is_same_v) cache.Register({ Integer{ 32, true } }); else if constexpr (std::is_same_v) cache.Register({ Integer{ 64, true } }); else if constexpr (std::is_same_v) cache.Register({ Integer{ 32, false } }); else if constexpr (std::is_same_v) cache.Register({ Integer{ 64, false } }); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, scalar.value); } void Register(const ConstantComposite& composite) { assert(composite.type); cache.Register(*composite.type); for (auto&& value : composite.values) { assert(value); cache.Register(*value); } } void Register(const Constant& c) { return Register(c.constant); } void Register(const Type& t) { return Register(t.type); } template void Register(const std::shared_ptr& ptr) { assert(ptr); return Register(*ptr); } template void Register(const std::optional& opt) { if (opt) Register(*opt); } template void Register(const std::variant& v) { return std::visit([&](auto&& arg) { return Register(arg); }, v); } void Register(const std::vector& lhs) { for (std::size_t i = 0; i < lhs.size(); ++i) cache.Register(*lhs[i]); } template void Register(const std::vector& lhs) { for (std::size_t i = 0; i < lhs.size(); ++i) Register(lhs[i]); } template void Register(const std::unique_ptr& lhs) { assert(lhs); return Register(*lhs); } SpirvConstantCache& cache; }; //< FIXME PLZ struct AnyHasher { template std::size_t operator()(const U&) const { return 42; } }; struct SpirvConstantCache::Internal { Internal(UInt32& resultId) : nextResultId(resultId) { } tsl::ordered_map, UInt32 /*id*/, AnyHasher, Eq> ids; tsl::ordered_map variableIds; tsl::ordered_map structureSizes; StructCallback structCallback; UInt32& nextResultId; }; SpirvConstantCache::SpirvConstantCache(UInt32& resultId) { m_internal = std::make_unique(resultId); } SpirvConstantCache::SpirvConstantCache(SpirvConstantCache&& cache) noexcept = default; SpirvConstantCache::~SpirvConstantCache() = default; auto SpirvConstantCache::BuildConstant(const ShaderAst::ConstantValue& value) const -> ConstantPtr { return std::make_shared(std::visit([&](auto&& arg) -> SpirvConstantCache::AnyConstant { using T = std::decay_t; if constexpr (std::is_same_v) return ConstantBool{ arg }; else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) return ConstantScalar{ arg }; else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ BuildType(ShaderAst::VectorType{ 2, (std::is_same_v) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }), { BuildConstant(arg.x), BuildConstant(arg.y) } }; } else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ BuildType(ShaderAst::VectorType{ 3, (std::is_same_v) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }), { BuildConstant(arg.x), BuildConstant(arg.y), BuildConstant(arg.z) } }; } else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ BuildType(ShaderAst::VectorType{ 4, (std::is_same_v) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }), { BuildConstant(arg.x), BuildConstant(arg.y), BuildConstant(arg.z), BuildConstant(arg.w) } }; } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, value)); } auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector& parameters) const -> TypePtr { std::vector parameterTypes; parameterTypes.reserve(parameters.size()); for (const auto& parameterType : parameters) parameterTypes.push_back(BuildPointerType(parameterType, SpirvStorageClass::Function)); return std::make_shared(Function{ BuildType(retType), std::move(parameterTypes) }); } auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const -> TypePtr { return std::make_shared(Pointer{ BuildType(type), storageClass }); } auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr { return std::make_shared(Pointer{ BuildType(type), storageClass }); } auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr { return std::visit([&](auto&& arg) -> TypePtr { return BuildType(arg); }, type); } auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& /*type*/) const -> TypePtr { // No IdentifierType is expected (as they should have been resolved by now) throw std::runtime_error("unexpected identifier"); } auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) const -> TypePtr { return std::make_shared([&]() -> AnyType { switch (type) { case ShaderAst::PrimitiveType::Boolean: return Bool{}; case ShaderAst::PrimitiveType::Float32: return Float{ 32 }; case ShaderAst::PrimitiveType::Int32: return Integer{ 32, true }; case ShaderAst::PrimitiveType::UInt32: return Integer{ 32, false }; } throw std::runtime_error("unexpected type"); }()); } auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) const -> TypePtr { return std::make_shared( Matrix{ BuildType(ShaderAst::VectorType { UInt32(type.rowCount), type.type }), UInt32(type.columnCount) }); } auto SpirvConstantCache::BuildType(const ShaderAst::NoType& /*type*/) const -> TypePtr { return std::make_shared(Void{}); } auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) const -> TypePtr { Image imageType; imageType.sampled = true; imageType.sampledType = BuildType(type.sampledType); switch (type.dim) { case ImageType::Cubemap: imageType.dim = SpirvDim::Cube; break; case ImageType::E1D_Array: imageType.arrayed = true; case ImageType::E1D: imageType.dim = SpirvDim::Dim1D; break; case ImageType::E2D_Array: imageType.arrayed = true; case ImageType::E2D: imageType.dim = SpirvDim::Dim2D; break; case ImageType::E3D: imageType.dim = SpirvDim::Dim3D; break; } return std::make_shared(SampledImage{ std::make_shared(imageType) }); } auto SpirvConstantCache::BuildType(const ShaderAst::StructType& type) const -> TypePtr { assert(m_internal->structCallback); return BuildType(m_internal->structCallback(type.structIndex)); } auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) const -> TypePtr { Structure sType; sType.name = structDesc.name; for (const auto& member : structDesc.members) { auto& sMembers = sType.members.emplace_back(); sMembers.name = member.name; sMembers.type = BuildType(member.type); } return std::make_shared(std::move(sType)); } auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) const -> TypePtr { return std::make_shared(Vector{ BuildType(type.type), UInt32(type.componentCount) }); } auto SpirvConstantCache::BuildType(const ShaderAst::UniformType& type) const -> TypePtr { assert(std::holds_alternative(type.containedType)); return BuildType(std::get(type.containedType)); } UInt32 SpirvConstantCache::GetId(const Constant& c) { auto it = m_internal->ids.find(c.constant); if (it == m_internal->ids.end()) throw std::runtime_error("constant is not registered"); return it->second; } UInt32 SpirvConstantCache::GetId(const Type& t) { auto it = m_internal->ids.find(t.type); if (it == m_internal->ids.end()) throw std::runtime_error("type is not registered"); return it->second; } UInt32 SpirvConstantCache::GetId(const Variable& v) { auto it = m_internal->variableIds.find(v); if (it == m_internal->variableIds.end()) throw std::runtime_error("variable is not registered"); return it->second; } UInt32 SpirvConstantCache::Register(Constant c) { AnyConstant& constant = c.constant; DepRegisterer registerer(*this); registerer.Register(constant); std::size_t h = m_internal->ids.hash_function()(constant); auto it = m_internal->ids.find(constant, h); if (it == m_internal->ids.end()) { UInt32 resultId = m_internal->nextResultId++; it = m_internal->ids.emplace(std::move(constant), resultId).first; } return it.value(); } UInt32 SpirvConstantCache::Register(Type t) { AnyType& type = t.type; DepRegisterer registerer(*this); registerer.Register(type); std::size_t h = m_internal->ids.hash_function()(type); auto it = m_internal->ids.find(type, h); if (it == m_internal->ids.end()) { UInt32 resultId = m_internal->nextResultId++; it = m_internal->ids.emplace(std::move(type), resultId).first; } return it.value(); } UInt32 SpirvConstantCache::Register(Variable v) { DepRegisterer registerer(*this); registerer.Register(v); std::size_t h = m_internal->variableIds.hash_function()(v); auto it = m_internal->variableIds.find(v, h); if (it == m_internal->variableIds.end()) { UInt32 resultId = m_internal->nextResultId++; it = m_internal->variableIds.emplace(std::move(v), resultId).first; } return it.value(); } void SpirvConstantCache::SetStructCallback(StructCallback callback) { m_internal->structCallback = std::move(callback); } void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) { for (auto&& [object, id] : m_internal->ids) { UInt32 resultId = id; std::visit(overloaded { [&](const AnyConstant& constant) { Write(constant, resultId, constants); }, [&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); }, }, object); } for (auto&& [variable, id] : m_internal->variableIds) { const auto& var = variable; UInt32 resultId = id; if (!variable.debugName.empty()) debugInfos.Append(SpirvOp::OpName, resultId, variable.debugName); constants.AppendVariadic(SpirvOp::OpVariable, [&](const auto& appender) { appender(GetId(*var.type)); appender(resultId); appender(var.storageClass); if (var.initializer) appender(GetId((*var.initializer)->constant)); }); } } SpirvConstantCache& SpirvConstantCache::operator=(SpirvConstantCache&& cache) noexcept = default; void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants) { std::visit([&](auto&& arg) { using ConstantType = std::decay_t; if constexpr (std::is_same_v) constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, GetId({ Bool{} }), resultId); else if constexpr (std::is_same_v) { constants.AppendVariadic(SpirvOp::OpConstantComposite, [&](const auto& appender) { appender(GetId(arg.type->type)); appender(resultId); for (const auto& value : arg.values) appender(GetId(value->constant)); }); } else if constexpr (std::is_same_v) { std::visit([&](auto&& value) { using ValueType = std::decay_t; UInt32 typeId; if constexpr (std::is_same_v) typeId = GetId({ Float{ 64 } }); else if constexpr (std::is_same_v) typeId = GetId({ Float{ 32 } }); else if constexpr (std::is_same_v) typeId = GetId({ Integer{ 32, true } }); else if constexpr (std::is_same_v) typeId = GetId({ Integer{ 64, true } }); else if constexpr (std::is_same_v) typeId = GetId({ Integer{ 32, false } }); else if constexpr (std::is_same_v) typeId = GetId({ Integer{ 64, false } }); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); constants.Append(SpirvOp::OpConstant, typeId, resultId, SpirvSection::Raw{ &value, sizeof(value) }); }, arg.value); } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, constant); } void SpirvConstantCache::Write(const AnyType& type, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) { std::visit([&](auto&& arg) { using T = std::decay_t; if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeBool, resultId); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeFloat, resultId, arg.width); else if constexpr (std::is_same_v) { constants.AppendVariadic(SpirvOp::OpTypeFunction, [&](const auto& appender) { appender(resultId); appender(GetId(*arg.returnType)); for (const auto& param : arg.parameters) appender(GetId(*param)); }); } else if constexpr (std::is_same_v) throw std::runtime_error("unexpected identifier"); else if constexpr (std::is_same_v) { UInt32 depth; if (arg.depth.has_value()) depth = (*arg.depth) ? 1 : 0; else depth = 2; UInt32 sampled; if (arg.sampled.has_value()) sampled = (*arg.sampled) ? 1 : 2; //< Yes/No else sampled = 0; //< Dunno constants.AppendVariadic(SpirvOp::OpTypeImage, [&](const auto& appender) { appender(resultId); appender(GetId(*arg.sampledType)); appender(arg.dim); appender(depth); appender(arg.arrayed); appender(arg.multisampled); appender(sampled); appender(arg.format); if (arg.qualifier) appender(*arg.qualifier); }); } else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeInt, resultId, arg.width, arg.signedness); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeMatrix, resultId, GetId(*arg.columnType), arg.columnCount); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypePointer, resultId, arg.storageClass, GetId(*arg.type)); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeSampledImage, resultId, GetId(*arg.image)); else if constexpr (std::is_same_v) WriteStruct(arg, resultId, annotations, constants, debugInfos); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeVector, resultId, GetId(*arg.componentType), arg.componentCount); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeVoid, resultId); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, type); } void SpirvConstantCache::WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos) { constants.AppendVariadic(SpirvOp::OpTypeStruct, [&](const auto& appender) { appender(resultId); for (const auto& member : structData.members) appender(GetId(*member.type)); }); debugInfos.Append(SpirvOp::OpName, resultId, structData.name); annotations.Append(SpirvOp::OpDecorate, resultId, SpirvDecoration::Block); FieldOffsets structOffsets(StructLayout::Std140); for (std::size_t memberIndex = 0; memberIndex < structData.members.size(); ++memberIndex) { const auto& member = structData.members[memberIndex]; debugInfos.Append(SpirvOp::OpMemberName, resultId, memberIndex, member.name); std::size_t offset = std::visit([&](auto&& arg) -> std::size_t { using T = std::decay_t; if constexpr (std::is_same_v) return structOffsets.AddField(StructFieldType::Bool1); else if constexpr (std::is_same_v) { switch (arg.width) { case 32: return structOffsets.AddField(StructFieldType::Float1); case 64: return structOffsets.AddField(StructFieldType::Double1); default: throw std::runtime_error("unexpected float width " + std::to_string(arg.width)); } } else if constexpr (std::is_same_v) return structOffsets.AddField((arg.signedness) ? StructFieldType::Int1 : StructFieldType::UInt1); else if constexpr (std::is_same_v) { assert(std::holds_alternative(arg.columnType->type)); Vector& columnVec = std::get(arg.columnType->type); if (!std::holds_alternative(columnVec.componentType->type)) throw std::runtime_error("unexpected vector type"); Float& vecType = std::get(columnVec.componentType->type); StructFieldType columnType; switch (vecType.width) { case 32: columnType = StructFieldType::Float1; break; case 64: columnType = StructFieldType::Double1; break; default: throw std::runtime_error("unexpected float width " + std::to_string(vecType.width)); } annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::ColMajor); annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::MatrixStride, 16); return structOffsets.AddMatrix(columnType, arg.columnCount, columnVec.componentCount, true); } else if constexpr (std::is_same_v) throw std::runtime_error("unhandled pointer in struct"); else if constexpr (std::is_same_v) { auto it = m_internal->structureSizes.find(arg); assert(it != m_internal->structureSizes.end()); return structOffsets.AddStruct(it->second); } else if constexpr (std::is_same_v) { if (std::holds_alternative(arg.componentType->type)) return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Bool1) + arg.componentCount - 1)); else if (std::holds_alternative(arg.componentType->type)) { Float& floatData = std::get(arg.componentType->type); switch (floatData.width) { case 32: return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Float1) + arg.componentCount - 1)); case 64: return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Double1) + arg.componentCount - 1)); default: throw std::runtime_error("unexpected float width " + std::to_string(floatData.width)); } } else if (std::holds_alternative(arg.componentType->type)) { Integer& intData = std::get(arg.componentType->type); if (intData.width != 32) throw std::runtime_error("unexpected integer width " + std::to_string(intData.width)); if (intData.signedness) return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::Int1) + arg.componentCount - 1)); else return structOffsets.AddField(static_cast(UnderlyingCast(StructFieldType::UInt1) + arg.componentCount - 1)); } else throw std::runtime_error("unexpected type for vector"); } else if constexpr (std::is_same_v) throw std::runtime_error("unexpected function as struct member"); else if constexpr (std::is_same_v) throw std::runtime_error("unexpected identifier"); else if constexpr (std::is_same_v || std::is_same_v) throw std::runtime_error("unexpected opaque type as struct member"); else if constexpr (std::is_same_v) throw std::runtime_error("unexpected void as struct member"); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, member.type->type); annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::Offset, offset); } m_internal->structureSizes.emplace(structData, std::move(structOffsets)); } }