diff --git a/examples/bin/frag.shader b/examples/bin/frag.shader index cd71692bd..7ecc8951b 100644 Binary files a/examples/bin/frag.shader and b/examples/bin/frag.shader differ diff --git a/examples/bin/test.spirv b/examples/bin/test.spirv index 3d47fe4e8..b88584775 100644 Binary files a/examples/bin/test.spirv and b/examples/bin/test.spirv differ diff --git a/include/Nazara/Shader/ShaderConstantValue.hpp b/include/Nazara/Shader/ShaderConstantValue.hpp new file mode 100644 index 000000000..27c9e1d7e --- /dev/null +++ b/include/Nazara/Shader/ShaderConstantValue.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADER_CONSTANTVALUE_HPP +#define NAZARA_SHADER_CONSTANTVALUE_HPP + +#include +#include +#include +#include +#include + +namespace Nz +{ + using ShaderConstantValue = std::variant< + bool, + float, + Int32, + UInt32, + Vector2f, + Vector3f, + Vector4f, + Vector2i32, + Vector3i32, + Vector4i32 + >; +} + +#endif diff --git a/include/Nazara/Shader/ShaderEnums.hpp b/include/Nazara/Shader/ShaderEnums.hpp index 2ddbb5b97..ed322e3e6 100644 --- a/include/Nazara/Shader/ShaderEnums.hpp +++ b/include/Nazara/Shader/ShaderEnums.hpp @@ -29,8 +29,11 @@ namespace Nz::ShaderNodes Int4, //< ivec4 Mat4x4, //< mat4 Sampler2D, //< sampler2D - - Void //< void + Void, //< void + UInt1, //< uint + UInt2, //< uvec2 + UInt3, //< uvec3 + UInt4 //< uvec4 }; enum class BinaryType diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index 833af1248..af96afeda 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -222,19 +223,7 @@ namespace Nz ShaderExpressionType GetExpressionType() const override; void Visit(ShaderAstVisitor& visitor) override; - using Variant = std::variant< - bool, - float, - Int32, - Vector2f, - Vector3f, - Vector4f, - Vector2i32, - Vector3i32, - Vector4i32 - >; - - Variant value; + ShaderConstantValue value; template static std::shared_ptr Build(const T& value); }; diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp new file mode 100644 index 000000000..95172c757 --- /dev/null +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -0,0 +1,194 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SPIRVCONSTANTCACHE_HPP +#define NAZARA_SPIRVCONSTANTCACHE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Nz +{ + class ShaderAst; + class SpirvSection; + + class NAZARA_SHADER_API SpirvConstantCache + { + public: + SpirvConstantCache(UInt32& resultId); + SpirvConstantCache(const SpirvConstantCache& cache) = delete; + SpirvConstantCache(SpirvConstantCache&& cache) noexcept; + ~SpirvConstantCache(); + + struct Constant; + struct Type; + + using ConstantPtr = std::shared_ptr; + using TypePtr = std::shared_ptr; + + struct Bool {}; + + struct Float + { + UInt32 width; + }; + + struct Integer + { + UInt32 width; + bool signedness; + }; + + struct Void {}; + + struct Vector + { + TypePtr componentType; + UInt32 componentCount; + }; + + struct Matrix + { + TypePtr columnType; + UInt32 columnCount; + }; + + struct Image + { + std::optional qualifier; + std::optional depth; + std::optional sampled; + SpirvDim dim; + SpirvImageFormat format; + TypePtr sampledType; + bool arrayed; + bool multisampled; + }; + + struct Pointer + { + TypePtr type; + SpirvStorageClass storageClass; + }; + + struct Function + { + TypePtr returnType; + std::vector parameters; + }; + + struct SampledImage + { + TypePtr image; + }; + + struct Structure + { + struct Member + { + std::string name; + TypePtr type; + }; + + std::string name; + std::vector members; + }; + + using AnyType = std::variant; + + struct ConstantBool + { + bool value; + }; + + struct ConstantComposite + { + TypePtr type; + std::vector values; + }; + + struct ConstantScalar + { + std::variant value; + }; + + using AnyConstant = std::variant; + + struct Variable + { + std::string debugName; + TypePtr type; + SpirvStorageClass storageClass; + std::optional initializer; + }; + + using BaseType = std::variant; + using CompositeValue = std::variant; + using PointerOrBaseType = std::variant; + using PrimitiveType = std::variant; + using ScalarType = std::variant; + + struct Constant + { + Constant(AnyConstant c) : + constant(std::move(c)) + { + } + + AnyConstant constant; + }; + + struct Type + { + Type(AnyType c) : + type(std::move(c)) + { + } + + AnyType type; + }; + + UInt32 GetId(const Constant& c); + UInt32 GetId(const Type& t); + UInt32 GetId(const Variable& v); + + UInt32 Register(Constant c); + UInt32 Register(Type t); + UInt32 Register(Variable v); + + void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos, SpirvSection& types); + + SpirvConstantCache& operator=(const SpirvConstantCache& cache) = delete; + SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept; + + static ConstantPtr BuildConstant(const ShaderConstantValue& value); + static TypePtr BuildPointerType(const ShaderNodes::BasicType& type, SpirvStorageClass storageClass); + static TypePtr BuildPointerType(const ShaderAst& shader, const ShaderExpressionType& type, SpirvStorageClass storageClass); + static TypePtr BuildType(const ShaderNodes::BasicType& type); + static TypePtr BuildType(const ShaderAst& shader, const ShaderExpressionType& type); + + private: + struct DepRegisterer; + struct Eq; + struct Internal; + + void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& debugInfos, SpirvSection& types); + + std::unique_ptr m_internal; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/SpirvConstantCache.inl b/include/Nazara/Shader/SpirvConstantCache.inl new file mode 100644 index 000000000..b007dbb7f --- /dev/null +++ b/include/Nazara/Shader/SpirvConstantCache.inl @@ -0,0 +1,12 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ +} + +#include diff --git a/include/Nazara/Shader/SpirvSection.inl b/include/Nazara/Shader/SpirvSection.inl index f4770e3f3..e86eb59ff 100644 --- a/include/Nazara/Shader/SpirvSection.inl +++ b/include/Nazara/Shader/SpirvSection.inl @@ -96,7 +96,7 @@ namespace Nz } template - unsigned int SpirvSection::CountWord(const T& value) + unsigned int SpirvSection::CountWord(const T& /*value*/) { return 1; } diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index 71a617eaa..579ad1fef 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -11,9 +11,10 @@ #include #include #include +#include #include #include -#include +#include #include #include #include @@ -47,20 +48,22 @@ namespace Nz UInt32 AllocateResultId(); - void AppendConstants(); void AppendHeader(); - void AppendStructType(std::size_t structIndex, UInt32 resultId); - void AppendTypes(); UInt32 EvaluateExpression(const ShaderNodes::ExpressionPtr& expr); - UInt32 GetConstantId(const ShaderNodes::Constant::Variant& value) const; + UInt32 GetConstantId(const ShaderConstantValue& value) const; + UInt32 GetFunctionTypeId(ShaderExpressionType retType, const std::vector& parameters); + UInt32 GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const; UInt32 GetTypeId(const ShaderExpressionType& type) const; void PushResultId(UInt32 value); UInt32 PopResultId(); UInt32 ReadVariable(ExtVar& var); + UInt32 RegisterConstant(const ShaderConstantValue& value); + UInt32 RegisterFunctionType(ShaderExpressionType retType, const std::vector& parameters); + UInt32 RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass); UInt32 RegisterType(ShaderExpressionType type); using ShaderAstVisitor::Visit; diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index b17c8d4a3..1e911ce13 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -231,6 +231,10 @@ namespace Nz case ShaderNodes::BasicType::Int4: return Append("ivec4"); case ShaderNodes::BasicType::Mat4x4: return Append("mat4"); case ShaderNodes::BasicType::Sampler2D: return Append("sampler2D"); + case ShaderNodes::BasicType::UInt1: return Append("uint"); + case ShaderNodes::BasicType::UInt2: return Append("uvec2"); + case ShaderNodes::BasicType::UInt3: return Append("uvec3"); + case ShaderNodes::BasicType::UInt4: return Append("uvec4"); case ShaderNodes::BasicType::Void: return Append("void"); } } @@ -459,7 +463,7 @@ namespace Nz if constexpr (std::is_same_v) Append((arg) ? "true" : "false"); - else if constexpr (std::is_same_v || std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) Append(std::to_string(arg)); else if constexpr (std::is_same_v || std::is_same_v) Append("vec2(" + std::to_string(arg.x) + ", " + std::to_string(arg.y) + ")"); diff --git a/src/Nazara/Shader/ShaderAstSerializer.cpp b/src/Nazara/Shader/ShaderAstSerializer.cpp index 6d45cc711..396cd3d98 100644 --- a/src/Nazara/Shader/ShaderAstSerializer.cpp +++ b/src/Nazara/Shader/ShaderAstSerializer.cpp @@ -193,18 +193,19 @@ namespace Nz Value(value); }; - static_assert(std::variant_size_v == 9); + static_assert(std::variant_size_v == 10); switch (typeIndex) { case 0: SerializeValue(bool()); break; case 1: SerializeValue(float()); break; case 2: SerializeValue(Int32()); break; - case 3: SerializeValue(Vector2f()); break; - case 4: SerializeValue(Vector3f()); break; - case 5: SerializeValue(Vector4f()); break; - case 6: SerializeValue(Vector2i32()); break; - case 7: SerializeValue(Vector3i32()); break; - case 8: SerializeValue(Vector4i32()); break; + case 3: SerializeValue(UInt32()); break; + case 4: SerializeValue(Vector2f()); break; + case 5: SerializeValue(Vector3f()); break; + case 6: SerializeValue(Vector4f()); break; + case 7: SerializeValue(Vector2i32()); break; + case 8: SerializeValue(Vector3i32()); break; + case 9: SerializeValue(Vector4i32()); break; default: throw std::runtime_error("unexpected data type"); } } diff --git a/src/Nazara/Shader/ShaderNodes.cpp b/src/Nazara/Shader/ShaderNodes.cpp index 2f72a1ea0..15c7e40f0 100644 --- a/src/Nazara/Shader/ShaderNodes.cpp +++ b/src/Nazara/Shader/ShaderNodes.cpp @@ -114,12 +114,16 @@ namespace Nz::ShaderNodes case BasicType::Int2: case BasicType::Int3: case BasicType::Int4: + case BasicType::UInt2: + case BasicType::UInt3: + case BasicType::UInt4: exprType = leftExprType; break; case BasicType::Float1: case BasicType::Int1: case BasicType::Mat4x4: + case BasicType::UInt1: exprType = rightExprType; break; @@ -165,6 +169,8 @@ namespace Nz::ShaderNodes return ShaderNodes::BasicType::Float1; else if constexpr (std::is_same_v) return ShaderNodes::BasicType::Int1; + else if constexpr (std::is_same_v) + return ShaderNodes::BasicType::Int1; else if constexpr (std::is_same_v) return ShaderNodes::BasicType::Float2; else if constexpr (std::is_same_v) diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp new file mode 100644 index 000000000..02f7b4f34 --- /dev/null +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -0,0 +1,897 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include +#include + +namespace Nz +{ + struct SpirvConstantCache::Eq + { + bool Compare(const ConstantBool& lhs, const ConstantBool& rhs) const + { + return lhs.value == rhs.value; + } + + bool Compare(const ConstantComposite& lhs, const ConstantComposite& rhs) const + { + return Compare(lhs.type, rhs.type) && Compare(lhs.values, rhs.values); + } + + bool Compare(const ConstantScalar& lhs, const ConstantScalar& rhs) const + { + return lhs.value == rhs.value; + } + + bool Compare(const Bool& /*lhs*/, const Bool& /*rhs*/) const + { + return true; + } + + bool Compare(const Float& lhs, const Float& rhs) const + { + return lhs.width == rhs.width; + } + + bool Compare(const Function& lhs, const Function& rhs) const + { + return Compare(lhs.parameters, rhs.parameters) && Compare(lhs.returnType, rhs.returnType); + } + + bool Compare(const Image& lhs, const Image& rhs) const + { + return lhs.arrayed == rhs.arrayed + && lhs.dim == rhs.dim + && lhs.format == rhs.format + && lhs.multisampled == rhs.multisampled + && lhs.qualifier == rhs.qualifier + && Compare(lhs.sampledType, rhs.sampledType) + && lhs.depth == rhs.depth + && lhs.sampled == rhs.sampled; + } + + bool Compare(const Integer& lhs, const Integer& rhs) const + { + return lhs.width == rhs.width && lhs.signedness == rhs.signedness; + } + + bool Compare(const Matrix& lhs, const Matrix& rhs) const + { + return lhs.columnCount == rhs.columnCount && Compare(lhs.columnType, rhs.columnType); + } + + bool Compare(const Pointer& lhs, const Pointer& rhs) const + { + return lhs.storageClass == rhs.storageClass && Compare(lhs.type, rhs.type); + } + + bool Compare(const SampledImage& lhs, const SampledImage& rhs) const + { + return Compare(lhs.image, rhs.image); + } + + bool Compare(const Structure& lhs, const Structure& rhs) const + { + if (lhs.name != rhs.name) + return false; + + if (!Compare(lhs.members, rhs.members)) + return false; + + return true; + } + + bool Compare(const Structure::Member& lhs, const Structure::Member& rhs) const + { + if (!Compare(lhs.type, rhs.type)) + return false; + + if (lhs.name != rhs.name) + return false; + + return true; + } + + bool Compare(const Variable& lhs, const Variable& rhs) const + { + if (lhs.debugName != rhs.debugName) + return false; + + if (!Compare(lhs.initializer, rhs.initializer)) + return false; + + if (lhs.storageClass != rhs.storageClass) + return false; + + if (!Compare(lhs.type, rhs.type)) + return false; + + return true; + } + + bool Compare(const Vector& lhs, const Vector& rhs) const + { + return Compare(lhs.componentType, rhs.componentType) && lhs.componentCount == rhs.componentCount; + } + + bool Compare(const Void& /*lhs*/, const Void& /*rhs*/) const + { + return true; + } + + + bool Compare(const Constant& lhs, const Constant& rhs) const + { + return Compare(lhs.constant, rhs.constant); + } + + bool Compare(const Type& lhs, const Type& rhs) const + { + return Compare(lhs.type, rhs.type); + } + + + template + bool Compare(const std::optional& lhs, const std::optional& rhs) const + { + if (lhs.has_value() != rhs.has_value()) + return false; + + if (!lhs.has_value()) + return true; + + return Compare(*lhs, *rhs); + } + + template + bool Compare(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const + { + if (bool(lhs) != bool(rhs)) + return false; + + if (!lhs) + return true; + + return Compare(*lhs, *rhs); + } + + template + bool Compare(const std::variant& lhs, const std::variant& rhs) const + { + if (lhs.index() != rhs.index()) + return false; + + return std::visit([&](auto&& arg) + { + using U = std::decay_t; + return Compare(arg, std::get(rhs)); + }, lhs); + } + + template + bool Compare(const std::vector& lhs, const std::vector& rhs) const + { + if (lhs.size() != rhs.size()) + return false; + + for (std::size_t i = 0; i < lhs.size(); ++i) + { + if (!Compare(lhs[i], rhs[i])) + return false; + } + + return true; + } + + template + bool Compare(const std::unique_ptr& lhs, const std::unique_ptr& rhs) const + { + if (bool(lhs) != bool(rhs)) + return false; + + if (!lhs) + return true; + + return Compare(*lhs, *rhs); + } + + template + bool operator()(const T& lhs, const T& rhs) const + { + return Compare(lhs, rhs); + } + }; + + struct SpirvConstantCache::DepRegisterer + { + DepRegisterer(SpirvConstantCache& c) : + cache(c) + { + } + + void Register(const Bool&) {} + void Register(const Float&) {} + void Register(const Integer&) {} + void Register(const Void&) {} + + void Register(const Image& image) + { + Register(image.sampledType); + } + + void Register(const Function& func) + { + Register(func.returnType); + Register(func.parameters); + } + + void Register(const Matrix& vec) + { + assert(vec.columnType); + cache.Register(*vec.columnType); + } + + void Register(const Pointer& ptr) + { + assert(ptr.type); + cache.Register(*ptr.type); + } + + void Register(const SampledImage& sampledImage) + { + assert(sampledImage.image); + cache.Register(*sampledImage.image); + } + + void Register(const Structure& s) + { + Register(s.members); + } + + void Register(const SpirvConstantCache::Structure::Member& m) + { + cache.Register(*m.type); + } + + void Register(const Variable& variable) + { + assert(variable.type); + cache.Register(*variable.type); + } + + void Register(const Vector& vec) + { + assert(vec.componentType); + cache.Register(*vec.componentType); + } + + void Register(const ConstantBool&) + { + cache.Register({ Bool{} }); + } + + void Register(const ConstantScalar& scalar) + { + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + cache.Register({ Float{ 64 } }); + else if constexpr (std::is_same_v) + cache.Register({ Float{ 32 } }); + else if constexpr (std::is_same_v) + cache.Register({ Integer{ 32, 1 } }); + else if constexpr (std::is_same_v) + cache.Register({ Integer{ 64, 1 } }); + else if constexpr (std::is_same_v) + cache.Register({ Integer{ 32, 0 } }); + else if constexpr (std::is_same_v) + cache.Register({ Integer{ 64, 0 } }); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + + }, scalar.value); + } + + void Register(const ConstantComposite& composite) + { + assert(composite.type); + cache.Register(*composite.type); + + for (auto&& value : composite.values) + { + assert(value); + cache.Register(*value); + } + } + + + void Register(const Constant& c) + { + return Register(c.constant); + } + + void Register(const Type& t) + { + return Register(t.type); + } + + + template + void Register(const std::shared_ptr& ptr) + { + assert(ptr); + return Register(*ptr); + } + + template + void Register(const std::optional& opt) + { + if (opt) + Register(*opt); + } + + template + void Register(const std::variant& v) + { + return std::visit([&](auto&& arg) + { + return Register(arg); + }, v); + } + + template + void Register(const std::vector& lhs) + { + for (std::size_t i = 0; i < lhs.size(); ++i) + Register(lhs[i]); + } + + template + void Register(const std::unique_ptr& lhs) + { + assert(lhs); + return Register(*lhs); + } + + SpirvConstantCache& cache; + }; + + //< FIXME PLZ + struct AnyHasher + { + template + std::size_t operator()(const U&) const + { + return 42; + } + }; + + struct SpirvConstantCache::Internal + { + Internal(UInt32& resultId) : + nextResultId(resultId) + { + } + + tsl::ordered_map constantIds; + tsl::ordered_map typeIds; + tsl::ordered_map variableIds; + tsl::ordered_map structureSizes; + UInt32& nextResultId; + }; + + SpirvConstantCache::SpirvConstantCache(UInt32& resultId) + { + m_internal = std::make_unique(resultId); + } + + SpirvConstantCache::SpirvConstantCache(SpirvConstantCache&& cache) noexcept = default; + + SpirvConstantCache::~SpirvConstantCache() = default; + + UInt32 SpirvConstantCache::GetId(const Constant& c) + { + auto it = m_internal->constantIds.find(c.constant); + if (it == m_internal->constantIds.end()) + throw std::runtime_error("constant is not registered"); + + return it->second; + } + + UInt32 SpirvConstantCache::GetId(const Type& t) + { + auto it = m_internal->typeIds.find(t.type); + if (it == m_internal->typeIds.end()) + throw std::runtime_error("constant is not registered"); + + return it->second; + } + + UInt32 SpirvConstantCache::GetId(const Variable& v) + { + auto it = m_internal->variableIds.find(v); + if (it == m_internal->variableIds.end()) + throw std::runtime_error("variable is not registered"); + + return it->second; + } + + UInt32 SpirvConstantCache::Register(Constant c) + { + AnyConstant& constant = c.constant; + + DepRegisterer registerer(*this); + registerer.Register(constant); + + std::size_t h = m_internal->typeIds.hash_function()(constant); + auto it = m_internal->constantIds.find(constant, h); + if (it == m_internal->constantIds.end()) + { + UInt32 resultId = m_internal->nextResultId++; + it = m_internal->constantIds.emplace(std::move(constant), resultId).first; + } + + return it.value(); + } + + UInt32 SpirvConstantCache::Register(Type t) + { + AnyType& type = t.type; + + DepRegisterer registerer(*this); + registerer.Register(type); + + std::size_t h = m_internal->typeIds.hash_function()(type); + auto it = m_internal->typeIds.find(type, h); + if (it == m_internal->typeIds.end()) + { + UInt32 resultId = m_internal->nextResultId++; + it = m_internal->typeIds.emplace(std::move(type), resultId).first; + } + + return it.value(); + } + + UInt32 SpirvConstantCache::Register(Variable v) + { + DepRegisterer registerer(*this); + registerer.Register(v); + + std::size_t h = m_internal->variableIds.hash_function()(v); + auto it = m_internal->variableIds.find(v, h); + if (it == m_internal->variableIds.end()) + { + UInt32 resultId = m_internal->nextResultId++; + it = m_internal->variableIds.emplace(std::move(v), resultId).first; + } + + return it.value(); + } + + void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos, SpirvSection& types) + { + for (auto&& [type, id] : m_internal->typeIds) + { + UInt32 resultId = id; + + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + types.Append(SpirvOp::OpTypeBool, resultId); + else if constexpr (std::is_same_v) + types.Append(SpirvOp::OpTypeFloat, resultId, arg.width); + else if constexpr (std::is_same_v) + { + types.AppendVariadic(SpirvOp::OpTypeFunction, [&](const auto& appender) + { + appender(resultId); + appender(GetId(*arg.returnType)); + + for (const auto& param : arg.parameters) + appender(GetId(*param)); + }); + } + else if constexpr (std::is_same_v) + { + UInt32 depth; + if (arg.depth.has_value()) + depth = (*arg.depth) ? 1 : 0; + else + depth = 2; + + UInt32 sampled; + if (arg.sampled.has_value()) + sampled = (*arg.sampled) ? 1 : 0; + else + sampled = 2; + + types.AppendVariadic(SpirvOp::OpTypeImage, [&](const auto& appender) + { + appender(resultId); + appender(GetId(*arg.sampledType)); + appender(arg.dim); + appender(depth); + appender(arg.arrayed); + appender(arg.multisampled); + appender(sampled); + appender(arg.format); + + if (arg.qualifier) + appender(*arg.qualifier); + }); + } + else if constexpr (std::is_same_v) + types.Append(SpirvOp::OpTypeInt, resultId, arg.width, arg.signedness); + else if constexpr (std::is_same_v) + types.Append(SpirvOp::OpTypeMatrix, resultId, GetId(*arg.columnType), arg.columnCount); + else if constexpr (std::is_same_v) + types.Append(SpirvOp::OpTypePointer, resultId, arg.storageClass, GetId(*arg.type)); + else if constexpr (std::is_same_v) + types.Append(SpirvOp::OpTypeSampledImage, resultId, GetId(*arg.image)); + else if constexpr (std::is_same_v) + WriteStruct(arg, resultId, annotations, debugInfos, types); + else if constexpr (std::is_same_v) + types.Append(SpirvOp::OpTypeVector, resultId, GetId(*arg.componentType), arg.componentCount); + else if constexpr (std::is_same_v) + types.Append(SpirvOp::OpTypeVoid, resultId); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, type); + } + + for (auto&& [constant, id] : m_internal->constantIds) + { + UInt32 resultId = id; + + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, resultId); + else if constexpr (std::is_same_v) + { + constants.AppendVariadic(SpirvOp::OpConstantComposite, [&](const auto& appender) + { + appender(GetId(arg.type->type)); + appender(resultId); + + for (const auto& value : arg.values) + appender(GetId(value->constant)); + }); + } + else if constexpr (std::is_same_v) + { + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + UInt32 typeId; + if constexpr (std::is_same_v) + typeId = GetId({ Float{ 64 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Float{ 32 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Integer{ 32, 1 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Integer{ 64, 1 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Integer{ 32, 0 } }); + else if constexpr (std::is_same_v) + typeId = GetId({ Integer{ 64, 0 } }); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + + constants.Append(SpirvOp::OpConstant, typeId, resultId, SpirvSection::Raw{ &arg, sizeof(arg) }); + + }, arg.value); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, constant); + } + + for (auto&& [variable, id] : m_internal->variableIds) + { + UInt32 resultId = id; + + if (!variable.debugName.empty()) + debugInfos.Append(SpirvOp::OpName, resultId, variable.debugName); + + constants.AppendVariadic(SpirvOp::OpVariable, [&](const auto& appender) + { + appender(GetId(*variable.type)); + appender(resultId); + appender(variable.storageClass); + + if (variable.initializer) + appender(GetId((*variable.initializer)->constant)); + }); + } + } + + SpirvConstantCache& SpirvConstantCache::operator=(SpirvConstantCache&& cache) noexcept = default; + + auto SpirvConstantCache::BuildConstant(const ShaderConstantValue& value) -> ConstantPtr + { + return std::make_shared(std::visit([&](auto&& arg) -> SpirvConstantCache::AnyConstant + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + return ConstantBool{ arg }; + else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) + return ConstantScalar{ arg }; + else if constexpr (std::is_same_v || std::is_same_v) + { + return ConstantComposite{ + BuildType((std::is_same_v) ? ShaderNodes::BasicType::Float2 : ShaderNodes::BasicType::Int2), + { + BuildConstant(arg.x), + BuildConstant(arg.y) + } + }; + } + else if constexpr (std::is_same_v || std::is_same_v) + { + return ConstantComposite{ + BuildType((std::is_same_v) ? ShaderNodes::BasicType::Float3 : ShaderNodes::BasicType::Int3), + { + BuildConstant(arg.x), + BuildConstant(arg.y), + BuildConstant(arg.z) + } + }; + } + else if constexpr (std::is_same_v || std::is_same_v) + { + return ConstantComposite{ + BuildType((std::is_same_v) ? ShaderNodes::BasicType::Float4 : ShaderNodes::BasicType::Int4), + { + BuildConstant(arg.x), + BuildConstant(arg.y), + BuildConstant(arg.z), + BuildConstant(arg.w) + } + }; + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, value)); + } + + auto SpirvConstantCache::BuildPointerType(const ShaderNodes::BasicType& type, SpirvStorageClass storageClass) -> TypePtr + { + return std::make_shared(SpirvConstantCache::Pointer{ + SpirvConstantCache::BuildType(type), + storageClass + }); + } + + auto SpirvConstantCache::BuildPointerType(const ShaderAst& shader, const ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr + { + return std::make_shared(SpirvConstantCache::Pointer{ + SpirvConstantCache::BuildType(shader, type), + storageClass + }); + } + + auto SpirvConstantCache::BuildType(const ShaderNodes::BasicType& type) -> TypePtr + { + return std::make_shared([&]() -> AnyType + { + switch (type) + { + case ShaderNodes::BasicType::Boolean: + return Bool{}; + + case ShaderNodes::BasicType::Float1: + return Float{ 32 }; + + case ShaderNodes::BasicType::Int1: + return Integer{ 32, 1 }; + + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: + { + auto vecType = BuildType(ShaderNodes::Node::GetComponentType(type)); + UInt32 componentCount = ShaderNodes::Node::GetComponentCount(type); + + return Vector{ vecType, componentCount }; + } + + case ShaderNodes::BasicType::Mat4x4: + return Matrix{ BuildType(ShaderNodes::BasicType::Float4), 4u }; + + case ShaderNodes::BasicType::UInt1: + return Integer{ 32, 0 }; + + case ShaderNodes::BasicType::Void: + return Void{}; + + case ShaderNodes::BasicType::Sampler2D: + { + auto imageType = Image{ + {}, //< qualifier + {}, //< depth + {}, //< sampled + SpirvDim::Dim2D, //< dim + SpirvImageFormat::Unknown, //< format + BuildType(ShaderNodes::BasicType::Float1), //< sampledType + false, //< arrayed, + false //< multisampled + }; + + return SampledImage{ std::make_shared(imageType) }; + } + } + + throw std::runtime_error("unexpected type"); + }()); + } + + auto SpirvConstantCache::BuildType(const ShaderAst& shader, const ShaderExpressionType& type) -> TypePtr + { + return std::visit([&](auto&& arg) -> TypePtr + { + using T = std::decay_t; + if constexpr (std::is_same_v) + return BuildType(arg); + else if constexpr (std::is_same_v) + { + // Register struct members type + const auto& structs = shader.GetStructs(); + auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); + if (it == structs.end()) + throw std::runtime_error("struct " + arg + " has not been defined"); + + const ShaderAst::Struct& s = *it; + + Structure sType; + sType.name = s.name; + + for (const auto& member : s.members) + { + auto& sMembers = sType.members.emplace_back(); + sMembers.name = member.name; + sMembers.type = BuildType(shader, member.type); + } + + return std::make_shared(std::move(sType)); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, type); + } + + void SpirvConstantCache::WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& debugInfos, SpirvSection& types) + { + types.AppendVariadic(SpirvOp::OpTypeStruct, [&](const auto& appender) + { + appender(resultId); + + for (const auto& member : structData.members) + appender(GetId(*member.type)); + }); + + debugInfos.Append(SpirvOp::OpName, resultId, structData.name); + + annotations.Append(SpirvOp::OpDecorate, resultId, SpirvDecoration::Block); + + FieldOffsets structOffsets(StructLayout_Std140); + + for (std::size_t memberIndex = 0; memberIndex < structData.members.size(); ++memberIndex) + { + const auto& member = structData.members[memberIndex]; + debugInfos.Append(SpirvOp::OpMemberName, resultId, memberIndex, member.name); + + std::size_t offset = std::visit([&](auto&& arg) -> std::size_t + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + return structOffsets.AddField(StructFieldType_Bool1); + else if constexpr (std::is_same_v) + { + switch (arg.width) + { + case 32: return structOffsets.AddField(StructFieldType_Float1); + case 64: return structOffsets.AddField(StructFieldType_Double1); + default: throw std::runtime_error("unexpected float width " + std::to_string(arg.width)); + } + } + else if constexpr (std::is_same_v) + return structOffsets.AddField((arg.signedness) ? StructFieldType_Int1 : StructFieldType_UInt1); + else if constexpr (std::is_same_v) + { + assert(std::holds_alternative(arg.columnType->type)); + Vector& columnVec = std::get(arg.columnType->type); + + if (!std::holds_alternative(columnVec.componentType->type)) + throw std::runtime_error("unexpected vector type"); + + Float& vecType = std::get(columnVec.componentType->type); + + StructFieldType columnType; + switch (vecType.width) + { + case 32: columnType = StructFieldType_Float1; break; + case 64: columnType = StructFieldType_Double1; break; + default: throw std::runtime_error("unexpected float width " + std::to_string(vecType.width)); + } + + annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::ColMajor); + annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::MatrixStride, 16); + + return structOffsets.AddMatrix(columnType, arg.columnCount, columnVec.componentCount, true); + } + else if constexpr (std::is_same_v) + throw std::runtime_error("unhandled pointer in struct"); + else if constexpr (std::is_same_v) + { + auto it = m_internal->structureSizes.find(arg); + assert(it != m_internal->structureSizes.end()); + + return structOffsets.AddStruct(it->second); + } + else if constexpr (std::is_same_v) + { + if (std::holds_alternative(arg.componentType->type)) + return structOffsets.AddField(static_cast(StructFieldType_Bool1 + arg.componentCount - 1)); + else if (std::holds_alternative(arg.componentType->type)) + { + Float& floatData = std::get(arg.componentType->type); + switch (floatData.width) + { + case 32: return structOffsets.AddField(static_cast(StructFieldType_Float1 + arg.componentCount - 1)); + case 64: return structOffsets.AddField(static_cast(StructFieldType_Double1 + arg.componentCount - 1)); + default: throw std::runtime_error("unexpected float width " + std::to_string(floatData.width)); + } + } + else if (std::holds_alternative(arg.componentType->type)) + { + Integer& intData = std::get(arg.componentType->type); + if (intData.width != 32) + throw std::runtime_error("unexpected integer width " + std::to_string(intData.width)); + + if (intData.signedness) + return structOffsets.AddField(static_cast(StructFieldType_Int1 + arg.componentCount - 1)); + else + return structOffsets.AddField(static_cast(StructFieldType_UInt1 + arg.componentCount - 1)); + } + else + throw std::runtime_error("unexpected type for vector"); + } + else if constexpr (std::is_same_v) + throw std::runtime_error("unexpected function as struct member"); + else if constexpr (std::is_same_v || std::is_same_v) + throw std::runtime_error("unexpected opaque type as struct member"); + else if constexpr (std::is_same_v) + throw std::runtime_error("unexpected void as struct member"); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, member.type->type); + + annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::Offset, offset); + } + + m_internal->structureSizes.emplace(structData, std::move(structOffsets)); + } +} diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index fb767d0e8..398ee8e71 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -4,10 +4,10 @@ #include #include -#include #include #include #include +#include #include #include #include @@ -24,23 +24,25 @@ namespace Nz { namespace { - using ConstantVariant = ShaderNodes::Constant::Variant; - class PreVisitor : public ShaderAstRecursiveVisitor, public ShaderVarVisitor { public: using BuiltinContainer = std::unordered_set>; - using ConstantContainer = tsl::ordered_set; using ExtInstList = std::unordered_set; using LocalContainer = std::unordered_set>; using ParameterContainer = std::unordered_set< std::shared_ptr>; + PreVisitor(SpirvConstantCache& constantCache) : + m_constantCache(constantCache) + { + } + using ShaderAstRecursiveVisitor::Visit; using ShaderVarVisitor::Visit; void Visit(ShaderNodes::AccessMember& node) override { - constants.emplace(Int32(node.memberIndex)); + m_constantCache.Register(*SpirvConstantCache::BuildConstant(UInt32(node.memberIndex))); ShaderAstRecursiveVisitor::Visit(node); } @@ -49,35 +51,8 @@ namespace Nz { std::visit([&](auto&& arg) { - using T = std::decay_t; - - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) - constants.emplace(arg); - else if constexpr (std::is_same_v || std::is_same_v) - { - constants.emplace(arg.x); - constants.emplace(arg.y); - constants.emplace(arg); - } - else if constexpr (std::is_same_v || std::is_same_v) - { - constants.emplace(arg.x); - constants.emplace(arg.y); - constants.emplace(arg.z); - constants.emplace(arg); - } - else if constexpr (std::is_same_v || std::is_same_v) - { - constants.emplace(arg.x); - constants.emplace(arg.y); - constants.emplace(arg.z); - constants.emplace(arg.w); - constants.emplace(arg); - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, - node.value); + m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg)); + }, node.value); ShaderAstRecursiveVisitor::Visit(node); } @@ -118,7 +93,7 @@ namespace Nz builtinVars.insert(std::static_pointer_cast(var.shared_from_this())); } - void Visit(ShaderNodes::InputVariable& var) override + void Visit(ShaderNodes::InputVariable& /*var*/) override { /* Handled by ShaderAst */ } @@ -128,7 +103,7 @@ namespace Nz localVars.insert(std::static_pointer_cast(var.shared_from_this())); } - void Visit(ShaderNodes::OutputVariable& var) override + void Visit(ShaderNodes::OutputVariable& /*var*/) override { /* Handled by ShaderAst */ } @@ -138,32 +113,18 @@ namespace Nz paramVars.insert(std::static_pointer_cast(var.shared_from_this())); } - void Visit(ShaderNodes::UniformVariable& var) override + void Visit(ShaderNodes::UniformVariable& /*var*/) override { /* Handled by ShaderAst */ } BuiltinContainer builtinVars; - ConstantContainer constants; ExtInstList extInsts; LocalContainer localVars; ParameterContainer paramVars; - }; - class AssignVisitor : public ShaderAstRecursiveVisitor - { - public: - void Visit(ShaderNodes::AccessMember& node) override - { - } - - void Visit(ShaderNodes::Identifier& node) override - { - } - - void Visit(ShaderNodes::SwizzleOp& node) override - { - } + private: + SpirvConstantCache& m_constantCache; }; template @@ -202,6 +163,11 @@ namespace Nz struct SpirvWriter::State { + State() : + constantTypeCache(nextVarIndex) + { + } + struct Func { UInt32 typeId; @@ -209,18 +175,16 @@ namespace Nz std::vector paramsId; }; - std::unordered_map extensionInstructions; - std::unordered_map builtinIds; - std::unordered_map varToResult; - tsl::ordered_map constantIds; - tsl::ordered_map typeIds; - std::vector funcs; tsl::ordered_map inputIds; tsl::ordered_map outputIds; tsl::ordered_map uniformIds; - std::vector> structFields; + std::unordered_map extensionInstructions; + std::unordered_map builtinIds; + std::unordered_map varToResult; + std::vector funcs; std::vector resultIds; UInt32 nextVarIndex = 1; + SpirvConstantCache constantTypeCache; //< init after nextVarIndex // Output SpirvSection header; @@ -251,13 +215,11 @@ namespace Nz m_currentState = nullptr; }); - state.structFields.resize(shader.GetStructCount()); - std::vector functionStatements; ShaderAstCloner cloner; - PreVisitor preVisitor; + PreVisitor preVisitor(state.constantTypeCache); for (const auto& func : shader.GetFunctions()) { functionStatements.emplace_back(cloner.Clone(func.statement)); @@ -277,13 +239,16 @@ namespace Nz } for (const auto& input : shader.GetInputs()) - RegisterType(input.type); + RegisterPointerType(input.type, SpirvStorageClass::Input); for (const auto& output : shader.GetOutputs()) - RegisterType(output.type); + RegisterPointerType(output.type, SpirvStorageClass::Output); for (const auto& uniform : shader.GetUniforms()) - RegisterType(uniform.type); + RegisterPointerType(uniform.type, SpirvStorageClass::Uniform); + + for (const auto& func : shader.GetFunctions()) + RegisterFunctionType(func.returnType, func.parameters); for (const auto& local : preVisitor.localVars) RegisterType(local->type); @@ -291,104 +256,103 @@ namespace Nz for (const auto& builtin : preVisitor.builtinVars) RegisterType(builtin->type); - // Register constant types - for (const auto& constant : preVisitor.constants) - { - std::visit([&](auto&& arg) - { - using T = std::decay_t; - RegisterType(GetBasicType()); - }, constant); - } - - AppendTypes(); - // Register result id and debug infos for global variables/functions for (const auto& builtin : preVisitor.builtinVars) { - const ShaderExpressionType& builtinExprType = builtin->type; - assert(std::holds_alternative(builtinExprType)); - - ShaderNodes::BasicType builtinType = std::get(builtinExprType); - - ExtVar builtinData; - builtinData.pointerTypeId = AllocateResultId(); - builtinData.typeId = GetTypeId(builtinType); - builtinData.varId = AllocateResultId(); - - SpvBuiltIn spvBuiltin; - std::string debugName; + SpirvConstantCache::Variable variable; + SpirvBuiltIn builtinDecoration; switch (builtin->entry) { case ShaderNodes::BuiltinEntry::VertexPosition: - debugName = "builtin_VertexPosition"; - spvBuiltin = SpvBuiltInPosition; + variable.debugName = "builtin_VertexPosition"; + variable.storageClass = SpirvStorageClass::Output; + + builtinDecoration = SpirvBuiltIn::Position; break; default: throw std::runtime_error("unexpected builtin type"); } - state.debugInfo.Append(SpirvOp::OpName, builtinData.varId, debugName); - state.types.Append(SpirvOp::OpTypePointer, builtinData.pointerTypeId, SpvStorageClassOutput, builtinData.typeId); - state.types.Append(SpirvOp::OpVariable, builtinData.pointerTypeId, builtinData.varId, SpvStorageClassOutput); + const ShaderExpressionType& builtinExprType = builtin->type; + assert(std::holds_alternative(builtinExprType)); - state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpvDecorationBuiltIn, spvBuiltin); + ShaderNodes::BasicType builtinType = std::get(builtinExprType); + + variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass); + + UInt32 varId = m_currentState->constantTypeCache.Register(variable); + + ExtVar builtinData; + builtinData.pointerTypeId = GetPointerTypeId(builtinType, variable.storageClass); + builtinData.typeId = GetTypeId(builtinType); + builtinData.varId = varId; + + state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpvDecorationBuiltIn, builtinDecoration); state.builtinIds.emplace(builtin->entry, builtinData); } for (const auto& input : shader.GetInputs()) { + SpirvConstantCache::Variable variable; + variable.debugName = input.name; + variable.storageClass = SpirvStorageClass::Input; + variable.type = SpirvConstantCache::BuildPointerType(shader, input.type, variable.storageClass); + + UInt32 varId = m_currentState->constantTypeCache.Register(variable); + ExtVar inputData; - inputData.pointerTypeId = AllocateResultId(); + inputData.pointerTypeId = GetPointerTypeId(input.type, variable.storageClass); inputData.typeId = GetTypeId(input.type); - inputData.varId = AllocateResultId(); + inputData.varId = varId; - state.inputIds.emplace(input.name, inputData); - - state.debugInfo.Append(SpirvOp::OpName, inputData.varId, input.name); - state.types.Append(SpirvOp::OpTypePointer, inputData.pointerTypeId, SpvStorageClassInput, inputData.typeId); - state.types.Append(SpirvOp::OpVariable, inputData.pointerTypeId, inputData.varId, SpvStorageClassInput); + state.inputIds.emplace(input.name, std::move(inputData)); if (input.locationIndex) - state.annotations.Append(SpirvOp::OpDecorate, inputData.varId, SpvDecorationLocation, *input.locationIndex); + state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *input.locationIndex); } for (const auto& output : shader.GetOutputs()) { + SpirvConstantCache::Variable variable; + variable.debugName = output.name; + variable.storageClass = SpirvStorageClass::Output; + variable.type = SpirvConstantCache::BuildPointerType(shader, output.type, variable.storageClass); + + UInt32 varId = m_currentState->constantTypeCache.Register(variable); + ExtVar outputData; - outputData.pointerTypeId = AllocateResultId(); + outputData.pointerTypeId = GetPointerTypeId(output.type, variable.storageClass); outputData.typeId = GetTypeId(output.type); - outputData.varId = AllocateResultId(); + outputData.varId = varId; - state.outputIds.emplace(output.name, outputData); - - state.debugInfo.Append(SpirvOp::OpName, outputData.varId, output.name); - state.types.Append(SpirvOp::OpTypePointer, outputData.pointerTypeId, SpvStorageClassOutput, outputData.typeId); - state.types.Append(SpirvOp::OpVariable, outputData.pointerTypeId, outputData.varId, SpvStorageClassOutput); + state.outputIds.emplace(output.name, std::move(outputData)); if (output.locationIndex) - state.annotations.Append(SpirvOp::OpDecorate, outputData.varId, SpvDecorationLocation, *output.locationIndex); + state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *output.locationIndex); } for (const auto& uniform : shader.GetUniforms()) { + SpirvConstantCache::Variable variable; + variable.debugName = uniform.name; + variable.storageClass = SpirvStorageClass::Uniform; + variable.type = SpirvConstantCache::BuildPointerType(shader, uniform.type, variable.storageClass); + + UInt32 varId = m_currentState->constantTypeCache.Register(variable); + ExtVar uniformData; - uniformData.pointerTypeId = AllocateResultId(); + uniformData.pointerTypeId = GetPointerTypeId(uniform.type, variable.storageClass); uniformData.typeId = GetTypeId(uniform.type); - uniformData.varId = AllocateResultId(); + uniformData.varId = varId; - state.uniformIds.emplace(uniform.name, uniformData); - - state.debugInfo.Append(SpirvOp::OpName, uniformData.varId, uniform.name); - state.types.Append(SpirvOp::OpTypePointer, uniformData.pointerTypeId, SpvStorageClassUniform, uniformData.typeId); - state.types.Append(SpirvOp::OpVariable, uniformData.pointerTypeId, uniformData.varId, SpvStorageClassUniform); + state.uniformIds.emplace(uniform.name, std::move(uniformData)); if (uniform.bindingIndex) { - state.annotations.Append(SpirvOp::OpDecorate, uniformData.varId, SpvDecorationBinding, *uniform.bindingIndex); - state.annotations.Append(SpirvOp::OpDecorate, uniformData.varId, SpvDecorationDescriptorSet, 0); + state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationBinding, *uniform.bindingIndex); + state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationDescriptorSet, 0); } } @@ -396,26 +360,11 @@ namespace Nz { auto& funcData = state.funcs.emplace_back(); funcData.id = AllocateResultId(); - funcData.typeId = AllocateResultId(); + funcData.typeId = GetFunctionTypeId(func.returnType, func.parameters); state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name); - - state.types.AppendVariadic(SpirvOp::OpTypeFunction, [&](const auto& appender) - { - appender(funcData.typeId); - appender(GetTypeId(func.returnType)); - - for (const auto& param : func.parameters) - appender(GetTypeId(param.type)); - }); } - // Register constants - for (const auto& constant : preVisitor.constants) - state.constantIds[constant] = AllocateResultId(); - - AppendConstants(); - std::size_t entryPointIndex = std::numeric_limits::max(); for (std::size_t funcIndex = 0; funcIndex < shader.GetFunctionCount(); ++funcIndex) @@ -448,11 +397,13 @@ namespace Nz assert(entryPointIndex != std::numeric_limits::max()); + m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo, m_currentState->types); + AppendHeader(); SpvExecutionModel execModel; const auto& entryFuncData = shader.GetFunction(entryPointIndex); - const auto& entryFunc = m_currentState->funcs[entryPointIndex]; + const auto& entryFunc = state.funcs[entryPointIndex]; assert(m_context.shader); switch (m_context.shader->GetStage()) @@ -471,21 +422,19 @@ namespace Nz // OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos - std::size_t nameSize = state.header.CountWord(entryFuncData.name); - state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender) { appender(execModel); appender(entryFunc.id); appender(entryFuncData.name); - for (const auto& [name, varData] : m_currentState->builtinIds) + for (const auto& [name, varData] : state.builtinIds) appender(varData.varId); - for (const auto& [name, varData] : m_currentState->inputIds) + for (const auto& [name, varData] : state.inputIds) appender(varData.varId); - for (const auto& [name, varData] : m_currentState->outputIds) + for (const auto& [name, varData] : state.outputIds) appender(varData.varId); }); @@ -513,31 +462,6 @@ namespace Nz return m_currentState->nextVarIndex++; } - void SpirvWriter::AppendConstants() - { - for (const auto& [value, resultId] : m_currentState->constantIds) - { - UInt32 constantId = resultId; - std::visit([&](auto&& arg) - { - using T = std::decay_t; - - if constexpr (std::is_same_v) - m_currentState->constants.Append((arg) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, constantId); - else if constexpr (std::is_same_v || std::is_same_v) - m_currentState->constants.Append(SpirvOp::OpConstant, GetTypeId(GetBasicType()), constantId, SpirvSection::Raw{ &arg, sizeof(arg) }); - else if constexpr (std::is_same_v || std::is_same_v) - m_currentState->constants.Append(SpirvOp::OpConstantComposite, GetTypeId(GetBasicType()), constantId, GetConstantId(arg.x), GetConstantId(arg.y)); - else if constexpr (std::is_same_v || std::is_same_v) - m_currentState->constants.Append(SpirvOp::OpConstantComposite, GetTypeId(GetBasicType()), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z)); - else if constexpr (std::is_same_v || std::is_same_v) - m_currentState->constants.Append(SpirvOp::OpConstantComposite, GetTypeId(GetBasicType()), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z), GetConstantId(arg.w)); - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, value); - } - } - void SpirvWriter::AppendHeader() { m_currentState->header.Append(SpvMagicNumber); //< Spir-V magic number @@ -557,180 +481,41 @@ namespace Nz m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); } - void SpirvWriter::AppendStructType(std::size_t structIndex, UInt32 resultId) - { - const ShaderAst::Struct& s = m_context.shader->GetStruct(structIndex); - - m_currentState->types.Append(SpirvOp::OpTypeStruct, SpirvSection::OpSize{ static_cast(1 + 1 + s.members.size()) }); - m_currentState->types.Append(resultId); - - m_currentState->debugInfo.Append(SpirvOp::OpName, resultId, s.name); - - m_currentState->annotations.Append(SpirvOp::OpDecorate, resultId, SpvDecorationBlock); - - FieldOffsets structOffsets(StructLayout_Std140); - - for (std::size_t memberIndex = 0; memberIndex < s.members.size(); ++memberIndex) - { - const auto& member = s.members[memberIndex]; - m_currentState->types.Append(GetTypeId(member.type)); - m_currentState->debugInfo.Append(SpirvOp::OpMemberName, resultId, memberIndex, member.name); - - std::visit([&](auto&& arg) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - std::size_t offset = [&] { - switch (arg) - { - case ShaderNodes::BasicType::Boolean: return structOffsets.AddField(StructFieldType_Bool1); - case ShaderNodes::BasicType::Float1: return structOffsets.AddField(StructFieldType_Float1); - case ShaderNodes::BasicType::Float2: return structOffsets.AddField(StructFieldType_Float2); - case ShaderNodes::BasicType::Float3: return structOffsets.AddField(StructFieldType_Float3); - case ShaderNodes::BasicType::Float4: return structOffsets.AddField(StructFieldType_Float4); - case ShaderNodes::BasicType::Int1: return structOffsets.AddField(StructFieldType_Int1); - case ShaderNodes::BasicType::Int2: return structOffsets.AddField(StructFieldType_Int2); - case ShaderNodes::BasicType::Int3: return structOffsets.AddField(StructFieldType_Int3); - case ShaderNodes::BasicType::Int4: return structOffsets.AddField(StructFieldType_Int4); - case ShaderNodes::BasicType::Mat4x4: return structOffsets.AddMatrix(StructFieldType_Float1, 4, 4, true); - case ShaderNodes::BasicType::Sampler2D: throw std::runtime_error("unexpected sampler2D as struct member"); - case ShaderNodes::BasicType::Void: throw std::runtime_error("unexpected void as struct member"); - } - - assert(false); - throw std::runtime_error("unhandled type"); - }(); - - m_currentState->annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpvDecorationOffset, offset); - - if (arg == ShaderNodes::BasicType::Mat4x4) - { - m_currentState->annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpvDecorationColMajor); - m_currentState->annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpvDecorationMatrixStride, 16); - } - } - else if constexpr (std::is_same_v) - { - // Register struct members type - const auto& structs = m_context.shader->GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); - if (it == structs.end()) - throw std::runtime_error("struct " + arg + " has not been defined"); - - std::size_t nestedStructIndex = std::distance(structs.begin(), it); - std::optional nestedFieldOffset = m_currentState->structFields[nestedStructIndex]; - if (!nestedFieldOffset) - throw std::runtime_error("struct dependency cycle"); - - structOffsets.AddStruct(nestedFieldOffset.value()); - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, member.type); - } - - m_currentState->structFields[structIndex] = structOffsets; - } - - void SpirvWriter::AppendTypes() - { - for (const auto& [type, typeId] : m_currentState->typeIds.values_container()) - { - UInt32 resultId = typeId; - - // Register sub-types, if any - std::visit([&](auto&& arg) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - switch (arg) - { - case ShaderNodes::BasicType::Boolean: - m_currentState->types.Append(SpirvOp::OpTypeBool, resultId); - break; - - case ShaderNodes::BasicType::Float1: - m_currentState->types.Append(SpirvOp::OpTypeFloat, resultId, 32); - break; - - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - { - ShaderNodes::BasicType baseType = ShaderNodes::Node::GetComponentType(arg); - - UInt32 vecSize = UInt32(arg) - UInt32(baseType) + 1; - - m_currentState->types.Append(SpirvOp::OpTypeVector, resultId, GetTypeId(baseType), vecSize); - break; - } - - case ShaderNodes::BasicType::Int1: - m_currentState->types.Append(SpirvOp::OpTypeInt, resultId, 32, 1); - break; - - case ShaderNodes::BasicType::Mat4x4: - { - m_currentState->types.Append(SpirvOp::OpTypeMatrix, resultId, GetTypeId(ShaderNodes::BasicType::Float4), 4); - break; - } - - case ShaderNodes::BasicType::Sampler2D: - { - UInt32 imageTypeId = resultId - 1; - - m_currentState->types.Append(SpirvOp::OpTypeImage, imageTypeId, GetTypeId(ShaderNodes::BasicType::Float1), SpvDim2D, 0, 0, 0, 1, SpvImageFormatUnknown); - m_currentState->types.Append(SpirvOp::OpTypeSampledImage, resultId, imageTypeId); - break; - } - - case ShaderNodes::BasicType::Void: - m_currentState->types.Append(SpirvOp::OpTypeVoid, resultId); - break; - } - } - else if constexpr (std::is_same_v) - { - // Register struct members type - const auto& structs = m_context.shader->GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); - if (it == structs.end()) - throw std::runtime_error("struct " + arg + " has not been defined"); - - std::size_t structIndex = std::distance(structs.begin(), it); - AppendStructType(structIndex, resultId); - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, type); - } - } - UInt32 SpirvWriter::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr) { Visit(expr); return PopResultId(); } - UInt32 SpirvWriter::GetConstantId(const ShaderNodes::Constant::Variant& value) const + UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const { - auto typeIt = m_currentState->constantIds.find(value); - assert(typeIt != m_currentState->constantIds.end()); + return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value)); + } - return typeIt->second; + UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector& parameters) + { + std::vector parameterTypes; + parameterTypes.reserve(parameters.size()); + + for (const auto& parameter : parameters) + parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type)); + + return m_currentState->constantTypeCache.GetId({ + SpirvConstantCache::Function { + SpirvConstantCache::BuildType(*m_context.shader, retType), + std::move(parameterTypes) + } + }); + } + + UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const + { + return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass)); } UInt32 SpirvWriter::GetTypeId(const ShaderExpressionType& type) const { - auto typeIt = m_currentState->typeIds.find(type); - assert(typeIt != m_currentState->typeIds.end()); - - return typeIt->second; + return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(*m_context.shader, type)); } void SpirvWriter::PushResultId(UInt32 value) @@ -762,68 +547,42 @@ namespace Nz return var.valueId.value(); } + UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value) + { + return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value)); + } + + UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector& parameters) + { + std::vector parameterTypes; + parameterTypes.reserve(parameters.size()); + + for (const auto& parameter : parameters) + parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type)); + + return m_currentState->constantTypeCache.Register({ + SpirvConstantCache::Function { + SpirvConstantCache::BuildType(*m_context.shader, retType), + std::move(parameterTypes) + } + }); + } + + UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass) + { + return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass)); + } + UInt32 SpirvWriter::RegisterType(ShaderExpressionType type) { - auto it = m_currentState->typeIds.find(type); - if (it == m_currentState->typeIds.end()) - { - // Register sub-types, if any - std::visit([&](auto&& arg) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - switch (arg) - { - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Void: - break; //< Nothing to do - - // In SPIR-V, vec3 (for example) depends on float - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::Mat4x4: - RegisterType(ShaderNodes::Node::GetComponentType(arg)); - break; - - case ShaderNodes::BasicType::Sampler2D: - RegisterType(ShaderNodes::BasicType::Float1); - AllocateResultId(); //< Reserve a result id for the image type - break; - } - } - else if constexpr (std::is_same_v) - { - // Register struct members type - const auto& structs = m_context.shader->GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); - if (it == structs.end()) - throw std::runtime_error("struct " + arg + " has not been defined"); - - const ShaderAst::Struct& s = *it; - for (const auto& member : s.members) - RegisterType(member.type); - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, type); - - it = m_currentState->typeIds.emplace(std::move(type), AllocateResultId()).first; - } - - return it->second; + assert(m_currentState); + return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(*m_context.shader, type)); } void SpirvWriter::Visit(ShaderNodes::AccessMember& node) { UInt32 pointerId; - SpvStorageClass storage; + SpirvStorageClass storage; switch (node.structExpr->GetType()) { @@ -848,7 +607,7 @@ namespace Nz auto it = m_currentState->inputIds.find(inputVar.name); assert(it != m_currentState->inputIds.end()); - storage = SpvStorageClassInput; + storage = SpirvStorageClass::Input; pointerId = it->second.varId; break; @@ -860,7 +619,7 @@ namespace Nz auto it = m_currentState->outputIds.find(outputVar.name); assert(it != m_currentState->outputIds.end()); - storage = SpvStorageClassOutput; + storage = SpirvStorageClass::Output; pointerId = it->second.varId; break; @@ -872,7 +631,7 @@ namespace Nz auto it = m_currentState->uniformIds.find(uniformVar.name); assert(it != m_currentState->uniformIds.end()); - storage = SpvStorageClassUniform; + storage = SpirvStorageClass::Uniform; pointerId = it->second.varId; break; @@ -892,11 +651,9 @@ namespace Nz } UInt32 memberPointerId = AllocateResultId(); - UInt32 pointerType = AllocateResultId(); + UInt32 pointerType = RegisterPointerType(node.exprType, storage); //< FIXME UInt32 typeId = GetTypeId(node.exprType); - UInt32 indexId = GetConstantId(Int32(node.memberIndex)); - - m_currentState->types.Append(SpirvOp::OpTypePointer, pointerType, storage, typeId); + UInt32 indexId = GetConstantId(UInt32(node.memberIndex)); m_currentState->instructions.Append(SpirvOp::OpAccessChain, pointerType, memberPointerId, pointerId, indexId); @@ -1002,6 +759,10 @@ namespace Nz case ShaderNodes::BasicType::Int2: case ShaderNodes::BasicType::Int3: case ShaderNodes::BasicType::Int4: + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: return SpirvOp::OpIAdd; case ShaderNodes::BasicType::Boolean: @@ -1026,6 +787,10 @@ namespace Nz case ShaderNodes::BasicType::Int2: case ShaderNodes::BasicType::Int3: case ShaderNodes::BasicType::Int4: + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: return SpirvOp::OpISub; case ShaderNodes::BasicType::Boolean: @@ -1052,6 +817,12 @@ namespace Nz case ShaderNodes::BasicType::Int4: return SpirvOp::OpSDiv; + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: + return SpirvOp::OpUDiv; + case ShaderNodes::BasicType::Boolean: case ShaderNodes::BasicType::Sampler2D: case ShaderNodes::BasicType::Void: @@ -1077,6 +848,10 @@ namespace Nz case ShaderNodes::BasicType::Int2: case ShaderNodes::BasicType::Int3: case ShaderNodes::BasicType::Int4: + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: return SpirvOp::OpIEqual; case ShaderNodes::BasicType::Sampler2D: @@ -1141,6 +916,10 @@ namespace Nz case ShaderNodes::BasicType::Int2: case ShaderNodes::BasicType::Int3: case ShaderNodes::BasicType::Int4: + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: return SpirvOp::OpIMul; case ShaderNodes::BasicType::Mat4x4: