From e3ad9be759b82c4f7d73fb964b980d6a27ddae7e Mon Sep 17 00:00:00 2001 From: SirLynix Date: Sat, 3 Feb 2024 22:40:12 +0100 Subject: [PATCH] Core/TaskScheduler: Fix work ending condition Use remaining task count instead of idle worker count, this avoids a race condition where a worker signals idle after being tasked with a new job --- include/Nazara/Core/TaskScheduler.hpp | 5 ++- src/Nazara/Core/TaskScheduler.cpp | 33 ++++--------------- .../Engine/Core/TaskSchedulerTests.cpp | 27 +++++++++++++-- 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/include/Nazara/Core/TaskScheduler.hpp b/include/Nazara/Core/TaskScheduler.hpp index a73e2bdf1..4f0c9c6fe 100644 --- a/include/Nazara/Core/TaskScheduler.hpp +++ b/include/Nazara/Core/TaskScheduler.hpp @@ -40,11 +40,10 @@ namespace Nz friend Worker; Worker& GetWorker(unsigned int workerIndex); - void NotifyWorkerActive(); - void NotifyWorkerIdle(); + void NotifyTaskCompletion(); std::atomic_bool m_idle; - std::atomic_uint m_idleWorkerCount; + std::atomic_uint m_remainingTasks; std::size_t m_nextWorkerIndex; std::vector m_workers; MemoryPool m_tasks; diff --git a/src/Nazara/Core/TaskScheduler.cpp b/src/Nazara/Core/TaskScheduler.cpp index ca1d3f6e1..ac8573af4 100644 --- a/src/Nazara/Core/TaskScheduler.cpp +++ b/src/Nazara/Core/TaskScheduler.cpp @@ -116,6 +116,7 @@ namespace Nz { // Wait until task scheduler started m_notifier.wait(false); + m_notifier.clear(); StackArray randomWorkerIndices = NazaraStackArrayNoInit(unsigned int, m_owner.GetWorkerCount() - 1); { @@ -158,20 +159,13 @@ namespace Nz (*task)(); -#ifdef NAZARA_WITH_TSAN - // Workaround for TSan false-positive - __tsan_release(task); -#endif + m_owner.NotifyTaskCompletion(); } else { // Wait for tasks if we don't have any right now - m_owner.NotifyWorkerIdle(); - m_notifier.wait(false); m_notifier.clear(); - - m_owner.NotifyWorkerActive(); } } while (m_running.load(std::memory_order_relaxed)); @@ -210,7 +204,7 @@ namespace Nz TaskScheduler::TaskScheduler(unsigned int workerCount) : m_idle(false), - m_idleWorkerCount(0), + m_remainingTasks(0), m_nextWorkerIndex(0), m_tasks(256 * sizeof(Task)), m_workerCount(workerCount) @@ -224,9 +218,6 @@ namespace Nz for (unsigned int i = 0; i < m_workerCount; ++i) m_workers[i].WakeUp(); - - // Wait until all worked started - m_idle.wait(false); } TaskScheduler::~TaskScheduler() @@ -246,6 +237,8 @@ namespace Nz __tsan_release(taskPtr); #endif + m_remainingTasks++; + Worker& worker = m_workers[m_nextWorkerIndex++]; worker.AddTask(taskPtr); @@ -256,13 +249,6 @@ namespace Nz void TaskScheduler::WaitForTasks() { m_idle.wait(false); - -#ifdef NAZARA_WITH_TSAN - // Workaround for TSan false-positive - for (Task& task : m_tasks) - __tsan_acquire(&task); -#endif - m_tasks.Clear(); } @@ -271,14 +257,9 @@ namespace Nz return m_workers[workerIndex]; } - void TaskScheduler::NotifyWorkerActive() + void TaskScheduler::NotifyTaskCompletion() { - m_idleWorkerCount--; - } - - void TaskScheduler::NotifyWorkerIdle() - { - if (++m_idleWorkerCount == m_workers.size()) + if (--m_remainingTasks == 0) { m_idle = true; m_idle.notify_one(); diff --git a/tests/UnitTests/Engine/Core/TaskSchedulerTests.cpp b/tests/UnitTests/Engine/Core/TaskSchedulerTests.cpp index 8ef643e75..342f10b74 100644 --- a/tests/UnitTests/Engine/Core/TaskSchedulerTests.cpp +++ b/tests/UnitTests/Engine/Core/TaskSchedulerTests.cpp @@ -1,14 +1,17 @@ +#include #include #include #include +#include +#include SCENARIO("TaskScheduler", "[CORE][TaskScheduler]") { - for (std::size_t workerCount : { 0, 1, 2, 4 }) + for (std::size_t workerCount : { 0, 1, 2, 4, 8 }) { GIVEN("A task scheduler with " << workerCount << " workers") { - Nz::TaskScheduler scheduler(4); + Nz::TaskScheduler scheduler(workerCount); WHEN("We add a single task and wait for it") { @@ -19,6 +22,26 @@ SCENARIO("TaskScheduler", "[CORE][TaskScheduler]") CHECK(executed); } + WHEN("We add time-consuming tasks, they are split between workers") + { + std::atomic_uint count = 0; + + Nz::HighPrecisionClock clock; + for (unsigned int i = 0; i < scheduler.GetWorkerCount(); ++i) + { + scheduler.AddTask([&] + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + count++; + }); + } + scheduler.WaitForTasks(); + Nz::Time elapsedTime = clock.GetElapsedTime(); + + CHECK(count == scheduler.GetWorkerCount()); + CHECK(elapsedTime < Nz::Time::Milliseconds(120)); + } + WHEN("We add a lot of tasks and wait for all of them") { constexpr std::size_t taskCount = 512;