Shader: First working version on both Vulkan & OpenGL (ES)

This commit is contained in:
Jérôme Leclercq
2021-04-12 15:38:20 +02:00
parent f93a5bbdc1
commit ea99c6a19e
42 changed files with 1803 additions and 1053 deletions

View File

@@ -12,10 +12,12 @@
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/Ast/TransformVisitor.hpp>
#include <tsl/ordered_map.h>
#include <tsl/ordered_set.h>
#include <SpirV/GLSL.std.450.h>
#include <cassert>
#include <map>
#include <stdexcept>
#include <type_traits>
#include <vector>
@@ -25,34 +27,61 @@ namespace Nz
{
namespace
{
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
struct Builtin
{
const char* debugName;
ShaderStageTypeFlags compatibleStages;
SpirvBuiltIn decoration;
};
std::unordered_map<std::string, Builtin> s_builtinMapping = {
{ "position", { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
};
class PreVisitor : public ShaderAst::AstScopedVisitor
{
public:
struct UniformVar
{
std::optional<UInt32> bindingIndex;
UInt32 pointerId;
};
using BuiltinDecoration = std::map<UInt32, SpirvBuiltIn>;
using LocationDecoration = std::map<UInt32, UInt32>;
using ExtInstList = std::unordered_set<std::string>;
using ExtVarContainer = std::unordered_map<std::size_t /*varIndex*/, UniformVar>;
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
using StructContainer = std::vector<ShaderAst::StructDescription>;
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector<SpirvAstVisitor::FuncData>& funcs) :
m_conditions(conditions),
m_constantCache(constantCache)
m_constantCache(constantCache),
m_externalBlockIndex(0),
m_funcs(funcs)
{
m_constantCache.SetIdentifierCallback([&](const std::string& identifierName)
m_constantCache.SetStructCallback([this](std::size_t structIndex) -> const ShaderAst::StructDescription&
{
const Identifier* identifier = FindIdentifier(identifierName);
if (!identifier)
throw std::runtime_error("invalid identifier " + identifierName);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
return SpirvConstantCache::BuildType(std::get<ShaderAst::StructDescription>(identifier->value));
assert(structIndex < declaredStructs.size());
return declaredStructs[structIndex];
});
}
void Visit(ShaderAst::AccessMemberExpression& node) override
void Visit(ShaderAst::AccessMemberIndexExpression& node) override
{
/*for (std::size_t index : node.memberIdentifiers)
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/
AstRecursiveVisitor::Visit(node);
for (std::size_t index : node.memberIndices)
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(index)));
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::ConditionalExpression& node) override
@@ -64,6 +93,8 @@ namespace Nz
Visit(node.truePath);
else
Visit(node.falsePath);*/
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::ConditionalStatement& node) override
@@ -79,52 +110,189 @@ namespace Nz
{
std::visit([&](auto&& arg)
{
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
m_constantCache.Register(*m_constantCache.BuildConstant(arg));
}, node.value);
AstScopedVisitor::Visit(node);
}
void Visit(ShaderAst::DeclareExternalStatement& node) override
{
assert(node.varIndex);
std::size_t varIndex = *node.varIndex;
for (auto& extVar : node.externalVars)
{
SpirvConstantCache::Variable variable;
variable.debugName = extVar.name;
variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
UniformVar& uniformVar = extVars[varIndex++];
uniformVar.pointerId = m_constantCache.Register(variable);
for (const auto& [attributeType, attributeParam] : extVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
{
uniformVar.bindingIndex = std::get<long long>(attributeParam);
break;
}
}
}
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
funcs.emplace_back(node);
std::optional<ShaderStageType> entryPointType;
for (auto& attribute : node.attributes)
{
if (attribute.type == ShaderAst::AttributeType::Entry)
{
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
std::vector<ShaderAst::ExpressionType> parameterTypes;
for (auto& parameter : node.parameters)
parameterTypes.push_back(parameter.type);
entryPointType = it->second;
break;
}
}
m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes));
assert(node.funcIndex);
std::size_t funcIndex = *node.funcIndex;
if (funcIndex >= m_funcs.size())
m_funcs.resize(funcIndex + 1);
auto& funcData = m_funcs[funcIndex];
funcData.name = node.name;
if (!entryPointType)
{
std::vector<ShaderAst::ExpressionType> parameterTypes;
for (auto& parameter : node.parameters)
parameterTypes.push_back(parameter.type);
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(node.returnType, parameterTypes));
for (auto& parameter : node.parameters)
{
auto& funcParam = funcData.parameters.emplace_back();
funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function));
funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameter.type));
}
}
else
{
using EntryPoint = SpirvAstVisitor::EntryPoint;
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{}));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, {}));
std::optional<EntryPoint::InputStruct> inputStruct;
std::vector<EntryPoint::Input> inputs;
if (!node.parameters.empty())
{
assert(node.parameters.size() == 1);
auto& parameter = node.parameters.front();
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
for (const auto& member : structDesc.members)
{
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0)
{
inputs.push_back({
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(memberIndex))),
m_constantCache.Register(*m_constantCache.BuildPointerType(member.type, SpirvStorageClass::Function)),
varId
});
}
memberIndex++;
}
inputStruct = EntryPoint::InputStruct{
m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)),
m_constantCache.Register(*m_constantCache.BuildType(parameter.type))
};
}
std::optional<UInt32> outputStructId;
std::vector<EntryPoint::Output> outputs;
if (!IsNoType(node.returnType))
{
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
for (const auto& member : structDesc.members)
{
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0)
{
outputs.push_back({
Int32(memberIndex),
m_constantCache.Register(*m_constantCache.BuildType(member.type)),
varId
});
}
memberIndex++;
}
outputStructId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
}
funcData.entryPointData = EntryPoint{
*entryPointType,
inputStruct,
outputStructId,
funcIndex,
std::move(inputs),
std::move(outputs)
};
}
m_funcIndex = funcIndex;
AstScopedVisitor::Visit(node);
m_funcIndex.reset();
}
void Visit(ShaderAst::DeclareStructStatement& node) override
{
AstScopedVisitor::Visit(node);
SpirvConstantCache::Structure sType;
sType.name = node.description.name;
assert(node.structIndex);
std::size_t structIndex = *node.structIndex;
if (structIndex >= declaredStructs.size())
declaredStructs.resize(structIndex + 1);
for (const auto& [name, attribute, type] : node.description.members)
{
auto& sMembers = sType.members.emplace_back();
sMembers.name = name;
sMembers.type = SpirvConstantCache::BuildType(type);
}
declaredStructs[structIndex] = node.description;
m_constantCache.Register(SpirvConstantCache::Type{ std::move(sType) });
m_constantCache.Register(*m_constantCache.BuildType(node.description));
}
void Visit(ShaderAst::DeclareVariableStatement& node) override
{
AstScopedVisitor::Visit(node);
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
assert(m_funcIndex);
auto& func = m_funcs[*m_funcIndex];
assert(node.varIndex);
func.varIndexToVarId[*node.varIndex] = func.variables.size();
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType, SpirvStorageClass::Function));
}
void Visit(ShaderAst::IdentifierExpression& node) override
{
m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value()));
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
AstScopedVisitor::Visit(node);
}
@@ -144,40 +312,88 @@ namespace Nz
case ShaderAst::IntrinsicType::DotProduct:
break;
}
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::SwizzleExpression& node) override
{
AstScopedVisitor::Visit(node);
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
{
std::optional<std::reference_wrapper<Builtin>> builtinOpt;
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
{
if (attributeType == ShaderAst::AttributeType::Builtin)
{
auto it = s_builtinMapping.find(std::get<std::string>(attributeParam));
if (it != s_builtinMapping.end())
{
builtinOpt = it->second;
break;
}
}
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
}
if (builtinOpt)
{
Builtin& builtin = *builtinOpt;
if ((builtin.compatibleStages & entryPointType) == 0)
return 0;
SpirvBuiltIn builtinDecoration = builtin.decoration;
SpirvConstantCache::Variable variable;
variable.debugName = builtin.debugName;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
UInt32 varId = m_constantCache.Register(variable);
builtinDecorations[varId] = builtinDecoration;
return varId;
}
else if (attributeLocation)
{
SpirvConstantCache::Variable variable;
variable.debugName = member.name;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
UInt32 varId = m_constantCache.Register(variable);
locationDecorations[varId] = *attributeLocation;
return varId;
}
return 0;
}
BuiltinDecoration builtinDecorations;
ExtInstList extInsts;
FunctionContainer funcs;
ExtVarContainer extVars;
LocationDecoration locationDecorations;
StructContainer declaredStructs;
private:
const SpirvWriter::States& m_conditions;
SpirvConstantCache& m_constantCache;
std::optional<std::size_t> m_funcIndex;
std::size_t m_externalBlockIndex;
std::vector<SpirvAstVisitor::FuncData>& m_funcs;
};
template<typename T>
constexpr ShaderAst::PrimitiveType GetBasicType()
{
if constexpr (std::is_same_v<T, bool>)
return ShaderAst::PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return(ShaderAst::PrimitiveType::Float32);
else if constexpr (std::is_same_v<T, Int32>)
return(ShaderAst::PrimitiveType::Int32);
else if constexpr (std::is_same_v<T, Vector2f>)
return(ShaderAst::PrimitiveType::Float2);
else if constexpr (std::is_same_v<T, Vector3f>)
return(ShaderAst::PrimitiveType::Float3);
else if constexpr (std::is_same_v<T, Vector4f>)
return(ShaderAst::PrimitiveType::Float4);
else if constexpr (std::is_same_v<T, Vector2i32>)
return(ShaderAst::PrimitiveType::Int2);
else if constexpr (std::is_same_v<T, Vector3i32>)
return(ShaderAst::PrimitiveType::Int3);
else if constexpr (std::is_same_v<T, Vector4i32>)
return(ShaderAst::PrimitiveType::Int4);
else
static_assert(AlwaysFalse<T>::value, "unhandled type");
}
}
struct SpirvWriter::State
@@ -194,18 +410,13 @@ namespace Nz
UInt32 id;
};
tsl::ordered_map<std::string, ExtVar> inputIds;
tsl::ordered_map<std::string, ExtVar> outputIds;
tsl::ordered_map<std::string, ExtVar> parameterIds;
tsl::ordered_map<std::string, ExtVar> uniformIds;
std::unordered_map<std::string, UInt32> extensionInstructions;
std::unordered_map<ShaderAst::BuiltinEntry, ExtVar> builtinIds;
std::unordered_map<std::string, UInt32> varToResult;
std::vector<Func> funcs;
std::vector<SpirvBlock> functionBlocks;
std::vector<SpirvAstVisitor::FuncData> funcs;
std::vector<UInt32> resultIds;
UInt32 nextVarIndex = 1;
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
PreVisitor* preVisitor;
// Output
SpirvSection header;
@@ -226,6 +437,9 @@ namespace Nz
if (!ShaderAst::ValidateAst(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
ShaderAst::TransformVisitor transformVisitor;
ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader);
m_context.states = &conditions;
State state;
@@ -235,245 +449,37 @@ namespace Nz
m_currentState = nullptr;
});
ShaderAst::AstCloner cloner;
// Register all extended instruction sets
PreVisitor preVisitor(conditions, state.constantTypeCache);
shader->Visit(preVisitor);
PreVisitor preVisitor(conditions, state.constantTypeCache, state.funcs);
transformedShader->Visit(preVisitor);
m_currentState->preVisitor = &preVisitor;
for (const std::string& extInst : preVisitor.extInsts)
state.extensionInstructions[extInst] = AllocateResultId();
// Register all types
/*for (const auto& func : shader.GetFunctions())
{
RegisterType(func.returnType);
for (const auto& param : func.parameters)
RegisterType(param.type);
}
for (const auto& input : shader.GetInputs())
RegisterPointerType(input.type, SpirvStorageClass::Input);
for (const auto& output : shader.GetOutputs())
RegisterPointerType(output.type, SpirvStorageClass::Output);
for (const auto& uniform : shader.GetUniforms())
RegisterPointerType(uniform.type, (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform);
for (const auto& func : shader.GetFunctions())
RegisterFunctionType(func.returnType, func.parameters);
for (const auto& type : preVisitor.variableTypes)
RegisterType(type);
for (const auto& builtin : preVisitor.builtinVars)
RegisterType(builtin->type);
// Register result id and debug infos for global variables/functions
for (const auto& builtin : preVisitor.builtinVars)
{
SpirvConstantCache::Variable variable;
SpirvBuiltIn builtinDecoration;
switch (builtin->entry)
{
case ShaderAst::BuiltinEntry::VertexPosition:
variable.debugName = "builtin_VertexPosition";
variable.storageClass = SpirvStorageClass::Output;
builtinDecoration = SpirvBuiltIn::Position;
break;
default:
throw std::runtime_error("unexpected builtin type");
}
const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type;
assert(IsBasicType(builtinExprType));
ShaderAst::BasicType builtinType = std::get<ShaderAst::BasicType>(builtinExprType);
variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar builtinData;
builtinData.pointerTypeId = GetPointerTypeId(builtinType, variable.storageClass);
builtinData.typeId = GetTypeId(builtinType);
builtinData.varId = varId;
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpirvDecoration::BuiltIn, builtinDecoration);
state.builtinIds.emplace(builtin->entry, builtinData);
}
for (const auto& input : shader.GetInputs())
{
SpirvConstantCache::Variable variable;
variable.debugName = input.name;
variable.storageClass = SpirvStorageClass::Input;
variable.type = SpirvConstantCache::BuildPointerType(shader, input.type, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar inputData;
inputData.pointerTypeId = GetPointerTypeId(input.type, variable.storageClass);
inputData.typeId = GetTypeId(input.type);
inputData.varId = varId;
state.inputIds.emplace(input.name, std::move(inputData));
if (input.locationIndex)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *input.locationIndex);
}
for (const auto& output : shader.GetOutputs())
{
SpirvConstantCache::Variable variable;
variable.debugName = output.name;
variable.storageClass = SpirvStorageClass::Output;
variable.type = SpirvConstantCache::BuildPointerType(shader, output.type, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar outputData;
outputData.pointerTypeId = GetPointerTypeId(output.type, variable.storageClass);
outputData.typeId = GetTypeId(output.type);
outputData.varId = varId;
state.outputIds.emplace(output.name, std::move(outputData));
if (output.locationIndex)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *output.locationIndex);
}
for (const auto& uniform : shader.GetUniforms())
{
SpirvConstantCache::Variable variable;
variable.debugName = uniform.name;
variable.storageClass = (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
variable.type = SpirvConstantCache::BuildPointerType(shader, uniform.type, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar uniformData;
uniformData.pointerTypeId = GetPointerTypeId(uniform.type, variable.storageClass);
uniformData.typeId = GetTypeId(uniform.type);
uniformData.varId = varId;
state.uniformIds.emplace(uniform.name, std::move(uniformData));
if (uniform.bindingIndex)
{
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0);
}
}*/
for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
{
auto& funcData = state.funcs.emplace_back();
funcData.statement = &func;
funcData.id = AllocateResultId();
funcData.typeId = GetFunctionTypeId(func);
state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name);
}
std::size_t funcIndex = 0;
for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
{
auto& funcData = state.funcs[funcIndex++];
state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId);
state.functionBlocks.clear();
state.functionBlocks.emplace_back(*this);
state.parameterIds.clear();
for (const auto& param : func.parameters)
{
UInt32 paramResultId = AllocateResultId();
state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
ExtVar parameterData;
parameterData.pointerTypeId = GetPointerTypeId(param.type, SpirvStorageClass::Function);
parameterData.typeId = GetTypeId(param.type);
parameterData.varId = paramResultId;
state.parameterIds.emplace(param.name, std::move(parameterData));
}
SpirvAstVisitor visitor(*this, state.functionBlocks);
for (const auto& statement : func.statements)
statement->Visit(visitor);
if (!state.functionBlocks.back().IsTerminated())
{
assert(func.returnType == ShaderAst::ExpressionType{ ShaderAst::NoType{} });
state.functionBlocks.back().Append(SpirvOp::OpReturn);
}
for (SpirvBlock& block : state.functionBlocks)
state.instructions.AppendSection(block);
state.instructions.Append(SpirvOp::OpFunctionEnd);
}
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
transformedShader->Visit(visitor);
AppendHeader();
for (std::size_t i = 0; i < ShaderStageTypeCount; ++i)
for (auto&& [varIndex, extVar] : preVisitor.extVars)
{
/*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
if (!statement)
continue;
auto it = std::find_if(state.funcs.begin(), state.funcs.end(), [&](const auto& funcData) { return funcData.statement == statement; });
assert(it != state.funcs.end());
const auto& entryFunc = *it;
SpirvExecutionModel execModel;
ShaderStageType stage = static_cast<ShaderStageType>(i);
switch (stage)
if (extVar.bindingIndex)
{
case ShaderStageType::Fragment:
execModel = SpirvExecutionModel::Fragment;
break;
case ShaderStageType::Vertex:
execModel = SpirvExecutionModel::Vertex;
break;
default:
throw std::runtime_error("not yet implemented");
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, *extVar.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, 0);
}
state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
{
appender(execModel);
appender(entryFunc.id);
appender(statement->name);
for (const auto& [name, varData] : state.builtinIds)
appender(varData.varId);
for (const auto& [name, varData] : state.inputIds)
appender(varData.varId);
for (const auto& [name, varData] : state.outputIds)
appender(varData.varId);
});
if (stage == ShaderStageType::Fragment)
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/
}
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
for (auto&& [varId, location] : preVisitor.locationDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
std::vector<UInt32> ret;
MergeSections(ret, state.header);
MergeSections(ret, state.debugInfo);
@@ -511,171 +517,53 @@ namespace Nz
m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst);
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450);
}
UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
}
UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
}
auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar&
{
auto it = m_currentState->builtinIds.find(builtin);
assert(it != m_currentState->builtinIds.end());
return it->second;
}
auto SpirvWriter::GetInputVariable(const std::string& name) const -> const ExtVar&
{
auto it = m_currentState->inputIds.find(name);
assert(it != m_currentState->inputIds.end());
return it->second;
}
auto SpirvWriter::GetOutputVariable(const std::string& name) const -> const ExtVar&
{
auto it = m_currentState->outputIds.find(name);
assert(it != m_currentState->outputIds.end());
return it->second;
}
auto SpirvWriter::GetUniformVariable(const std::string& name) const -> const ExtVar&
{
auto it = m_currentState->uniformIds.find(name);
assert(it != m_currentState->uniformIds.end());
return it.value();
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
}
UInt32 SpirvWriter::ReadInputVariable(const std::string& name)
{
auto it = m_currentState->inputIds.find(name);
assert(it != m_currentState->inputIds.end());
return ReadVariable(it.value());
}
std::optional<UInt32> SpirvWriter::ReadInputVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->inputIds.find(name);
assert(it != m_currentState->inputIds.end());
return ReadVariable(it.value(), OnlyCache{});
}
UInt32 SpirvWriter::ReadLocalVariable(const std::string& name)
{
auto it = m_currentState->varToResult.find(name);
assert(it != m_currentState->varToResult.end());
return it->second;
}
std::optional<UInt32> SpirvWriter::ReadLocalVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->varToResult.find(name);
if (it == m_currentState->varToResult.end())
return {};
return it->second;
}
UInt32 SpirvWriter::ReadParameterVariable(const std::string& name)
{
auto it = m_currentState->parameterIds.find(name);
assert(it != m_currentState->parameterIds.end());
return ReadVariable(it.value());
}
std::optional<UInt32> SpirvWriter::ReadParameterVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->parameterIds.find(name);
assert(it != m_currentState->parameterIds.end());
return ReadVariable(it.value(), OnlyCache{});
}
UInt32 SpirvWriter::ReadUniformVariable(const std::string& name)
{
auto it = m_currentState->uniformIds.find(name);
assert(it != m_currentState->uniformIds.end());
return ReadVariable(it.value());
}
std::optional<UInt32> SpirvWriter::ReadUniformVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->uniformIds.find(name);
assert(it != m_currentState->uniformIds.end());
return ReadVariable(it.value(), OnlyCache{});
}
UInt32 SpirvWriter::ReadVariable(ExtVar& var)
{
if (!var.valueId.has_value())
std::optional<UInt32> fragmentFuncId;
for (auto& func : m_currentState->funcs)
{
UInt32 resultId = AllocateResultId();
m_currentState->functionBlocks.back().Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId);
m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name);
var.valueId = resultId;
if (func.entryPointData)
{
auto& entryPointData = func.entryPointData.value();
SpirvExecutionModel execModel;
switch (entryPointData.stageType)
{
case ShaderStageType::Fragment:
execModel = SpirvExecutionModel::Fragment;
break;
case ShaderStageType::Vertex:
execModel = SpirvExecutionModel::Vertex;
break;
default:
throw std::runtime_error("not yet implemented");
}
m_currentState->header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
{
appender(execModel);
appender(func.funcId);
appender(func.name);
for (const auto& input : entryPointData.inputs)
appender(input.varId);
for (const auto& output : entryPointData.outputs)
appender(output.varId);
});
if (entryPointData.stageType == ShaderStageType::Fragment)
fragmentFuncId = func.funcId;
}
}
return var.valueId.value();
}
if (fragmentFuncId)
m_currentState->header.Append(SpirvOp::OpExecutionMode, *fragmentFuncId, SpirvExecutionMode::OriginUpperLeft);
std::optional<UInt32> SpirvWriter::ReadVariable(const ExtVar& var, OnlyCache)
{
if (!var.valueId.has_value())
return {};
return var.valueId.value();
}
UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
}
UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
{
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
}
void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)
{
assert(m_currentState);
m_currentState->varToResult.insert_or_assign(std::move(name), resultId);
}
SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
@@ -686,7 +574,56 @@ namespace Nz
for (const auto& parameter : functionNode.parameters)
parameterTypes.push_back(parameter.type);
return SpirvConstantCache::BuildFunctionType(functionNode.returnType, parameterTypes);
return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType, parameterTypes);
}
UInt32 SpirvWriter::GetConstantId(const ShaderAst::ConstantValue& value) const
{
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildConstant(value));
}
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
{
auto it = m_currentState->preVisitor->extVars.find(extVarIndex);
assert(it != m_currentState->preVisitor->extVars.end());
return it->second.pointerId;
}
UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
{
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildType(type));
}
UInt32 SpirvWriter::RegisterConstant(const ShaderAst::ConstantValue& value)
{
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildConstant(value));
}
UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
{
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildType(type));
}
void SpirvWriter::MergeSections(std::vector<UInt32>& output, const SpirvSection& from)