From a5b71f33b92df2f44d74a18ef2cd19f16ea31d43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Mon, 2 Aug 2021 11:12:34 +0200 Subject: [PATCH] Shader: Fix support of conditional functions --- include/Nazara/Shader/SpirvAstVisitor.hpp | 8 +++--- include/Nazara/Shader/SpirvAstVisitor.inl | 11 +++----- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 15 ++++++++++- src/Nazara/Shader/SpirvAstVisitor.cpp | 20 ++++++++++----- src/Nazara/Shader/SpirvWriter.cpp | 31 +++++++++++++++-------- 5 files changed, 57 insertions(+), 28 deletions(-) diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index 6f1a1aa94..b49279b3b 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -27,7 +27,7 @@ namespace Nz struct FuncData; struct Variable; - inline SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::vector& funcData); + inline SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::unordered_map& funcData); SpirvAstVisitor(const SpirvAstVisitor&) = delete; SpirvAstVisitor(SpirvAstVisitor&&) = delete; ~SpirvAstVisitor() = default; @@ -147,10 +147,10 @@ namespace Nz std::size_t m_extVarIndex; std::size_t m_funcCallIndex; std::size_t m_funcIndex; + std::unordered_map& m_funcData; + std::unordered_map m_structs; + std::unordered_map m_variables; std::vector m_scopeSizes; - std::vector& m_funcData; - std::vector m_structs; - std::vector> m_variables; std::vector m_functionBlocks; std::vector m_resultIds; SpirvBlock* m_currentBlock; diff --git a/include/Nazara/Shader/SpirvAstVisitor.inl b/include/Nazara/Shader/SpirvAstVisitor.inl index 48ad986c7..12bd915d8 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.inl +++ b/include/Nazara/Shader/SpirvAstVisitor.inl @@ -3,11 +3,12 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include namespace Nz { - inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::vector& funcData) : + inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::unordered_map& funcData) : m_extVarIndex(0), m_funcIndex(0), m_funcData(funcData), @@ -27,17 +28,13 @@ namespace Nz inline void SpirvAstVisitor::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* structDesc) { - if (structIndex >= m_structs.size()) - m_structs.resize(structIndex + 1); - + assert(m_structs.find(structIndex) == m_structs.end()); m_structs[structIndex] = structDesc; } inline void SpirvAstVisitor::RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass) { - if (varIndex >= m_variables.size()) - m_variables.resize(varIndex + 1); - + assert(m_variables.find(varIndex) == m_variables.end()); m_variables[varIndex] = Variable{ storageClass, pointerId, diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 8e976679e..96be5b846 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -81,10 +81,22 @@ namespace Nz::ShaderAst { if (statementPtr->GetType() == NodeType::DeclareFunctionStatement) DeclareFunction(static_cast(*statementPtr)); + else if (statementPtr->GetType() == NodeType::ConditionalStatement) + { + const ConditionalStatement& condStatement = static_cast(*statementPtr); + if (condStatement.statement->GetType() == NodeType::DeclareFunctionStatement) + DeclareFunction(static_cast(*condStatement.statement)); + } } } else if (statement.GetType() == NodeType::DeclareFunctionStatement) DeclareFunction(static_cast(statement)); + else if (statement.GetType() == NodeType::ConditionalStatement) + { + const ConditionalStatement& condStatement = static_cast(statement); + if (condStatement.statement->GetType() == NodeType::DeclareFunctionStatement) + DeclareFunction(static_cast(*condStatement.statement)); + } try { @@ -1193,7 +1205,8 @@ namespace Nz::ShaderAst { assert(funcIndex < m_context->functions.size()); auto& funcData = m_context->functions[funcIndex]; - assert(funcData.defined); + if (!funcData.defined) + return; funcData.flags |= flags; diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index c41d77620..15ee3bf94 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -15,6 +15,16 @@ namespace Nz { + namespace + { + template const T& Retrieve(const std::unordered_map& map, std::size_t id) + { + auto it = map.find(id); + assert(it != map.end()); + return it->second; + } + } + UInt32 SpirvAstVisitor::AllocateResultId() { return m_writer.AllocateResultId(); @@ -30,9 +40,7 @@ namespace Nz auto SpirvAstVisitor::GetVariable(std::size_t varIndex) const -> const Variable& { - assert(varIndex < m_variables.size()); - assert(m_variables[varIndex]); - return *m_variables[varIndex]; + return Retrieve(m_variables, varIndex); } void SpirvAstVisitor::Visit(ShaderAst::AccessIndexExpression& node) @@ -415,9 +423,9 @@ namespace Nz std::size_t functionIndex = std::get(node.targetFunction); UInt32 funcId = 0; - for (const auto& func : m_funcData) + for (const auto& [funcIndex, func] : m_funcData) { - if (func.funcIndex == functionIndex) + if (funcIndex == functionIndex) { funcId = func.funcId; break; @@ -425,7 +433,7 @@ namespace Nz } assert(funcId != 0); - const FuncData& funcData = m_funcData[m_funcIndex]; + const FuncData& funcData = Retrieve(m_funcData, m_funcIndex); const auto& funcCall = funcData.funcCalls[m_funcCallIndex++]; StackArray parameterIds = NazaraStackArrayNoInit(UInt32, node.parameters.size()); diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index ce9d550ce..a979d6ef2 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -35,6 +35,20 @@ namespace Nz SpirvBuiltIn decoration; }; + template T& Retrieve(std::unordered_map& map, std::size_t id) + { + auto it = map.find(id); + assert(it != map.end()); + return it->second; + } + + template const T& Retrieve(const std::unordered_map& map, std::size_t id) + { + auto it = map.find(id); + assert(it != map.end()); + return it->second; + } + std::unordered_map s_builtinMapping = { { ShaderAst::BuiltinEntry::FragCoord, { "FragmentCoordinates", ShaderStageType::Fragment, SpirvBuiltIn::FragCoord } }, { ShaderAst::BuiltinEntry::FragDepth, { "FragmentDepth", ShaderStageType::Fragment, SpirvBuiltIn::FragDepth } }, @@ -59,7 +73,7 @@ namespace Nz using FunctionContainer = std::vector>; using StructContainer = std::vector; - PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector& funcs) : + PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::unordered_map& funcs) : m_states(conditions), m_constantCache(constantCache), m_externalBlockIndex(0), @@ -91,7 +105,7 @@ namespace Nz AstRecursiveVisitor::Visit(node); assert(m_funcIndex); - auto& func = m_funcs[*m_funcIndex]; + auto& func = Retrieve(m_funcs, *m_funcIndex); auto& funcCall = func.funcCalls.emplace_back(); funcCall.firstVarIndex = func.variables.size(); @@ -152,9 +166,6 @@ namespace Nz assert(node.funcIndex); std::size_t funcIndex = *node.funcIndex; - if (funcIndex >= m_funcs.size()) - m_funcs.resize(funcIndex + 1); - auto& funcData = m_funcs[funcIndex]; funcData.name = node.name; funcData.funcIndex = funcIndex; @@ -409,7 +420,7 @@ namespace Nz SpirvConstantCache& m_constantCache; std::optional m_funcIndex; std::size_t m_externalBlockIndex; - std::vector& m_funcs; + std::unordered_map& m_funcs; }; } @@ -429,7 +440,7 @@ namespace Nz std::unordered_map extensionInstructionSet; std::unordered_map varToResult; - std::vector funcs; + std::unordered_map funcs; std::vector resultIds; UInt32 nextVarIndex = 1; SpirvConstantCache constantTypeCache; //< init after nextVarIndex @@ -488,7 +499,7 @@ namespace Nz state.extensionInstructionSet[extInst] = AllocateResultId(); // Assign function ID (required for forward declaration) - for (auto& func : state.funcs) + for (auto&& [funcIndex, func] : state.funcs) func.funcId = AllocateResultId(); SpirvAstVisitor visitor(*this, state.instructions, state.funcs); @@ -548,7 +559,7 @@ namespace Nz m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450); - for (auto& func : m_currentState->funcs) + for (auto&& [funcIndex, func] : m_currentState->funcs) { m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name); @@ -588,7 +599,7 @@ namespace Nz } // Write execution modes - for (auto& func : m_currentState->funcs) + for (auto&& [funcIndex, func] : m_currentState->funcs) { if (func.entryPointData) {