Last active
October 31, 2024 11:13
-
Star
(110)
You must be signed in to star a gist -
Fork
(12)
You must be signed in to fork a gist
-
-
Save andrewjong/6b02ff237533b3b2c554701fb53d5c4d to your computer and use it in GitHub Desktop.
PyTorch Image File Paths With Dataset Dataloader
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchvision import datasets | |
class ImageFolderWithPaths(datasets.ImageFolder): | |
"""Custom dataset that includes image file paths. Extends | |
torchvision.datasets.ImageFolder | |
""" | |
# override the __getitem__ method. this is the method that dataloader calls | |
def __getitem__(self, index): | |
# this is what ImageFolder normally returns | |
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index) | |
# the image file path | |
path = self.imgs[index][0] | |
# make a new tuple that includes original and the path | |
tuple_with_path = (original_tuple + (path,)) | |
return tuple_with_path | |
# EXAMPLE USAGE: | |
# instantiate the dataset and dataloader | |
data_dir = "your/data_dir/here" | |
dataset = ImageFolderWithPaths(data_dir) # our custom dataset | |
dataloader = torch.utils.DataLoader(dataset) | |
# iterate over data | |
for inputs, labels, paths in dataloader: | |
# use the above variables freely | |
print(inputs, labels, paths) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
import torch
import torchvision
from torchvision import datasets, transforms
transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
#override the getitem method. this is the method that dataloader calls
def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
#EXAMPLE USAGE:
#instantiate the dataset and dataloader
data_dir = './dog_vs_cat/train/'
dataset = ImageFolderWithPaths(data_dir, transform=transforms) # our custom dataset
dataloader = torch.utils.data.DataLoader(dataset)
#iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)
This code worked for me.