Core/TaskScheduler: Make implementation private

This commit is contained in:
SirLynix 2024-02-05 15:59:45 +01:00
parent 3eae055d3a
commit a4827a99a0
3 changed files with 55 additions and 72 deletions

View File

@ -8,9 +8,7 @@
#define NAZARA_CORE_TASKSCHEDULER_HPP #define NAZARA_CORE_TASKSCHEDULER_HPP
#include <NazaraUtils/Prerequisites.hpp> #include <NazaraUtils/Prerequisites.hpp>
#include <NazaraUtils/MemoryPool.hpp>
#include <Nazara/Core/Config.hpp> #include <Nazara/Core/Config.hpp>
#include <atomic>
#include <functional> #include <functional>
#include <memory> #include <memory>
@ -28,7 +26,7 @@ namespace Nz
void AddTask(Task&& task); void AddTask(Task&& task);
inline unsigned int GetWorkerCount() const; unsigned int GetWorkerCount() const;
void WaitForTasks(); void WaitForTasks();
@ -36,17 +34,9 @@ namespace Nz
TaskScheduler& operator=(TaskScheduler&&) = delete; TaskScheduler& operator=(TaskScheduler&&) = delete;
private: private:
struct Data;
class Worker; class Worker;
friend Worker; std::unique_ptr<Data> m_data;
Worker& GetWorker(unsigned int workerIndex);
void NotifyTaskCompletion();
std::atomic_uint m_remainingTasks;
std::size_t m_nextWorkerIndex;
std::vector<Worker> m_workers;
MemoryPool<Task> m_tasks;
unsigned int m_workerCount;
}; };
} }

View File

@ -6,10 +6,6 @@
namespace Nz namespace Nz
{ {
inline unsigned int TaskScheduler::GetWorkerCount() const
{
return m_workerCount;
}
} }
#include <Nazara/Core/DebugOff.hpp> #include <Nazara/Core/DebugOff.hpp>

View File

@ -33,12 +33,20 @@ namespace Nz
#endif #endif
} }
struct TaskScheduler::Data
{
std::atomic_uint remainingTasks = 0;
std::size_t nextWorkerIndex = 0;
std::vector<Worker> workers;
unsigned int workerCount;
};
class alignas(NAZARA_ANONYMOUS_NAMESPACE_PREFIX(hardware_destructive_interference_size * 2)) TaskScheduler::Worker class alignas(NAZARA_ANONYMOUS_NAMESPACE_PREFIX(hardware_destructive_interference_size * 2)) TaskScheduler::Worker
{ {
public: public:
Worker(TaskScheduler& owner, unsigned int workerIndex) : Worker(TaskScheduler::Data& data, unsigned int workerIndex) :
m_running(true), m_running(true),
m_owner(owner), m_data(data),
m_workerIndex(workerIndex) m_workerIndex(workerIndex)
{ {
m_thread = std::thread([this] m_thread = std::thread([this]
@ -52,7 +60,7 @@ namespace Nz
// "Implement" movement to make the compiler happy // "Implement" movement to make the compiler happy
Worker(Worker&& worker) : Worker(Worker&& worker) :
m_owner(worker.m_owner) m_data(worker.m_data)
{ {
NAZARA_UNREACHABLE(); NAZARA_UNREACHABLE();
} }
@ -62,22 +70,28 @@ namespace Nz
m_thread.join(); m_thread.join();
} }
void AddTask(TaskScheduler::Task* task) void AddTask(TaskScheduler::Task&& task)
{ {
m_tasks.enqueue(task); m_tasks.enqueue(std::move(task));
WakeUp(); WakeUp();
} }
void NotifyTaskCompletion()
{
if (--m_data.remainingTasks == 0)
m_data.remainingTasks.notify_one();
}
void Run() void Run()
{ {
// Wait until task scheduler started // Wait until task scheduler started
m_notifier.wait(false); m_notifier.wait(false);
m_notifier.clear(); m_notifier.clear();
StackArray<unsigned int> randomWorkerIndices = NazaraStackArrayNoInit(unsigned int, m_owner.GetWorkerCount() - 1); StackArray<unsigned int> randomWorkerIndices = NazaraStackArrayNoInit(unsigned int, m_data.workerCount - 1);
{ {
unsigned int* indexPtr = randomWorkerIndices.data(); unsigned int* indexPtr = randomWorkerIndices.data();
for (unsigned int i = 0; i < m_owner.GetWorkerCount(); ++i) for (unsigned int i = 0; i < m_data.workerCount; ++i)
{ {
if (i != m_workerIndex) if (i != m_workerIndex)
*indexPtr++ = i; *indexPtr++ = i;
@ -90,12 +104,12 @@ namespace Nz
while (m_running.load(std::memory_order_relaxed)) while (m_running.load(std::memory_order_relaxed))
{ {
// Get a task // Get a task
TaskScheduler::Task* task = nullptr; TaskScheduler::Task task;
if (!m_tasks.try_dequeue(task)) if (!m_tasks.try_dequeue(task))
{ {
for (unsigned int workerIndex : randomWorkerIndices) for (unsigned int workerIndex : randomWorkerIndices)
{ {
task = m_owner.GetWorker(workerIndex).StealTask(); task = m_data.workers[workerIndex].StealTask();
if (task) if (task)
break; break;
} }
@ -108,9 +122,9 @@ namespace Nz
__tsan_acquire(task); __tsan_acquire(task);
#endif #endif
(*task)(); task();
m_owner.NotifyTaskCompletion(); NotifyTaskCompletion();
} }
else else
{ {
@ -128,9 +142,9 @@ namespace Nz
m_notifier.notify_one(); m_notifier.notify_one();
} }
TaskScheduler::Task* StealTask() TaskScheduler::Task StealTask()
{ {
TaskScheduler::Task* task = nullptr; TaskScheduler::Task task;
m_tasks.try_dequeue(task); m_tasks.try_dequeue(task);
return task; return task;
} }
@ -153,57 +167,53 @@ namespace Nz
std::atomic_bool m_running; std::atomic_bool m_running;
std::atomic_flag m_notifier; std::atomic_flag m_notifier;
std::thread m_thread; //< std::jthread is not yet widely implemented std::thread m_thread; //< std::jthread is not yet widely implemented
moodycamel::ConcurrentQueue<TaskScheduler::Task*> m_tasks; moodycamel::ConcurrentQueue<TaskScheduler::Task> m_tasks;
TaskScheduler& m_owner; TaskScheduler::Data& m_data;
unsigned int m_workerIndex; unsigned int m_workerIndex;
}; };
NAZARA_WARNING_POP() NAZARA_WARNING_POP()
TaskScheduler::TaskScheduler(unsigned int workerCount) : TaskScheduler::TaskScheduler(unsigned int workerCount)
m_remainingTasks(0),
m_nextWorkerIndex(0),
m_tasks(256 * sizeof(Task)),
m_workerCount(workerCount)
{ {
if (m_workerCount == 0) if (workerCount == 0)
m_workerCount = std::max(Core::Instance()->GetHardwareInfo().GetCpuThreadCount(), 1u); workerCount = std::max(Core::Instance()->GetHardwareInfo().GetCpuThreadCount(), 1u);
m_workers.reserve(m_workerCount); m_data = std::make_unique<Data>();
for (unsigned int i = 0; i < m_workerCount; ++i) m_data->workerCount = workerCount;
m_workers.emplace_back(*this, i);
for (Worker& worker : m_workers) m_data->workers.reserve(workerCount);
for (unsigned int i = 0; i < workerCount; ++i)
m_data->workers.emplace_back(*m_data, i);
for (Worker& worker : m_data->workers)
worker.WakeUp(); worker.WakeUp();
} }
TaskScheduler::~TaskScheduler() TaskScheduler::~TaskScheduler()
{ {
// Wake up workers and tell them to exit // Wake up workers and tell them to exit
for (Worker& worker : m_workers) for (Worker& worker : m_data->workers)
worker.Shutdown(); worker.Shutdown();
// wait for them to have exited // wait for them to have exited
m_workers.clear(); m_data->workers.clear();
} }
void TaskScheduler::AddTask(Task&& task) void TaskScheduler::AddTask(Task&& task)
{ {
std::size_t taskIndex; //< not used m_data->remainingTasks++;
Task* taskPtr = m_tasks.Allocate(taskIndex, std::move(task));
#ifdef NAZARA_WITH_TSAN Worker& worker = m_data->workers[m_data->nextWorkerIndex++];
// Workaround for TSan false-positive worker.AddTask(std::move(task));
__tsan_release(taskPtr);
#endif
m_remainingTasks++; if (m_data->nextWorkerIndex >= m_data->workers.size())
m_data->nextWorkerIndex = 0;
}
Worker& worker = m_workers[m_nextWorkerIndex++]; unsigned int TaskScheduler::GetWorkerCount() const
worker.AddTask(taskPtr); {
return m_data->workerCount;
if (m_nextWorkerIndex >= m_workers.size())
m_nextWorkerIndex = 0;
} }
void TaskScheduler::WaitForTasks() void TaskScheduler::WaitForTasks()
@ -212,26 +222,13 @@ namespace Nz
for (;;) for (;;)
{ {
// Load and test current value // Load and test current value
unsigned int remainingTasks = m_remainingTasks.load(); unsigned int remainingTasks = m_data->remainingTasks.load();
if (remainingTasks == 0) if (remainingTasks == 0)
break; break;
// If task count isn't 0, wait until it's signaled // If task count isn't 0, wait until it's signaled
// (we need to retest remainingTasks because a worker can signal m_remainingTasks while we're still adding tasks) // (we need to retest remainingTasks because a worker can signal m_remainingTasks while we're still adding tasks)
m_remainingTasks.wait(remainingTasks); m_data->remainingTasks.wait(remainingTasks);
} }
m_tasks.Clear();
}
auto TaskScheduler::GetWorker(unsigned int workerIndex) -> Worker&
{
return m_workers[workerIndex];
}
void TaskScheduler::NotifyTaskCompletion()
{
if (--m_remainingTasks == 0)
m_remainingTasks.notify_one();
} }
} }