Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save nov05/95cb7edcbe2e8bb68c9d29bdc00b9ca8 to your computer and use it in GitHub Desktop.
Save nov05/95cb7edcbe2e8bb68c9d29bdc00b9ca8 to your computer and use it in GitHub Desktop.

🟒 AWS S3 data to SageMaker machine learning training


Code snippets are from the following sources:

import torch, time
from statistics import mean, variance
dataset = get_dataset()
dl = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=4)
stats_lst = []
t0 = time.perf_counter()
for batch_idx, batch in enumerate(dl, start=1):
    if batch_idx % 100 == 0:
        t = time.perf_counter() - t0
        print(f'Iteration {batch_idx} Time {t}')
        stats_lst.append(t)
        t0 = time.perf_counter()
mean_calc = mean(stats_lst[1:])
var_calc = variance(stats_lst[1:])
print(f'mean {mean_calc} variance {var_calc}')
## measure how the step time changes when running on the streamed data samples
import torch, time
from statistics import mean, variance
dataset=get_dataset()
dl=torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=4)
batch = next(iter(dl))
t0 = time.perf_counter()
for batch_idx in range(1,1000):
    train_step(batch)
    if batch_idx % 100 == 0:
        t = time.perf_counter() - t0
        print(f'Iteration {batch_idx} Time {t}')
        t0 = time.perf_counter()
  • Create random synthetic data
import webdataset as wds
import numpy as np
from PIL import Image
import io
out_tar = 'wds.tar'
sink = wds.TarWriter(out_tar)
im_width = 1024
im_height = 1024
num_classes = 256
for i in range(100):
    image = Image.fromarray(np.random.randint(0, high=256,
                  size=(im_height,im_width,3), dtype=np.uint8))
    label = Image.fromarray(np.random.randint(0, high=num_classes,
                  size=(im_height,im_width), dtype=np.uint8))
    image_bytes = io.BytesIO()
    label_bytes = io.BytesIO()
    image.save(image_bytes, format='PNG')
    label.save(label_bytes, format='PNG')
    sample = {"__key__": str(i),
              f'image': image_bytes.getvalue(),
              f'label': label_bytes.getvalue()}
    sink.write(sample)
  • FastFile input mode for SageMaker estimator
import os, webdataset
def get_dataset():
    ffm = os.environ['SM_CHANNEL_TRAINING']
    urls = [os.path.join(ffm, f'{i}.tar') for i in range(num_files)]
    dataset = (
        webdataset
            .WebDataset(urls, shardshuffle=True)  ## shard shuffle
            .shuffle(10)                          ## buffer shuffle
    )
    return dataset
@nov05
Copy link
Author

nov05 commented Jan 28, 2025

  • FastFile input mode, 2 types of approach

@nov05
Copy link
Author

nov05 commented Jan 28, 2025

  • SageMaker File input mode
import sagemaker
from sagemaker.inputs import TrainingInput

# Define your S3 paths for images and metadata
s3_images_prefix = 's3://bucket-name/images/'
s3_metadata_prefix = 's3://bucket-name/metadata/'

# Configure the input data channels
train_data = {
    'images': TrainingInput(s3_data=s3_images_prefix, distribution='FullyReplicated', input_mode='FastFile'),
    'metadata': TrainingInput(s3_data=s3_metadata_prefix, distribution='FullyReplicated', input_mode='FastFile')
}

# Define your SageMaker Estimator
estimator = sagemaker.estimator.Estimator(
    image_uri='your-docker-image',
    role='your-iam-role',
    instance_count=1,
    instance_type='ml.p3.2xlarge',
    output_path='s3://bucket-name/output/',
    base_job_name='your-training-job'
)

# Start the training job with the input channels
estimator.fit(inputs=train_data)
  • In the train script
import os
import json
from PIL import Image

# SageMaker automatically sets the environment variable for each input channel
images_dir = os.environ['SM_CHANNEL_IMAGES']  # The directory where images are stored
metadata_dir = os.environ['SM_CHANNEL_METADATA']  # The directory where metadata is stored

# Example: Load images and corresponding metadata
for image_file in os.listdir(images_dir):
    image_path = os.path.join(images_dir, image_file)
    img = Image.open(image_path)

    # Load corresponding metadata (assuming metadata file names match image file names)
    metadata_file = os.path.join(metadata_dir, f'{os.path.splitext(image_file)[0]}.json')
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)

    # Now you can use both the image and its metadata for training

@nov05
Copy link
Author

nov05 commented Jan 29, 2025


import torch
import webdataset as wds

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, urls):
        self.urls = urls

    def __iter__(self):
        for url in self.urls:
            dataset = wds.WebDataset(url).shuffle(1000).decode("rgb").to_tuple("png", "json")
            for sample in dataset:
                # Process each sample here
                yield sample

# Example usage:
dataset = MyIterableDataset(['http://example.com/data1.tar', 'http://example.com/data2.tar'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=4)
for batch in dataloader:
    # Use the batch here
    print(batch)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=4)
  • Shuffling Mechanism: The shuffle(1000) function creates an in-memory buffer of 1000 samples from the tar file and shuffles them. As training progresses, the buffer is refilled, and shuffling continues. This ensures that each training epoch processes the data in a randomized order while keeping memory usage efficient.

  • Streaming in the Data from S3 Directly with Pipe Method
url = "s3://bucket/data.tar"
s3_url = f"pipe:aws s3 cp {url} -"
dataset = wds.WebDataset(s3_url)
dataset = wds.WebDataset(s3_url, nodesplitter=wds.split_by_worker)
num_workers = int(os.environ["WORLD_SIZE"])
rank = int(os.environ['RANK'])
img_obj = next(iter(iter_dataset))
img_val, img_id = img_obj
if int(img_id) % num_workers != rank:
  continue
else:
  yield img_val

@nov05
Copy link
Author

nov05 commented Jan 29, 2025

🟒 SageMaker ETL options

Summary:

  • SageMaker Processing Jobs: For Python-based ETL tasks, this is the most straightforward option. βœ…
  • Data Wrangler: If you have simple tabular transformations.
  • SageMaker Pipelines: For complex end-to-end pipelines that include both ETL and model training.
  • Combine SageMaker and Glue: For very large or complex data transformations.

  • My own working code βœ…βœ…βœ…
## get my own AWS account number
with open('../secrets/aws_account_number', 'r') as file:
    for line in file:
        aws_account_number = line.strip()
        break

## no need to run this cell 
# ## To pull ECR image from another AWS account 
# import boto3
# import subprocess
# import base64
# ecr_client = boto3.client('ecr', region_name='us-east-1')
# # Retrieve the authentication token from ECR
# response = ecr_client.get_authorization_token()
# authorization_data = response['authorizationData'][0]
# token = authorization_data['authorizationToken']
# registry_uri = authorization_data['proxyEndpoint']
# decoded_token = base64.b64decode(token).decode('utf-8')
# username, password = decoded_token.split(':')
# # Docker login command
# login_command = f"docker login --username {username} --password {password} {registry_uri}"
# subprocess.run(login_command, shell=True, check=True)
# # Now you can use this image in your SageMaker processing job 
## TODO: Perform any data cleaning or data preprocessing
from sagemaker.processing import ScriptProcessor
processor = ScriptProcessor(
    command=['python3'],
    ## You can use a custom image or use the default SageMaker image
    image_uri=f'{aws_account_number}.dkr.ecr.us-east-1.amazonaws.com/udacity/p5-amazon-bin-images:latest', 
    role=sagemaker_role_arn,  # Execution role
    instance_count=1,
    instance_type='ml.t3.large',  # Use the appropriate instance type
    volume_size_in_gb=10,  # Minimal disk space since we're streaming
    base_job_name='p5-amazon-bin-images' 
)
processor.run(
    code='../scripts_process/test_convert_to_webdataset.py',  # Your script to process data
    arguments=[
        '--SM_INPUT_BUCKET', 'aft-vbi-pds',
        '--SM_INPUT_PREFIX_IMAGES', 'bin-images/',
        '--SM_INPUT_PREFIX_METADATA', 'metadata/',
        '--SM_OUTPUT_BUCKET', 'p5-amazon-bin-images',
        '--SM_OUTPUT_PREFIX', 'webdataset/',
    ]
)

  • My code: WebDataset.ShardWriter() βœ…βœ…βœ…
    type_prefix = 'train/' or 'val/' or 'test/'
def convert_dataset(type_prefix, file_list, maxcount=1000):
    with wds.ShardWriter("shard-%06d.tar", maxcount=maxcount) as sink:
        for image_id,label in file_list:
            image_key = f'{input_prefix_images}{image_id}.jpg'
            try:  # Ensure the corresponding JSON file exists
                image_data = read_s3_file(input_bucket, image_key)
            except Exception as e:
                print(f"⚠️ Skipping image '{image_key}' due to error: {e}")
                continue
            # Save as WebDataset sample
            sink.write({
                "__key__": f"{image_id}",
                "image": image_data,
                "label": label,
            })
    # Upload the tar file to S3
    tar_list = glob.glob("shard-*.tar")
    for tar_file in tar_list:
        file_name = os.path.basename(tar_file)
        s3_key = os.path.join(output_prefix, type_prefix, file_name)
        s3_client.upload_file(tar_file, output_bucket, s3_key)
    print(f"🟒 Successfully uploaded tar files to "
          f"s3://{output_bucket}/{output_prefix}{type_prefix}:\n"
          f"    {tar_list}")
  • My code: WebDataset.TarWriter() + io.BytesIO() βœ…βœ…βœ…
def convert_dataset(image_keys, num_tar_files):
    # Create a tar file in memory and write WebDataset format
    tar_stream = io.BytesIO()
    with wds.TarWriter(tar_stream) as sink:
        for image_key in image_keys:
            if not (image_key.endswith('.jpg') or image_key.endswith('.jpeg')):
                print(f"⚠️ Skipping non-image file: {image_key}")
                continue
            base_name = os.path.splitext(image_key.split('/')[-1])[0]
            try:  # Ensure the corresponding JSON file exists
                metadata_data = read_s3_file(input_bucket, f'{input_prefix_metadata}{base_name}.json')
                image_data = read_s3_file(input_bucket, image_key)
            except Exception as e:
                print(f"⚠️ Skipping image '{image_key}' due to error: {e}")
                continue
            # Save as WebDataset sample
            sink.write({
                "__key__": f"{base_name}",
                "image": image_data,
                "metadata": metadata_data
            })
    # Once the tar file is in memory, upload it back to S3
    tar_stream.seek(0)
    file_name = f'{output_prefix}data_{num_tar_files}.tar'
    s3_client.upload_fileobj(tar_stream, output_bucket, file_name)
    print(f"🟒 Successfully uploaded tar file to s3://{output_bucket}/{file_name}")

@nov05
Copy link
Author

nov05 commented Jan 29, 2025

  • ScriptProcessor official documentation
  • My tutorial: Create custom docker image for SageMaker data processing jobs, create AWS ECR private repo, and upload the image to the repo βœ…βœ…βœ…
  • AWS re:Post, pull ECR image from the repo of another account βœ…

@nov05
Copy link
Author

nov05 commented Jan 29, 2025

import webdataset as wds
from huggingface_hub import get_token
from torch.utils.data import DataLoader

hf_token = get_token()
url = "https://huggingface.co/datasets/timm/imagenet-12k-wds/resolve/main/imagenet12k-train-{{0000..1023}}.tar"
url = f"pipe:curl -s -L {url} -H 'Authorization:Bearer {hf_token}'"
dataset = wds.WebDataset(url).decode()
dataloader = DataLoader(dataset, batch_size=64, num_workers=4)
buffer_size = 1000
dataset = (
    wds.WebDataset(url, shardshuffle=True)
    .shuffle(buffer_size)
    .decode()
)

Generally, datasets in WebDataset formats are already shuffled and ready to feed to a DataLoader. But you can still reshuffle the data with WebDataset’s approximate shuffling.

In addition to shuffling the list of shards, WebDataset uses a buffer to shuffle a dataset without any cost to speed.

@nov05
Copy link
Author

nov05 commented Jan 31, 2025

  • TorchVision dataset
from torchvision import datasets, transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(224),
    transforms.ColorJitter(brightness=0.2, 
                           contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225]),
])  
train_dataset = datasets.ImageFolder(task.config.train, transform=train_transform)
  • WebDataset dataset
import webdataset as wds
from torchvision import transforms

def identity(x):
     return x

path = "s3://p5-amazon-bin-images/webdataset/train/shard-{{000000..000001}}.tar"
task.config.train = f"pipe:curl -s -L {path}"
# Create the WebDataset pipeline
train_dataset = (
    wds.WebDataset(task.config.train, shardshuffle=True)  ## Shuffle shards
        .shuffle(1000)  # Shuffle dataset
        .decode("pil")  
        .to_tuple("jpg", "cls")  # Tuple of image and label; specify file extensions
        .map_tuple(train_transform, identity)  # Apply the train transforms to the image
)
# Wrap the dataset in a DataLoader for batching
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=task.config.batch_size, 
    num_workers=task.config.num_cpu)
# Example usage in a training loop
for batch_images, batch_labels in train_loader:
    # Training code here
    print(batch_images.shape, batch_labels.shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment