Last active
January 21, 2024 07:55
-
-
Save yashkant/7526f988551f934a88f21769c740c0fa to your computer and use it in GitHub Desktop.
FFCV datamodule for pytorch lightning that can support iterating over many chunked beton files.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
class FFCVDataModule(pl.LightningDataModule): | |
def __init__(self, batch_size, train=None, reg = None, validation=None, test=None, predict=None, | |
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, | |
shuffle_val_dataloader=False, beton_path=None, **kwargs): | |
super().__init__() | |
self.batch_size = batch_size | |
self.num_workers = num_workers if num_workers is not None else batch_size * 2 | |
self.train_dataloader = self._train_dataloader | |
self.beton_path = beton_path | |
self.latents = kwargs.pop("latents", False) | |
self.read_paths = glob.glob(f"{beton_path}/**.beton") | |
self.curr_read_idx = -1 | |
self.read_paths = sorted(self.read_paths) | |
# print(f"\n \n ********** ffcv datamodule init called by pid: {os.getpid()}********** \n \n") | |
# print(f"\n \n ********** ffcv datamodule init no read paths : {len(self.read_paths)}********** \n \n") | |
def prepare_data(self, *args, **kwargs): | |
# print(f"\n \n ********** ffcv prepare data called by pid: {os.getpid()} and read_idx: {self.curr_read_idx} ********** \n \n") | |
pass | |
def setup(self, *args, **kwargs): | |
pass | |
def _train_dataloader(self): | |
start_time = time.time() | |
from ffcv.fields import NDArrayField, FloatField, IntField, RGBImageField | |
from ffcv.loader import Loader, OrderOption | |
from ffcv.fields.decoders import NDArrayDecoder, FloatDecoder | |
from ffcv.loader import OrderOption | |
from ffcv.transforms import ToTensor | |
from tqdm import tqdm | |
# map from field name to list of transforms (can add splatting here?) | |
PIPELINES = { | |
'id1': [ToTensor()], | |
'id2': [ToTensor()], | |
'img1': [ToTensor()], | |
'img2': [ToTensor()], | |
'img1_splat': [ToTensor()], | |
'img2_splat': [ToTensor()], | |
'img1_splat_mask': [ToTensor()], | |
'img2_splat_mask': [ToTensor()], | |
} | |
if "cams" in self.beton_path: | |
PIPELINES.update({ | |
'cam1': [NDArrayDecoder(), ToTensor()], | |
'cam2': [NDArrayDecoder(), ToTensor()], | |
'fov1': [ToTensor()], | |
'fov2': [ToTensor()], | |
}) | |
if "mono" in self.beton_path: | |
PIPELINES.update({ | |
'mono1': [NDArrayDecoder()], | |
'mono2': [NDArrayDecoder()], | |
'mono1_max': [ToTensor()], | |
'mono2_max': [ToTensor()], | |
'mono1_min': [ToTensor()], | |
'mono2_min': [ToTensor()], | |
}) | |
# ground truth depth | |
if "gdepth" in self.beton_path: | |
PIPELINES.update({ | |
'gdepth1': [ToTensor()], | |
'gdepth2': [ToTensor()], | |
'gdepth1_splat': [ToTensor()], | |
'gdepth2_splat': [ToTensor()], | |
}) | |
if "all" in self.beton_path: | |
from scripts.add_all_overfit import PIPELINES as ALL_PIPELINES | |
PIPELINES = ALL_PIPELINES | |
if self.latents: | |
from scripts.add_latents import latents_format_read | |
PIPELINES = latents_format_read | |
PIPELINES.pop("bg_mask1") | |
PIPELINES.pop("bg_mask2") | |
for k,v in PIPELINES.items(): | |
if len(PIPELINES[k]) == 0 or not isinstance(PIPELINES[k][-1], ToTensor): | |
PIPELINES[k].append(ToTensor()) | |
success = False | |
while not success: | |
try: | |
self.curr_read_idx = (self.curr_read_idx + 1) % len(self.read_paths) | |
read_path = self.read_paths[self.curr_read_idx] | |
# print(f"\n \n ********** train dataloader called by pid: {os.getpid()}********** and read_idx: {self.curr_read_idx} and read_path: {read_path} \n \n") | |
loader = Loader(read_path, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
order=OrderOption.RANDOM, | |
pipelines=PIPELINES, | |
distributed=True) | |
success = True | |
except Exception as e: | |
print(f"\n \n ********** train dataloader exception: {e}********** \n \n") | |
continue | |
load_time = time.time() | |
print(f"loaded chunk {read_path} in {load_time - start_time} seconds") | |
return loader |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment