Shader: Handle type as expressions

This commit is contained in:
Jérôme Leclercq
2022-02-08 17:03:34 +01:00
parent 5ce8120a0c
commit 402e16bd2b
53 changed files with 1746 additions and 1141 deletions

View File

@@ -144,17 +144,18 @@ namespace Nz
SpirvConstantCache::Variable variable;
variable.debugName = extVar.name;
if (ShaderAst::IsSamplerType(extVar.type))
const ShaderAst::ExpressionType& extVarType = extVar.type.GetResultingValue();
if (ShaderAst::IsSamplerType(extVarType))
{
variable.storageClass = SpirvStorageClass::UniformConstant;
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
variable.type = m_constantCache.BuildPointerType(extVarType, variable.storageClass);
}
else
{
assert(ShaderAst::IsUniformType(extVar.type));
const auto& uniformType = std::get<ShaderAst::UniformType>(extVar.type);
assert(std::holds_alternative<ShaderAst::StructType>(uniformType.containedType));
const auto& structType = std::get<ShaderAst::StructType>(uniformType.containedType);
assert(ShaderAst::IsUniformType(extVarType));
const auto& uniformType = std::get<ShaderAst::UniformType>(extVarType);
const auto& structType = uniformType.containedType;
assert(structType.structIndex < declaredStructs.size());
const auto& type = m_constantCache.BuildType(*declaredStructs[structType.structIndex], { SpirvDecoration::Block });
@@ -188,16 +189,27 @@ namespace Nz
{
std::vector<ShaderAst::ExpressionType> parameterTypes;
for (auto& parameter : node.parameters)
parameterTypes.push_back(parameter.type);
parameterTypes.push_back(parameter.type.GetResultingValue());
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(node.returnType, parameterTypes));
if (node.returnType.HasValue())
{
const auto& returnType = node.returnType.GetResultingValue();
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(returnType));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(returnType, parameterTypes));
}
else
{
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{}));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, parameterTypes));
}
for (auto& parameter : node.parameters)
{
const auto& parameterType = parameter.type.GetResultingValue();
auto& funcParam = funcData.parameters.emplace_back();
funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function));
funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameter.type));
funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameterType, SpirvStorageClass::Function));
funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameterType));
}
}
else
@@ -235,9 +247,11 @@ namespace Nz
{
assert(node.parameters.size() == 1);
auto& parameter = node.parameters.front();
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
const auto& parameterType = parameter.type.GetResultingValue();
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
assert(std::holds_alternative<ShaderAst::StructType>(parameterType));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameterType).structIndex;
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
@@ -250,7 +264,7 @@ namespace Nz
{
inputs.push_back({
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(memberIndex))),
m_constantCache.Register(*m_constantCache.BuildPointerType(member.type, SpirvStorageClass::Function)),
m_constantCache.Register(*m_constantCache.BuildPointerType(member.type.GetResultingValue(), SpirvStorageClass::Function)),
varId
});
}
@@ -259,18 +273,20 @@ namespace Nz
}
inputStruct = EntryPoint::InputStruct{
m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)),
m_constantCache.Register(*m_constantCache.BuildType(parameter.type))
m_constantCache.Register(*m_constantCache.BuildPointerType(parameterType, SpirvStorageClass::Function)),
m_constantCache.Register(*m_constantCache.BuildType(parameter.type.GetResultingValue()))
};
}
std::optional<UInt32> outputStructId;
std::vector<EntryPoint::Output> outputs;
if (!IsNoType(node.returnType))
if (node.returnType.HasValue())
{
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue();
std::size_t structIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
assert(std::holds_alternative<ShaderAst::StructType>(returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
@@ -283,7 +299,7 @@ namespace Nz
{
outputs.push_back({
Int32(memberIndex),
m_constantCache.Register(*m_constantCache.BuildType(member.type)),
m_constantCache.Register(*m_constantCache.BuildType(member.type.GetResultingValue())),
varId
});
}
@@ -291,7 +307,7 @@ namespace Nz
memberIndex++;
}
outputStructId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
outputStructId = m_constantCache.Register(*m_constantCache.BuildType(returnType));
}
funcData.entryPointData = EntryPoint{
@@ -334,7 +350,7 @@ namespace Nz
func.varIndexToVarId[*node.varIndex] = func.variables.size();
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType, SpirvStorageClass::Function));
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType.GetResultingValue(), SpirvStorageClass::Function));
}
void Visit(ShaderAst::IdentifierExpression& node) override
@@ -408,7 +424,7 @@ namespace Nz
variable.debugName = builtin.debugName;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
variable.type = m_constantCache.BuildPointerType(member.type.GetResultingValue(), storageClass);
UInt32 varId = m_constantCache.Register(variable);
builtinDecorations[varId] = builtinDecoration;
@@ -421,7 +437,7 @@ namespace Nz
variable.debugName = member.name;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
variable.type = m_constantCache.BuildPointerType(member.type.GetResultingValue(), storageClass);
UInt32 varId = m_constantCache.Register(variable);
locationDecorations[varId] = member.locationIndex.GetResultingValue();
@@ -643,9 +659,12 @@ namespace Nz
parameterTypes.reserve(functionNode.parameters.size());
for (const auto& parameter : functionNode.parameters)
parameterTypes.push_back(parameter.type);
parameterTypes.push_back(parameter.type.GetResultingValue());
return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType, parameterTypes);
if (functionNode.returnType.HasValue())
return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType.GetResultingValue(), parameterTypes);
else
return m_currentState->constantTypeCache.BuildFunctionType(ShaderAst::NoType{}, parameterTypes);
}
UInt32 SpirvWriter::GetConstantId(const ShaderAst::ConstantValue& value) const