Shader: Fix struct indexes in case of disabled field

This commit is contained in:
Jérôme Leclercq 2021-07-07 21:36:40 +02:00
parent 72edff30c7
commit d679eccb43
4 changed files with 48 additions and 11 deletions

View File

@ -135,21 +135,28 @@ namespace Nz::ShaderAst
assert(structIndex < m_context->structs.size()); assert(structIndex < m_context->structs.size());
const StructDescription* s = m_context->structs[structIndex]; const StructDescription* s = m_context->structs[structIndex];
auto it = std::find_if(s->members.begin(), s->members.end(), [&](const auto& field) // Retrieve member index (not counting disabled fields)
Int32 fieldIndex = 0;
const StructDescription::StructMember* fieldPtr = nullptr;
for (const auto& field : s->members)
{ {
if (field.name != identifier)
return false;
if (field.cond.HasValue() && !field.cond.GetResultingValue()) if (field.cond.HasValue() && !field.cond.GetResultingValue())
return false; continue;
return true; if (field.name == identifier)
}); {
if (it == s->members.end()) fieldPtr = &field;
break;
}
fieldIndex++;
}
if (!fieldPtr)
throw AstError{ "unknown field " + identifier }; throw AstError{ "unknown field " + identifier };
accessIndexPtr->indices.push_back(ShaderBuilder::Constant(Int32(std::distance(s->members.begin(), it)))); accessIndexPtr->indices.push_back(ShaderBuilder::Constant(fieldIndex));
accessIndexPtr->cachedExpressionType = ResolveType(it->type); accessIndexPtr->cachedExpressionType = ResolveType(fieldPtr->type);
} }
else if (IsVectorType(exprType)) else if (IsVectorType(exprType))
{ {

View File

@ -390,8 +390,23 @@ namespace Nz
assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantExpression); assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantExpression);
auto& constantValue = static_cast<ShaderAst::ConstantExpression&>(**memberIndices); auto& constantValue = static_cast<ShaderAst::ConstantExpression&>(**memberIndices);
Int32 index = std::get<Int32>(constantValue.value); Int32 index = std::get<Int32>(constantValue.value);
assert(index >= 0);
const auto& member = structDesc->members[index]; auto it = structDesc->members.begin();
for (; it != structDesc->members.end(); ++it)
{
const auto& member = *it;
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
if (index == 0)
break;
index--;
}
assert(it != structDesc->members.end());
const auto& member = *it;
Append("."); Append(".");
Append(member.name); Append(member.name);
@ -586,6 +601,9 @@ namespace Nz
{ {
for (const auto& member : structDesc.members) for (const auto& member : structDesc.members)
{ {
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
if (member.builtin.HasValue()) if (member.builtin.HasValue())
{ {
auto it = s_builtinMapping.find(member.builtin.GetResultingValue()); auto it = s_builtinMapping.find(member.builtin.GetResultingValue());
@ -898,6 +916,9 @@ namespace Nz
bool first = true; bool first = true;
for (const auto& member : structDesc->members) for (const auto& member : structDesc->members)
{ {
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
if (!first) if (!first)
AppendLine(); AppendLine();

View File

@ -588,6 +588,9 @@ namespace Nz
for (const auto& member : structDesc.members) for (const auto& member : structDesc.members)
{ {
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
auto& sMembers = sType.members.emplace_back(); auto& sMembers = sType.members.emplace_back();
sMembers.name = member.name; sMembers.name = member.name;
sMembers.type = BuildType(member.type); sMembers.type = BuildType(member.type);

View File

@ -218,6 +218,9 @@ namespace Nz
std::size_t memberIndex = 0; std::size_t memberIndex = 0;
for (const auto& member : structDesc->members) for (const auto& member : structDesc->members)
{ {
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0) if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0)
{ {
inputs.push_back({ inputs.push_back({
@ -248,6 +251,9 @@ namespace Nz
std::size_t memberIndex = 0; std::size_t memberIndex = 0;
for (const auto& member : structDesc->members) for (const auto& member : structDesc->members)
{ {
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0) if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0)
{ {
outputs.push_back({ outputs.push_back({