From d679eccb430cce1ab71f835378758d58c171b0e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Wed, 7 Jul 2021 21:36:40 +0200 Subject: [PATCH] Shader: Fix struct indexes in case of disabled field --- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 27 ++++++++++++++--------- src/Nazara/Shader/GlslWriter.cpp | 23 ++++++++++++++++++- src/Nazara/Shader/SpirvConstantCache.cpp | 3 +++ src/Nazara/Shader/SpirvWriter.cpp | 6 +++++ 4 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index a48681575..84a4d8ec4 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -135,21 +135,28 @@ namespace Nz::ShaderAst assert(structIndex < m_context->structs.size()); const StructDescription* s = m_context->structs[structIndex]; - auto it = std::find_if(s->members.begin(), s->members.end(), [&](const auto& field) + // Retrieve member index (not counting disabled fields) + Int32 fieldIndex = 0; + const StructDescription::StructMember* fieldPtr = nullptr; + for (const auto& field : s->members) { - if (field.name != identifier) - return false; - if (field.cond.HasValue() && !field.cond.GetResultingValue()) - return false; + continue; - return true; - }); - if (it == s->members.end()) + if (field.name == identifier) + { + fieldPtr = &field; + break; + } + + fieldIndex++; + } + + if (!fieldPtr) throw AstError{ "unknown field " + identifier }; - accessIndexPtr->indices.push_back(ShaderBuilder::Constant(Int32(std::distance(s->members.begin(), it)))); - accessIndexPtr->cachedExpressionType = ResolveType(it->type); + accessIndexPtr->indices.push_back(ShaderBuilder::Constant(fieldIndex)); + accessIndexPtr->cachedExpressionType = ResolveType(fieldPtr->type); } else if (IsVectorType(exprType)) { diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index b28eaec2d..160c606c7 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -390,8 +390,23 @@ namespace Nz assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantExpression); auto& constantValue = static_cast(**memberIndices); Int32 index = std::get(constantValue.value); + assert(index >= 0); - const auto& member = structDesc->members[index]; + auto it = structDesc->members.begin(); + for (; it != structDesc->members.end(); ++it) + { + const auto& member = *it; + if (member.cond.HasValue() && !member.cond.GetResultingValue()) + continue; + + if (index == 0) + break; + + index--; + } + + assert(it != structDesc->members.end()); + const auto& member = *it; Append("."); Append(member.name); @@ -586,6 +601,9 @@ namespace Nz { for (const auto& member : structDesc.members) { + if (member.cond.HasValue() && !member.cond.GetResultingValue()) + continue; + if (member.builtin.HasValue()) { auto it = s_builtinMapping.find(member.builtin.GetResultingValue()); @@ -898,6 +916,9 @@ namespace Nz bool first = true; for (const auto& member : structDesc->members) { + if (member.cond.HasValue() && !member.cond.GetResultingValue()) + continue; + if (!first) AppendLine(); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index d569bb22f..4e9699dc9 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -588,6 +588,9 @@ namespace Nz for (const auto& member : structDesc.members) { + if (member.cond.HasValue() && !member.cond.GetResultingValue()) + continue; + auto& sMembers = sType.members.emplace_back(); sMembers.name = member.name; sMembers.type = BuildType(member.type); diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 9640c80c4..49fb35d7e 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -218,6 +218,9 @@ namespace Nz std::size_t memberIndex = 0; for (const auto& member : structDesc->members) { + if (member.cond.HasValue() && !member.cond.GetResultingValue()) + continue; + if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0) { inputs.push_back({ @@ -248,6 +251,9 @@ namespace Nz std::size_t memberIndex = 0; for (const auto& member : structDesc->members) { + if (member.cond.HasValue() && !member.cond.GetResultingValue()) + continue; + if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0) { outputs.push_back({