diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 3212d81cd..ad76fbe18 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -51,6 +51,7 @@ namespace Nz void Append(ShaderNodes::MemoryLayout layout); template void Append(const T& param); void AppendCommentSection(const std::string& section); + void AppendField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers); void AppendFunction(const ShaderAst::Function& func); void AppendFunctionPrototype(const ShaderAst::Function& func); void AppendLine(const std::string& txt = {}); diff --git a/include/Nazara/Shader/ShaderAstValidator.hpp b/include/Nazara/Shader/ShaderAstValidator.hpp index 0f9e99347..8195494b6 100644 --- a/include/Nazara/Shader/ShaderAstValidator.hpp +++ b/include/Nazara/Shader/ShaderAstValidator.hpp @@ -33,6 +33,8 @@ namespace Nz void TypeMustMatch(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right); void TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right); + const ShaderAst::StructMember& CheckField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers); + using ShaderAstRecursiveVisitor::Visit; void Visit(ShaderNodes::AccessMember& node) override; void Visit(ShaderNodes::AssignOp& node) override; diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index af96afeda..15898af3f 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -144,11 +144,12 @@ namespace Nz ShaderExpressionType GetExpressionType() const override; void Visit(ShaderAstVisitor& visitor) override; - std::size_t memberIndex; ExpressionPtr structExpr; ShaderExpressionType exprType; + std::vector memberIndices; static inline std::shared_ptr Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType); + static inline std::shared_ptr Build(ExpressionPtr structExpr, std::vector memberIndices, ShaderExpressionType exprType); }; ////////////////////////////////////////////////////////////////////////// diff --git a/include/Nazara/Shader/ShaderNodes.inl b/include/Nazara/Shader/ShaderNodes.inl index 9265fe359..77bc813f3 100644 --- a/include/Nazara/Shader/ShaderNodes.inl +++ b/include/Nazara/Shader/ShaderNodes.inl @@ -168,10 +168,15 @@ namespace Nz::ShaderNodes } inline std::shared_ptr AccessMember::Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType) + { + return Build(std::move(structExpr), std::vector{ memberIndex }, exprType); + } + + inline std::shared_ptr AccessMember::Build(ExpressionPtr structExpr, std::vector memberIndices, ShaderExpressionType exprType) { auto node = std::make_shared(); node->exprType = std::move(exprType); - node->memberIndex = memberIndex; + node->memberIndices = std::move(memberIndices); node->structExpr = std::move(structExpr); return node; diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 1e911ce13..39f2ec3ce 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -258,6 +258,26 @@ namespace Nz AppendLine(); } + void GlslWriter::AppendField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers) + { + const auto& structs = m_context.shader->GetStructs(); + auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; }); + assert(it != structs.end()); + + const ShaderAst::Struct& s = *it; + assert(*memberIndex < s.members.size()); + + const auto& member = s.members[*memberIndex]; + Append("."); + Append(member.name); + + if (remainingMembers > 1) + { + assert(std::holds_alternative(member.type)); + AppendField(std::get(member.type), memberIndex + 1, remainingMembers - 1); + } + } + void GlslWriter::AppendFunction(const ShaderAst::Function& func) { NazaraAssert(!m_context.currentFunction, "A function is already being processed"); @@ -345,18 +365,7 @@ namespace Nz const ShaderExpressionType& exprType = node.structExpr->GetExpressionType(); assert(std::holds_alternative(exprType)); - const std::string& structName = std::get(exprType); - - const auto& structs = m_context.shader->GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; }); - assert(it != structs.end()); - - const ShaderAst::Struct& s = *it; - assert(node.memberIndex < s.members.size()); - - const auto& member = s.members[node.memberIndex]; - Append("."); - Append(member.name); + AppendField(std::get(exprType), node.memberIndices.data(), node.memberIndices.size()); } void GlslWriter::Visit(ShaderNodes::AssignOp& node) diff --git a/src/Nazara/Shader/ShaderAstCloner.cpp b/src/Nazara/Shader/ShaderAstCloner.cpp index 1e0f04899..39754b7a9 100644 --- a/src/Nazara/Shader/ShaderAstCloner.cpp +++ b/src/Nazara/Shader/ShaderAstCloner.cpp @@ -47,7 +47,7 @@ namespace Nz void ShaderAstCloner::Visit(ShaderNodes::AccessMember& node) { - PushExpression(ShaderNodes::AccessMember::Build(CloneExpression(node.structExpr), node.memberIndex, node.exprType)); + PushExpression(ShaderNodes::AccessMember::Build(CloneExpression(node.structExpr), node.memberIndices, node.exprType)); } void ShaderAstCloner::Visit(ShaderNodes::AssignOp& node) diff --git a/src/Nazara/Shader/ShaderAstSerializer.cpp b/src/Nazara/Shader/ShaderAstSerializer.cpp index 396cd3d98..03cd36d8b 100644 --- a/src/Nazara/Shader/ShaderAstSerializer.cpp +++ b/src/Nazara/Shader/ShaderAstSerializer.cpp @@ -132,9 +132,12 @@ namespace Nz void ShaderAstSerializerBase::Serialize(ShaderNodes::AccessMember& node) { - Value(node.memberIndex); Node(node.structExpr); Type(node.exprType); + + Container(node.memberIndices); + for (std::size_t& index : node.memberIndices) + Value(index); } void ShaderAstSerializerBase::Serialize(ShaderNodes::AssignOp& node) diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index 2f3177283..b614ec4a2 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -83,6 +83,30 @@ namespace Nz throw AstError{ "Left expression type must match right expression type" }; } + const ShaderAst::StructMember& ShaderAstValidator::CheckField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers) + { + const auto& structs = m_shader.GetStructs(); + auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; }); + if (it == structs.end()) + throw AstError{ "invalid structure" }; + + const ShaderAst::Struct& s = *it; + if (*memberIndex >= s.members.size()) + throw AstError{ "member index out of bounds" }; + + const auto& member = s.members[*memberIndex]; + + if (remainingMembers > 1) + { + if (!std::holds_alternative(member.type)) + throw AstError{ "member type does not match node type" }; + + return CheckField(std::get(member.type), memberIndex + 1, remainingMembers - 1); + } + else + return member; + } + void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node) { const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType(); @@ -91,16 +115,7 @@ namespace Nz const std::string& structName = std::get(exprType); - const auto& structs = m_shader.GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; }); - if (it == structs.end()) - throw AstError{ "invalid structure" }; - - const ShaderAst::Struct& s = *it; - if (node.memberIndex >= s.members.size()) - throw AstError{ "member index out of bounds" }; - - const auto& member = s.members[node.memberIndex]; + const auto& member = CheckField(structName, node.memberIndices.data(), node.memberIndices.size()); if (member.type != node.exprType) throw AstError{ "member type does not match node type" }; } diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 398ee8e71..92630d506 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -42,7 +42,8 @@ namespace Nz void Visit(ShaderNodes::AccessMember& node) override { - m_constantCache.Register(*SpirvConstantCache::BuildConstant(UInt32(node.memberIndex))); + for (std::size_t index : node.memberIndices) + m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index))); ShaderAstRecursiveVisitor::Visit(node); } @@ -653,9 +654,16 @@ namespace Nz UInt32 memberPointerId = AllocateResultId(); UInt32 pointerType = RegisterPointerType(node.exprType, storage); //< FIXME UInt32 typeId = GetTypeId(node.exprType); - UInt32 indexId = GetConstantId(UInt32(node.memberIndex)); - m_currentState->instructions.Append(SpirvOp::OpAccessChain, pointerType, memberPointerId, pointerId, indexId); + m_currentState->instructions.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) + { + appender(pointerType); + appender(memberPointerId); + appender(pointerId); + + for (std::size_t index : node.memberIndices) + appender(GetConstantId(Int32(index))); + }); UInt32 resultId = AllocateResultId(); @@ -1047,8 +1055,6 @@ namespace Nz void SpirvWriter::Visit(ShaderNodes::Sample2D& node) { - // OpImageSampleImplicitLod %v4float %31 %35 - UInt32 typeId = GetTypeId(ShaderNodes::BasicType::Float4); UInt32 samplerId = EvaluateExpression(node.sampler); diff --git a/src/ShaderNode/DataModels/BufferField.cpp b/src/ShaderNode/DataModels/BufferField.cpp index 2d9cde16d..3f3986fb5 100644 --- a/src/ShaderNode/DataModels/BufferField.cpp +++ b/src/ShaderNode/DataModels/BufferField.cpp @@ -76,6 +76,9 @@ Nz::ShaderNodes::ExpressionPtr BufferField::GetExpression(Nz::ShaderNodes::Expre Nz::ShaderNodes::ExpressionPtr sourceExpr = Nz::ShaderBuilder::Identifier(varPtr); + std::vector memberIndices; + memberIndices.reserve(currentField.nestedFields.size() + 1); + const ShaderGraph::StructEntry* sourceStruct = &structEntry; for (std::size_t nestedIndex : currentField.nestedFields) { @@ -86,14 +89,16 @@ Nz::ShaderNodes::ExpressionPtr BufferField::GetExpression(Nz::ShaderNodes::Expre std::size_t nestedStructIndex = std::get(memberEntry.type); sourceStruct = &graph.GetStruct(nestedStructIndex); - sourceExpr = Nz::ShaderBuilder::AccessMember(std::move(sourceExpr), 0, graph.ToShaderExpressionType(memberEntry.type)); + memberIndices.push_back(nestedIndex); } + memberIndices.push_back(currentField.finalFieldIndex); + assert(currentField.finalFieldIndex < sourceStruct->members.size()); const auto& memberEntry = sourceStruct->members[currentField.finalFieldIndex]; assert(std::holds_alternative(memberEntry.type)); - return Nz::ShaderBuilder::AccessMember(std::move(sourceExpr), currentField.finalFieldIndex, graph.ToShaderExpressionType(std::get(memberEntry.type))); + return Nz::ShaderBuilder::AccessMember(std::move(sourceExpr), std::move(memberIndices), graph.ToShaderExpressionType(std::get(memberEntry.type))); } unsigned int BufferField::nPorts(QtNodes::PortType portType) const