From 7a5f91f74089c39f0ebeacfc607374b542eeefc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Tue, 4 Aug 2020 01:35:30 +0200 Subject: [PATCH] SpivWriter WIP We have debug label, annotations, types and constants. The big part missing is instructions --- include/Nazara/Renderer/SpirvWriter.hpp | 51 +- include/Nazara/Renderer/SpirvWriter.inl | 37 +- src/Nazara/Renderer/SpirvWriter.cpp | 682 +++++++++++++++++++----- 3 files changed, 594 insertions(+), 176 deletions(-) diff --git a/include/Nazara/Renderer/SpirvWriter.hpp b/include/Nazara/Renderer/SpirvWriter.hpp index e95bc5a4e..5786ea55b 100644 --- a/include/Nazara/Renderer/SpirvWriter.hpp +++ b/include/Nazara/Renderer/SpirvWriter.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -41,33 +42,49 @@ namespace Nz private: struct Opcode; + struct Raw; + struct WordCount; - inline std::size_t Append(const char* str); - inline std::size_t Append(const std::string_view& str); - inline std::size_t Append(const std::string& str); - std::size_t Append(UInt32 value); - std::size_t Append(const Opcode& opcode, unsigned int wordCount); - inline std::size_t Append(std::initializer_list codepoints); - template std::size_t Append(Opcode opcode, const Args&... args); - template std::size_t Append(T value); + struct Section + { + inline std::size_t Append(const char* str); + inline std::size_t Append(const std::string_view& str); + inline std::size_t Append(const std::string& str); + inline std::size_t Append(UInt32 value); + std::size_t Append(const Opcode& opcode, const WordCount& wordCount); + std::size_t Append(const Raw& raw); + inline std::size_t Append(std::initializer_list codepoints); + template std::size_t Append(Opcode opcode, const Args&... args); + template std::size_t Append(T value); + + inline unsigned int CountWord(const char* str); + inline unsigned int CountWord(const std::string_view& str); + inline unsigned int CountWord(const std::string& str); + unsigned int CountWord(const Raw& raw); + template unsigned int CountWord(const T& value); + template unsigned int CountWord(const T1& value, const T2& value2, const Args&... rest); + + inline std::size_t GetOutputOffset() const; + + std::vector data; + }; UInt32 AllocateResultId(); + void AppendConstants(); void AppendHeader(); + void AppendStructType(std::size_t structIndex, UInt32 resultId); void AppendTypes(); - inline unsigned int CountWord(const char* str); - unsigned int CountWord(const std::string_view& str); - inline unsigned int CountWord(const std::string& str); - template unsigned int CountWord(const T& value); - template unsigned int CountWord(const T1& value, const T2& value2, const Args&... rest); + UInt32 GetConstantId(const ShaderNodes::Constant::Variant& value) const; + UInt32 GetTypeId(const ShaderExpressionType& type) const; - std::size_t GetOutputOffset() const; + void PushResultId(UInt32 value); + UInt32 PopResultId(); - UInt32 ProcessType(ShaderExpressionType type); + UInt32 RegisterType(ShaderExpressionType type); using ShaderVisitor::Visit; - void Visit(const ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired = false); void Visit(const ShaderNodes::AccessMember& node) override; void Visit(const ShaderNodes::AssignOp& node) override; void Visit(const ShaderNodes::Branch& node) override; @@ -82,7 +99,7 @@ namespace Nz void Visit(const ShaderNodes::StatementBlock& node) override; void Visit(const ShaderNodes::SwizzleOp& node) override; - static void MergeBlocks(std::vector& output, const std::vector& from); + static void MergeBlocks(std::vector& output, const Section& from); struct Context { diff --git a/include/Nazara/Renderer/SpirvWriter.inl b/include/Nazara/Renderer/SpirvWriter.inl index 5d0bbd328..2e75f1271 100644 --- a/include/Nazara/Renderer/SpirvWriter.inl +++ b/include/Nazara/Renderer/SpirvWriter.inl @@ -9,12 +9,12 @@ namespace Nz { - inline std::size_t SpirvWriter::Append(const char* str) + inline std::size_t SpirvWriter::Section::Append(const char* str) { return Append(std::string_view(str)); } - inline std::size_t SpirvWriter::Append(const std::string_view& str) + inline std::size_t SpirvWriter::Section::Append(const std::string_view& str) { std::size_t offset = GetOutputOffset(); @@ -35,12 +35,20 @@ namespace Nz return offset; } - inline std::size_t SpirvWriter::Append(const std::string& str) + inline std::size_t SpirvWriter::Section::Append(const std::string& str) { return Append(std::string_view(str)); } - inline std::size_t SpirvWriter::Append(std::initializer_list codepoints) + inline std::size_t SpirvWriter::Section::Append(UInt32 value) + { + std::size_t offset = GetOutputOffset(); + data.push_back(value); + + return offset; + } + + inline std::size_t SpirvWriter::Section::Append(std::initializer_list codepoints) { std::size_t offset = GetOutputOffset(); @@ -51,10 +59,10 @@ namespace Nz } template - inline std::size_t SpirvWriter::Append(Opcode opcode, const Args&... args) + std::size_t SpirvWriter::Section::Append(Opcode opcode, const Args&... args) { unsigned int wordCount = 1 + (CountWord(args) + ... + 0); - std::size_t offset = Append(opcode, wordCount); + std::size_t offset = Append(opcode, WordCount{ wordCount }); if constexpr (sizeof...(args) > 0) (Append(args), ...); @@ -62,37 +70,42 @@ namespace Nz } template - inline std::size_t SpirvWriter::Append(T value) + std::size_t SpirvWriter::Section::Append(T value) { return Append(static_cast(value)); } template - inline unsigned int SpirvWriter::CountWord(const T& value) + unsigned int SpirvWriter::Section::CountWord(const T& value) { return 1; } template - unsigned int SpirvWriter::CountWord(const T1& value, const T2& value2, const Args&... rest) + unsigned int SpirvWriter::Section::CountWord(const T1& value, const T2& value2, const Args&... rest) { return CountWord(value) + CountWord(value2) + (CountWord(rest) + ...); } - inline unsigned int SpirvWriter::CountWord(const char* str) + inline unsigned int SpirvWriter::Section::CountWord(const char* str) { return CountWord(std::string_view(str)); } - inline unsigned int Nz::SpirvWriter::CountWord(const std::string& str) + inline unsigned int Nz::SpirvWriter::Section::CountWord(const std::string& str) { return CountWord(std::string_view(str)); } - inline unsigned int SpirvWriter::CountWord(const std::string_view& str) + inline unsigned int SpirvWriter::Section::CountWord(const std::string_view& str) { return (static_cast(str.size() + 1) + sizeof(UInt32) - 1) / sizeof(UInt32); //< + 1 for null character } + + std::size_t SpirvWriter::Section::GetOutputOffset() const + { + return data.size(); + } } #include diff --git a/src/Nazara/Renderer/SpirvWriter.cpp b/src/Nazara/Renderer/SpirvWriter.cpp index 24d9fcd84..1ccf23f97 100644 --- a/src/Nazara/Renderer/SpirvWriter.cpp +++ b/src/Nazara/Renderer/SpirvWriter.cpp @@ -4,8 +4,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -18,10 +20,13 @@ namespace Nz { namespace { + using ConstantVariant = ShaderNodes::Constant::Variant; + class PreVisitor : public ShaderRecursiveVisitor, public ShaderVarVisitor { public: using BuiltinContainer = std::unordered_set>; + using ConstantContainer = tsl::ordered_set; using ExtInstList = std::unordered_set; using LocalContainer = std::unordered_set>; using ParameterContainer = std::unordered_set< std::shared_ptr>; @@ -29,14 +34,55 @@ namespace Nz using ShaderRecursiveVisitor::Visit; using ShaderVarVisitor::Visit; + void Visit(const ShaderNodes::Constant& node) override + { + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v || std::is_same_v) + constants.emplace(arg); + else if constexpr (std::is_same_v) + { + constants.emplace(arg.x); + constants.emplace(arg.y); + constants.emplace(arg); + } + else if constexpr (std::is_same_v) + { + constants.emplace(arg.x); + constants.emplace(arg.y); + constants.emplace(arg.z); + constants.emplace(arg); + } + else if constexpr (std::is_same_v) + { + constants.emplace(arg.x); + constants.emplace(arg.y); + constants.emplace(arg.z); + constants.emplace(arg.w); + constants.emplace(arg); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, + node.value); + + ShaderRecursiveVisitor::Visit(node); + } + void Visit(const ShaderNodes::DeclareVariable& node) override { Visit(node.variable); + + ShaderRecursiveVisitor::Visit(node); } void Visit(const ShaderNodes::Identifier& node) override { Visit(node.var); + + ShaderRecursiveVisitor::Visit(node); } void Visit(const ShaderNodes::IntrinsicCall& node) override @@ -87,6 +133,7 @@ namespace Nz } BuiltinContainer builtinVars; + ConstantContainer constants; ExtInstList extInsts; LocalContainer localVars; ParameterContainer paramVars; @@ -98,24 +145,51 @@ namespace Nz SpvOp op; }; + struct SpirvWriter::Raw + { + const void* ptr; + std::size_t size; + }; + + struct SpirvWriter::WordCount + { + unsigned int wc; + }; + struct SpirvWriter::State { - std::size_t boundIndex; + struct Func + { + UInt32 typeId; + UInt32 id; + std::vector paramsId; + }; + + struct ExtVar + { + UInt32 pointerTypeId; + UInt32 varId; + }; + std::unordered_map extensionInstructions; std::unordered_map builtinIds; + tsl::ordered_map constantIds; tsl::ordered_map typeIds; - std::vector funcIds; - std::vector funcTypeIds; - std::vector inputIds; - std::vector outputIds; - std::vector uniformIds; + std::vector funcs; + std::vector inputIds; + std::vector outputIds; + std::vector uniformIds; + std::vector> structFields; + std::vector resultIds; UInt32 nextVarIndex = 1; // Output - std::vector* output; - std::vector header; - std::vector info; - std::vector instructions; + Section header; + Section constants; + Section debugInfo; + Section annotations; + Section types; + Section instructions; }; SpirvWriter::SpirvWriter() : @@ -138,74 +212,154 @@ namespace Nz m_currentState = nullptr; }); + state.structFields.resize(shader.GetStructCount()); + state.annotations.Append(Opcode{ SpvOpNop }); + state.constants.Append(Opcode{ SpvOpNop }); + state.debugInfo.Append(Opcode{ SpvOpNop }); + state.types.Append(Opcode{ SpvOpNop }); + PreVisitor preVisitor; for (const auto& func : shader.GetFunctions()) preVisitor.Visit(func.statement); // Register all extended instruction sets for (const std::string& extInst : preVisitor.extInsts) - m_currentState->extensionInstructions[extInst] = AllocateResultId(); + state.extensionInstructions[extInst] = AllocateResultId(); // Register all types - state.output = &state.instructions; - for (const auto& func : shader.GetFunctions()) { - ProcessType(func.returnType); + RegisterType(func.returnType); for (const auto& param : func.parameters) - ProcessType(param.type); - - m_currentState->funcTypeIds.push_back(AllocateResultId()); + RegisterType(param.type); } for (const auto& input : shader.GetInputs()) - ProcessType(input.type); + RegisterType(input.type); for (const auto& output : shader.GetOutputs()) - ProcessType(output.type); + RegisterType(output.type); for (const auto& uniform : shader.GetUniforms()) - ProcessType(uniform.type); + RegisterType(uniform.type); for (const auto& local : preVisitor.localVars) - ProcessType(local->type); + RegisterType(local->type); + + // Register constant types + for (const auto& constant : preVisitor.constants) + { + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + RegisterType(ShaderNodes::BasicType::Boolean); + else if constexpr (std::is_same_v) + RegisterType(ShaderNodes::BasicType::Float1); + else if constexpr (std::is_same_v) + RegisterType(ShaderNodes::BasicType::Float2); + else if constexpr (std::is_same_v) + RegisterType(ShaderNodes::BasicType::Float3); + else if constexpr (std::is_same_v) + RegisterType(ShaderNodes::BasicType::Float4); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, constant); + } + + AppendTypes(); // Register result id and debug infos for global variables/functions - state.output = &state.info; - for (const auto& input : shader.GetInputs()) { - UInt32 resultId = AllocateResultId(); - Append(Opcode{ SpvOpName }, resultId, input.name); + auto& inputData = state.inputIds.emplace_back(); + inputData.pointerTypeId = AllocateResultId(); + inputData.varId = AllocateResultId(); - m_currentState->inputIds.push_back(resultId); + state.debugInfo.Append(Opcode{ SpvOpName }, inputData.varId, input.name); + state.types.Append(Opcode{ SpvOpTypePointer }, inputData.pointerTypeId, SpvStorageClassInput, GetTypeId(input.type)); + state.types.Append(Opcode{ SpvOpVariable }, inputData.pointerTypeId, inputData.varId, SpvStorageClassInput); + + if (input.locationIndex) + state.annotations.Append(Opcode{ SpvOpDecorate }, inputData.varId, SpvDecorationLocation, *input.locationIndex); } for (const auto& output : shader.GetOutputs()) { - UInt32 resultId = AllocateResultId(); - Append(Opcode{ SpvOpName }, resultId, output.name); + auto& outputData = state.outputIds.emplace_back(); + outputData.pointerTypeId = AllocateResultId(); + outputData.varId = AllocateResultId(); - m_currentState->outputIds.push_back(resultId); + state.debugInfo.Append(Opcode{ SpvOpName }, outputData.varId, output.name); + state.types.Append(Opcode{ SpvOpTypePointer }, outputData.pointerTypeId, SpvStorageClassOutput, GetTypeId(output.type)); + state.types.Append(Opcode{ SpvOpVariable }, outputData.pointerTypeId, outputData.varId, SpvStorageClassOutput); + + if (output.locationIndex) + state.annotations.Append(Opcode{ SpvOpDecorate }, outputData.varId, SpvDecorationLocation, *output.locationIndex); } for (const auto& uniform : shader.GetUniforms()) { - UInt32 resultId = AllocateResultId(); - Append(Opcode{ SpvOpName }, resultId, uniform.name); + auto& uniformData = state.uniformIds.emplace_back(); + uniformData.pointerTypeId = AllocateResultId(); + uniformData.varId = AllocateResultId(); - m_currentState->uniformIds.push_back(resultId); + state.debugInfo.Append(Opcode{ SpvOpName }, uniformData.varId, uniform.name); + state.types.Append(Opcode{ SpvOpTypePointer }, uniformData.pointerTypeId, SpvStorageClassUniform, GetTypeId(uniform.type)); + state.types.Append(Opcode{ SpvOpVariable }, uniformData.pointerTypeId, uniformData.varId, SpvStorageClassUniform); + + if (uniform.bindingIndex) + { + state.annotations.Append(Opcode{ SpvOpDecorate }, uniformData.varId, SpvDecorationBinding, *uniform.bindingIndex); + state.annotations.Append(Opcode{ SpvOpDecorate }, uniformData.varId, SpvDecorationDescriptorSet, 0); + } } for (const auto& func : shader.GetFunctions()) { - UInt32 resultId = AllocateResultId(); - Append(Opcode{ SpvOpName }, resultId, func.name); + auto& funcData = state.funcs.emplace_back(); + funcData.id = AllocateResultId(); + funcData.typeId = AllocateResultId(); - m_currentState->funcIds.push_back(resultId); + state.debugInfo.Append(Opcode{ SpvOpName }, funcData.id, func.name); + + state.types.Append(Opcode{ SpvOpTypeFunction }, WordCount{ 3 + static_cast(func.parameters.size()) }); + state.types.Append(funcData.typeId); + state.types.Append(GetTypeId(func.returnType)); + + for (const auto& param : func.parameters) + state.types.Append(GetTypeId(param.type)); } - state.output = &state.header; + // Register constants + for (const auto& constant : preVisitor.constants) + state.constantIds[constant] = AllocateResultId(); + + AppendConstants(); + + for (std::size_t funcIndex = 0; funcIndex < shader.GetFunctionCount(); ++funcIndex) + { + const auto& func = shader.GetFunction(funcIndex); + + auto& funcData = state.funcs[funcIndex]; + + state.instructions.Append(Opcode{ SpvOpNop }); + + state.instructions.Append(Opcode{ SpvOpFunction }, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId); + + for (const auto& param : func.parameters) + { + UInt32 paramResultId = AllocateResultId(); + funcData.paramsId.push_back(paramResultId); + + state.instructions.Append(Opcode{ SpvOpFunctionParameter }, GetTypeId(param.type), paramResultId); + } + + Visit(func.statement); + + state.instructions.Append(Opcode{ SpvOpFunctionEnd }); + } AppendHeader(); @@ -221,13 +375,12 @@ namespace Nz break; }*/ - state.header[state.boundIndex] = state.nextVarIndex; - std::vector ret; - ret.reserve(state.header.size() + state.info.size() + state.instructions.size()); - MergeBlocks(ret, state.header); - MergeBlocks(ret, state.info); + MergeBlocks(ret, state.debugInfo); + MergeBlocks(ret, state.annotations); + MergeBlocks(ret, state.types); + MergeBlocks(ret, state.constants); MergeBlocks(ret, state.instructions); return ret; @@ -238,41 +391,125 @@ namespace Nz m_environment = std::move(environment); } - std::size_t Nz::SpirvWriter::Append(UInt32 value) - { - std::size_t offset = GetOutputOffset(); - m_currentState->output->push_back(value); - - return offset; - } - - std::size_t SpirvWriter::Append(const Opcode& opcode, unsigned int wordCount) - { - return Append(UInt32(opcode.op) | UInt32(wordCount) << 16); - } - UInt32 Nz::SpirvWriter::AllocateResultId() { return m_currentState->nextVarIndex++; } + void SpirvWriter::AppendConstants() + { + for (const auto& [value, resultId] : m_currentState->constantIds) + { + UInt32 constantId = resultId; + std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + m_currentState->constants.Append(Opcode{ (arg) ? SpvOpConstantTrue : SpvOpConstantFalse }, constantId); + else if constexpr (std::is_same_v) + m_currentState->constants.Append(Opcode{ SpvOpConstant }, GetTypeId(ShaderNodes::BasicType::Float1), constantId, Raw{ &arg, sizeof(arg) }); + else if constexpr (std::is_same_v) + m_currentState->constants.Append(Opcode{ SpvOpConstantComposite }, GetTypeId(ShaderNodes::BasicType::Float2), constantId, GetConstantId(arg.x), GetConstantId(arg.y)); + else if constexpr (std::is_same_v) + m_currentState->constants.Append(Opcode{ SpvOpConstantComposite }, GetTypeId(ShaderNodes::BasicType::Float3), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z)); + else if constexpr (std::is_same_v) + m_currentState->constants.Append(Opcode{ SpvOpConstantComposite }, GetTypeId(ShaderNodes::BasicType::Float3), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z), GetConstantId(arg.w)); + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, value); + } + } + void SpirvWriter::AppendHeader() { - Append(SpvMagicNumber); //< Spir-V magic number + m_currentState->header.Append(SpvMagicNumber); //< Spir-V magic number UInt32 version = (m_environment.spvMajorVersion << 16) | m_environment.spvMinorVersion << 8; - Append(version); //< Spir-V version number (1.0 for compatibility) - Append(0); //< Generator identifier (TODO: Register generator to Khronos) + m_currentState->header.Append(version); //< Spir-V version number (1.0 for compatibility) + m_currentState->header.Append(0); //< Generator identifier (TODO: Register generator to Khronos) - m_currentState->boundIndex = Append(0); //< Bound (ID count), will be filled later - Append(0); //< Instruction schema (required to be 0 for now) + m_currentState->header.Append(m_currentState->nextVarIndex); //< Bound (ID count) + m_currentState->header.Append(0); //< Instruction schema (required to be 0 for now) - Append(Opcode{ SpvOpCapability }, SpvCapabilityShader); + m_currentState->header.Append(Opcode{ SpvOpCapability }, SpvCapabilityShader); for (const auto& [extInst, resultId] : m_currentState->extensionInstructions) - Append(Opcode{ SpvOpExtInstImport }, resultId, extInst); + m_currentState->header.Append(Opcode{ SpvOpExtInstImport }, resultId, extInst); - Append(Opcode{ SpvOpMemoryModel }, SpvAddressingModelLogical, SpvMemoryModelGLSL450); + m_currentState->header.Append(Opcode{ SpvOpMemoryModel }, SpvAddressingModelLogical, SpvMemoryModelGLSL450); + } + + void SpirvWriter::AppendStructType(std::size_t structIndex, UInt32 resultId) + { + const ShaderAst::Struct& s = m_context.shader->GetStruct(structIndex); + + m_currentState->types.Append(Opcode{ SpvOpTypeStruct }, WordCount{ static_cast(1 + 1 + s.members.size()) }); + m_currentState->types.Append(resultId); + + m_currentState->debugInfo.Append(Opcode{ SpvOpName }, resultId, s.name); + + m_currentState->annotations.Append(Opcode{ SpvOpDecorate }, resultId, SpvDecorationBlock); + + FieldOffsets structOffsets(StructLayout_Std140); + + for (std::size_t memberIndex = 0; memberIndex < s.members.size(); ++memberIndex) + { + const auto& member = s.members[memberIndex]; + m_currentState->types.Append(GetTypeId(member.type)); + m_currentState->debugInfo.Append(Opcode{ SpvOpMemberName }, resultId, memberIndex, member.name); + + std::visit([&](auto&& arg) + { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + std::size_t offset = [&] { + switch (arg) + { + case ShaderNodes::BasicType::Boolean: return structOffsets.AddField(StructFieldType_Bool1); + case ShaderNodes::BasicType::Float1: return structOffsets.AddField(StructFieldType_Float1); + case ShaderNodes::BasicType::Float2: return structOffsets.AddField(StructFieldType_Float2); + case ShaderNodes::BasicType::Float3: return structOffsets.AddField(StructFieldType_Float3); + case ShaderNodes::BasicType::Float4: return structOffsets.AddField(StructFieldType_Float4); + case ShaderNodes::BasicType::Mat4x4: return structOffsets.AddMatrix(StructFieldType_Float1, 4, 4, true); + case ShaderNodes::BasicType::Sampler2D: throw std::runtime_error("unexpected sampler2D as struct member"); + case ShaderNodes::BasicType::Void: throw std::runtime_error("unexpected void as struct member"); + } + + assert(false); + throw std::runtime_error("unhandled type"); + }(); + + m_currentState->annotations.Append(Opcode{ SpvOpMemberDecorate }, resultId, memberIndex, SpvDecorationOffset, offset); + + if (arg == ShaderNodes::BasicType::Mat4x4) + { + m_currentState->annotations.Append(Opcode{ SpvOpMemberDecorate }, resultId, memberIndex, SpvDecorationColMajor); + m_currentState->annotations.Append(Opcode{ SpvOpMemberDecorate }, resultId, memberIndex, SpvDecorationMatrixStride, 16); + } + } + else if constexpr (std::is_same_v) + { + // Register struct members type + const auto& structs = m_context.shader->GetStructs(); + auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); + if (it == structs.end()) + throw std::runtime_error("struct " + arg + " has not been defined"); + + std::size_t nestedStructIndex = std::distance(structs.begin(), it); + std::optional nestedFieldOffset = m_currentState->structFields[nestedStructIndex]; + if (!nestedFieldOffset) + throw std::runtime_error("struct dependency cycle"); + + structOffsets.AddStruct(nestedFieldOffset.value()); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, member.type); + } + + m_currentState->structFields[structIndex] = structOffsets; } void SpirvWriter::AppendTypes() @@ -287,28 +524,127 @@ namespace Nz using T = std::decay_t; if constexpr (std::is_same_v) { - // In SPIR-V, vec3 (for example) depends on float - UInt32 depResultId; - if (ShaderNodes::Node::GetComponentCount(arg) != 1) - depResultId = ProcessType(ShaderNodes::Node::GetComponentType(arg)); - switch (arg) { case ShaderNodes::BasicType::Boolean: - Append(Opcode{ SpvOpTypeBool }, resultId); + m_currentState->types.Append(Opcode{ SpvOpTypeBool }, resultId); break; case ShaderNodes::BasicType::Float1: - Append(Opcode{ SpvOpTypeFloat }, resultId); + m_currentState->types.Append(Opcode{ SpvOpTypeFloat }, resultId, 32); break; case ShaderNodes::BasicType::Float2: case ShaderNodes::BasicType::Float3: case ShaderNodes::BasicType::Float4: + { + UInt32 vecSize = UInt32(arg) - UInt32(ShaderNodes::BasicType::Float2) + 1; + + m_currentState->types.Append(Opcode{ SpvOpTypeVector }, resultId, GetTypeId(ShaderNodes::BasicType::Float1), vecSize); + break; + } + case ShaderNodes::BasicType::Mat4x4: + { + m_currentState->types.Append(Opcode{ SpvOpTypeMatrix }, resultId, GetTypeId(ShaderNodes::BasicType::Float4), 4); + break; + } + case ShaderNodes::BasicType::Sampler2D: + { + UInt32 imageTypeId = resultId - 1; + + m_currentState->types.Append(Opcode{ SpvOpTypeImage }, imageTypeId, GetTypeId(ShaderNodes::BasicType::Float1), SpvDim2D, 0, 0, 0, 1, SpvImageFormatUnknown); + m_currentState->types.Append(Opcode{ SpvOpTypeSampledImage }, resultId, imageTypeId); + break; + } + case ShaderNodes::BasicType::Void: - Append(Opcode{ SpvOpTypeVoid }, resultId); + m_currentState->types.Append(Opcode{ SpvOpTypeVoid }, resultId); + break; + } + } + else if constexpr (std::is_same_v) + { + // Register struct members type + const auto& structs = m_context.shader->GetStructs(); + auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); + if (it == structs.end()) + throw std::runtime_error("struct " + arg + " has not been defined"); + + std::size_t structIndex = std::distance(structs.begin(), it); + AppendStructType(structIndex, resultId); + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, type); + } + } + + UInt32 SpirvWriter::GetConstantId(const ShaderNodes::Constant::Variant& value) const + { + auto typeIt = m_currentState->constantIds.find(value); + assert(typeIt != m_currentState->constantIds.end()); + + return typeIt->second; + } + + UInt32 SpirvWriter::GetTypeId(const ShaderExpressionType& type) const + { + auto typeIt = m_currentState->typeIds.find(type); + assert(typeIt != m_currentState->typeIds.end()); + + return typeIt->second; + } + + void SpirvWriter::PushResultId(UInt32 value) + { + m_currentState->resultIds.push_back(value); + } + + UInt32 SpirvWriter::PopResultId() + { + if (m_currentState->resultIds.empty()) + throw std::runtime_error("invalid operation"); + + UInt32 resultId = m_currentState->resultIds.back(); + m_currentState->resultIds.pop_back(); + + return resultId; + } + + UInt32 SpirvWriter::RegisterType(ShaderExpressionType type) + { + auto it = m_currentState->typeIds.find(type); + if (it == m_currentState->typeIds.end()) + { + // Register sub-types, if any + std::visit([&](auto&& arg) + { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + switch (arg) + { + case ShaderNodes::BasicType::Boolean: + case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Void: + break; //< Nothing to do + + // In SPIR-V, vec3 (for example) depends on float + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + RegisterType(ShaderNodes::BasicType::Float1); + break; + + case ShaderNodes::BasicType::Mat4x4: + RegisterType(ShaderNodes::BasicType::Float4); + break; + + case ShaderNodes::BasicType::Sampler2D: + RegisterType(ShaderNodes::BasicType::Float1); + AllocateResultId(); //< Reserve a result id for the image type break; } } @@ -322,46 +658,7 @@ namespace Nz const ShaderAst::Struct& s = *it; for (const auto& member : s.members) - ProcessType(member.type); - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, type); - } - } - - std::size_t SpirvWriter::GetOutputOffset() const - { - assert(m_currentState); - return m_currentState->output->size(); - } - - UInt32 SpirvWriter::ProcessType(ShaderExpressionType type) - { - auto it = m_currentState->typeIds.find(type); - if (it == m_currentState->typeIds.end()) - { - // Register sub-types, if any - std::visit([&](auto&& arg) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - // In SPIR-V, vec3 (for example) depends on float - if (ShaderNodes::Node::GetComponentCount(arg) != 1) - ProcessType(ShaderNodes::Node::GetComponentType(arg)); - } - else if constexpr (std::is_same_v) - { - // Register struct members type - const auto& structs = m_context.shader->GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); - if (it == structs.end()) - throw std::runtime_error("struct " + arg + " has not been defined"); - - const ShaderAst::Struct& s = *it; - for (const auto& member : s.members) - ProcessType(member.type); + RegisterType(member.type); } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); @@ -373,56 +670,147 @@ namespace Nz return it->second; } - void SpirvWriter::Visit(const ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired) - { - } - void SpirvWriter::Visit(const ShaderNodes::AccessMember& /*node*/) - { - } - void SpirvWriter::Visit(const ShaderNodes::AssignOp& /*node*/) - { - } - void SpirvWriter::Visit(const ShaderNodes::Branch& /*node*/) - { - } - void SpirvWriter::Visit(const ShaderNodes::BinaryOp& /*node*/) + void SpirvWriter::Visit(const ShaderNodes::AccessMember& node) { + Visit(node.structExpr); } - void SpirvWriter::Visit(const ShaderNodes::Cast& /*node*/) - { - } - void SpirvWriter::Visit(const ShaderNodes::Constant& /*node*/) - { - } - void SpirvWriter::Visit(const ShaderNodes::DeclareVariable& /*node*/) - { - } - void SpirvWriter::Visit(const ShaderNodes::ExpressionStatement& /*node*/) - { - } - void SpirvWriter::Visit(const ShaderNodes::Identifier& /*node*/) + void SpirvWriter::Visit(const ShaderNodes::AssignOp& node) { + Visit(node.left); + Visit(node.right); } - void SpirvWriter::Visit(const ShaderNodes::IntrinsicCall& /*node*/) + void SpirvWriter::Visit(const ShaderNodes::Branch& node) { + throw std::runtime_error("not yet implemented"); } - void SpirvWriter::Visit(const ShaderNodes::Sample2D& /*node*/) - { - } - void SpirvWriter::Visit(const ShaderNodes::StatementBlock& /*node*/) - { - } - void SpirvWriter::Visit(const ShaderNodes::SwizzleOp& /*node*/) + void SpirvWriter::Visit(const ShaderNodes::BinaryOp& node) { + Visit(node.left); + Visit(node.right); + + UInt32 resultId = AllocateResultId(); + UInt32 leftOperand = PopResultId(); + UInt32 rightOperand = PopResultId(); + + SpvOp op = [&] { + switch (node.op) + { + case ShaderNodes::BinaryType::Add: return SpvOpFAdd; + case ShaderNodes::BinaryType::Substract: return SpvOpFSub; + case ShaderNodes::BinaryType::Multiply: return SpvOpFMul; + case ShaderNodes::BinaryType::Divide: return SpvOpFDiv; + case ShaderNodes::BinaryType::Equality: return SpvOpFOrdEqual; + } + + assert(false); + throw std::runtime_error("unexpected binary operation"); + }(); + + m_currentState->instructions.Append(Opcode{ op }, GetTypeId(ShaderNodes::BasicType::Float3), resultId, leftOperand, rightOperand); } - void SpirvWriter::MergeBlocks(std::vector& output, const std::vector& from) + void SpirvWriter::Visit(const ShaderNodes::Cast& node) + { + for (auto& expr : node.expressions) + { + if (!expr) + break; + + Visit(expr); + } + } + + void SpirvWriter::Visit(const ShaderNodes::Constant& node) + { + std::visit([&] (const auto& value) + { + PushResultId(GetConstantId(value)); + }, node.value); + } + + void SpirvWriter::Visit(const ShaderNodes::DeclareVariable& node) + { + if (node.expression) + Visit(node.expression); + } + + void SpirvWriter::Visit(const ShaderNodes::ExpressionStatement& node) + { + Visit(node.expression); + } + + void SpirvWriter::Visit(const ShaderNodes::Identifier& node) + { + PushResultId(42); + } + + void SpirvWriter::Visit(const ShaderNodes::IntrinsicCall& node) + { + for (auto& param : node.parameters) + Visit(param); + } + + void SpirvWriter::Visit(const ShaderNodes::Sample2D& node) + { + Visit(node.sampler); + Visit(node.coordinates); + } + + void SpirvWriter::Visit(const ShaderNodes::StatementBlock& node) + { + for (auto& statement : node.statements) + Visit(statement); + } + + void SpirvWriter::Visit(const ShaderNodes::SwizzleOp& node) + { + Visit(node.expression); + } + + void SpirvWriter::MergeBlocks(std::vector& output, const Section& from) { std::size_t prevSize = output.size(); - output.resize(prevSize + from.size()); - std::copy(from.begin(), from.end(), output.begin() + prevSize); + output.resize(prevSize + from.data.size()); + std::copy(from.data.begin(), from.data.end(), output.begin() + prevSize); + } + + std::size_t SpirvWriter::Section::Append(const Opcode& opcode, const WordCount& wordCount) + { + return Append(UInt32(opcode.op) | UInt32(wordCount.wc) << 16); + } + + std::size_t SpirvWriter::Section::Append(const Raw& raw) + { + std::size_t offset = GetOutputOffset(); + + const UInt8* ptr = static_cast(raw.ptr); + + std::size_t size4 = CountWord(raw); + for (std::size_t i = 0; i < size4; ++i) + { + UInt32 codepoint = 0; + for (std::size_t j = 0; j < 4; ++j) + { + std::size_t pos = i * 4 + j; + if (pos < raw.size) + codepoint |= UInt32(ptr[pos]) << (j * 8); + } + +#ifdef NAZARA_BIG_ENDIAN + SwapBytes(codepoint); +#endif + + Append(codepoint); + } + + return offset; + } + + unsigned int SpirvWriter::Section::CountWord(const Raw& raw) + { + return (raw.size + sizeof(UInt32) - 1) / sizeof(UInt32); } }