Skip to content

Instantly share code, notes, and snippets.

@yashkant
Last active January 21, 2024 07:55
Show Gist options
  • Save yashkant/7526f988551f934a88f21769c740c0fa to your computer and use it in GitHub Desktop.
Save yashkant/7526f988551f934a88f21769c740c0fa to your computer and use it in GitHub Desktop.
FFCV datamodule for pytorch lightning that can support iterating over many chunked beton files.
import pytorch_lightning as pl
class FFCVDataModule(pl.LightningDataModule):
def __init__(self, batch_size, train=None, reg = None, validation=None, test=None, predict=None,
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
shuffle_val_dataloader=False, beton_path=None, **kwargs):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.train_dataloader = self._train_dataloader
self.beton_path = beton_path
self.latents = kwargs.pop("latents", False)
self.read_paths = glob.glob(f"{beton_path}/**.beton")
self.curr_read_idx = -1
self.read_paths = sorted(self.read_paths)
# print(f"\n \n ********** ffcv datamodule init called by pid: {os.getpid()}********** \n \n")
# print(f"\n \n ********** ffcv datamodule init no read paths : {len(self.read_paths)}********** \n \n")
def prepare_data(self, *args, **kwargs):
# print(f"\n \n ********** ffcv prepare data called by pid: {os.getpid()} and read_idx: {self.curr_read_idx} ********** \n \n")
pass
def setup(self, *args, **kwargs):
pass
def _train_dataloader(self):
start_time = time.time()
from ffcv.fields import NDArrayField, FloatField, IntField, RGBImageField
from ffcv.loader import Loader, OrderOption
from ffcv.fields.decoders import NDArrayDecoder, FloatDecoder
from ffcv.loader import OrderOption
from ffcv.transforms import ToTensor
from tqdm import tqdm
# map from field name to list of transforms (can add splatting here?)
PIPELINES = {
'id1': [ToTensor()],
'id2': [ToTensor()],
'img1': [ToTensor()],
'img2': [ToTensor()],
'img1_splat': [ToTensor()],
'img2_splat': [ToTensor()],
'img1_splat_mask': [ToTensor()],
'img2_splat_mask': [ToTensor()],
}
if "cams" in self.beton_path:
PIPELINES.update({
'cam1': [NDArrayDecoder(), ToTensor()],
'cam2': [NDArrayDecoder(), ToTensor()],
'fov1': [ToTensor()],
'fov2': [ToTensor()],
})
if "mono" in self.beton_path:
PIPELINES.update({
'mono1': [NDArrayDecoder()],
'mono2': [NDArrayDecoder()],
'mono1_max': [ToTensor()],
'mono2_max': [ToTensor()],
'mono1_min': [ToTensor()],
'mono2_min': [ToTensor()],
})
# ground truth depth
if "gdepth" in self.beton_path:
PIPELINES.update({
'gdepth1': [ToTensor()],
'gdepth2': [ToTensor()],
'gdepth1_splat': [ToTensor()],
'gdepth2_splat': [ToTensor()],
})
if "all" in self.beton_path:
from scripts.add_all_overfit import PIPELINES as ALL_PIPELINES
PIPELINES = ALL_PIPELINES
if self.latents:
from scripts.add_latents import latents_format_read
PIPELINES = latents_format_read
PIPELINES.pop("bg_mask1")
PIPELINES.pop("bg_mask2")
for k,v in PIPELINES.items():
if len(PIPELINES[k]) == 0 or not isinstance(PIPELINES[k][-1], ToTensor):
PIPELINES[k].append(ToTensor())
success = False
while not success:
try:
self.curr_read_idx = (self.curr_read_idx + 1) % len(self.read_paths)
read_path = self.read_paths[self.curr_read_idx]
# print(f"\n \n ********** train dataloader called by pid: {os.getpid()}********** and read_idx: {self.curr_read_idx} and read_path: {read_path} \n \n")
loader = Loader(read_path,
batch_size=self.batch_size,
num_workers=self.num_workers,
order=OrderOption.RANDOM,
pipelines=PIPELINES,
distributed=True)
success = True
except Exception as e:
print(f"\n \n ********** train dataloader exception: {e}********** \n \n")
continue
load_time = time.time()
print(f"loaded chunk {read_path} in {load_time - start_time} seconds")
return loader
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment