NazaraEngine/src/Nazara/Shader/SpirvWriter.cpp

742 lines
25 KiB
C++

// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Engine - Shader module"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/Ast/AstCloner.hpp>
#include <Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp>
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
#include <Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp>
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
#include <SpirV/GLSL.std.450.h>
#include <frozen/unordered_map.h>
#include <tsl/ordered_map.h>
#include <tsl/ordered_set.h>
#include <cassert>
#include <map>
#include <stdexcept>
#include <type_traits>
#include <vector>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
namespace
{
struct SpirvBuiltin
{
const char* debugName;
ShaderStageTypeFlags compatibleStages;
SpirvBuiltIn decoration;
};
constexpr auto s_spirvBuiltinMapping = frozen::make_unordered_map<ShaderAst::BuiltinEntry, SpirvBuiltin>({
{ ShaderAst::BuiltinEntry::FragCoord, { "FragmentCoordinates", ShaderStageType::Fragment, SpirvBuiltIn::FragCoord } },
{ ShaderAst::BuiltinEntry::FragDepth, { "FragmentDepth", ShaderStageType::Fragment, SpirvBuiltIn::FragDepth } },
{ ShaderAst::BuiltinEntry::VertexPosition, { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
});
class SpirvPreVisitor : public ShaderAst::AstRecursiveVisitor
{
public:
struct UniformVar
{
UInt32 bindingIndex;
UInt32 descriptorSet;
UInt32 pointerId;
};
using BuiltinDecoration = std::map<UInt32, SpirvBuiltIn>;
using LocationDecoration = std::map<UInt32, UInt32>;
using ExtInstList = std::unordered_set<std::string>;
using ExtVarContainer = std::unordered_map<std::size_t /*varIndex*/, UniformVar>;
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
using StructContainer = std::vector<ShaderAst::StructDescription*>;
SpirvPreVisitor(SpirvConstantCache& constantCache, std::unordered_map<std::size_t, SpirvAstVisitor::FuncData>& funcs) :
m_constantCache(constantCache),
m_funcs(funcs)
{
m_constantCache.SetStructCallback([this](std::size_t structIndex) -> const ShaderAst::StructDescription&
{
assert(structIndex < declaredStructs.size());
return *declaredStructs[structIndex];
});
}
void Visit(ShaderAst::AccessIndexExpression& node) override
{
AstRecursiveVisitor::Visit(node);
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::BinaryExpression& node) override
{
AstRecursiveVisitor::Visit(node);
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::CallFunctionExpression& node) override
{
AstRecursiveVisitor::Visit(node);
assert(m_funcIndex);
auto& func = Retrieve(m_funcs, *m_funcIndex);
auto& funcCall = func.funcCalls.emplace_back();
funcCall.firstVarIndex = func.variables.size();
for (const auto& parameter : node.parameters)
{
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(*GetExpressionType(*parameter), SpirvStorageClass::Function));
}
}
void Visit(ShaderAst::ConditionalExpression& /*node*/) override
{
throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?");
}
void Visit(ShaderAst::ConditionalStatement& /*node*/) override
{
throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?");
}
void Visit(ShaderAst::ConstantValueExpression& node) override
{
std::visit([&](auto&& arg)
{
m_constantCache.Register(*m_constantCache.BuildConstant(arg));
}, node.value);
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderAst::DeclareExternalStatement& node) override
{
for (auto& extVar : node.externalVars)
{
SpirvConstantCache::Variable variable;
variable.debugName = extVar.name;
const ShaderAst::ExpressionType& extVarType = extVar.type.GetResultingValue();
if (ShaderAst::IsSamplerType(extVarType))
{
variable.storageClass = SpirvStorageClass::UniformConstant;
variable.type = m_constantCache.BuildPointerType(extVarType, variable.storageClass);
}
else
{
assert(ShaderAst::IsUniformType(extVarType));
const auto& uniformType = std::get<ShaderAst::UniformType>(extVarType);
const auto& structType = uniformType.containedType;
assert(structType.structIndex < declaredStructs.size());
const auto& type = m_constantCache.BuildType(*declaredStructs[structType.structIndex], { SpirvDecoration::Block });
variable.storageClass = SpirvStorageClass::Uniform;
variable.type = m_constantCache.BuildPointerType(type, variable.storageClass);
}
assert(extVar.bindingIndex.IsResultingValue());
assert(extVar.varIndex);
UniformVar& uniformVar = extVars[*extVar.varIndex];
uniformVar.pointerId = m_constantCache.Register(variable);
uniformVar.bindingIndex = extVar.bindingIndex.GetResultingValue();
uniformVar.descriptorSet = (extVar.bindingSet.HasValue()) ? extVar.bindingSet.GetResultingValue() : 0;
}
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
std::optional<ShaderStageType> entryPointType;
if (node.entryStage.HasValue())
entryPointType = node.entryStage.GetResultingValue();
assert(node.funcIndex);
std::size_t funcIndex = *node.funcIndex;
auto& funcData = m_funcs[funcIndex];
funcData.name = node.name;
funcData.funcIndex = funcIndex;
if (!entryPointType)
{
std::vector<ShaderAst::ExpressionType> parameterTypes;
for (auto& parameter : node.parameters)
parameterTypes.push_back(parameter.type.GetResultingValue());
if (node.returnType.HasValue())
{
const auto& returnType = node.returnType.GetResultingValue();
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(returnType));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(returnType, parameterTypes));
}
else
{
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{}));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, parameterTypes));
}
for (auto& parameter : node.parameters)
{
const auto& parameterType = parameter.type.GetResultingValue();
auto& funcParam = funcData.parameters.emplace_back();
funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameterType, SpirvStorageClass::Function));
funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameterType));
}
}
else
{
using EntryPoint = SpirvAstVisitor::EntryPoint;
std::vector<SpirvExecutionMode> executionModes;
if (*entryPointType == ShaderStageType::Fragment)
{
executionModes.push_back(SpirvExecutionMode::OriginUpperLeft);
if (node.earlyFragmentTests.HasValue() && node.earlyFragmentTests.GetResultingValue())
executionModes.push_back(SpirvExecutionMode::EarlyFragmentTests);
if (node.depthWrite.HasValue())
{
executionModes.push_back(SpirvExecutionMode::DepthReplacing);
switch (node.depthWrite.GetResultingValue())
{
case ShaderAst::DepthWriteMode::Replace: break;
case ShaderAst::DepthWriteMode::Greater: executionModes.push_back(SpirvExecutionMode::DepthGreater); break;
case ShaderAst::DepthWriteMode::Less: executionModes.push_back(SpirvExecutionMode::DepthLess); break;
case ShaderAst::DepthWriteMode::Unchanged: executionModes.push_back(SpirvExecutionMode::DepthUnchanged); break;
}
}
}
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{}));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, {}));
std::optional<EntryPoint::InputStruct> inputStruct;
std::vector<EntryPoint::Input> inputs;
if (!node.parameters.empty())
{
assert(node.parameters.size() == 1);
auto& parameter = node.parameters.front();
const auto& parameterType = parameter.type.GetResultingValue();
assert(std::holds_alternative<ShaderAst::StructType>(parameterType));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameterType).structIndex;
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
for (const auto& member : structDesc->members)
{
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0)
{
inputs.push_back({
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(memberIndex))),
m_constantCache.Register(*m_constantCache.BuildPointerType(member.type.GetResultingValue(), SpirvStorageClass::Function)),
varId
});
}
memberIndex++;
}
inputStruct = EntryPoint::InputStruct{
m_constantCache.Register(*m_constantCache.BuildPointerType(parameterType, SpirvStorageClass::Function)),
m_constantCache.Register(*m_constantCache.BuildType(parameter.type.GetResultingValue()))
};
}
std::optional<UInt32> outputStructId;
std::vector<EntryPoint::Output> outputs;
if (node.returnType.HasValue() && !IsNoType(node.returnType.GetResultingValue()))
{
const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue();
assert(std::holds_alternative<ShaderAst::StructType>(returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
for (const auto& member : structDesc->members)
{
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0)
{
outputs.push_back({
Int32(memberIndex),
m_constantCache.Register(*m_constantCache.BuildType(member.type.GetResultingValue())),
varId
});
}
memberIndex++;
}
outputStructId = m_constantCache.Register(*m_constantCache.BuildType(returnType));
}
funcData.entryPointData = EntryPoint{
*entryPointType,
inputStruct,
outputStructId,
std::move(inputs),
std::move(outputs),
std::move(executionModes)
};
}
m_funcIndex = funcIndex;
AstRecursiveVisitor::Visit(node);
m_funcIndex.reset();
}
void Visit(ShaderAst::DeclareStructStatement& node) override
{
AstRecursiveVisitor::Visit(node);
assert(node.structIndex);
std::size_t structIndex = *node.structIndex;
if (structIndex >= declaredStructs.size())
declaredStructs.resize(structIndex + 1);
declaredStructs[structIndex] = &node.description;
m_constantCache.Register(*m_constantCache.BuildType(node.description));
}
void Visit(ShaderAst::DeclareVariableStatement& node) override
{
AstRecursiveVisitor::Visit(node);
assert(m_funcIndex);
auto& func = m_funcs[*m_funcIndex];
assert(node.varIndex);
func.varIndexToVarId[*node.varIndex] = func.variables.size();
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType.GetResultingValue(), SpirvStorageClass::Function));
}
void Visit(ShaderAst::IdentifierExpression& node) override
{
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderAst::IntrinsicExpression& node) override
{
AstRecursiveVisitor::Visit(node);
switch (node.intrinsic)
{
// Require GLSL.std.450
case ShaderAst::IntrinsicType::CrossProduct:
case ShaderAst::IntrinsicType::Exp:
case ShaderAst::IntrinsicType::Length:
case ShaderAst::IntrinsicType::Max:
case ShaderAst::IntrinsicType::Min:
case ShaderAst::IntrinsicType::Normalize:
case ShaderAst::IntrinsicType::Pow:
case ShaderAst::IntrinsicType::Reflect:
extInsts.emplace("GLSL.std.450");
break;
// Part of SPIR-V core
case ShaderAst::IntrinsicType::DotProduct:
case ShaderAst::IntrinsicType::SampleTexture:
break;
}
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::SwizzleExpression& node) override
{
AstRecursiveVisitor::Visit(node);
for (std::size_t i = 0; i < node.componentCount; ++i)
{
Int32 indexCount = SafeCast<Int32>(node.components[i]);
m_constantCache.Register(*m_constantCache.BuildConstant(indexCount));
}
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::UnaryExpression& node) override
{
AstRecursiveVisitor::Visit(node);
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
{
if (member.builtin.HasValue())
{
auto it = s_spirvBuiltinMapping.find(member.builtin.GetResultingValue());
assert(it != s_spirvBuiltinMapping.end());
const SpirvBuiltin& builtin = it->second;
if ((builtin.compatibleStages & entryPointType) == 0)
return 0;
SpirvBuiltIn builtinDecoration = builtin.decoration;
SpirvConstantCache::Variable variable;
variable.debugName = builtin.debugName;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type.GetResultingValue(), storageClass);
UInt32 varId = m_constantCache.Register(variable);
builtinDecorations[varId] = builtinDecoration;
return varId;
}
else if (member.locationIndex.HasValue())
{
SpirvConstantCache::Variable variable;
variable.debugName = member.name;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type.GetResultingValue(), storageClass);
UInt32 varId = m_constantCache.Register(variable);
locationDecorations[varId] = member.locationIndex.GetResultingValue();
return varId;
}
return 0;
}
BuiltinDecoration builtinDecorations;
ExtInstList extInsts;
ExtVarContainer extVars;
LocationDecoration locationDecorations;
StructContainer declaredStructs;
private:
SpirvConstantCache& m_constantCache;
std::optional<std::size_t> m_funcIndex;
std::unordered_map<std::size_t, SpirvAstVisitor::FuncData>& m_funcs;
};
}
struct SpirvWriter::State
{
State() :
constantTypeCache(nextVarIndex)
{
}
struct Func
{
const ShaderAst::DeclareFunctionStatement* statement = nullptr;
UInt32 typeId;
UInt32 id;
};
std::unordered_map<std::string, UInt32> extensionInstructionSet;
std::unordered_map<std::string, UInt32> varToResult;
std::unordered_map<std::size_t, SpirvAstVisitor::FuncData> funcs;
std::vector<UInt32> resultIds;
UInt32 nextVarIndex = 1;
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
SpirvPreVisitor* previsitor;
// Output
SpirvSection header;
SpirvSection constants;
SpirvSection debugInfo;
SpirvSection annotations;
SpirvSection instructions;
};
SpirvWriter::SpirvWriter() :
m_currentState(nullptr)
{
}
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst::Module& module, const States& states)
{
ShaderAst::ModulePtr sanitizedModule;
const ShaderAst::Module* targetModule;
if (!states.sanitized)
{
ShaderAst::SanitizeVisitor::Options options;
options.moduleResolver = states.shaderModuleResolver;
options.optionValues = states.optionValues;
options.reduceLoopsToWhile = true;
options.removeAliases = true;
options.removeCompoundAssignments = true;
options.removeConstDeclaration = true;
options.removeMatrixCast = true;
options.removeOptionDeclaration = true;
options.splitMultipleBranches = true;
options.useIdentifierAccessesForStructs = false;
sanitizedModule = ShaderAst::Sanitize(module, options);
targetModule = sanitizedModule.get();
}
else
targetModule = &module;
ShaderAst::ModulePtr optimizedModule;
if (states.optimize)
{
ShaderAst::StatementPtr tempAst;
ShaderAst::DependencyCheckerVisitor::Config dependencyConfig;
dependencyConfig.usedShaderStages = ShaderStageType_All;
optimizedModule = ShaderAst::PropagateConstants(*targetModule);
optimizedModule = ShaderAst::EliminateUnusedPass(*optimizedModule, dependencyConfig);
targetModule = optimizedModule.get();
}
// Previsitor
m_context.states = &states;
State state;
m_currentState = &state;
CallOnExit onExit([this]()
{
m_currentState = nullptr;
});
// Register all extended instruction sets
SpirvPreVisitor previsitor(state.constantTypeCache, state.funcs);
for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(previsitor);
targetModule->rootNode->Visit(previsitor);
m_currentState->previsitor = &previsitor;
for (const std::string& extInst : previsitor.extInsts)
state.extensionInstructionSet[extInst] = AllocateResultId();
// Assign function ID (required for forward declaration)
for (auto&& [funcIndex, func] : state.funcs)
func.funcId = AllocateResultId();
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
for (const auto& importedModule : targetModule->importedModules)
importedModule.module->rootNode->Visit(visitor);
targetModule->rootNode->Visit(visitor);
AppendHeader();
for (auto&& [varIndex, extVar] : previsitor.extVars)
{
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
}
for (auto&& [varId, builtin] : previsitor.builtinDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
for (auto&& [varId, location] : previsitor.locationDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
std::vector<UInt32> ret;
MergeSections(ret, state.header);
MergeSections(ret, state.debugInfo);
MergeSections(ret, state.annotations);
MergeSections(ret, state.constants);
MergeSections(ret, state.instructions);
return ret;
}
void SpirvWriter::SetEnv(Environment environment)
{
m_environment = std::move(environment);
}
UInt32 SpirvWriter::AllocateResultId()
{
return m_currentState->nextVarIndex++;
}
void SpirvWriter::AppendHeader()
{
m_currentState->header.AppendRaw(SpirvMagicNumber); //< Spir-V magic number
UInt32 version = (m_environment.spvMajorVersion << 16) | m_environment.spvMinorVersion << 8;
m_currentState->header.AppendRaw(version); //< Spir-V version number (1.0 for compatibility)
m_currentState->header.AppendRaw(0); //< Generator identifier (TODO: Register generator to Khronos)
m_currentState->header.AppendRaw(m_currentState->nextVarIndex); //< Bound (ID count)
m_currentState->header.AppendRaw(0); //< Instruction schema (required to be 0 for now)
m_currentState->header.Append(SpirvOp::OpCapability, SpirvCapability::Shader);
for (const auto& [extInst, resultId] : m_currentState->extensionInstructionSet)
m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst);
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450);
for (auto&& [funcIndex, func] : m_currentState->funcs)
{
m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name);
if (func.entryPointData)
{
auto& entryPointData = func.entryPointData.value();
SpirvExecutionModel execModel;
switch (entryPointData.stageType)
{
case ShaderStageType::Fragment:
execModel = SpirvExecutionModel::Fragment;
break;
case ShaderStageType::Vertex:
execModel = SpirvExecutionModel::Vertex;
break;
default:
throw std::runtime_error("not yet implemented");
}
auto& funcData = func;
m_currentState->header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
{
appender(execModel);
appender(funcData.funcId);
appender(funcData.name);
for (const auto& input : entryPointData.inputs)
appender(input.varId);
for (const auto& output : entryPointData.outputs)
appender(output.varId);
});
}
}
// Write execution modes
for (auto&& [funcIndex, func] : m_currentState->funcs)
{
if (func.entryPointData)
{
for (SpirvExecutionMode executionMode : func.entryPointData->executionModes)
m_currentState->header.Append(SpirvOp::OpExecutionMode, func.funcId, executionMode);
}
}
}
SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{
std::vector<ShaderAst::ExpressionType> parameterTypes;
parameterTypes.reserve(functionNode.parameters.size());
for (const auto& parameter : functionNode.parameters)
parameterTypes.push_back(parameter.type.GetResultingValue());
if (functionNode.returnType.HasValue())
return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType.GetResultingValue(), parameterTypes);
else
return m_currentState->constantTypeCache.BuildFunctionType(ShaderAst::NoType{}, parameterTypes);
}
UInt32 SpirvWriter::GetConstantId(const ShaderAst::ConstantValue& value) const
{
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildConstant(value));
}
UInt32 SpirvWriter::GetExtendedInstructionSet(const std::string& instructionSetName) const
{
auto it = m_currentState->extensionInstructionSet.find(instructionSetName);
assert(it != m_currentState->extensionInstructionSet.end());
return it->second;
}
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
{
auto it = m_currentState->previsitor->extVars.find(extVarIndex);
assert(it != m_currentState->previsitor->extVars.end());
return it->second.pointerId;
}
UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
{
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildType(type));
}
UInt32 SpirvWriter::RegisterConstant(const ShaderAst::ConstantValue& value)
{
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildConstant(value));
}
UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
{
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildType(type));
}
void SpirvWriter::MergeSections(std::vector<UInt32>& output, const SpirvSection& from)
{
const std::vector<UInt32>& bytecode = from.GetBytecode();
std::size_t prevSize = output.size();
output.resize(prevSize + bytecode.size());
std::copy(bytecode.begin(), bytecode.end(), output.begin() + prevSize);
}
}