Last active
July 29, 2024 02:46
-
-
Save gbiz123/27f77712a7c4d26b000c41ef795aa2c2 to your computer and use it in GitHub Desktop.
Downsample Binary Pytorch Dataset Down To Size Of Smallest Class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset | |
import torchvision | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import time | |
import os | |
from PIL import Image | |
from tempfile import TemporaryDirectory | |
import random | |
def downsample_balance_binary_dataset(dataset: Dataset) -> Dataset: | |
class_0_count = len([d for d in dataset if d[1] == 0]) | |
class_1_count = len([d for d in dataset if d[1] == 1]) | |
if class_0_count > class_1_count: | |
class_0_indeces = [i for i, val in enumerate(dataset) if val[1] == 0] | |
class_1_indeces = [i for i, val in enumerate(dataset) if val[1] == 1] | |
downsampled_class_0_indeces = class_0_indeces[:class_1_count] | |
if len(class_1_indeces) != len(downsampled_class_0_indeces): | |
raise ValueError("Error during downsampling, class_1_indices was not the same as downsampled_class_0_indeces") | |
all_indices = downsampled_class_0_indeces + class_1_indeces | |
print(f"Sampled dataset down to {len(all_indices)} samples") | |
return Subset(dataset, all_indices) | |
elif class_1_count > class_0_count: | |
class_0_indeces = [i for i, val in enumerate(dataset) if val[1] == 0] | |
class_1_indeces = [i for i, val in enumerate(dataset) if val[1] == 1] | |
downsampled_class_1_indeces = class_1_indeces[:class_0_count] | |
all_indices = downsampled_class_1_indeces + class_0_indeces | |
if len(class_0_indeces) != len(downsampled_class_1_indeces): | |
raise ValueError("Error during downsampling, class_0_indices was not the same as downsampled_class_0_indeces") | |
print(f"Sampled dataset down to {len(all_indices)} samples") | |
return Subset(dataset, all_indices) | |
else: | |
return dataset |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment