Shader/GlslWriter: Move function forward declaration before functions using them

Because of some weird ass shit bug in nvidia driver if functions were forward declared before declaration of UBO they were using
This commit is contained in:
Jérôme Leclercq 2021-05-26 22:21:57 +02:00
parent 948f0517ea
commit 58fe411750
2 changed files with 67 additions and 50 deletions

View File

@ -63,7 +63,7 @@ namespace Nz
void AppendCommentSection(const std::string& section);
void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward = false);
void AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers);
void AppendHeader(const std::vector<ShaderAst::DeclareFunctionStatement*>& forwardFunctionDeclarations);
void AppendHeader();
void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params);
void AppendStatementList(std::vector<ShaderAst::StatementPtr>& statements);
@ -74,7 +74,6 @@ namespace Nz
void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node);
void HandleInOut();
void RegisterFunction(std::size_t funcIndex, std::string funcName);
void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc);
void RegisterVariable(std::size_t varIndex, std::string varName);

View File

@ -4,6 +4,7 @@
#include <Nazara/Shader/GlslWriter.hpp>
#include <Nazara/Core/Algorithm.hpp>
#include <Nazara/Core/Bitset.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Math/Algorithm.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
@ -36,6 +37,14 @@ namespace Nz
{
using AstRecursiveVisitor::Visit;
void Visit(ShaderAst::CallFunctionExpression& node) override
{
AstRecursiveVisitor::Visit(node);
assert(currentFunction);
currentFunction->calledFunctions.UnboundedSet(std::get<std::size_t>(node.targetFunction));
}
void Visit(ShaderAst::ConditionalStatement& node) override
{
if (TestBit<UInt64>(enabledOptions, node.optionIndex))
@ -64,16 +73,31 @@ namespace Nz
entryPoint = &node;
}
}
else
forwardFunctionDeclarations.push_back(&node);
assert(node.funcIndex);
functionNames[node.funcIndex.value()] = node.name;
assert(functions.find(node.funcIndex.value()) == functions.end());
FunctionData& funcData = functions[node.funcIndex.value()];
funcData.name = node.name;
funcData.node = &node;
currentFunction = &funcData;
AstRecursiveVisitor::Visit(node);
currentFunction = nullptr;
}
struct FunctionData
{
std::string name;
Bitset<> calledFunctions;
ShaderAst::DeclareFunctionStatement* node;
};
FunctionData* currentFunction = nullptr;
std::optional<ShaderStageType> selectedStage;
std::unordered_map<std::size_t, std::string> functionNames;
std::vector<ShaderAst::DeclareFunctionStatement*> forwardFunctionDeclarations;
std::unordered_map<std::size_t, FunctionData> functions;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
UInt64 enabledOptions = 0;
};
@ -100,14 +124,14 @@ namespace Nz
};
std::optional<ShaderStageType> stage;
const States* states = nullptr;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream;
std::unordered_map<std::size_t, ShaderAst::StructDescription> structs;
std::unordered_map<std::size_t, std::string> functionNames;
std::unordered_map<std::size_t, std::string> variableNames;
std::vector<InOutField> inputFields;
std::vector<InOutField> outputFields;
Bitset<> declaredFunctions;
PreVisitor previsitor;
const States* states = nullptr;
UInt64 enabledOptions = 0;
bool isInEntryPoint = false;
unsigned int indentLevel = 0;
@ -145,25 +169,11 @@ namespace Nz
ShaderAst::StatementPtr& targetAst = *targetAstPtr;
PreVisitor previsitor;
previsitor.enabledOptions = states.enabledOptions;
previsitor.selectedStage = shaderStage;
targetAst->Visit(previsitor);
state.previsitor.enabledOptions = states.enabledOptions;
state.previsitor.selectedStage = shaderStage;
targetAst->Visit(state.previsitor);
if (!previsitor.entryPoint)
{
if (previsitor.forwardFunctionDeclarations.empty())
throw std::runtime_error("no function found");
state.entryFunc = previsitor.forwardFunctionDeclarations.front();
previsitor.forwardFunctionDeclarations.erase(previsitor.forwardFunctionDeclarations.begin());
}
else
state.entryFunc = previsitor.entryPoint;
state.functionNames = std::move(previsitor.functionNames);
AppendHeader(previsitor.forwardFunctionDeclarations);
AppendHeader();
sanitizedAst->Visit(*this);
@ -255,7 +265,7 @@ namespace Nz
case ImageType::E2D: Append("2D"); break;
case ImageType::E2D_Array: Append("2DArray"); break;
case ImageType::E3D: Append("3D"); break;
case ImageType::Cubemap: Append("Cube"); break;
case ImageType::Cubemap: Append("Cube"); break;
}
}
@ -356,7 +366,7 @@ namespace Nz
}
}
void GlslWriter::AppendHeader(const std::vector<ShaderAst::DeclareFunctionStatement*>& forwardFunctionDeclarations)
void GlslWriter::AppendHeader()
{
unsigned int glslVersion;
if (m_environment.glES)
@ -437,15 +447,6 @@ namespace Nz
AppendLine("#endif");
AppendLine();
}
if (!forwardFunctionDeclarations.empty())
{
AppendCommentSection("function declarations");
for (const ShaderAst::DeclareFunctionStatement* node : forwardFunctionDeclarations)
AppendFunctionDeclaration(*node, true);
AppendLine();
}
}
void GlslWriter::AppendLine(const std::string& txt)
@ -499,9 +500,6 @@ namespace Nz
void GlslWriter::HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node)
{
if (m_currentState->entryFunc != &node)
return; //< Ignore other entry points
HandleInOut();
AppendLine("void main()");
EnterScope();
@ -578,7 +576,7 @@ namespace Nz
AppendLine();
};
const ShaderAst::DeclareFunctionStatement& node = *m_currentState->entryFunc;
const ShaderAst::DeclareFunctionStatement& node = *m_currentState->previsitor.entryPoint;
const ShaderAst::StructDescription* inputStruct = nullptr;
@ -613,12 +611,6 @@ namespace Nz
}
}
void GlslWriter::RegisterFunction(std::size_t funcIndex, std::string funcName)
{
assert(m_currentState->functionNames.find(funcIndex) == m_currentState->functionNames.end());
m_currentState->functionNames.emplace(funcIndex, std::move(funcName));
}
void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc)
{
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
@ -722,7 +714,7 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node)
{
assert(std::holds_alternative<std::size_t>(node.targetFunction));
const std::string& targetName = Retrieve(m_currentState->functionNames, std::get<std::size_t>(node.targetFunction));
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, std::get<std::size_t>(node.targetFunction)).name;
Append(targetName, "(");
for (std::size_t i = 0; i < node.parameters.size(); ++i)
@ -869,6 +861,30 @@ namespace Nz
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
if (node.entryStage && m_currentState->previsitor.entryPoint != &node)
return; //< Ignore other entry points
assert(node.funcIndex);
auto& funcData = Retrieve(m_currentState->previsitor.functions, node.funcIndex.value());
// Declare functions called by this function which aren't already defined
bool hasPredeclaration = false;
for (std::size_t i = funcData.calledFunctions.FindFirst(); i != funcData.calledFunctions.npos; i = funcData.calledFunctions.FindNext(i))
{
if (!m_currentState->declaredFunctions.UnboundedTest(i))
{
hasPredeclaration = true;
auto& targetFunc = Retrieve(m_currentState->previsitor.functions, i);
AppendFunctionDeclaration(*targetFunc.node, true);
m_currentState->declaredFunctions.UnboundedSet(i);
}
}
if (hasPredeclaration)
AppendLine();
if (node.entryStage)
return HandleEntryPoint(node);
@ -887,6 +903,8 @@ namespace Nz
AppendStatementList(node.statements);
}
LeaveScope();
m_currentState->declaredFunctions.UnboundedSet(node.funcIndex.value());
}
void GlslWriter::Visit(ShaderAst::DeclareOptionStatement& /*node*/)