co_ecs 0.9.0
Cobalt ECS
Loading...
Searching...
No Matches
thread_pool.hpp
Go to the documentation of this file.
1#pragma once
2
5
6#include <random>
7#include <semaphore>
8#include <thread>
9
10namespace co_ecs {
11
18public:
19 using thread_t = std::thread;
20
22 class worker {
23 public:
24#ifdef CO_ECS_WORKER_STATS
26 struct worker_stats {
27 std::atomic<uint64_t> task_count;
28 std::atomic<uint64_t> steal_count;
29 std::atomic<uint64_t> idle_count;
30
31 void inc_task() {
32 task_count.fetch_add(1, std::memory_order::relaxed);
33 }
34
35 void inc_steal() {
36 steal_count.fetch_add(1, std::memory_order::relaxed);
37 }
38
39 void inc_idle() {
40 idle_count.fetch_add(1, std::memory_order::relaxed);
41 }
42 };
43#endif
44
48 worker(thread_pool& pool, uint16_t id) : _pool(pool), _id(id) {
49 }
50
53 [[nodiscard]] std::size_t id() const noexcept {
54 return _id;
55 }
56
59 static worker& current() noexcept {
60 return *current_worker;
61 }
62
66 task_t* submit(auto&& func, task_t* parent = nullptr) {
67 task_t* task = task_pool::allocate(std::forward<decltype(func)>(func), parent);
68 submit(task);
69 return task;
70 }
71
74 void submit(task_t* task) {
75 get_queue().push(task);
76 _pool.wake_worker();
77 }
78
81 void wait(task_t* task) {
82 while (!task->is_completed()) {
83 auto* next_task = get_task();
84 if (next_task) {
85 execute(next_task);
86 } else {
87#ifdef CO_ECS_WORKER_STATS
88 _stats.inc_idle();
89#endif
90 }
91 _pool.wake_worker();
92 }
93 }
94
95#ifdef CO_ECS_WORKER_STATS
98 const worker_stats& stats() const noexcept {
99 return _stats;
100 }
101#endif
102
103 private:
104 friend class thread_pool;
105
106 static inline thread_local worker* current_worker;
107
108 void run() {
109 current_worker = this;
110
111 while (true) {
112 task_t* task;
113
114 // fetch and execute tasks while we can
115 do {
116 task = get_task();
117 if (task) {
118 execute(task);
119 } else {
120 idle();
121 }
122 } while (task);
123
124 if (!is_active()) {
125 break;
126 }
127 }
128 }
129
130 void start() {
131 _thread = thread_t([this]() { run(); });
132 }
133
134 void stop() {
135 _active.store(false, std::memory_order::relaxed);
136 }
137
138 void join() {
139 if (_thread.joinable()) {
140 _thread.join();
141 }
142 }
143
144 auto is_active() const noexcept -> bool {
145 return _active.load(std::memory_order::relaxed);
146 }
147
148 [[nodiscard]] task_t* get_task() {
149 // First, attempt to retrieve a task from the worker's own local queue.
150 if (auto maybe_task = get_queue().pop()) {
151 return *maybe_task;
152 }
153
154 // No tasks in the local queue; attempt to steal from the main worker queue, if not the main worker.
155 if (worker* main_worker = &_pool.main_worker(); main_worker != this) {
156 if (auto maybe_task = steal(*main_worker)) {
157 return *maybe_task;
158 }
159 }
160
161 // If stealing from the main worker fails, attempt to steal from a random worker.
162 // This method is optimal for smaller numbers of workers (e.g., 4-8).
163 if (worker* random_worker = _pool.random_worker(); random_worker && random_worker != this) {
164 if (auto maybe_task = steal(*random_worker)) {
165 return *maybe_task;
166 }
167 }
168
169 return nullptr;
170 }
171
172 [[nodiscard]]
173 std::optional<task_t*> steal(worker& worker) {
174 auto maybe_task = worker.get_queue().steal();
175#ifdef CO_ECS_WORKER_STATS
176 if (maybe_task) {
177 _stats.inc_steal();
178 }
179#endif
180 return maybe_task;
181 }
182
183 void execute(task_t* task) {
184 task->execute();
185#ifdef CO_ECS_WORKER_STATS
186 _stats.inc_task();
187#endif
188 }
189
190 void idle() {
191 _pool.wait();
192
193#ifdef CO_ECS_WORKER_STATS
194 _stats.inc_idle();
195#endif
196 }
197
198 [[nodiscard]]
199 detail::work_stealing_queue<task_t*>& get_queue() noexcept {
200 return _queue;
201 }
202
203 private:
204 detail::work_stealing_queue<task_t*> _queue;
205 thread_pool& _pool;
206
207 std::atomic<bool> _active{ true };
208 thread_t _thread{};
209 std::size_t _id;
210#ifdef CO_ECS_WORKER_STATS
211 worker_stats _stats;
212#endif
213 };
214
217 thread_pool(std::size_t num_workers = std::thread::hardware_concurrency()) {
218 assert(num_workers > 0 && "Number of workers should be > 0");
219 _workers.reserve(num_workers);
220
221 if (_instance) {
222 throw std::logic_error("Thread pool already created");
223 }
224 _instance = this;
225
226 // create main worker that will execute tasks in main thread
227 _workers.emplace_back(std::make_unique<worker>(*this, 0));
228 worker::current_worker = _workers[0].get();
229
230 // create background workers
231 for (auto i = 1; i < num_workers; i++) {
232 _workers.emplace_back(std::make_unique<worker>(*this, i));
233 }
234
235 // start workers
236 for (auto i = 1; i < num_workers; i++) {
237 _workers[i]->start();
238 }
239 }
240
243 // stop workers
244 for (auto i = 1; i < _workers.size(); i++) {
245 _workers[i]->stop();
246 }
247
248 // join worker threads
249 for (auto i = 1; i < _workers.size(); i++) {
250 _workers[i]->join();
251 }
252
253 _instance = nullptr;
254 }
255
256 thread_pool(const thread_pool&) = delete;
258
261
264 static thread_pool& get() {
265 if (!_instance) {
266 static thread_pool tp;
267 _instance = &tp;
268 }
269
270 return *_instance;
271 }
272
276 task_t* submit(auto&& func, task_t* parent = nullptr) {
277 return current_worker().submit(std::forward<decltype(func)>(func), parent);
278 }
279
282 void wait(task_t* task) {
283 current_worker().wait(task);
284 }
285
289 worker& get_worker_by_id(std::size_t id) noexcept {
290 return *_workers.at(id);
291 }
292
295 [[nodiscard]]
296 std::size_t num_workers() const noexcept {
297 return _workers.size();
298 }
299
302 static worker& current_worker() noexcept {
303 return worker::current();
304 }
305
306private:
307 worker& main_worker() noexcept {
308 return *_workers[0];
309 }
310
311 worker* random_worker() noexcept {
312 if (num_workers() == 1) {
313 // no other workers than main
314 return nullptr;
315 }
316
317 std::uniform_int_distribution<std::size_t> dist{ 1, num_workers() - 1 };
318 std::default_random_engine random_engine{ std::random_device()() };
319
320 auto random_index = dist(random_engine);
321
322 return _workers[random_index].get();
323 }
324
325 void wake_worker() {
326 _worker_wait_semaphore.release();
327 }
328
329 void wait() {
330 constexpr auto wait_time = std::chrono::milliseconds(5);
331 _worker_wait_semaphore.try_acquire_for(wait_time);
332 }
333
334private:
335 static inline thread_pool* _instance;
336
337 std::vector<std::unique_ptr<worker>> _workers;
338 std::counting_semaphore<> _worker_wait_semaphore{ 0 };
339};
340
341} // namespace co_ecs
static task_t * allocate(auto &&func, task_t *parent=nullptr)
Allocates a task with the specified function and parent, placing it in a circular buffer.
Definition task.hpp:68
Represents a task that can be executed, monitored for completion, and linked to a parent task.
Definition task.hpp:10
bool is_completed() const noexcept
Checks if the task has been completed.
Definition task.hpp:33
Thread pool worker.
Definition thread_pool.hpp:22
void submit(task_t *task)
Submit a task into local workers queue.
Definition thread_pool.hpp:74
void wait(task_t *task)
Wait for task completion.
Definition thread_pool.hpp:81
std::size_t id() const noexcept
Return ID of the worker.
Definition thread_pool.hpp:53
friend class thread_pool
Definition thread_pool.hpp:104
static worker & current() noexcept
Get current thread worker.
Definition thread_pool.hpp:59
worker(thread_pool &pool, uint16_t id)
Create a thread pool worker.
Definition thread_pool.hpp:48
task_t * submit(auto &&func, task_t *parent=nullptr)
Submit a task into local workers queue.
Definition thread_pool.hpp:66
Generic thread pool implementation.
Definition thread_pool.hpp:17
static worker & current_worker() noexcept
Get current worker.
Definition thread_pool.hpp:302
thread_pool & operator=(const thread_pool &)=delete
thread_pool & operator=(thread_pool &&)=delete
std::thread thread_t
Definition thread_pool.hpp:19
task_t * submit(auto &&func, task_t *parent=nullptr)
Submit a task to a thread pool.
Definition thread_pool.hpp:276
std::size_t num_workers() const noexcept
Return the number of workers.
Definition thread_pool.hpp:296
worker & get_worker_by_id(std::size_t id) noexcept
Get worker by ID.
Definition thread_pool.hpp:289
~thread_pool()
Destroy thread pool, worker threads are notified to exit and joined.
Definition thread_pool.hpp:242
thread_pool(const thread_pool &)=delete
thread_pool(thread_pool &&)=delete
static thread_pool & get()
Get thread pool instance.
Definition thread_pool.hpp:264
void wait(task_t *task)
Wait a task to complete.
Definition thread_pool.hpp:282
thread_pool(std::size_t num_workers=std::thread::hardware_concurrency())
Construct thread pool with num_workers workers.
Definition thread_pool.hpp:217
Definition archetype.hpp:11