Skip to content

Instantly share code, notes, and snippets.

@chertov
Last active March 11, 2022 17:05
Show Gist options
  • Save chertov/ee763daf5dae51209326405307004bbf to your computer and use it in GitHub Desktop.
Save chertov/ee763daf5dae51209326405307004bbf to your computer and use it in GitHub Desktop.
simple thread pool
use std::sync::Arc;
use parking_lot::RwLock;
use std::collections::VecDeque;
use log::trace;
use anyhow::anyhow;
struct Thread {
thread_handler: std::thread::JoinHandle<()>,
running: Arc<RwLock<bool>>,
join: Arc<RwLock<bool>>,
}
impl Thread {
fn new(tasks: Arc<RwLock<VecDeque<Task>>>) -> Self {
let running = Arc::new(RwLock::new(false));
let join = Arc::new(RwLock::new(false));
let thread_handler = std::thread::spawn({
let running = running.clone();
let join = join.clone();
move || { Thread::_run(tasks, running, join) }
});
Self { thread_handler, running, join }
}
fn _run(tasks: Arc<RwLock<VecDeque<Task>>>, running: Arc<RwLock<bool>>, join: Arc<RwLock<bool>>) {
let id = std::thread::current().id();
trace!("[pool: {:?}] create a new pool thread", id);
loop {
let task = { tasks.write().pop_back() };
if let Some(task) = task {
trace!("[pool: {:?}] a new task running", id);
{ *running.write() = true; }
(task.task)();
{ *running.write() = false; }
trace!("[pool: {:?}] task complete", id);
} else {
{ if *join.read() { break } }
trace!("[pool: {:?}] thread park", id);
std::thread::park();
}
}
trace!("[pool: {:?}] complete", id);
}
}
struct Task {
task: Box<dyn FnOnce() -> () + Sync + Send + 'static>,
}
struct ThreadPoolInner {
threads: Vec<Thread>,
tasks: Arc<RwLock<VecDeque<Task>>>,
all_join: bool,
}
impl ThreadPoolInner {
pub fn new(thread_count: usize) -> Result<Self, anyhow::Error> {
if thread_count == 0 { return Err(anyhow!("The count threads of pool must be more then zero")) }
let mut threads = vec![];
let tasks = Arc::new(RwLock::new(VecDeque::new() ));
for _ in 0..thread_count {
threads.push(Thread::new(tasks.clone()) );
}
Ok(Self { threads, tasks, all_join: false })
}
}
#[derive(Clone)]
struct ThreadPool {
pool: Arc<RwLock<ThreadPoolInner>>
}
impl ThreadPool {
pub fn new(thread_count: usize) -> Result<Self, anyhow::Error> {
let pool = ThreadPoolInner::new(thread_count)?;
let pool = Arc::new(RwLock::new(pool));
Ok(Self{ pool })
}
pub fn run_task<F>(&self, f: F) -> Result<(), anyhow::Error>
where
F: FnOnce() -> () + Sync + Send + 'static,
{
let pool = self.pool.write();
if pool.all_join { return Err(anyhow!("The pool is in a stop process")); }
pool.tasks.write().push_front(Task{ task: Box::new(f) });
for thread in &pool.threads {
if !*thread.running.read() {
thread.thread_handler.thread().unpark();
break;
}
}
Ok(())
}
pub fn join(&self) {
let mut pool = self.pool.write();
if pool.all_join { return; }
trace!("[pool] wait pool threads");
for thread in std::mem::replace(&mut pool.threads, vec![]) {
*thread.join.write() = true;
thread.thread_handler.thread().unpark();
let _ = thread.thread_handler.join();
}
trace!("[pool] exit");
}
}
pub fn log_init() {
env_logger::builder()
.target(env_logger::Target::Stdout)
.filter_level(log::LevelFilter::Trace)
.init();
log::set_max_level(log::LevelFilter::Trace);
}
fn simple_task(id: u64) {
println!("Task {} start", id);
std::thread::sleep(std::time::Duration::from_secs(1));
println!("Task {} end", id);
}
pub fn main() -> Result<(), anyhow::Error> {
log_init();
let pool = ThreadPool::new(10)?;
pool.run_task(|| { println!("zxxzczxc") })?;
pool.run_task(|| simple_task(0))?;
pool.run_task(|| simple_task(1))?;
pool.run_task(|| simple_task(2))?;
pool.run_task(|| simple_task(3))?;
pool.run_task(|| simple_task(4))?;
pool.run_task(|| simple_task(5))?;
pool.run_task(|| simple_task(6))?;
pool.run_task(|| simple_task(7))?;
pool.run_task(|| simple_task(8))?;
pool.run_task(|| simple_task(9))?;
pool.join();
Ok(())
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use parking_lot::RwLock;
use super::{ ThreadPool, log_init };
#[test]
fn test_zero_thead() {
log_init();
assert!(ThreadPool::new(0).is_err());
}
fn test_pool(pool: ThreadPool) -> Result<Vec<usize>, anyhow::Error> {
let mut task_ids = Arc::new(RwLock::new(vec![]));
pool.run_task({
let task_ids = task_ids.clone();
move || {
std::thread::sleep(std::time::Duration::from_secs(1));
task_ids.write().push(0);
}
})?;
pool.run_task({
let task_ids = task_ids.clone();
move || {
std::thread::sleep(std::time::Duration::from_secs(2));
task_ids.write().push(1);
}
})?;
pool.run_task({
let task_ids = task_ids.clone();
move || {
std::thread::sleep(std::time::Duration::from_secs(3));
task_ids.write().push(2);
}
})?;
pool.join();
let task_ids = task_ids.write().clone();
let task_ids : Vec<usize> = task_ids.into_iter().collect();
Ok(task_ids)
}
// Run strictly independently! std::time::Instant is using
#[test]
fn test_one_thead() -> Result<(), anyhow::Error> {
let pool = ThreadPool::new(1)?;
let start = std::time::Instant::now();
let task_ids = test_pool(pool)?;
assert_eq!(task_ids, vec![0, 1, 2]);
assert_eq!(start.elapsed().as_secs(), 1+2+3);
Ok(())
}
// Run strictly independently! std::time::Instant is using
#[test]
fn test_two_theads() -> Result<(), anyhow::Error> {
log_init();
let pool = ThreadPool::new(2)?;
let start = std::time::Instant::now();
let task_ids = test_pool(pool)?;
assert_eq!(task_ids, vec![0, 1, 2]);
assert_eq!(start.elapsed().as_secs(), 1+3);
Ok(())
}
// Run strictly independently! std::time::Instant is using
#[test]
fn test_multi_theads() -> Result<(), anyhow::Error> {
log_init();
let pool = ThreadPool::new(3)?;
let start = std::time::Instant::now();
let task_ids = test_pool(pool)?;
assert_eq!(task_ids, vec![0, 1, 2]);
let t = start.elapsed().as_secs_f64();
println!("t {t}");
assert_eq!(start.elapsed().as_secs(), 3);
Ok(())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment