Skip to content

Instantly share code, notes, and snippets.

@ShairozS
Last active December 6, 2021 00:52
Show Gist options
  • Save ShairozS/80e2d454c0acd473bc3da844116ed450 to your computer and use it in GitHub Desktop.
Save ShairozS/80e2d454c0acd473bc3da844116ed450 to your computer and use it in GitHub Desktop.
Hierarchical ImageFolder Pytorch
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import os
class HierarchicalImageFolder(Dataset):
def __init__(self, root, hlevel=2, ftype='.jpg', transforms=None):
'''
A version of torchvision.datasets.ImageFolder adapted to images in
a hierarchical file structure (i.e iNat18). The __getitem__(x) method
returns the x-th image, along with a tuple for the hierarchical file label
'''
self.root = root
print("Forming filetree...")
self.imgs, self.labels = self.read_nested_images(root, ftype, hlevel)
print("done")
self.transforms = transforms
def __len__(self):
return(len(self.imgs))
def __getitem__(self, x):
img_path, label = (self.imgs[x], self.labels[x])
img = plt.imread(img_path); img = Image.fromarray(img)
if self.transforms is not None:
img = self.transforms(img)
return(img, label)
@staticmethod
def read_nested_images(root, ftype='.jpg', hlevel=2):
img_paths = []
img_labels = []
for root, dirs, files in os.walk(root, topdown=False):
f = [os.path.join(root, x) for x in files]
img_paths += [x for x in f if x.endswith(ftype)]
for file in f:
if file.endswith(ftype):
path = os.path.normpath(file)
path_components = tuple(path.split(os.sep)[:-1])
img_labels.append(path_components[-hlevel:])
return(img_paths, img_labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment