diff --git a/include/Nazara/Core/TaskScheduler.hpp b/include/Nazara/Core/TaskScheduler.hpp index 4f0c9c6fe..3b732f060 100644 --- a/include/Nazara/Core/TaskScheduler.hpp +++ b/include/Nazara/Core/TaskScheduler.hpp @@ -42,7 +42,6 @@ namespace Nz Worker& GetWorker(unsigned int workerIndex); void NotifyTaskCompletion(); - std::atomic_bool m_idle; std::atomic_uint m_remainingTasks; std::size_t m_nextWorkerIndex; std::vector m_workers; diff --git a/src/Nazara/Core/TaskScheduler.cpp b/src/Nazara/Core/TaskScheduler.cpp index ac8573af4..12b054f58 100644 --- a/src/Nazara/Core/TaskScheduler.cpp +++ b/src/Nazara/Core/TaskScheduler.cpp @@ -203,7 +203,6 @@ namespace Nz NAZARA_WARNING_POP() TaskScheduler::TaskScheduler(unsigned int workerCount) : - m_idle(false), m_remainingTasks(0), m_nextWorkerIndex(0), m_tasks(256 * sizeof(Task)), @@ -227,8 +226,6 @@ namespace Nz void TaskScheduler::AddTask(Task&& task) { - m_idle = false; - std::size_t taskIndex; //< not used Task* taskPtr = m_tasks.Allocate(taskIndex, std::move(task)); @@ -248,7 +245,19 @@ namespace Nz void TaskScheduler::WaitForTasks() { - m_idle.wait(false); + // Wait until remaining task counter reaches 0 + for (;;) + { + // Load and test current value + unsigned int remainingTasks = m_remainingTasks.load(); + if (remainingTasks == 0) + break; + + // If task count isn't 0, wait until it's signaled + // (we need to retest remainingTasks because a worker can signal m_remainingTasks while we're still adding tasks) + m_remainingTasks.wait(remainingTasks); + } + m_tasks.Clear(); } @@ -260,9 +269,6 @@ namespace Nz void TaskScheduler::NotifyTaskCompletion() { if (--m_remainingTasks == 0) - { - m_idle = true; - m_idle.notify_one(); - } + m_remainingTasks.notify_one(); } } diff --git a/tests/UnitTests/Engine/Core/TaskSchedulerTests.cpp b/tests/UnitTests/Engine/Core/TaskSchedulerTests.cpp index a2e7e1716..772a5a678 100644 --- a/tests/UnitTests/Engine/Core/TaskSchedulerTests.cpp +++ b/tests/UnitTests/Engine/Core/TaskSchedulerTests.cpp @@ -39,7 +39,7 @@ SCENARIO("TaskScheduler", "[CORE][TaskScheduler]") Nz::Time elapsedTime = clock.GetElapsedTime(); CHECK(count == scheduler.GetWorkerCount()); - CHECK(elapsedTime < Nz::Time::Milliseconds(scheduler.GetWorkerCount() * 100)); + CHECK(elapsedTime < Nz::Time::Milliseconds(std::max(scheduler.GetWorkerCount(), 2u) * 100)); } WHEN("We add a lot of tasks and wait for all of them")