Skip to content

Instantly share code, notes, and snippets.

@nov05
Last active February 6, 2025 22:35
Show Gist options
  • Save nov05/12a89d7d0828b669bc9a98a71ca79b81 to your computer and use it in GitHub Desktop.
Save nov05/12a89d7d0828b669bc9a98a71ca79b81 to your computer and use it in GitHub Desktop.

✅✅✅ My working code: Create WebDataset from local data files to local .tar files

## example code for webdataset
import webdataset as wds
import io
import json
print("👉 WebDataset version:", wds.__version__)
tar_stream = io.BytesIO()
base_name = "100313"
with wds.TarWriter(tar_stream) as sink:
    with open("../data/bin-images/100313.jpg", "rb") as f:
        image_data = f.read()
    with open("../data/metadata/100313.json", "rb") as f:
        label = json.load(f)['EXPECTED_QUANTITY']  ## load json binary
    with open("../data/metadata/100313.json", "rb") as f:
        metadata_data = f.read()
    # Save as WebDataset sample
    sink.write({
        "__key__": f"{base_name}",
        "image": image_data,
        "label": str(label),
        "metadata": metadata_data,
    })
# Once the tar file is in memory, save to local file
tar_stream.seek(0)
with open("../data/test/test.tar", "wb") as f:
    f.write(tar_stream.getvalue())
!tar -tf ../data/test/test.tar

✅✅✅ My working code: stream WebDataset .tar data from s3 and transform the data for training

## test code streaming data from s3. pay attention to the object types.
import webdataset as wds
import matplotlib.pyplot as plt
from PIL import Image
import io
try:
    s3_uri = "s3://p5-amazon-bin-images/webdataset/train/train-shard-{000000..000001}.tar"
    path = f"pipe:aws s3 cp {s3_uri} -"  ## write to standard output (stdout)
    train_dataset = (
        wds.WebDataset(
                path, 
                shardshuffle=True,
                # nodesplitter=wds.split_by_worker,  ## distributed training
            )
            .shuffle(1000)  # Shuffle dataset 
            ## The tuple names have to be the same with the WebDataset keys
            ## check the "scripts_process/*convert_to_webdataset*.py" files
            .to_tuple("image", "label")  ## Tuple of image and label
            .map_tuple(
                lambda x:Image.open(io.BytesIO(x)),  # Apply the train transforms to the image
                lambda x:x.decode(),
            )  
    ) 
    for image,label in iter(train_dataset):
        print(type(label), label)
        # img = Image.open(io.BytesIO(image))
        print(type(image))
        plt.imshow(image)
        plt.show()
        break
except Exception as e:
    print(e)
@nov05
Copy link
Author

nov05 commented Feb 6, 2025

⚠️ WebDataset pipeline that can work with Torch dataloader

  • Data ingestion, transformation are ok. However the iteration for different GPUs (ranks, nodes and workers) have some issues.
def key_transform(x):
    return int(x)



class image_transform:
    def __call__(self, x):
        return Image.open(io.BytesIO(x))
    


train_transform = transforms.Compose([
    image_transform(),
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
])  



def label_transform(x):
    ## Original lables are (1,2,3,4,5)
    ## Convert to (0,1,2,3,4)
    return torch.tensor(int(x.decode())-1, dtype=torch.int64)



## WebDataset class inherits from IterableDataset class
class WebDatasetDDP(IterableDataset):
    def __init__(self,
                 path,  
                 num_samples=0,
                 world_size=1, 
                 rank=0,  
                 no_shuffle=False,
                 shuffle_shard_size=100,
                 split_by_node=False,
                 split_by_worker=False,
                #  shardshuffle=True,
                #  empty_check=False,
                 key_transform=None,
                 train_transform=None,
                 label_transform=None,
                 shuffle_sample_size=1000,
                #  batch_size=64,
                ):
        super().__init__()
        self.dataset = (
## WebDataset
## https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-library
            # wds.WebDataset(
            #     path, 
            #     shardshuffle=shardshuffle,
            #     ## Official doc: add wds.split_by_node here if you are using multiple nodes
            #     # nodesplitter=wds.split_by_node, 
            #     ## Or "ValueError: you need to add an explicit nodesplitter 
            #     ## to your input pipeline for multi-node training"
            #     nodesplitter=wds.split_by_worker,
            #     empty_check=empty_check, 
            # )
            # .shuffle(shuffle_buffer_size)  # Shuffle dataset 
            # ## The tuple names have to be the same with the WebDataset keys
            # ## check the "scripts_process/*convert_to_webdataset*.py" files
            # .to_tuple("__key__", "image", "label")  ## Tuple of image and label
            # .map_tuple(
            #     key_transform,
            #     train_transform,  # Apply the train transforms to the image
            #     ## lambda function can't not be pickled, hence cause error when num_workers>1 
            #     label_transform,  
            # )
## WebDataset pipeline
## https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-pipeline-api
            wds.DataPipeline(
                wds.SimpleShardList(path),
                # at this point we have an iterator over all the shards
                wds.shuffle(shuffle_shard_size) if not no_shuffle else None,
                # add wds.split_by_node here if you are using multiple nodes
                wds.split_by_node if split_by_node else None,
                wds.split_by_worker if split_by_worker else None,
                # at this point, we have an iterator over the shards assigned to each worker
                wds.tarfile_to_samples(),
                # this shuffles the samples in memory
                wds.shuffle(shuffle_sample_size) if not no_shuffle else None,
                # this decodes the images and json
                # wds.decode("pil"),
                wds.to_tuple("__key__", "image", "label"),
                # wds.map(preprocess),
                wds.map_tuple(
                    key_transform,
                    train_transform, 
                    label_transform,  
                ),
                wds.shuffle(shuffle_sample_size) if not no_shuffle else None,
                # wds.batched(batch_size),
            )
        )
        self.world_size = world_size
        self.rank = rank
        self.num_samples = num_samples
        self.split_by_node = split_by_node
        self.split_by_worker = split_by_worker

    def __len__(self):
        return self.num_samples
    
    def __iter__(self): 
        for key,image,label in self.dataset:  ## Use dataset keys to distribute data
            ## ⚠️ need a fix
            if key%self.world_size == self.rank:  ## Ensure each GPU gets different data
                yield (image, label)



def collate_fn(batch):
    images, labels = zip(*batch)
    # Stack the images into a single tensor (this assumes the images have the same size)
    images = torch.stack(images)
    labels = torch.stack(labels)
    return images, labels



    ## For data distributed training, use torch.utils.data.DistributedSampler or WebDataset? 
    path = f"pipe:aws s3 cp {task.config.train_data_path} -"
    train_dataset = (
        WebDatasetDDP(
            path, 
            num_samples=task.config.train_data_size,
            world_size=dist.get_world_size(), 
            rank=dist.get_rank(), 
            split_by_node=True,
            split_by_worker=True,
            shuffle_sample_size=1000,
            key_transform=key_transform,
            train_transform=train_transform,
            label_transform=label_transform,
        )
    ) 
    path = f"pipe:aws s3 cp {task.config.val_data_path} -"
    val_dataset = (
        WebDatasetDDP(
            path, 
            num_samples=task.config.val_data_size,
            no_shuffle=True,
            key_transform=key_transform,
            train_transform=train_transform,
            label_transform=label_transform,                 
        )   
    )
    path = f"pipe:aws s3 cp {task.config.test_data_path} -"
    test_dataset = (
        WebDatasetDDP(
            path, 
            num_samples=task.config.test_data_size,
            no_shuffle=True,
            key_transform=key_transform,
            train_transform=train_transform,
            label_transform=label_transform,                 
        )  
    )
 
    ## Handle class imbalance. class weights will be used in the loss functions.
    ## train_dataset is an instance of TorchVision.datasets.ImageFolder().
    ## class_weights is an instance of <class 'numpy.ndarray'>.
    # class_weights = compute_class_weight(
    #     class_weight='balanced', 
    #     classes=np.unique(train_dataset.cls),   
    #     y=train_dataset.cls)
    ## Use pre-calculated class weights if the dataset is very large.
    classes = np.unique(list(task.config.class_weights_dict.keys()))   ## It has to be sorted.
    task.config.num_classes = len(classes)  ## get number of total classes for net creation
    class_weights = [task.config.class_weights_dict[k] for k in classes]
    class_weights = torch.tensor(
        class_weights, 
        dtype=torch.float32).to(task.config.device)
    
    # ## SMDDP: set num_replicas and rank in Torch DistributedSampler
    # train_sampler = DistributedSampler(  ## ⚠️ doesn't work with WebDataset
    #     train_dataset,
    #     num_replicas=dist.get_world_size(),
    #     rank=dist.get_rank(),
    #     shuffle=False,
    # )
    
    ## Torch dataloader
    task.train_loader = DataLoader(
        train_dataset, 
        batch_size=task.config.batch_size_ddp, 
        shuffle=False,  ## Don't shuffle for Distributed Data Parallel (DDP)  
        # sampler=train_sampler, # ⚠️ Distributed Sampler + WebDataset causes error
        num_workers=task.config.num_cpu,
        persistent_workers=True,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    task.val_loader = DataLoader(
        val_dataset, 
        batch_size=task.config.batch_size, 
        shuffle=False,   ## Don't shuffle for eval anyway
        ## no DDP sampler
        num_workers=task.config.num_cpu,
        persistent_workers=True,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    task.test_loader = DataLoader(
        test_dataset, 
        batch_size=task.config.batch_size, 
        shuffle=False,  ## Don't shuffle for eval anyway
        # no DDP sampler
        num_workers=task.config.num_cpu,
        persistent_workers=True,
        pin_memory=True,
        collate_fn=collate_fn,
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment