Shader: Add SpirvDecoder

This commit is contained in:
Jérôme Leclercq
2021-04-04 20:29:23 +02:00
parent 5a63eb4d97
commit 09df5f389e
6 changed files with 194 additions and 90 deletions

View File

@@ -0,0 +1,87 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvDecoder.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/StackArray.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <cassert>
#include <iomanip>
#include <sstream>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
void SpirvDecoder::Decode(const UInt32* codepoints, std::size_t count)
{
m_currentCodepoint = codepoints;
m_codepointEnd = codepoints + count;
UInt32 magicNumber = ReadWord();
if (magicNumber != SpirvMagicNumber)
throw std::runtime_error("invalid Spir-V: magic number didn't match");
UInt32 versionNumber = ReadWord();
if (versionNumber > SpirvVersion)
throw std::runtime_error("Spir-V is more recent than decoder, dismissing");
SpirvHeader header;
header.generatorId = ReadWord();
header.bound = ReadWord();
header.schema = ReadWord();
header.versionNumber = versionNumber;
if (!HandleHeader(header))
return;
while (m_currentCodepoint < m_codepointEnd)
{
UInt32 firstWord = ReadWord();
UInt16 wordCount = static_cast<UInt16>((firstWord >> 16) & 0xFFFF);
UInt16 opcode = static_cast<UInt16>(firstWord & 0xFFFF);
const SpirvInstruction* inst = GetInstructionData(opcode);
if (!inst)
throw std::runtime_error("invalid instruction");
if (!HandleOpcode(*inst, wordCount))
break;
m_currentCodepoint += wordCount - 1;
}
}
bool SpirvDecoder::HandleHeader(const SpirvHeader& /*header*/)
{
return true;
}
std::string SpirvDecoder::ReadString()
{
std::string str;
for (;;)
{
UInt32 value = ReadWord();
for (std::size_t j = 0; j < 4; ++j)
{
char c = static_cast<char>((value >> (j * 8)) & 0xFF);
if (c == '\0')
return str;
str.push_back(c);
}
}
}
UInt32 SpirvDecoder::ReadWord()
{
if (m_currentCodepoint >= m_codepointEnd)
throw std::runtime_error("unexpected end of stream");
return *m_currentCodepoint++;
}
}

View File

@@ -21,10 +21,7 @@ namespace Nz
{
}
const UInt32* codepoints;
std::size_t resultOffset;
std::size_t index = 0;
std::size_t count;
std::ostringstream stream;
const Settings& settings;
};
@@ -32,71 +29,49 @@ namespace Nz
std::string SpirvPrinter::Print(const UInt32* codepoints, std::size_t count, const Settings& settings)
{
State state(settings);
state.codepoints = codepoints;
state.count = count;
m_currentState = &state;
CallOnExit resetOnExit([&] { m_currentState = nullptr; });
UInt32 magicNumber = ReadWord();
if (magicNumber != SpirvMagicNumber)
throw std::runtime_error("invalid Spir-V: magic number didn't match");
if (m_currentState->settings.printHeader)
m_currentState->stream << "Spir-V module\n";
UInt32 versionNumber = ReadWord();
if (versionNumber > SpirvVersion)
throw std::runtime_error("Spir-V is more recent than printer, dismissing");
UInt8 majorVersion = ((versionNumber) >> 16) & 0xFF;
UInt8 minorVersion = ((versionNumber) >> 8) & 0xFF;
UInt32 generatorId = ReadWord();
UInt32 bound = ReadWord();
UInt32 schema = ReadWord();
state.resultOffset = std::snprintf(nullptr, 0, "%%%u = ", bound);
if (m_currentState->settings.printHeader)
{
m_currentState->stream << "Version " + std::to_string(+majorVersion) << "." << std::to_string(+minorVersion) << "\n";
m_currentState->stream << "Generator: " << std::to_string(generatorId) << "\n";
m_currentState->stream << "Bound: " << std::to_string(bound) << "\n";
m_currentState->stream << "Schema: " << std::to_string(schema) << "\n";
}
while (m_currentState->index < m_currentState->count)
AppendInstruction();
Decode(codepoints, count);
return m_currentState->stream.str();
}
void SpirvPrinter::AppendInstruction()
bool SpirvPrinter::HandleHeader(const SpirvHeader& header)
{
std::size_t startIndex = m_currentState->index;
UInt8 majorVersion = ((header.versionNumber) >> 16) & 0xFF;
UInt8 minorVersion = ((header.versionNumber) >> 8) & 0xFF;
UInt32 firstWord = ReadWord();
m_currentState->resultOffset = std::snprintf(nullptr, 0, "%%%u = ", header.bound);
UInt16 wordCount = static_cast<UInt16>((firstWord >> 16) & 0xFFFF);
UInt16 opcode = static_cast<UInt16>(firstWord & 0xFFFF);
if (m_currentState->settings.printHeader)
{
m_currentState->stream << "Version " + std::to_string(+majorVersion) << "." << std::to_string(+minorVersion) << "\n";
m_currentState->stream << "Generator: " << std::to_string(header.generatorId) << "\n";
m_currentState->stream << "Bound: " << std::to_string(header.bound) << "\n";
m_currentState->stream << "Schema: " << std::to_string(header.schema) << "\n";
}
const SpirvInstruction* inst = GetInstructionData(opcode);
if (!inst)
throw std::runtime_error("invalid instruction");
return true;
}
bool SpirvPrinter::HandleOpcode(const SpirvInstruction& instruction, UInt32 wordCount)
{
const UInt32* startPtr = GetCurrentPtr();
if (m_currentState->settings.printParameters)
{
std::ostringstream instructionStream;
instructionStream << inst->name;
instructionStream << instruction.name;
UInt32 resultId = 0;
std::size_t currentOperand = 0;
std::size_t instructionEnd = startIndex + wordCount;
while (m_currentState->index < instructionEnd)
const UInt32* endPtr = startPtr + wordCount;
while (GetCurrentPtr() < endPtr)
{
const SpirvInstruction::Operand* operand = &inst->operands[currentOperand];
const SpirvInstruction::Operand* operand = &instruction.operands[currentOperand];
if (operand->kind != SpirvOperandKind::IdResult)
{
@@ -181,7 +156,7 @@ namespace Nz
break;
}
case SpirvOperandKind::PairLiteralIntegerIdRef:
{
ReadWord();
@@ -210,13 +185,12 @@ namespace Nz
default:
break;
}
}
else
resultId = ReadWord();
if (currentOperand < inst->minOperandCount - 1)
if (currentOperand < instruction.minOperandCount - 1)
currentOperand++;
}
@@ -231,42 +205,12 @@ namespace Nz
m_currentState->stream << instructionStream.str();
}
else
{
m_currentState->stream << inst->name;
m_currentState->index += wordCount - 1;
if (m_currentState->index > m_currentState->count)
throw std::runtime_error("unexpected end of stream");
}
m_currentState->stream << instruction.name;
m_currentState->stream << "\n";
assert(m_currentState->index == startIndex + wordCount);
}
assert(GetCurrentPtr() == startPtr + wordCount);
std::string SpirvPrinter::ReadString()
{
std::string str;
for (;;)
{
UInt32 value = ReadWord();
for (std::size_t j = 0; j < 4; ++j)
{
char c = static_cast<char>((value >> (j * 8)) & 0xFF);
if (c == '\0')
return str;
str.push_back(c);
}
}
}
UInt32 SpirvPrinter::ReadWord()
{
if (m_currentState->index >= m_currentState->count)
throw std::runtime_error("unexpected end of stream");
return m_currentState->codepoints[m_currentState->index++];
return true;
}
}