Last active
December 6, 2021 00:52
-
-
Save ShairozS/80e2d454c0acd473bc3da844116ed450 to your computer and use it in GitHub Desktop.
Hierarchical ImageFolder Pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
import os | |
class HierarchicalImageFolder(Dataset): | |
def __init__(self, root, hlevel=2, ftype='.jpg', transforms=None): | |
''' | |
A version of torchvision.datasets.ImageFolder adapted to images in | |
a hierarchical file structure (i.e iNat18). The __getitem__(x) method | |
returns the x-th image, along with a tuple for the hierarchical file label | |
''' | |
self.root = root | |
print("Forming filetree...") | |
self.imgs, self.labels = self.read_nested_images(root, ftype, hlevel) | |
print("done") | |
self.transforms = transforms | |
def __len__(self): | |
return(len(self.imgs)) | |
def __getitem__(self, x): | |
img_path, label = (self.imgs[x], self.labels[x]) | |
img = plt.imread(img_path); img = Image.fromarray(img) | |
if self.transforms is not None: | |
img = self.transforms(img) | |
return(img, label) | |
@staticmethod | |
def read_nested_images(root, ftype='.jpg', hlevel=2): | |
img_paths = [] | |
img_labels = [] | |
for root, dirs, files in os.walk(root, topdown=False): | |
f = [os.path.join(root, x) for x in files] | |
img_paths += [x for x in f if x.endswith(ftype)] | |
for file in f: | |
if file.endswith(ftype): | |
path = os.path.normpath(file) | |
path_components = tuple(path.split(os.sep)[:-1]) | |
img_labels.append(path_components[-hlevel:]) | |
return(img_paths, img_labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment