/*
* Copyright (C) 2022 Mikhail Burakov. This file is part of toolbox.
*
* toolbox is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* toolbox is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with toolbox. If not, see .
*/
#include "thread_pool.h"
#include
#include
#include
#include
#include
struct ThreadPoolTask {
void (*fun)(void*);
void* user;
};
static bool FetchTask(struct ThreadPool* thread_pool,
struct ThreadPoolTask* task) {
for (size_t i = 0; i < thread_pool->tasks_count; i++) {
if (thread_pool->tasks[i].fun) {
*task = thread_pool->tasks[i];
thread_pool->tasks[i].fun = NULL;
return true;
}
}
return false;
}
static bool StoreTask(struct ThreadPool* thread_pool,
const struct ThreadPoolTask* task) {
for (size_t i = 0; i < thread_pool->tasks_count; i++) {
if (!thread_pool->tasks[i].fun) {
thread_pool->tasks[i] = *task;
return true;
}
}
size_t tasks_count = thread_pool->tasks_count + 1;
struct ThreadPoolTask* tasks =
realloc(thread_pool->tasks, tasks_count * sizeof(struct ThreadPoolTask));
if (!tasks) return false;
tasks[thread_pool->tasks_count] = *task;
thread_pool->tasks = tasks;
thread_pool->tasks_count = tasks_count;
return true;
}
static int ThreadProc(void* arg) {
for (struct ThreadPool* thread_pool = arg;
atomic_load_explicit(&thread_pool->running, memory_order_relaxed);) {
if (mtx_lock(&thread_pool->tasks_mutex) != thrd_success) {
// TODO(mburakov): Could we do something other than just reattempt?
thrd_yield();
continue;
}
for (;;) {
if (!atomic_load_explicit(&thread_pool->running, memory_order_relaxed)) {
mtx_unlock(&thread_pool->tasks_mutex);
return 0;
}
struct ThreadPoolTask task;
if (FetchTask(thread_pool, &task)) {
mtx_unlock(&thread_pool->tasks_mutex);
task.fun(task.user);
break;
}
cnd_wait(&thread_pool->tasks_cond, &thread_pool->tasks_mutex);
}
}
return 0;
}
int ThreadPool_Create(struct ThreadPool* thread_pool, size_t threads_count) {
atomic_init(&thread_pool->running, 1);
thread_pool->threads = malloc(threads_count * sizeof(thrd_t));
if (!thread_pool->threads) return false;
thread_pool->threads_count = 0;
if (cnd_init(&thread_pool->tasks_cond) != thrd_success) goto rollback_threads;
if (mtx_init(&thread_pool->tasks_mutex, mtx_plain) != thrd_success)
goto rollback_tasks_cond;
thread_pool->tasks = NULL;
thread_pool->tasks_count = 0;
for (; thread_pool->threads_count < threads_count;
thread_pool->threads_count++) {
thrd_t* thread = &thread_pool->threads[thread_pool->threads_count];
if (thrd_create(thread, ThreadProc, thread_pool) != thrd_success)
goto rollback_running;
}
return true;
rollback_running:
atomic_store_explicit(&thread_pool->running, 0, memory_order_relaxed);
cnd_broadcast(&thread_pool->tasks_cond);
while (thread_pool->threads_count-- > 0)
thrd_join(thread_pool->threads[thread_pool->threads_count], NULL);
mtx_destroy(&thread_pool->tasks_mutex);
rollback_tasks_cond:
cnd_destroy(&thread_pool->tasks_cond);
rollback_threads:
free(thread_pool->threads);
return false;
}
bool ThreadPoolSchedule(struct ThreadPool* thread_pool, void (*fun)(void*),
void* user) {
if (mtx_lock(&thread_pool->tasks_mutex) != thrd_success) return false;
struct ThreadPoolTask task = {.fun = fun, .user = user};
bool result = StoreTask(thread_pool, &task);
if (result) cnd_broadcast(&thread_pool->tasks_cond);
mtx_unlock(&thread_pool->tasks_mutex);
return result;
}
void ThreadPoolDestroy(struct ThreadPool* thread_pool) {
atomic_store_explicit(&thread_pool->running, 0, memory_order_relaxed);
cnd_broadcast(&thread_pool->tasks_cond);
while (thread_pool->threads_count-- > 0)
thrd_join(thread_pool->threads[thread_pool->threads_count], NULL);
mtx_destroy(&thread_pool->tasks_mutex);
cnd_destroy(&thread_pool->tasks_cond);
free(thread_pool->threads);
free(thread_pool->tasks);
}