Add SpirvConstantCache

And unsigned int types for shaders
This commit is contained in:
Jérôme Leclercq
2020-08-20 01:05:16 +02:00
parent 0b507708f4
commit 9df219e402
14 changed files with 1341 additions and 421 deletions

View File

@@ -4,10 +4,10 @@
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/Endianness.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <tsl/ordered_map.h>
@@ -24,23 +24,25 @@ namespace Nz
{
namespace
{
using ConstantVariant = ShaderNodes::Constant::Variant;
class PreVisitor : public ShaderAstRecursiveVisitor, public ShaderVarVisitor
{
public:
using BuiltinContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::BuiltinVariable>>;
using ConstantContainer = tsl::ordered_set<ConstantVariant>;
using ExtInstList = std::unordered_set<std::string>;
using LocalContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::LocalVariable>>;
using ParameterContainer = std::unordered_set< std::shared_ptr<const ShaderNodes::ParameterVariable>>;
PreVisitor(SpirvConstantCache& constantCache) :
m_constantCache(constantCache)
{
}
using ShaderAstRecursiveVisitor::Visit;
using ShaderVarVisitor::Visit;
void Visit(ShaderNodes::AccessMember& node) override
{
constants.emplace(Int32(node.memberIndex));
m_constantCache.Register(*SpirvConstantCache::BuildConstant(UInt32(node.memberIndex)));
ShaderAstRecursiveVisitor::Visit(node);
}
@@ -49,35 +51,8 @@ namespace Nz
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, float> || std::is_same_v<T, Int32>)
constants.emplace(arg);
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i32>)
{
constants.emplace(arg.x);
constants.emplace(arg.y);
constants.emplace(arg);
}
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i32>)
{
constants.emplace(arg.x);
constants.emplace(arg.y);
constants.emplace(arg.z);
constants.emplace(arg);
}
else if constexpr (std::is_same_v<T, Vector4f> || std::is_same_v<T, Vector4i32>)
{
constants.emplace(arg.x);
constants.emplace(arg.y);
constants.emplace(arg.z);
constants.emplace(arg.w);
constants.emplace(arg);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
},
node.value);
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
}, node.value);
ShaderAstRecursiveVisitor::Visit(node);
}
@@ -118,7 +93,7 @@ namespace Nz
builtinVars.insert(std::static_pointer_cast<const ShaderNodes::BuiltinVariable>(var.shared_from_this()));
}
void Visit(ShaderNodes::InputVariable& var) override
void Visit(ShaderNodes::InputVariable& /*var*/) override
{
/* Handled by ShaderAst */
}
@@ -128,7 +103,7 @@ namespace Nz
localVars.insert(std::static_pointer_cast<const ShaderNodes::LocalVariable>(var.shared_from_this()));
}
void Visit(ShaderNodes::OutputVariable& var) override
void Visit(ShaderNodes::OutputVariable& /*var*/) override
{
/* Handled by ShaderAst */
}
@@ -138,32 +113,18 @@ namespace Nz
paramVars.insert(std::static_pointer_cast<const ShaderNodes::ParameterVariable>(var.shared_from_this()));
}
void Visit(ShaderNodes::UniformVariable& var) override
void Visit(ShaderNodes::UniformVariable& /*var*/) override
{
/* Handled by ShaderAst */
}
BuiltinContainer builtinVars;
ConstantContainer constants;
ExtInstList extInsts;
LocalContainer localVars;
ParameterContainer paramVars;
};
class AssignVisitor : public ShaderAstRecursiveVisitor
{
public:
void Visit(ShaderNodes::AccessMember& node) override
{
}
void Visit(ShaderNodes::Identifier& node) override
{
}
void Visit(ShaderNodes::SwizzleOp& node) override
{
}
private:
SpirvConstantCache& m_constantCache;
};
template<typename T>
@@ -202,6 +163,11 @@ namespace Nz
struct SpirvWriter::State
{
State() :
constantTypeCache(nextVarIndex)
{
}
struct Func
{
UInt32 typeId;
@@ -209,18 +175,16 @@ namespace Nz
std::vector<UInt32> paramsId;
};
std::unordered_map<std::string, UInt32> extensionInstructions;
std::unordered_map<ShaderNodes::BuiltinEntry, ExtVar> builtinIds;
std::unordered_map<std::string, UInt32> varToResult;
tsl::ordered_map<ConstantVariant, UInt32> constantIds;
tsl::ordered_map<ShaderExpressionType, UInt32> typeIds;
std::vector<Func> funcs;
tsl::ordered_map<std::string, ExtVar> inputIds;
tsl::ordered_map<std::string, ExtVar> outputIds;
tsl::ordered_map<std::string, ExtVar> uniformIds;
std::vector<std::optional<FieldOffsets>> structFields;
std::unordered_map<std::string, UInt32> extensionInstructions;
std::unordered_map<ShaderNodes::BuiltinEntry, ExtVar> builtinIds;
std::unordered_map<std::string, UInt32> varToResult;
std::vector<Func> funcs;
std::vector<UInt32> resultIds;
UInt32 nextVarIndex = 1;
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
// Output
SpirvSection header;
@@ -251,13 +215,11 @@ namespace Nz
m_currentState = nullptr;
});
state.structFields.resize(shader.GetStructCount());
std::vector<ShaderNodes::StatementPtr> functionStatements;
ShaderAstCloner cloner;
PreVisitor preVisitor;
PreVisitor preVisitor(state.constantTypeCache);
for (const auto& func : shader.GetFunctions())
{
functionStatements.emplace_back(cloner.Clone(func.statement));
@@ -277,13 +239,16 @@ namespace Nz
}
for (const auto& input : shader.GetInputs())
RegisterType(input.type);
RegisterPointerType(input.type, SpirvStorageClass::Input);
for (const auto& output : shader.GetOutputs())
RegisterType(output.type);
RegisterPointerType(output.type, SpirvStorageClass::Output);
for (const auto& uniform : shader.GetUniforms())
RegisterType(uniform.type);
RegisterPointerType(uniform.type, SpirvStorageClass::Uniform);
for (const auto& func : shader.GetFunctions())
RegisterFunctionType(func.returnType, func.parameters);
for (const auto& local : preVisitor.localVars)
RegisterType(local->type);
@@ -291,104 +256,103 @@ namespace Nz
for (const auto& builtin : preVisitor.builtinVars)
RegisterType(builtin->type);
// Register constant types
for (const auto& constant : preVisitor.constants)
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
RegisterType(GetBasicType<T>());
}, constant);
}
AppendTypes();
// Register result id and debug infos for global variables/functions
for (const auto& builtin : preVisitor.builtinVars)
{
const ShaderExpressionType& builtinExprType = builtin->type;
assert(std::holds_alternative<ShaderNodes::BasicType>(builtinExprType));
ShaderNodes::BasicType builtinType = std::get<ShaderNodes::BasicType>(builtinExprType);
ExtVar builtinData;
builtinData.pointerTypeId = AllocateResultId();
builtinData.typeId = GetTypeId(builtinType);
builtinData.varId = AllocateResultId();
SpvBuiltIn spvBuiltin;
std::string debugName;
SpirvConstantCache::Variable variable;
SpirvBuiltIn builtinDecoration;
switch (builtin->entry)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
debugName = "builtin_VertexPosition";
spvBuiltin = SpvBuiltInPosition;
variable.debugName = "builtin_VertexPosition";
variable.storageClass = SpirvStorageClass::Output;
builtinDecoration = SpirvBuiltIn::Position;
break;
default:
throw std::runtime_error("unexpected builtin type");
}
state.debugInfo.Append(SpirvOp::OpName, builtinData.varId, debugName);
state.types.Append(SpirvOp::OpTypePointer, builtinData.pointerTypeId, SpvStorageClassOutput, builtinData.typeId);
state.types.Append(SpirvOp::OpVariable, builtinData.pointerTypeId, builtinData.varId, SpvStorageClassOutput);
const ShaderExpressionType& builtinExprType = builtin->type;
assert(std::holds_alternative<ShaderNodes::BasicType>(builtinExprType));
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpvDecorationBuiltIn, spvBuiltin);
ShaderNodes::BasicType builtinType = std::get<ShaderNodes::BasicType>(builtinExprType);
variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar builtinData;
builtinData.pointerTypeId = GetPointerTypeId(builtinType, variable.storageClass);
builtinData.typeId = GetTypeId(builtinType);
builtinData.varId = varId;
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpvDecorationBuiltIn, builtinDecoration);
state.builtinIds.emplace(builtin->entry, builtinData);
}
for (const auto& input : shader.GetInputs())
{
SpirvConstantCache::Variable variable;
variable.debugName = input.name;
variable.storageClass = SpirvStorageClass::Input;
variable.type = SpirvConstantCache::BuildPointerType(shader, input.type, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar inputData;
inputData.pointerTypeId = AllocateResultId();
inputData.pointerTypeId = GetPointerTypeId(input.type, variable.storageClass);
inputData.typeId = GetTypeId(input.type);
inputData.varId = AllocateResultId();
inputData.varId = varId;
state.inputIds.emplace(input.name, inputData);
state.debugInfo.Append(SpirvOp::OpName, inputData.varId, input.name);
state.types.Append(SpirvOp::OpTypePointer, inputData.pointerTypeId, SpvStorageClassInput, inputData.typeId);
state.types.Append(SpirvOp::OpVariable, inputData.pointerTypeId, inputData.varId, SpvStorageClassInput);
state.inputIds.emplace(input.name, std::move(inputData));
if (input.locationIndex)
state.annotations.Append(SpirvOp::OpDecorate, inputData.varId, SpvDecorationLocation, *input.locationIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *input.locationIndex);
}
for (const auto& output : shader.GetOutputs())
{
SpirvConstantCache::Variable variable;
variable.debugName = output.name;
variable.storageClass = SpirvStorageClass::Output;
variable.type = SpirvConstantCache::BuildPointerType(shader, output.type, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar outputData;
outputData.pointerTypeId = AllocateResultId();
outputData.pointerTypeId = GetPointerTypeId(output.type, variable.storageClass);
outputData.typeId = GetTypeId(output.type);
outputData.varId = AllocateResultId();
outputData.varId = varId;
state.outputIds.emplace(output.name, outputData);
state.debugInfo.Append(SpirvOp::OpName, outputData.varId, output.name);
state.types.Append(SpirvOp::OpTypePointer, outputData.pointerTypeId, SpvStorageClassOutput, outputData.typeId);
state.types.Append(SpirvOp::OpVariable, outputData.pointerTypeId, outputData.varId, SpvStorageClassOutput);
state.outputIds.emplace(output.name, std::move(outputData));
if (output.locationIndex)
state.annotations.Append(SpirvOp::OpDecorate, outputData.varId, SpvDecorationLocation, *output.locationIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *output.locationIndex);
}
for (const auto& uniform : shader.GetUniforms())
{
SpirvConstantCache::Variable variable;
variable.debugName = uniform.name;
variable.storageClass = SpirvStorageClass::Uniform;
variable.type = SpirvConstantCache::BuildPointerType(shader, uniform.type, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar uniformData;
uniformData.pointerTypeId = AllocateResultId();
uniformData.pointerTypeId = GetPointerTypeId(uniform.type, variable.storageClass);
uniformData.typeId = GetTypeId(uniform.type);
uniformData.varId = AllocateResultId();
uniformData.varId = varId;
state.uniformIds.emplace(uniform.name, uniformData);
state.debugInfo.Append(SpirvOp::OpName, uniformData.varId, uniform.name);
state.types.Append(SpirvOp::OpTypePointer, uniformData.pointerTypeId, SpvStorageClassUniform, uniformData.typeId);
state.types.Append(SpirvOp::OpVariable, uniformData.pointerTypeId, uniformData.varId, SpvStorageClassUniform);
state.uniformIds.emplace(uniform.name, std::move(uniformData));
if (uniform.bindingIndex)
{
state.annotations.Append(SpirvOp::OpDecorate, uniformData.varId, SpvDecorationBinding, *uniform.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, uniformData.varId, SpvDecorationDescriptorSet, 0);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationBinding, *uniform.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationDescriptorSet, 0);
}
}
@@ -396,26 +360,11 @@ namespace Nz
{
auto& funcData = state.funcs.emplace_back();
funcData.id = AllocateResultId();
funcData.typeId = AllocateResultId();
funcData.typeId = GetFunctionTypeId(func.returnType, func.parameters);
state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name);
state.types.AppendVariadic(SpirvOp::OpTypeFunction, [&](const auto& appender)
{
appender(funcData.typeId);
appender(GetTypeId(func.returnType));
for (const auto& param : func.parameters)
appender(GetTypeId(param.type));
});
}
// Register constants
for (const auto& constant : preVisitor.constants)
state.constantIds[constant] = AllocateResultId();
AppendConstants();
std::size_t entryPointIndex = std::numeric_limits<std::size_t>::max();
for (std::size_t funcIndex = 0; funcIndex < shader.GetFunctionCount(); ++funcIndex)
@@ -448,11 +397,13 @@ namespace Nz
assert(entryPointIndex != std::numeric_limits<std::size_t>::max());
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo, m_currentState->types);
AppendHeader();
SpvExecutionModel execModel;
const auto& entryFuncData = shader.GetFunction(entryPointIndex);
const auto& entryFunc = m_currentState->funcs[entryPointIndex];
const auto& entryFunc = state.funcs[entryPointIndex];
assert(m_context.shader);
switch (m_context.shader->GetStage())
@@ -471,21 +422,19 @@ namespace Nz
// OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos
std::size_t nameSize = state.header.CountWord(entryFuncData.name);
state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
{
appender(execModel);
appender(entryFunc.id);
appender(entryFuncData.name);
for (const auto& [name, varData] : m_currentState->builtinIds)
for (const auto& [name, varData] : state.builtinIds)
appender(varData.varId);
for (const auto& [name, varData] : m_currentState->inputIds)
for (const auto& [name, varData] : state.inputIds)
appender(varData.varId);
for (const auto& [name, varData] : m_currentState->outputIds)
for (const auto& [name, varData] : state.outputIds)
appender(varData.varId);
});
@@ -513,31 +462,6 @@ namespace Nz
return m_currentState->nextVarIndex++;
}
void SpirvWriter::AppendConstants()
{
for (const auto& [value, resultId] : m_currentState->constantIds)
{
UInt32 constantId = resultId;
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
m_currentState->constants.Append((arg) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, constantId);
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>)
m_currentState->constants.Append(SpirvOp::OpConstant, GetTypeId(GetBasicType<T>()), constantId, SpirvSection::Raw{ &arg, sizeof(arg) });
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i>)
m_currentState->constants.Append(SpirvOp::OpConstantComposite, GetTypeId(GetBasicType<T>()), constantId, GetConstantId(arg.x), GetConstantId(arg.y));
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i>)
m_currentState->constants.Append(SpirvOp::OpConstantComposite, GetTypeId(GetBasicType<T>()), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z));
else if constexpr (std::is_same_v<T, Vector4f> || std::is_same_v<T, Vector4i>)
m_currentState->constants.Append(SpirvOp::OpConstantComposite, GetTypeId(GetBasicType<T>()), constantId, GetConstantId(arg.x), GetConstantId(arg.y), GetConstantId(arg.z), GetConstantId(arg.w));
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, value);
}
}
void SpirvWriter::AppendHeader()
{
m_currentState->header.Append(SpvMagicNumber); //< Spir-V magic number
@@ -557,180 +481,41 @@ namespace Nz
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450);
}
void SpirvWriter::AppendStructType(std::size_t structIndex, UInt32 resultId)
{
const ShaderAst::Struct& s = m_context.shader->GetStruct(structIndex);
m_currentState->types.Append(SpirvOp::OpTypeStruct, SpirvSection::OpSize{ static_cast<unsigned int>(1 + 1 + s.members.size()) });
m_currentState->types.Append(resultId);
m_currentState->debugInfo.Append(SpirvOp::OpName, resultId, s.name);
m_currentState->annotations.Append(SpirvOp::OpDecorate, resultId, SpvDecorationBlock);
FieldOffsets structOffsets(StructLayout_Std140);
for (std::size_t memberIndex = 0; memberIndex < s.members.size(); ++memberIndex)
{
const auto& member = s.members[memberIndex];
m_currentState->types.Append(GetTypeId(member.type));
m_currentState->debugInfo.Append(SpirvOp::OpMemberName, resultId, memberIndex, member.name);
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
{
std::size_t offset = [&] {
switch (arg)
{
case ShaderNodes::BasicType::Boolean: return structOffsets.AddField(StructFieldType_Bool1);
case ShaderNodes::BasicType::Float1: return structOffsets.AddField(StructFieldType_Float1);
case ShaderNodes::BasicType::Float2: return structOffsets.AddField(StructFieldType_Float2);
case ShaderNodes::BasicType::Float3: return structOffsets.AddField(StructFieldType_Float3);
case ShaderNodes::BasicType::Float4: return structOffsets.AddField(StructFieldType_Float4);
case ShaderNodes::BasicType::Int1: return structOffsets.AddField(StructFieldType_Int1);
case ShaderNodes::BasicType::Int2: return structOffsets.AddField(StructFieldType_Int2);
case ShaderNodes::BasicType::Int3: return structOffsets.AddField(StructFieldType_Int3);
case ShaderNodes::BasicType::Int4: return structOffsets.AddField(StructFieldType_Int4);
case ShaderNodes::BasicType::Mat4x4: return structOffsets.AddMatrix(StructFieldType_Float1, 4, 4, true);
case ShaderNodes::BasicType::Sampler2D: throw std::runtime_error("unexpected sampler2D as struct member");
case ShaderNodes::BasicType::Void: throw std::runtime_error("unexpected void as struct member");
}
assert(false);
throw std::runtime_error("unhandled type");
}();
m_currentState->annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpvDecorationOffset, offset);
if (arg == ShaderNodes::BasicType::Mat4x4)
{
m_currentState->annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpvDecorationColMajor);
m_currentState->annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpvDecorationMatrixStride, 16);
}
}
else if constexpr (std::is_same_v<T, std::string>)
{
// Register struct members type
const auto& structs = m_context.shader->GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
if (it == structs.end())
throw std::runtime_error("struct " + arg + " has not been defined");
std::size_t nestedStructIndex = std::distance(structs.begin(), it);
std::optional<FieldOffsets> nestedFieldOffset = m_currentState->structFields[nestedStructIndex];
if (!nestedFieldOffset)
throw std::runtime_error("struct dependency cycle");
structOffsets.AddStruct(nestedFieldOffset.value());
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, member.type);
}
m_currentState->structFields[structIndex] = structOffsets;
}
void SpirvWriter::AppendTypes()
{
for (const auto& [type, typeId] : m_currentState->typeIds.values_container())
{
UInt32 resultId = typeId;
// Register sub-types, if any
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
{
switch (arg)
{
case ShaderNodes::BasicType::Boolean:
m_currentState->types.Append(SpirvOp::OpTypeBool, resultId);
break;
case ShaderNodes::BasicType::Float1:
m_currentState->types.Append(SpirvOp::OpTypeFloat, resultId, 32);
break;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
{
ShaderNodes::BasicType baseType = ShaderNodes::Node::GetComponentType(arg);
UInt32 vecSize = UInt32(arg) - UInt32(baseType) + 1;
m_currentState->types.Append(SpirvOp::OpTypeVector, resultId, GetTypeId(baseType), vecSize);
break;
}
case ShaderNodes::BasicType::Int1:
m_currentState->types.Append(SpirvOp::OpTypeInt, resultId, 32, 1);
break;
case ShaderNodes::BasicType::Mat4x4:
{
m_currentState->types.Append(SpirvOp::OpTypeMatrix, resultId, GetTypeId(ShaderNodes::BasicType::Float4), 4);
break;
}
case ShaderNodes::BasicType::Sampler2D:
{
UInt32 imageTypeId = resultId - 1;
m_currentState->types.Append(SpirvOp::OpTypeImage, imageTypeId, GetTypeId(ShaderNodes::BasicType::Float1), SpvDim2D, 0, 0, 0, 1, SpvImageFormatUnknown);
m_currentState->types.Append(SpirvOp::OpTypeSampledImage, resultId, imageTypeId);
break;
}
case ShaderNodes::BasicType::Void:
m_currentState->types.Append(SpirvOp::OpTypeVoid, resultId);
break;
}
}
else if constexpr (std::is_same_v<T, std::string>)
{
// Register struct members type
const auto& structs = m_context.shader->GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
if (it == structs.end())
throw std::runtime_error("struct " + arg + " has not been defined");
std::size_t structIndex = std::distance(structs.begin(), it);
AppendStructType(structIndex, resultId);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
}
}
UInt32 SpirvWriter::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr)
{
Visit(expr);
return PopResultId();
}
UInt32 SpirvWriter::GetConstantId(const ShaderNodes::Constant::Variant& value) const
UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const
{
auto typeIt = m_currentState->constantIds.find(value);
assert(typeIt != m_currentState->constantIds.end());
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
}
return typeIt->second;
UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
for (const auto& parameter : parameters)
parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type));
return m_currentState->constantTypeCache.GetId({
SpirvConstantCache::Function {
SpirvConstantCache::BuildType(*m_context.shader, retType),
std::move(parameterTypes)
}
});
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderExpressionType& type) const
{
auto typeIt = m_currentState->typeIds.find(type);
assert(typeIt != m_currentState->typeIds.end());
return typeIt->second;
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(*m_context.shader, type));
}
void SpirvWriter::PushResultId(UInt32 value)
@@ -762,68 +547,42 @@ namespace Nz
return var.valueId.value();
}
UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
}
UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
for (const auto& parameter : parameters)
parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type));
return m_currentState->constantTypeCache.Register({
SpirvConstantCache::Function {
SpirvConstantCache::BuildType(*m_context.shader, retType),
std::move(parameterTypes)
}
});
}
UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderExpressionType type)
{
auto it = m_currentState->typeIds.find(type);
if (it == m_currentState->typeIds.end())
{
// Register sub-types, if any
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
{
switch (arg)
{
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Void:
break; //< Nothing to do
// In SPIR-V, vec3 (for example) depends on float
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::Mat4x4:
RegisterType(ShaderNodes::Node::GetComponentType(arg));
break;
case ShaderNodes::BasicType::Sampler2D:
RegisterType(ShaderNodes::BasicType::Float1);
AllocateResultId(); //< Reserve a result id for the image type
break;
}
}
else if constexpr (std::is_same_v<T, std::string>)
{
// Register struct members type
const auto& structs = m_context.shader->GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
if (it == structs.end())
throw std::runtime_error("struct " + arg + " has not been defined");
const ShaderAst::Struct& s = *it;
for (const auto& member : s.members)
RegisterType(member.type);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
it = m_currentState->typeIds.emplace(std::move(type), AllocateResultId()).first;
}
return it->second;
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(*m_context.shader, type));
}
void SpirvWriter::Visit(ShaderNodes::AccessMember& node)
{
UInt32 pointerId;
SpvStorageClass storage;
SpirvStorageClass storage;
switch (node.structExpr->GetType())
{
@@ -848,7 +607,7 @@ namespace Nz
auto it = m_currentState->inputIds.find(inputVar.name);
assert(it != m_currentState->inputIds.end());
storage = SpvStorageClassInput;
storage = SpirvStorageClass::Input;
pointerId = it->second.varId;
break;
@@ -860,7 +619,7 @@ namespace Nz
auto it = m_currentState->outputIds.find(outputVar.name);
assert(it != m_currentState->outputIds.end());
storage = SpvStorageClassOutput;
storage = SpirvStorageClass::Output;
pointerId = it->second.varId;
break;
@@ -872,7 +631,7 @@ namespace Nz
auto it = m_currentState->uniformIds.find(uniformVar.name);
assert(it != m_currentState->uniformIds.end());
storage = SpvStorageClassUniform;
storage = SpirvStorageClass::Uniform;
pointerId = it->second.varId;
break;
@@ -892,11 +651,9 @@ namespace Nz
}
UInt32 memberPointerId = AllocateResultId();
UInt32 pointerType = AllocateResultId();
UInt32 pointerType = RegisterPointerType(node.exprType, storage); //< FIXME
UInt32 typeId = GetTypeId(node.exprType);
UInt32 indexId = GetConstantId(Int32(node.memberIndex));
m_currentState->types.Append(SpirvOp::OpTypePointer, pointerType, storage, typeId);
UInt32 indexId = GetConstantId(UInt32(node.memberIndex));
m_currentState->instructions.Append(SpirvOp::OpAccessChain, pointerType, memberPointerId, pointerId, indexId);
@@ -1002,6 +759,10 @@ namespace Nz
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIAdd;
case ShaderNodes::BasicType::Boolean:
@@ -1026,6 +787,10 @@ namespace Nz
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpISub;
case ShaderNodes::BasicType::Boolean:
@@ -1052,6 +817,12 @@ namespace Nz
case ShaderNodes::BasicType::Int4:
return SpirvOp::OpSDiv;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpUDiv;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
@@ -1077,6 +848,10 @@ namespace Nz
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIEqual;
case ShaderNodes::BasicType::Sampler2D:
@@ -1141,6 +916,10 @@ namespace Nz
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIMul;
case ShaderNodes::BasicType::Mat4x4: