Last active
December 4, 2024 08:32
-
-
Save jin-zhe/d1482e2ad732475b99ff941d3361cefa to your computer and use it in GitHub Desktop.
Pandas parallel apply function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
DESCRIPTION: | |
This simple convenience function provides parallelization of pandas .apply() | |
Adapted from: https://proinsias.github.io/tips/How-to-use-multiprocessing-with-pandas/ | |
REQUIREMENTS: | |
`multiprocess` and `dill` packages are required. | |
``` | |
python -m pip install multiprocess dill | |
``` | |
EXAMPLE USAGE: | |
``` | |
df = pd.DataFrame({'num_legs': [2, 4, 8, 0], | |
'num_wings': [2, 0, 0, 0], | |
'num_specimen_seen': [10, 2, 1, 8]}, | |
index=['falcon', 'dog', 'spider', 'fish']) | |
df['total_legs'] = parallel_apply(df, lambda row: row.num_legs * row.num_specimen_seen, axis=1) | |
print(df) | |
>>> num_legs num_wings num_specimen_seen total_legs | |
falcon 2 2 10 20 | |
dog 4 0 2 8 | |
spider 8 0 1 8 | |
fish 0 0 8 0 | |
``` | |
''' | |
from math import ceil | |
from tqdm import tqdm | |
import pandas as pd | |
import multiprocess as mp | |
tqdm.pandas() | |
def parallel_apply( | |
data, # the pd.DataFrame or pd.Series to be applied over | |
apply_fn, # function to be applied over each row in data | |
*apply_args, | |
n_jobs = mp.cpu_count(), | |
progress_bar = True, | |
**apply_kwargs # arguments for df.apply (e.g. axis=1) | |
): | |
with mp.Pool(n_jobs) as pool: | |
assert type(data) in [pd.DataFrame, pd.Series], 'data must be pd.DataFrame or pd.Series!' | |
num_splits = n_jobs * 2 | |
split_size = ceil(len(data) / num_splits) # the size of each chunk | |
data_splits = [data[i: i+split_size] for i in range(0, len(data), split_size)] | |
map_fn = lambda df: df.apply(apply_fn, *apply_args, **apply_kwargs) | |
# Enqueue processes | |
if progress_bar is True: | |
data_splits = tqdm(data_splits, desc='Enqueueing processes.') | |
ret_list = pool.map(map_fn, data_splits) | |
# Process processes | |
if progress_bar is True: | |
ret_list = tqdm(ret_list, desc='Processing processes.') | |
return pd.concat(ret_list) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment