diff --git a/include/Nazara/Shader/SpirvDecoder.hpp b/include/Nazara/Shader/SpirvDecoder.hpp new file mode 100644 index 000000000..fc7ccd935 --- /dev/null +++ b/include/Nazara/Shader/SpirvDecoder.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SPIRVDECODER_HPP +#define NAZARA_SPIRVDECODER_HPP + +#include +#include +#include +#include +#include +#include + +namespace Nz +{ + class NAZARA_SHADER_API SpirvDecoder + { + public: + SpirvDecoder() = default; + SpirvDecoder(const SpirvDecoder&) = default; + SpirvDecoder(SpirvDecoder&&) = default; + ~SpirvDecoder() = default; + + void Decode(const UInt32* codepoints, std::size_t count); + + SpirvDecoder& operator=(const SpirvDecoder&) = default; + SpirvDecoder& operator=(SpirvDecoder&&) = default; + + protected: + struct SpirvHeader; + + inline const UInt32* GetCurrentPtr() const; + + virtual bool HandleHeader(const SpirvHeader& header); + virtual bool HandleOpcode(const SpirvInstruction& instruction, UInt32 wordCount) = 0; + + std::string ReadString(); + UInt32 ReadWord(); + + struct SpirvHeader + { + UInt32 generatorId; + UInt32 bound; + UInt32 schema; + UInt32 versionNumber; + }; + + private: + const UInt32* m_currentCodepoint; + const UInt32* m_codepointEnd; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/SpirvDecoder.inl b/include/Nazara/Shader/SpirvDecoder.inl new file mode 100644 index 000000000..aa937bd5b --- /dev/null +++ b/include/Nazara/Shader/SpirvDecoder.inl @@ -0,0 +1,16 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + inline const UInt32* SpirvDecoder::GetCurrentPtr() const + { + return m_currentCodepoint; + } +} + +#include diff --git a/include/Nazara/Shader/SpirvPrinter.hpp b/include/Nazara/Shader/SpirvPrinter.hpp index c7c7815f7..014fce121 100644 --- a/include/Nazara/Shader/SpirvPrinter.hpp +++ b/include/Nazara/Shader/SpirvPrinter.hpp @@ -9,12 +9,13 @@ #include #include +#include #include #include namespace Nz { - class NAZARA_SHADER_API SpirvPrinter + class NAZARA_SHADER_API SpirvPrinter : SpirvDecoder { public: struct Settings; @@ -39,9 +40,8 @@ namespace Nz }; private: - void AppendInstruction(); - std::string ReadString(); - UInt32 ReadWord(); + bool HandleHeader(const SpirvHeader& header) override; + bool HandleOpcode(const SpirvInstruction& instruction, UInt32 wordCount) override; struct State; diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index 17ac7d023..3b770ae87 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -9,8 +9,8 @@ #include #include -#include #include +#include #include #include #include @@ -90,9 +90,7 @@ namespace Nz struct Context { - ShaderAst::AstCache cache; const States* states = nullptr; - std::vector functionBlocks; }; struct ExtVar diff --git a/src/Nazara/Shader/SpirvDecoder.cpp b/src/Nazara/Shader/SpirvDecoder.cpp new file mode 100644 index 000000000..141b009f1 --- /dev/null +++ b/src/Nazara/Shader/SpirvDecoder.cpp @@ -0,0 +1,87 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Nz +{ + void SpirvDecoder::Decode(const UInt32* codepoints, std::size_t count) + { + m_currentCodepoint = codepoints; + m_codepointEnd = codepoints + count; + + UInt32 magicNumber = ReadWord(); + if (magicNumber != SpirvMagicNumber) + throw std::runtime_error("invalid Spir-V: magic number didn't match"); + + UInt32 versionNumber = ReadWord(); + if (versionNumber > SpirvVersion) + throw std::runtime_error("Spir-V is more recent than decoder, dismissing"); + + SpirvHeader header; + header.generatorId = ReadWord(); + header.bound = ReadWord(); + header.schema = ReadWord(); + header.versionNumber = versionNumber; + + if (!HandleHeader(header)) + return; + + while (m_currentCodepoint < m_codepointEnd) + { + UInt32 firstWord = ReadWord(); + + UInt16 wordCount = static_cast((firstWord >> 16) & 0xFFFF); + UInt16 opcode = static_cast(firstWord & 0xFFFF); + + const SpirvInstruction* inst = GetInstructionData(opcode); + if (!inst) + throw std::runtime_error("invalid instruction"); + + if (!HandleOpcode(*inst, wordCount)) + break; + + m_currentCodepoint += wordCount - 1; + } + } + + bool SpirvDecoder::HandleHeader(const SpirvHeader& /*header*/) + { + return true; + } + + std::string SpirvDecoder::ReadString() + { + std::string str; + + for (;;) + { + UInt32 value = ReadWord(); + for (std::size_t j = 0; j < 4; ++j) + { + char c = static_cast((value >> (j * 8)) & 0xFF); + if (c == '\0') + return str; + + str.push_back(c); + } + } + } + + UInt32 SpirvDecoder::ReadWord() + { + if (m_currentCodepoint >= m_codepointEnd) + throw std::runtime_error("unexpected end of stream"); + + return *m_currentCodepoint++; + } +} diff --git a/src/Nazara/Shader/SpirvPrinter.cpp b/src/Nazara/Shader/SpirvPrinter.cpp index 01eb65cb0..cd618c8bb 100644 --- a/src/Nazara/Shader/SpirvPrinter.cpp +++ b/src/Nazara/Shader/SpirvPrinter.cpp @@ -21,10 +21,7 @@ namespace Nz { } - const UInt32* codepoints; std::size_t resultOffset; - std::size_t index = 0; - std::size_t count; std::ostringstream stream; const Settings& settings; }; @@ -32,71 +29,49 @@ namespace Nz std::string SpirvPrinter::Print(const UInt32* codepoints, std::size_t count, const Settings& settings) { State state(settings); - state.codepoints = codepoints; - state.count = count; m_currentState = &state; CallOnExit resetOnExit([&] { m_currentState = nullptr; }); - UInt32 magicNumber = ReadWord(); - if (magicNumber != SpirvMagicNumber) - throw std::runtime_error("invalid Spir-V: magic number didn't match"); - - if (m_currentState->settings.printHeader) - m_currentState->stream << "Spir-V module\n"; - - UInt32 versionNumber = ReadWord(); - if (versionNumber > SpirvVersion) - throw std::runtime_error("Spir-V is more recent than printer, dismissing"); - - UInt8 majorVersion = ((versionNumber) >> 16) & 0xFF; - UInt8 minorVersion = ((versionNumber) >> 8) & 0xFF; - - UInt32 generatorId = ReadWord(); - UInt32 bound = ReadWord(); - UInt32 schema = ReadWord(); - - state.resultOffset = std::snprintf(nullptr, 0, "%%%u = ", bound); - - if (m_currentState->settings.printHeader) - { - m_currentState->stream << "Version " + std::to_string(+majorVersion) << "." << std::to_string(+minorVersion) << "\n"; - m_currentState->stream << "Generator: " << std::to_string(generatorId) << "\n"; - m_currentState->stream << "Bound: " << std::to_string(bound) << "\n"; - m_currentState->stream << "Schema: " << std::to_string(schema) << "\n"; - } - - while (m_currentState->index < m_currentState->count) - AppendInstruction(); + Decode(codepoints, count); return m_currentState->stream.str(); } - void SpirvPrinter::AppendInstruction() + bool SpirvPrinter::HandleHeader(const SpirvHeader& header) { - std::size_t startIndex = m_currentState->index; + UInt8 majorVersion = ((header.versionNumber) >> 16) & 0xFF; + UInt8 minorVersion = ((header.versionNumber) >> 8) & 0xFF; - UInt32 firstWord = ReadWord(); + m_currentState->resultOffset = std::snprintf(nullptr, 0, "%%%u = ", header.bound); - UInt16 wordCount = static_cast((firstWord >> 16) & 0xFFFF); - UInt16 opcode = static_cast(firstWord & 0xFFFF); + if (m_currentState->settings.printHeader) + { + m_currentState->stream << "Version " + std::to_string(+majorVersion) << "." << std::to_string(+minorVersion) << "\n"; + m_currentState->stream << "Generator: " << std::to_string(header.generatorId) << "\n"; + m_currentState->stream << "Bound: " << std::to_string(header.bound) << "\n"; + m_currentState->stream << "Schema: " << std::to_string(header.schema) << "\n"; + } - const SpirvInstruction* inst = GetInstructionData(opcode); - if (!inst) - throw std::runtime_error("invalid instruction"); + return true; + } + + bool SpirvPrinter::HandleOpcode(const SpirvInstruction& instruction, UInt32 wordCount) + { + const UInt32* startPtr = GetCurrentPtr(); if (m_currentState->settings.printParameters) { std::ostringstream instructionStream; - instructionStream << inst->name; + instructionStream << instruction.name; UInt32 resultId = 0; std::size_t currentOperand = 0; - std::size_t instructionEnd = startIndex + wordCount; - while (m_currentState->index < instructionEnd) + const UInt32* endPtr = startPtr + wordCount; + while (GetCurrentPtr() < endPtr) { - const SpirvInstruction::Operand* operand = &inst->operands[currentOperand]; + const SpirvInstruction::Operand* operand = &instruction.operands[currentOperand]; if (operand->kind != SpirvOperandKind::IdResult) { @@ -181,7 +156,7 @@ namespace Nz break; } - + case SpirvOperandKind::PairLiteralIntegerIdRef: { ReadWord(); @@ -210,13 +185,12 @@ namespace Nz default: break; - } } else resultId = ReadWord(); - if (currentOperand < inst->minOperandCount - 1) + if (currentOperand < instruction.minOperandCount - 1) currentOperand++; } @@ -231,42 +205,12 @@ namespace Nz m_currentState->stream << instructionStream.str(); } else - { - m_currentState->stream << inst->name; - - m_currentState->index += wordCount - 1; - if (m_currentState->index > m_currentState->count) - throw std::runtime_error("unexpected end of stream"); - } + m_currentState->stream << instruction.name; m_currentState->stream << "\n"; - assert(m_currentState->index == startIndex + wordCount); - } + assert(GetCurrentPtr() == startPtr + wordCount); - std::string SpirvPrinter::ReadString() - { - std::string str; - - for (;;) - { - UInt32 value = ReadWord(); - for (std::size_t j = 0; j < 4; ++j) - { - char c = static_cast((value >> (j * 8)) & 0xFF); - if (c == '\0') - return str; - - str.push_back(c); - } - } - } - - UInt32 SpirvPrinter::ReadWord() - { - if (m_currentState->index >= m_currentState->count) - throw std::runtime_error("unexpected end of stream"); - - return m_currentState->codepoints[m_currentState->index++]; + return true; } }