diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index ecf1a0923..c59422b43 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -63,7 +63,7 @@ namespace Nz void AppendCommentSection(const std::string& section); void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward = false); void AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers); - void AppendHeader(const std::vector& forwardFunctionDeclarations); + void AppendHeader(); void AppendLine(const std::string& txt = {}); template void AppendLine(Args&&... params); void AppendStatementList(std::vector& statements); @@ -74,7 +74,6 @@ namespace Nz void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node); void HandleInOut(); - void RegisterFunction(std::size_t funcIndex, std::string funcName); void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc); void RegisterVariable(std::size_t varIndex, std::string varName); diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index f93f1c216..226b5d19a 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -36,6 +37,14 @@ namespace Nz { using AstRecursiveVisitor::Visit; + void Visit(ShaderAst::CallFunctionExpression& node) override + { + AstRecursiveVisitor::Visit(node); + + assert(currentFunction); + currentFunction->calledFunctions.UnboundedSet(std::get(node.targetFunction)); + } + void Visit(ShaderAst::ConditionalStatement& node) override { if (TestBit(enabledOptions, node.optionIndex)) @@ -64,16 +73,31 @@ namespace Nz entryPoint = &node; } } - else - forwardFunctionDeclarations.push_back(&node); assert(node.funcIndex); - functionNames[node.funcIndex.value()] = node.name; + assert(functions.find(node.funcIndex.value()) == functions.end()); + FunctionData& funcData = functions[node.funcIndex.value()]; + funcData.name = node.name; + funcData.node = &node; + + currentFunction = &funcData; + + AstRecursiveVisitor::Visit(node); + + currentFunction = nullptr; } + struct FunctionData + { + std::string name; + Bitset<> calledFunctions; + ShaderAst::DeclareFunctionStatement* node; + }; + + FunctionData* currentFunction = nullptr; + std::optional selectedStage; - std::unordered_map functionNames; - std::vector forwardFunctionDeclarations; + std::unordered_map functions; ShaderAst::DeclareFunctionStatement* entryPoint = nullptr; UInt64 enabledOptions = 0; }; @@ -100,14 +124,14 @@ namespace Nz }; std::optional stage; - const States* states = nullptr; - ShaderAst::DeclareFunctionStatement* entryFunc = nullptr; std::stringstream stream; std::unordered_map structs; - std::unordered_map functionNames; std::unordered_map variableNames; std::vector inputFields; std::vector outputFields; + Bitset<> declaredFunctions; + PreVisitor previsitor; + const States* states = nullptr; UInt64 enabledOptions = 0; bool isInEntryPoint = false; unsigned int indentLevel = 0; @@ -145,25 +169,11 @@ namespace Nz ShaderAst::StatementPtr& targetAst = *targetAstPtr; - PreVisitor previsitor; - previsitor.enabledOptions = states.enabledOptions; - previsitor.selectedStage = shaderStage; - targetAst->Visit(previsitor); + state.previsitor.enabledOptions = states.enabledOptions; + state.previsitor.selectedStage = shaderStage; + targetAst->Visit(state.previsitor); - if (!previsitor.entryPoint) - { - if (previsitor.forwardFunctionDeclarations.empty()) - throw std::runtime_error("no function found"); - - state.entryFunc = previsitor.forwardFunctionDeclarations.front(); - previsitor.forwardFunctionDeclarations.erase(previsitor.forwardFunctionDeclarations.begin()); - } - else - state.entryFunc = previsitor.entryPoint; - - state.functionNames = std::move(previsitor.functionNames); - - AppendHeader(previsitor.forwardFunctionDeclarations); + AppendHeader(); sanitizedAst->Visit(*this); @@ -255,7 +265,7 @@ namespace Nz case ImageType::E2D: Append("2D"); break; case ImageType::E2D_Array: Append("2DArray"); break; case ImageType::E3D: Append("3D"); break; - case ImageType::Cubemap: Append("Cube"); break; + case ImageType::Cubemap: Append("Cube"); break; } } @@ -356,7 +366,7 @@ namespace Nz } } - void GlslWriter::AppendHeader(const std::vector& forwardFunctionDeclarations) + void GlslWriter::AppendHeader() { unsigned int glslVersion; if (m_environment.glES) @@ -437,15 +447,6 @@ namespace Nz AppendLine("#endif"); AppendLine(); } - - if (!forwardFunctionDeclarations.empty()) - { - AppendCommentSection("function declarations"); - for (const ShaderAst::DeclareFunctionStatement* node : forwardFunctionDeclarations) - AppendFunctionDeclaration(*node, true); - - AppendLine(); - } } void GlslWriter::AppendLine(const std::string& txt) @@ -499,9 +500,6 @@ namespace Nz void GlslWriter::HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node) { - if (m_currentState->entryFunc != &node) - return; //< Ignore other entry points - HandleInOut(); AppendLine("void main()"); EnterScope(); @@ -578,7 +576,7 @@ namespace Nz AppendLine(); }; - const ShaderAst::DeclareFunctionStatement& node = *m_currentState->entryFunc; + const ShaderAst::DeclareFunctionStatement& node = *m_currentState->previsitor.entryPoint; const ShaderAst::StructDescription* inputStruct = nullptr; @@ -613,12 +611,6 @@ namespace Nz } } - void GlslWriter::RegisterFunction(std::size_t funcIndex, std::string funcName) - { - assert(m_currentState->functionNames.find(funcIndex) == m_currentState->functionNames.end()); - m_currentState->functionNames.emplace(funcIndex, std::move(funcName)); - } - void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc) { assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); @@ -722,7 +714,7 @@ namespace Nz void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node) { assert(std::holds_alternative(node.targetFunction)); - const std::string& targetName = Retrieve(m_currentState->functionNames, std::get(node.targetFunction)); + const std::string& targetName = Retrieve(m_currentState->previsitor.functions, std::get(node.targetFunction)).name; Append(targetName, "("); for (std::size_t i = 0; i < node.parameters.size(); ++i) @@ -869,6 +861,30 @@ namespace Nz { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); + if (node.entryStage && m_currentState->previsitor.entryPoint != &node) + return; //< Ignore other entry points + + assert(node.funcIndex); + auto& funcData = Retrieve(m_currentState->previsitor.functions, node.funcIndex.value()); + + // Declare functions called by this function which aren't already defined + bool hasPredeclaration = false; + for (std::size_t i = funcData.calledFunctions.FindFirst(); i != funcData.calledFunctions.npos; i = funcData.calledFunctions.FindNext(i)) + { + if (!m_currentState->declaredFunctions.UnboundedTest(i)) + { + hasPredeclaration = true; + + auto& targetFunc = Retrieve(m_currentState->previsitor.functions, i); + AppendFunctionDeclaration(*targetFunc.node, true); + + m_currentState->declaredFunctions.UnboundedSet(i); + } + } + + if (hasPredeclaration) + AppendLine(); + if (node.entryStage) return HandleEntryPoint(node); @@ -887,6 +903,8 @@ namespace Nz AppendStatementList(node.statements); } LeaveScope(); + + m_currentState->declaredFunctions.UnboundedSet(node.funcIndex.value()); } void GlslWriter::Visit(ShaderAst::DeclareOptionStatement& /*node*/)