Shader: Add function parameters and return handling
This commit is contained in:
@@ -191,16 +191,17 @@ namespace Nz
|
||||
{
|
||||
UInt32 typeId;
|
||||
UInt32 id;
|
||||
std::vector<UInt32> paramsId;
|
||||
};
|
||||
|
||||
tsl::ordered_map<std::string, ExtVar> inputIds;
|
||||
tsl::ordered_map<std::string, ExtVar> outputIds;
|
||||
tsl::ordered_map<std::string, ExtVar> parameterIds;
|
||||
tsl::ordered_map<std::string, ExtVar> uniformIds;
|
||||
std::unordered_map<std::string, UInt32> extensionInstructions;
|
||||
std::unordered_map<ShaderNodes::BuiltinEntry, ExtVar> builtinIds;
|
||||
std::unordered_map<std::string, UInt32> varToResult;
|
||||
std::vector<Func> funcs;
|
||||
std::vector<SpirvBlock> functionBlocks;
|
||||
std::vector<UInt32> resultIds;
|
||||
UInt32 nextVarIndex = 1;
|
||||
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
|
||||
@@ -307,7 +308,7 @@ namespace Nz
|
||||
builtinData.typeId = GetTypeId(builtinType);
|
||||
builtinData.varId = varId;
|
||||
|
||||
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpvDecorationBuiltIn, builtinDecoration);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpirvDecoration::BuiltIn, builtinDecoration);
|
||||
|
||||
state.builtinIds.emplace(builtin->entry, builtinData);
|
||||
}
|
||||
@@ -329,7 +330,7 @@ namespace Nz
|
||||
state.inputIds.emplace(input.name, std::move(inputData));
|
||||
|
||||
if (input.locationIndex)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *input.locationIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *input.locationIndex);
|
||||
}
|
||||
|
||||
for (const auto& output : shader.GetOutputs())
|
||||
@@ -349,7 +350,7 @@ namespace Nz
|
||||
state.outputIds.emplace(output.name, std::move(outputData));
|
||||
|
||||
if (output.locationIndex)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *output.locationIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *output.locationIndex);
|
||||
}
|
||||
|
||||
for (const auto& uniform : shader.GetUniforms())
|
||||
@@ -370,8 +371,8 @@ namespace Nz
|
||||
|
||||
if (uniform.bindingIndex)
|
||||
{
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationBinding, *uniform.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationDescriptorSet, 0);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,77 +397,86 @@ namespace Nz
|
||||
|
||||
state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId);
|
||||
|
||||
std::vector<SpirvBlock> blocks;
|
||||
blocks.emplace_back(*this);
|
||||
state.functionBlocks.clear();
|
||||
state.functionBlocks.emplace_back(*this);
|
||||
|
||||
state.parameterIds.clear();
|
||||
|
||||
for (const auto& param : func.parameters)
|
||||
{
|
||||
UInt32 paramResultId = AllocateResultId();
|
||||
funcData.paramsId.push_back(paramResultId);
|
||||
state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
|
||||
|
||||
blocks.back().Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
|
||||
ExtVar parameterData;
|
||||
parameterData.pointerTypeId = GetPointerTypeId(param.type, SpirvStorageClass::Function);
|
||||
parameterData.typeId = GetTypeId(param.type);
|
||||
parameterData.varId = paramResultId;
|
||||
|
||||
state.parameterIds.emplace(param.name, std::move(parameterData));
|
||||
}
|
||||
|
||||
SpirvAstVisitor visitor(*this, blocks);
|
||||
SpirvAstVisitor visitor(*this, state.functionBlocks);
|
||||
visitor.Visit(functionStatements[funcIndex]);
|
||||
|
||||
if (func.returnType == ShaderNodes::BasicType::Void)
|
||||
blocks.back().Append(SpirvOp::OpReturn);
|
||||
else
|
||||
throw std::runtime_error("returning values from functions is not yet supported"); //< TODO
|
||||
if (!state.functionBlocks.back().IsTerminated())
|
||||
{
|
||||
assert(func.returnType == ShaderExpressionType(ShaderNodes::BasicType::Void));
|
||||
state.functionBlocks.back().Append(SpirvOp::OpReturn);
|
||||
}
|
||||
|
||||
blocks.back().Append(SpirvOp::OpFunctionEnd);
|
||||
for (SpirvBlock& block : state.functionBlocks)
|
||||
state.instructions.AppendSection(block);
|
||||
|
||||
for (SpirvBlock& block : blocks)
|
||||
state.instructions.Append(block);
|
||||
state.instructions.Append(SpirvOp::OpFunctionEnd);
|
||||
}
|
||||
|
||||
assert(entryPointIndex != std::numeric_limits<std::size_t>::max());
|
||||
|
||||
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
|
||||
|
||||
AppendHeader();
|
||||
|
||||
SpvExecutionModel execModel;
|
||||
const auto& entryFuncData = shader.GetFunction(entryPointIndex);
|
||||
const auto& entryFunc = state.funcs[entryPointIndex];
|
||||
|
||||
assert(m_context.shader);
|
||||
switch (m_context.shader->GetStage())
|
||||
if (entryPointIndex != std::numeric_limits<std::size_t>::max())
|
||||
{
|
||||
case ShaderStageType::Fragment:
|
||||
execModel = SpvExecutionModelFragment;
|
||||
break;
|
||||
SpvExecutionModel execModel;
|
||||
const auto& entryFuncData = shader.GetFunction(entryPointIndex);
|
||||
const auto& entryFunc = state.funcs[entryPointIndex];
|
||||
|
||||
case ShaderStageType::Vertex:
|
||||
execModel = SpvExecutionModelVertex;
|
||||
break;
|
||||
assert(m_context.shader);
|
||||
switch (m_context.shader->GetStage())
|
||||
{
|
||||
case ShaderStageType::Fragment:
|
||||
execModel = SpvExecutionModelFragment;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("not yet implemented");
|
||||
case ShaderStageType::Vertex:
|
||||
execModel = SpvExecutionModelVertex;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("not yet implemented");
|
||||
}
|
||||
|
||||
// OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos
|
||||
|
||||
state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
|
||||
{
|
||||
appender(execModel);
|
||||
appender(entryFunc.id);
|
||||
appender(entryFuncData.name);
|
||||
|
||||
for (const auto& [name, varData] : state.builtinIds)
|
||||
appender(varData.varId);
|
||||
|
||||
for (const auto& [name, varData] : state.inputIds)
|
||||
appender(varData.varId);
|
||||
|
||||
for (const auto& [name, varData] : state.outputIds)
|
||||
appender(varData.varId);
|
||||
});
|
||||
|
||||
if (m_context.shader->GetStage() == ShaderStageType::Fragment)
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft);
|
||||
}
|
||||
|
||||
// OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos
|
||||
|
||||
state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
|
||||
{
|
||||
appender(execModel);
|
||||
appender(entryFunc.id);
|
||||
appender(entryFuncData.name);
|
||||
|
||||
for (const auto& [name, varData] : state.builtinIds)
|
||||
appender(varData.varId);
|
||||
|
||||
for (const auto& [name, varData] : state.inputIds)
|
||||
appender(varData.varId);
|
||||
|
||||
for (const auto& [name, varData] : state.outputIds)
|
||||
appender(varData.varId);
|
||||
});
|
||||
|
||||
if (m_context.shader->GetStage() == ShaderStageType::Fragment)
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft);
|
||||
|
||||
std::vector<UInt32> ret;
|
||||
MergeSections(ret, state.header);
|
||||
MergeSections(ret, state.debugInfo);
|
||||
@@ -489,14 +499,14 @@ namespace Nz
|
||||
|
||||
void SpirvWriter::AppendHeader()
|
||||
{
|
||||
m_currentState->header.Append(SpvMagicNumber); //< Spir-V magic number
|
||||
m_currentState->header.AppendRaw(SpvMagicNumber); //< Spir-V magic number
|
||||
|
||||
UInt32 version = (m_environment.spvMajorVersion << 16) | m_environment.spvMinorVersion << 8;
|
||||
m_currentState->header.Append(version); //< Spir-V version number (1.0 for compatibility)
|
||||
m_currentState->header.Append(0); //< Generator identifier (TODO: Register generator to Khronos)
|
||||
m_currentState->header.AppendRaw(version); //< Spir-V version number (1.0 for compatibility)
|
||||
m_currentState->header.AppendRaw(0); //< Generator identifier (TODO: Register generator to Khronos)
|
||||
|
||||
m_currentState->header.Append(m_currentState->nextVarIndex); //< Bound (ID count)
|
||||
m_currentState->header.Append(0); //< Instruction schema (required to be 0 for now)
|
||||
m_currentState->header.AppendRaw(m_currentState->nextVarIndex); //< Bound (ID count)
|
||||
m_currentState->header.AppendRaw(0); //< Instruction schema (required to be 0 for now)
|
||||
|
||||
m_currentState->header.Append(SpirvOp::OpCapability, SpvCapabilityShader);
|
||||
|
||||
@@ -506,6 +516,20 @@ namespace Nz
|
||||
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450);
|
||||
}
|
||||
|
||||
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
|
||||
{
|
||||
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
|
||||
parameterTypes.reserve(parameters.size());
|
||||
|
||||
for (const auto& parameter : parameters)
|
||||
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(*m_context.shader, parameter.type, SpirvStorageClass::Function));
|
||||
|
||||
return SpirvConstantCache::Function{
|
||||
SpirvConstantCache::BuildType(*m_context.shader, retType),
|
||||
std::move(parameterTypes)
|
||||
};
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
|
||||
@@ -513,18 +537,7 @@ namespace Nz
|
||||
|
||||
UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
|
||||
{
|
||||
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
|
||||
parameterTypes.reserve(parameters.size());
|
||||
|
||||
for (const auto& parameter : parameters)
|
||||
parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type));
|
||||
|
||||
return m_currentState->constantTypeCache.GetId({
|
||||
SpirvConstantCache::Function {
|
||||
SpirvConstantCache::BuildType(*m_context.shader, retType),
|
||||
std::move(parameterTypes)
|
||||
}
|
||||
});
|
||||
return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) });
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar&
|
||||
@@ -602,6 +615,22 @@ namespace Nz
|
||||
return it->second;
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadParameterVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->parameterIds.find(name);
|
||||
assert(it != m_currentState->parameterIds.end());
|
||||
|
||||
return ReadVariable(it.value());
|
||||
}
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadParameterVariable(const std::string& name, OnlyCache)
|
||||
{
|
||||
auto it = m_currentState->parameterIds.find(name);
|
||||
assert(it != m_currentState->parameterIds.end());
|
||||
|
||||
return ReadVariable(it.value(), OnlyCache{});
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadUniformVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->uniformIds.find(name);
|
||||
@@ -623,7 +652,7 @@ namespace Nz
|
||||
if (!var.valueId.has_value())
|
||||
{
|
||||
UInt32 resultId = AllocateResultId();
|
||||
m_currentState->instructions.Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId);
|
||||
m_currentState->functionBlocks.back().Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId);
|
||||
|
||||
var.valueId = resultId;
|
||||
}
|
||||
@@ -646,18 +675,7 @@ namespace Nz
|
||||
|
||||
UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
|
||||
{
|
||||
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
|
||||
parameterTypes.reserve(parameters.size());
|
||||
|
||||
for (const auto& parameter : parameters)
|
||||
parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type));
|
||||
|
||||
return m_currentState->constantTypeCache.Register({
|
||||
SpirvConstantCache::Function {
|
||||
SpirvConstantCache::BuildType(*m_context.shader, retType),
|
||||
std::move(parameterTypes)
|
||||
}
|
||||
});
|
||||
return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) });
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass)
|
||||
|
||||
Reference in New Issue
Block a user