Created
May 23, 2025 09:46
-
-
Save ArthurDelannoyazerty/f9ab8b36a79c6f2045bf89c4ddbc8840 to your computer and use it in GitHub Desktop.
Pyotrchdataloader benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from tqdm import tqdm | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from fourm.data.dummy_dataset import DummyDataset | |
def check_dataloader_speed(dataset:Dataset, | |
batch_sizes:list[int]=[1,2,4,8,16,32,64,128], | |
num_workers:list[int]=[0,1,2,3,4,5,6,7,8,9,10], | |
max_batches: int = 50): | |
print(f"Checking DataLoader speed for dataset with {len(dataset)} samples (up to {max_batches} batches per test)...") | |
pin_memory_options = [False] | |
if torch.cuda.is_available(): | |
pin_memory_options = [True, False] | |
results = [] | |
for batch_size in tqdm(batch_sizes, desc="Batch Sizes"): | |
for num_worker in tqdm(num_workers, desc="Num Workers", leave=False): | |
for pin_mem in tqdm(pin_memory_options, desc="Pin Memory", leave=False): | |
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_worker, pin_memory=pin_mem) | |
start_time = time.time() | |
processed_batches_count = 0 | |
for i, batch in enumerate(dataloader): | |
if i >= max_batches: # Limit the number of batches processed | |
break | |
processed_batches_count +=1 | |
pass | |
end_time = time.time() | |
total_time = end_time - start_time | |
if processed_batches_count == 0: | |
avg_time_per_batch = float('inf') | |
samples_per_second = 0.0 | |
else: | |
avg_time_per_batch = total_time / processed_batches_count | |
if total_time > 0: | |
samples_per_second = (processed_batches_count * batch_size) / total_time | |
else: # total_time is 0, implies instantaneous processing for processed_batches_count > 0 | |
samples_per_second = float('inf') if (processed_batches_count * batch_size) > 0 else 0.0 | |
results.append({ | |
"batch_size": batch_size, | |
"num_workers": num_worker, | |
"pin_memory": pin_mem, | |
"total_time_taken": total_time, | |
"avg_time_per_batch": avg_time_per_batch, | |
"avg_samples_per_second": samples_per_second, | |
"processed_batches": processed_batches_count | |
}) | |
# --- Print results and plot --- | |
print("\n--- Experiment Results Summary ---") | |
if not results: | |
print("No results to display.") | |
return | |
df_results = pd.DataFrame(results) | |
df_results['pin_memory'] = df_results['pin_memory'].astype(bool) # Ensure boolean type | |
print(df_results.to_string()) | |
fig, axes = plt.subplots(1, 3, figsize=(22, 7)) # Adjusted figsize | |
fig.suptitle("DataLoader Performance Analysis (Higher Samples/Sec is Better)", fontsize=16) | |
# Plot 1: Throughput vs. Batch Size | |
try: | |
pivot_bs_data = df_results.groupby(['batch_size', 'num_workers'])['avg_samples_per_second'].mean().reset_index() | |
pivot_bs = pivot_bs_data.pivot(index='batch_size', columns='num_workers', values='avg_samples_per_second') | |
if not pivot_bs.empty: | |
pivot_bs.plot(ax=axes[0], marker='o') | |
axes[0].set_ylabel("Average Samples/Second") | |
axes[0].legend(title="Num Workers", loc='best') | |
else: | |
axes[0].text(0.5, 0.5, "Not enough data diversity\nfor Batch Size plot", ha='center', va='center', transform=axes[0].transAxes) | |
except Exception as e: | |
axes[0].text(0.5, 0.5, f"Error plotting:\n{e}", ha='center', va='center', transform=axes[0].transAxes) | |
axes[0].set_xlabel("Batch Size") | |
axes[0].set_title("Throughput vs. Batch Size") | |
axes[0].grid(True, linestyle='--') | |
axes[0].set_xscale('log') # Log scale for batch size | |
# Plot 2: Throughput vs. Num Workers | |
try: | |
pivot_nw_data = df_results.groupby(['num_workers', 'batch_size'])['avg_samples_per_second'].mean().reset_index() | |
pivot_nw = pivot_nw_data.pivot(index='num_workers', columns='batch_size', values='avg_samples_per_second') | |
if not pivot_nw.empty: | |
pivot_nw.plot(ax=axes[1], marker='o') | |
axes[1].legend(title="Batch Size", loc='best') | |
else: | |
axes[1].text(0.5, 0.5, "Not enough data diversity\nfor Num Workers plot", ha='center', va='center', transform=axes[1].transAxes) | |
except Exception as e: | |
axes[1].text(0.5, 0.5, f"Error plotting:\n{e}", ha='center', va='center', transform=axes[1].transAxes) | |
axes[1].set_xlabel("Number of Workers") | |
axes[1].set_ylabel("Average Samples/Second") | |
axes[1].set_title("Throughput vs. Num Workers") | |
axes[1].grid(True, linestyle='--') | |
# Plot 3: Throughput vs. Pin Memory | |
if df_results['pin_memory'].nunique() > 1: # Only plot if both True and False were tested | |
try: | |
throughput_vs_pin = df_results.groupby('pin_memory')['avg_samples_per_second'].mean() | |
plot_labels = sorted(throughput_vs_pin.index.map(str).unique()) # ['False', 'True'] | |
plot_values = [throughput_vs_pin[eval(label)] for label in plot_labels] | |
axes[2].bar(plot_labels, plot_values, color=['skyblue', 'lightcoral']) | |
axes[2].set_ylabel("Average Samples/Second") | |
except Exception as e: | |
axes[2].text(0.5, 0.5, f"Error plotting:\n{e}", ha='center', va='center', transform=axes[2].transAxes) | |
else: | |
axes[2].text(0.5, 0.5, "Pin Memory test N/A\n(e.g. CUDA not available or\nonly one option tested)", | |
ha='center', va='center', transform=axes[2].transAxes) | |
axes[2].set_xlabel("Pin Memory") | |
axes[2].set_title("Avg Throughput vs. Pin Memory") | |
axes[2].grid(True, axis='y', linestyle='--') | |
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout for suptitle | |
plt.savefig('./dataloader_benchmark.jpg', dpi=300, bbox_inches='tight') | |
if __name__ == '__main__': | |
dummy_dataset = DummyDataset() | |
check_dataloader_speed(dummy_dataset, max_batches=50) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment