Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save nov05/95cb7edcbe2e8bb68c9d29bdc00b9ca8 to your computer and use it in GitHub Desktop.
Save nov05/95cb7edcbe2e8bb68c9d29bdc00b9ca8 to your computer and use it in GitHub Desktop.

🟢 AWS S3 data to SageMaker machine learning training


Code snippets are from the following sources:

import torch, time
from statistics import mean, variance
dataset = get_dataset()
dl = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=4)
stats_lst = []
t0 = time.perf_counter()
for batch_idx, batch in enumerate(dl, start=1):
    if batch_idx % 100 == 0:
        t = time.perf_counter() - t0
        print(f'Iteration {batch_idx} Time {t}')
        stats_lst.append(t)
        t0 = time.perf_counter()
mean_calc = mean(stats_lst[1:])
var_calc = variance(stats_lst[1:])
print(f'mean {mean_calc} variance {var_calc}')
## measure how the step time changes when running on the streamed data samples
import torch, time
from statistics import mean, variance
dataset=get_dataset()
dl=torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=4)
batch = next(iter(dl))
t0 = time.perf_counter()
for batch_idx in range(1,1000):
    train_step(batch)
    if batch_idx % 100 == 0:
        t = time.perf_counter() - t0
        print(f'Iteration {batch_idx} Time {t}')
        t0 = time.perf_counter()
  • Create random synthetic data
import webdataset as wds
import numpy as np
from PIL import Image
import io
out_tar = 'wds.tar'
sink = wds.TarWriter(out_tar)
im_width = 1024
im_height = 1024
num_classes = 256
for i in range(100):
    image = Image.fromarray(np.random.randint(0, high=256,
                  size=(im_height,im_width,3), dtype=np.uint8))
    label = Image.fromarray(np.random.randint(0, high=num_classes,
                  size=(im_height,im_width), dtype=np.uint8))
    image_bytes = io.BytesIO()
    label_bytes = io.BytesIO()
    image.save(image_bytes, format='PNG')
    label.save(label_bytes, format='PNG')
    sample = {"__key__": str(i),
              f'image': image_bytes.getvalue(),
              f'label': label_bytes.getvalue()}
    sink.write(sample)
  • FastFile input mode for SageMaker estimator
import os, webdataset
def get_dataset():
    ffm = os.environ['SM_CHANNEL_TRAINING']
    urls = [os.path.join(ffm, f'{i}.tar') for i in range(num_files)]
    dataset = (
        webdataset
            .WebDataset(urls, shardshuffle=True)  ## shard shuffle
            .shuffle(10)                          ## buffer shuffle
    )
    return dataset
@nov05
Copy link
Author

nov05 commented Jan 29, 2025

  • ScriptProcessor official documentation
  • My tutorial: Create custom docker image for SageMaker data processing jobs, create AWS ECR private repo, and upload the image to the repo ✅✅✅
  • AWS re:Post, pull ECR image from the repo of another account ✅

@nov05
Copy link
Author

nov05 commented Jan 29, 2025

import webdataset as wds
from huggingface_hub import get_token
from torch.utils.data import DataLoader

hf_token = get_token()
url = "https://huggingface.co/datasets/timm/imagenet-12k-wds/resolve/main/imagenet12k-train-{{0000..1023}}.tar"
url = f"pipe:curl -s -L {url} -H 'Authorization:Bearer {hf_token}'"
dataset = wds.WebDataset(url).decode()
dataloader = DataLoader(dataset, batch_size=64, num_workers=4)
buffer_size = 1000
dataset = (
    wds.WebDataset(url, shardshuffle=True)
    .shuffle(buffer_size)
    .decode()
)

Generally, datasets in WebDataset formats are already shuffled and ready to feed to a DataLoader. But you can still reshuffle the data with WebDataset’s approximate shuffling.

In addition to shuffling the list of shards, WebDataset uses a buffer to shuffle a dataset without any cost to speed.

@nov05
Copy link
Author

nov05 commented Jan 31, 2025

  • TorchVision dataset
from torchvision import datasets, transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(224),
    transforms.ColorJitter(brightness=0.2, 
                           contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225]),
])  
train_dataset = datasets.ImageFolder(task.config.train, transform=train_transform)
  • WebDataset dataset
import webdataset as wds
from torchvision import transforms

def identity(x):
     return x

path = "s3://p5-amazon-bin-images/webdataset/train/shard-{{000000..000001}}.tar"
task.config.train = f"pipe:curl -s -L {path}"
# Create the WebDataset pipeline
train_dataset = (
    wds.WebDataset(task.config.train, shardshuffle=True)  ## Shuffle shards
        .shuffle(1000)  # Shuffle dataset
        .decode("pil")  
        .to_tuple("jpg", "cls")  # Tuple of image and label; specify file extensions
        .map_tuple(train_transform, identity)  # Apply the train transforms to the image
)
# Wrap the dataset in a DataLoader for batching
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=task.config.batch_size, 
    num_workers=task.config.num_cpu)
# Example usage in a training loop
for batch_images, batch_labels in train_loader:
    # Training code here
    print(batch_images.shape, batch_labels.shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment