Skip to content

Instantly share code, notes, and snippets.

@Mohammed-Sunasra
Last active September 16, 2022 09:21
Show Gist options
  • Save Mohammed-Sunasra/65abff56d4c3afcd6b11cd3b5cd0846b to your computer and use it in GitHub Desktop.
Save Mohammed-Sunasra/65abff56d4c3afcd6b11cd3b5cd0846b to your computer and use it in GitHub Desktop.
Custom PyTorch Dataset and Data Loader
#Custom data generator class
class CactusDataset(Dataset):
"""
Dataset to generate batches of multiple images and labels from a CSV file.
Purpose: To work with CSV files where the format is (file_name, cclass_label)
and generate batches of data(images, labels) on-the-fly.
"""
def __init__(self, df_data, image_path, image_size, transform=None):
self.data = df_data
self.image_path = image_path
self.transform = transform
def __len__(self):
"""
Returns the no of datapoints in the dataset
"""
return len(self.data)
def __getitem__(self, index):
"""
Returns a mini-batch of data(images, labels) given an index
"""
image_name = self.data.iloc[index, 0]
image = Image.open(str(self.image_path) + '/' +image_name)
image = image.convert('RGB')
image = image.resize(image_size, Image.ANTIALIAS)
if self.transform is not None:
image = self.transform(image)
label = self.data.iloc[index, 1]
label = torch.from_numpy(np.asarray(label))
return image, label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment