Skip to content

Instantly share code, notes, and snippets.

@ArthurDelannoyazerty
Created May 23, 2025 09:46
Show Gist options
  • Save ArthurDelannoyazerty/f9ab8b36a79c6f2045bf89c4ddbc8840 to your computer and use it in GitHub Desktop.
Save ArthurDelannoyazerty/f9ab8b36a79c6f2045bf89c4ddbc8840 to your computer and use it in GitHub Desktop.
Pyotrchdataloader benchmark
import time
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from fourm.data.dummy_dataset import DummyDataset
def check_dataloader_speed(dataset:Dataset,
batch_sizes:list[int]=[1,2,4,8,16,32,64,128],
num_workers:list[int]=[0,1,2,3,4,5,6,7,8,9,10],
max_batches: int = 50):
print(f"Checking DataLoader speed for dataset with {len(dataset)} samples (up to {max_batches} batches per test)...")
pin_memory_options = [False]
if torch.cuda.is_available():
pin_memory_options = [True, False]
results = []
for batch_size in tqdm(batch_sizes, desc="Batch Sizes"):
for num_worker in tqdm(num_workers, desc="Num Workers", leave=False):
for pin_mem in tqdm(pin_memory_options, desc="Pin Memory", leave=False):
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_worker, pin_memory=pin_mem)
start_time = time.time()
processed_batches_count = 0
for i, batch in enumerate(dataloader):
if i >= max_batches: # Limit the number of batches processed
break
processed_batches_count +=1
pass
end_time = time.time()
total_time = end_time - start_time
if processed_batches_count == 0:
avg_time_per_batch = float('inf')
samples_per_second = 0.0
else:
avg_time_per_batch = total_time / processed_batches_count
if total_time > 0:
samples_per_second = (processed_batches_count * batch_size) / total_time
else: # total_time is 0, implies instantaneous processing for processed_batches_count > 0
samples_per_second = float('inf') if (processed_batches_count * batch_size) > 0 else 0.0
results.append({
"batch_size": batch_size,
"num_workers": num_worker,
"pin_memory": pin_mem,
"total_time_taken": total_time,
"avg_time_per_batch": avg_time_per_batch,
"avg_samples_per_second": samples_per_second,
"processed_batches": processed_batches_count
})
# --- Print results and plot ---
print("\n--- Experiment Results Summary ---")
if not results:
print("No results to display.")
return
df_results = pd.DataFrame(results)
df_results['pin_memory'] = df_results['pin_memory'].astype(bool) # Ensure boolean type
print(df_results.to_string())
fig, axes = plt.subplots(1, 3, figsize=(22, 7)) # Adjusted figsize
fig.suptitle("DataLoader Performance Analysis (Higher Samples/Sec is Better)", fontsize=16)
# Plot 1: Throughput vs. Batch Size
try:
pivot_bs_data = df_results.groupby(['batch_size', 'num_workers'])['avg_samples_per_second'].mean().reset_index()
pivot_bs = pivot_bs_data.pivot(index='batch_size', columns='num_workers', values='avg_samples_per_second')
if not pivot_bs.empty:
pivot_bs.plot(ax=axes[0], marker='o')
axes[0].set_ylabel("Average Samples/Second")
axes[0].legend(title="Num Workers", loc='best')
else:
axes[0].text(0.5, 0.5, "Not enough data diversity\nfor Batch Size plot", ha='center', va='center', transform=axes[0].transAxes)
except Exception as e:
axes[0].text(0.5, 0.5, f"Error plotting:\n{e}", ha='center', va='center', transform=axes[0].transAxes)
axes[0].set_xlabel("Batch Size")
axes[0].set_title("Throughput vs. Batch Size")
axes[0].grid(True, linestyle='--')
axes[0].set_xscale('log') # Log scale for batch size
# Plot 2: Throughput vs. Num Workers
try:
pivot_nw_data = df_results.groupby(['num_workers', 'batch_size'])['avg_samples_per_second'].mean().reset_index()
pivot_nw = pivot_nw_data.pivot(index='num_workers', columns='batch_size', values='avg_samples_per_second')
if not pivot_nw.empty:
pivot_nw.plot(ax=axes[1], marker='o')
axes[1].legend(title="Batch Size", loc='best')
else:
axes[1].text(0.5, 0.5, "Not enough data diversity\nfor Num Workers plot", ha='center', va='center', transform=axes[1].transAxes)
except Exception as e:
axes[1].text(0.5, 0.5, f"Error plotting:\n{e}", ha='center', va='center', transform=axes[1].transAxes)
axes[1].set_xlabel("Number of Workers")
axes[1].set_ylabel("Average Samples/Second")
axes[1].set_title("Throughput vs. Num Workers")
axes[1].grid(True, linestyle='--')
# Plot 3: Throughput vs. Pin Memory
if df_results['pin_memory'].nunique() > 1: # Only plot if both True and False were tested
try:
throughput_vs_pin = df_results.groupby('pin_memory')['avg_samples_per_second'].mean()
plot_labels = sorted(throughput_vs_pin.index.map(str).unique()) # ['False', 'True']
plot_values = [throughput_vs_pin[eval(label)] for label in plot_labels]
axes[2].bar(plot_labels, plot_values, color=['skyblue', 'lightcoral'])
axes[2].set_ylabel("Average Samples/Second")
except Exception as e:
axes[2].text(0.5, 0.5, f"Error plotting:\n{e}", ha='center', va='center', transform=axes[2].transAxes)
else:
axes[2].text(0.5, 0.5, "Pin Memory test N/A\n(e.g. CUDA not available or\nonly one option tested)",
ha='center', va='center', transform=axes[2].transAxes)
axes[2].set_xlabel("Pin Memory")
axes[2].set_title("Avg Throughput vs. Pin Memory")
axes[2].grid(True, axis='y', linestyle='--')
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout for suptitle
plt.savefig('./dataloader_benchmark.jpg', dpi=300, bbox_inches='tight')
if __name__ == '__main__':
dummy_dataset = DummyDataset()
check_dataloader_speed(dummy_dataset, max_batches=50)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment