Core/TaskScheduler: Fix race conditions when calling AddTask while workers are not idle

Update TaskScheduler.cpp
This commit is contained in:
SirLynix 2024-02-02 16:20:40 +01:00
parent 44e55adcd9
commit fa73e463a6
1 changed files with 41 additions and 33 deletions

View File

@ -62,14 +62,13 @@ namespace Nz
void AddTask(TaskScheduler::Task* task) void AddTask(TaskScheduler::Task* task)
{ {
m_tasks.push(task); m_tasks.push(task);
if (!m_notifier.test_and_set()) WakeUp();
m_notifier.notify_one();
} }
void Run() void Run()
{ {
bool idle = true; // Wait until task scheduler started
m_notifier.wait(false); // wait until task scheduler finishes initializing m_notifier.wait(false);
StackArray<unsigned int> randomWorkerIndices = NazaraStackArrayNoInit(unsigned int, m_owner.GetWorkerCount() - 1); StackArray<unsigned int> randomWorkerIndices = NazaraStackArrayNoInit(unsigned int, m_owner.GetWorkerCount() - 1);
{ {
@ -84,9 +83,23 @@ namespace Nz
std::shuffle(randomWorkerIndices.begin(), randomWorkerIndices.end(), gen); std::shuffle(randomWorkerIndices.begin(), randomWorkerIndices.end(), gen);
} }
bool idle = false;
do do
{ {
auto ExecuteTask = [&](TaskScheduler::Task* task) // Wait for tasks if we don't have any right now
// FIXME: We can't use pop() because push() and pop() are not thread-safe (and push is called on another thread), but steal() is
// is it an issue?
std::optional<TaskScheduler::Task*> task = m_tasks.steal();
if (!task)
{
for (unsigned int workerIndex : randomWorkerIndices)
{
if (task = m_owner.GetWorker(workerIndex).StealTask())
break;
}
}
if (task)
{ {
if (idle) if (idle)
{ {
@ -94,39 +107,22 @@ namespace Nz
idle = false; idle = false;
} }
(*task)(); NAZARA_ASSUME(*task != nullptr);
}; (**task)();
}
// Wait for tasks if we don't have any right now
std::optional<TaskScheduler::Task*> task = m_tasks.pop();
if (task)
ExecuteTask(*task);
else else
{ {
// Try to steal a task from another worker in a random order to avoid contention if (!idle)
for (unsigned int workerIndex : randomWorkerIndices)
{ {
if (task = m_owner.GetWorker(workerIndex).StealTask()) m_owner.NotifyWorkerIdle();
{ idle = true;
ExecuteTask(*task);
break;
}
} }
if (!task) m_notifier.wait(false);
{ m_notifier.clear();
if (!idle)
{
m_owner.NotifyWorkerIdle();
idle = true;
}
m_notifier.wait(false);
m_notifier.clear();
}
} }
} }
while (m_running); while (m_running.load(std::memory_order_relaxed));
} }
std::optional<TaskScheduler::Task*> StealTask() std::optional<TaskScheduler::Task*> StealTask()
@ -134,6 +130,12 @@ namespace Nz
return m_tasks.steal(); return m_tasks.steal();
} }
void WakeUp()
{
if (!m_notifier.test_and_set())
m_notifier.notify_one();
}
Worker& operator=(const Worker& worker) = delete; Worker& operator=(const Worker& worker) = delete;
Worker& operator=(Worker&&) Worker& operator=(Worker&&)
@ -153,7 +155,7 @@ namespace Nz
NAZARA_WARNING_POP() NAZARA_WARNING_POP()
TaskScheduler::TaskScheduler(unsigned int workerCount) : TaskScheduler::TaskScheduler(unsigned int workerCount) :
m_idle(true), m_idle(false),
m_nextWorkerIndex(0), m_nextWorkerIndex(0),
m_tasks(256 * sizeof(Task)), m_tasks(256 * sizeof(Task)),
m_workerCount(workerCount) m_workerCount(workerCount)
@ -161,11 +163,17 @@ namespace Nz
if (m_workerCount == 0) if (m_workerCount == 0)
m_workerCount = std::max(Core::Instance()->GetHardwareInfo().GetCpuThreadCount(), 1u); m_workerCount = std::max(Core::Instance()->GetHardwareInfo().GetCpuThreadCount(), 1u);
m_idleWorkerCount = m_workerCount; m_idleWorkerCount = 0;
m_workers.reserve(m_workerCount); m_workers.reserve(m_workerCount);
for (unsigned int i = 0; i < m_workerCount; ++i) for (unsigned int i = 0; i < m_workerCount; ++i)
m_workers.emplace_back(*this, i); m_workers.emplace_back(*this, i);
for (unsigned int i = 0; i < m_workerCount; ++i)
m_workers[i].WakeUp();
// Wait until all worked started
WaitForTasks();
} }
TaskScheduler::~TaskScheduler() TaskScheduler::~TaskScheduler()