Last active
March 11, 2022 17:05
-
-
Save chertov/ee763daf5dae51209326405307004bbf to your computer and use it in GitHub Desktop.
simple thread pool
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::sync::Arc; | |
use parking_lot::RwLock; | |
use std::collections::VecDeque; | |
use log::trace; | |
use anyhow::anyhow; | |
struct Thread { | |
thread_handler: std::thread::JoinHandle<()>, | |
running: Arc<RwLock<bool>>, | |
join: Arc<RwLock<bool>>, | |
} | |
impl Thread { | |
fn new(tasks: Arc<RwLock<VecDeque<Task>>>) -> Self { | |
let running = Arc::new(RwLock::new(false)); | |
let join = Arc::new(RwLock::new(false)); | |
let thread_handler = std::thread::spawn({ | |
let running = running.clone(); | |
let join = join.clone(); | |
move || { Thread::_run(tasks, running, join) } | |
}); | |
Self { thread_handler, running, join } | |
} | |
fn _run(tasks: Arc<RwLock<VecDeque<Task>>>, running: Arc<RwLock<bool>>, join: Arc<RwLock<bool>>) { | |
let id = std::thread::current().id(); | |
trace!("[pool: {:?}] create a new pool thread", id); | |
loop { | |
let task = { tasks.write().pop_back() }; | |
if let Some(task) = task { | |
trace!("[pool: {:?}] a new task running", id); | |
{ *running.write() = true; } | |
(task.task)(); | |
{ *running.write() = false; } | |
trace!("[pool: {:?}] task complete", id); | |
} else { | |
{ if *join.read() { break } } | |
trace!("[pool: {:?}] thread park", id); | |
std::thread::park(); | |
} | |
} | |
trace!("[pool: {:?}] complete", id); | |
} | |
} | |
struct Task { | |
task: Box<dyn FnOnce() -> () + Sync + Send + 'static>, | |
} | |
struct ThreadPoolInner { | |
threads: Vec<Thread>, | |
tasks: Arc<RwLock<VecDeque<Task>>>, | |
all_join: bool, | |
} | |
impl ThreadPoolInner { | |
pub fn new(thread_count: usize) -> Result<Self, anyhow::Error> { | |
if thread_count == 0 { return Err(anyhow!("The count threads of pool must be more then zero")) } | |
let mut threads = vec![]; | |
let tasks = Arc::new(RwLock::new(VecDeque::new() )); | |
for _ in 0..thread_count { | |
threads.push(Thread::new(tasks.clone()) ); | |
} | |
Ok(Self { threads, tasks, all_join: false }) | |
} | |
} | |
#[derive(Clone)] | |
struct ThreadPool { | |
pool: Arc<RwLock<ThreadPoolInner>> | |
} | |
impl ThreadPool { | |
pub fn new(thread_count: usize) -> Result<Self, anyhow::Error> { | |
let pool = ThreadPoolInner::new(thread_count)?; | |
let pool = Arc::new(RwLock::new(pool)); | |
Ok(Self{ pool }) | |
} | |
pub fn run_task<F>(&self, f: F) -> Result<(), anyhow::Error> | |
where | |
F: FnOnce() -> () + Sync + Send + 'static, | |
{ | |
let pool = self.pool.write(); | |
if pool.all_join { return Err(anyhow!("The pool is in a stop process")); } | |
pool.tasks.write().push_front(Task{ task: Box::new(f) }); | |
for thread in &pool.threads { | |
if !*thread.running.read() { | |
thread.thread_handler.thread().unpark(); | |
break; | |
} | |
} | |
Ok(()) | |
} | |
pub fn join(&self) { | |
let mut pool = self.pool.write(); | |
if pool.all_join { return; } | |
trace!("[pool] wait pool threads"); | |
for thread in std::mem::replace(&mut pool.threads, vec![]) { | |
*thread.join.write() = true; | |
thread.thread_handler.thread().unpark(); | |
let _ = thread.thread_handler.join(); | |
} | |
trace!("[pool] exit"); | |
} | |
} | |
pub fn log_init() { | |
env_logger::builder() | |
.target(env_logger::Target::Stdout) | |
.filter_level(log::LevelFilter::Trace) | |
.init(); | |
log::set_max_level(log::LevelFilter::Trace); | |
} | |
fn simple_task(id: u64) { | |
println!("Task {} start", id); | |
std::thread::sleep(std::time::Duration::from_secs(1)); | |
println!("Task {} end", id); | |
} | |
pub fn main() -> Result<(), anyhow::Error> { | |
log_init(); | |
let pool = ThreadPool::new(10)?; | |
pool.run_task(|| { println!("zxxzczxc") })?; | |
pool.run_task(|| simple_task(0))?; | |
pool.run_task(|| simple_task(1))?; | |
pool.run_task(|| simple_task(2))?; | |
pool.run_task(|| simple_task(3))?; | |
pool.run_task(|| simple_task(4))?; | |
pool.run_task(|| simple_task(5))?; | |
pool.run_task(|| simple_task(6))?; | |
pool.run_task(|| simple_task(7))?; | |
pool.run_task(|| simple_task(8))?; | |
pool.run_task(|| simple_task(9))?; | |
pool.join(); | |
Ok(()) | |
} | |
#[cfg(test)] | |
mod tests { | |
use std::sync::Arc; | |
use parking_lot::RwLock; | |
use super::{ ThreadPool, log_init }; | |
#[test] | |
fn test_zero_thead() { | |
log_init(); | |
assert!(ThreadPool::new(0).is_err()); | |
} | |
fn test_pool(pool: ThreadPool) -> Result<Vec<usize>, anyhow::Error> { | |
let mut task_ids = Arc::new(RwLock::new(vec![])); | |
pool.run_task({ | |
let task_ids = task_ids.clone(); | |
move || { | |
std::thread::sleep(std::time::Duration::from_secs(1)); | |
task_ids.write().push(0); | |
} | |
})?; | |
pool.run_task({ | |
let task_ids = task_ids.clone(); | |
move || { | |
std::thread::sleep(std::time::Duration::from_secs(2)); | |
task_ids.write().push(1); | |
} | |
})?; | |
pool.run_task({ | |
let task_ids = task_ids.clone(); | |
move || { | |
std::thread::sleep(std::time::Duration::from_secs(3)); | |
task_ids.write().push(2); | |
} | |
})?; | |
pool.join(); | |
let task_ids = task_ids.write().clone(); | |
let task_ids : Vec<usize> = task_ids.into_iter().collect(); | |
Ok(task_ids) | |
} | |
// Run strictly independently! std::time::Instant is using | |
#[test] | |
fn test_one_thead() -> Result<(), anyhow::Error> { | |
let pool = ThreadPool::new(1)?; | |
let start = std::time::Instant::now(); | |
let task_ids = test_pool(pool)?; | |
assert_eq!(task_ids, vec![0, 1, 2]); | |
assert_eq!(start.elapsed().as_secs(), 1+2+3); | |
Ok(()) | |
} | |
// Run strictly independently! std::time::Instant is using | |
#[test] | |
fn test_two_theads() -> Result<(), anyhow::Error> { | |
log_init(); | |
let pool = ThreadPool::new(2)?; | |
let start = std::time::Instant::now(); | |
let task_ids = test_pool(pool)?; | |
assert_eq!(task_ids, vec![0, 1, 2]); | |
assert_eq!(start.elapsed().as_secs(), 1+3); | |
Ok(()) | |
} | |
// Run strictly independently! std::time::Instant is using | |
#[test] | |
fn test_multi_theads() -> Result<(), anyhow::Error> { | |
log_init(); | |
let pool = ThreadPool::new(3)?; | |
let start = std::time::Instant::now(); | |
let task_ids = test_pool(pool)?; | |
assert_eq!(task_ids, vec![0, 1, 2]); | |
let t = start.elapsed().as_secs_f64(); | |
println!("t {t}"); | |
assert_eq!(start.elapsed().as_secs(), 3); | |
Ok(()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment