Last active
September 29, 2023 09:35
-
-
Save afiaka87/012846e2c907173a300346c749a2d0b2 to your computer and use it in GitHub Desktop.
Finetune CLIP on a 'webdataset' formatted dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision.transforms.transforms import GaussianBlur | |
import webdataset as wds | |
import io | |
from PIL import Image | |
from clip.loader import TextImageDataset | |
from clip.clip import load, tokenize | |
from torchvision import transforms as T | |
from torch.utils.data import DataLoader | |
from torch.optim import Adam | |
import torch.nn.functional as F | |
import wandb # Quit early if user doesn't have wandb installed. | |
import argparse | |
import time | |
import torch | |
from glob import glob | |
import clip | |
# argument parsing | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model_name', type=str, | |
help='name of CLIP model') | |
parser.add_argument('--image_text_path', type=str, required=True, | |
help='path to your path of images and text for learning the CLIP') | |
parser.add_argument('--clip_output_file_name', type=str, default="clip", | |
help='output_file_name') | |
parser.add_argument('--wandb_name', default='clip_finetuning', | |
help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`') | |
train_group = parser.add_argument_group('Training settings') | |
train_group.add_argument('--epochs', default=40, | |
type=int, help='Number of epochs') | |
train_group.add_argument('--text_seq_len', default=77, | |
type=int, help='Text sequence length') | |
train_group.add_argument('--save_every_n_steps', default=1000, | |
type=int, help='Save a checkpoint every n steps') | |
train_group.add_argument('--batch_size', default=32, | |
type=int, help='Batch size') | |
train_group.add_argument('--ga_steps', default=1, type=int, | |
help='Number of steps to accumulate gradients across per each iteration') | |
train_group.add_argument('--learning_rate', default=1e-7, | |
type=float, help='Learning rate') | |
train_group.add_argument('--clip_grad_norm', default=0.5, | |
type=float, help='Clip gradient norm') | |
train_group.add_argument('--warmup_steps', default=10000, type=int) | |
args = parser.parse_args() | |
# helpers | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
def get_trainable_params(model): | |
return [params for params in model.parameters() if params.requires_grad] | |
def create_clip_img_transform(image_width): | |
clip_mean = [0.48145466, 0.4578275, 0.40821073] | |
clip_std = [0.26862954, 0.26130258, 0.27577711] | |
transform = T.Compose([ | |
# T.ToPILImage(), | |
# T.CenterCrop((image_width, image_width)), | |
T.Resize(336, interpolation=T.InterpolationMode.LANCZOS), | |
T.Resize(image_width, interpolation=T.InterpolationMode.LANCZOS), | |
# T.RandomResizedCrop(size=(image_width, image_width), scale=(1.0, 1.0), ratio=(1.0, 1.0), interpolation=T.InterpolationMode.BILINEAR), | |
T.ToTensor(), | |
T.Normalize(mean=clip_mean, std=clip_std) | |
]) | |
return transform | |
def create_webdataset( | |
urls, | |
image_transform, | |
enable_text=True, | |
enable_image=True, | |
image_key='jpg', | |
caption_key='txt', | |
enable_metadata=False, | |
cache_path=None,): | |
dataset = wds.WebDataset(urls, cache_dir=cache_path, cache_size=10**10, handler=wds.handlers.warn_and_continue) | |
tokenizer = lambda text: clip.tokenize([text], truncate=True)[0] | |
def filter_dataset(item): | |
if enable_text and caption_key not in item: | |
return False | |
if enable_image and image_key not in item: | |
return False | |
if enable_metadata and "json" not in item: | |
return False | |
return True | |
filtered_dataset = dataset.select(filter_dataset) | |
def preprocess_dataset(item): | |
output = {} | |
if enable_image: | |
image_data = item[image_key] | |
image = Image.open(io.BytesIO(image_data)) | |
image_tensor = image_transform(image) | |
output["image_filename"] = item["__key__"] | |
output["image_tensor"] = image_tensor | |
if enable_text: | |
text = item[caption_key] | |
caption = text.decode("utf-8") | |
tokenized_text = tokenizer(caption) | |
output["text_tokens"] = tokenized_text | |
output["text"] = caption | |
if enable_metadata: | |
metadata_file = item["json"] | |
metadata = metadata_file.decode("utf-8") | |
output["metadata"] = metadata | |
return output | |
transformed_dataset = filtered_dataset.map(preprocess_dataset, handler=wds.handlers.warn_and_continue) | |
return transformed_dataset | |
CLIP_OUTPUT_FILE_NAME = args.clip_output_file_name + ".pt" | |
CLIP_FINAL_OUTPUT_FILE_NAME = args.clip_output_file_name + "-final.pt" | |
WARMUP_STEPS = int(args.warmup_steps) # enables learning rate warmup. | |
EPOCHS = args.epochs | |
BATCH_SIZE = args.batch_size | |
TEXT_SEQ_LEN = args.text_seq_len | |
LEARNING_RATE = args.learning_rate if WARMUP_STEPS == 0 else 1e-12 | |
print(f"Staring with learning rate: {LEARNING_RATE}") | |
GRAD_CLIP_NORM = args.clip_grad_norm | |
ACCUM_STEPS = args.ga_steps | |
SAVE_EVERY_N_STEPS = args.save_every_n_steps | |
MODEL_NAME = args.model_name | |
truncate_captions = True | |
input_resolution = 224 | |
IMAGE_SIZE = 224 | |
# load the dataset and transform | |
# create dataset and dataloader | |
is_shuffle = True #@not distributed_utils.using_backend(distributed_utils.HorovodBackend) | |
DATASET_SIZE = int(1e9) # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader | |
WEBDATASET_PATH = glob(args.image_text_path) | |
dataset = create_webdataset(WEBDATASET_PATH, create_clip_img_transform(IMAGE_SIZE), True, True, image_key='jpg', caption_key='txt', enable_metadata=False, cache_path=None) | |
dl = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=24, pin_memory=True, prefetch_factor=2, drop_last=True) | |
# Load CLIP | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if ".pt" in MODEL_NAME: | |
assert exists(MODEL_NAME), "checkpoint does not exist" | |
print(f"Resuming training from {MODEL_NAME}.") | |
clip_model, _ = load(MODEL_NAME, device=device) | |
clip_model.train() | |
# clip_model.eval() # TODO experimenting with https://github.com/openai/CLIP/issues/150 | |
input_res = clip_model.visual.input_resolution # 224 | |
clip_transform = create_clip_img_transform(input_res) | |
# optimizer | |
opt = Adam(get_trainable_params(clip_model), lr=LEARNING_RATE, | |
betas=(0.9, 0.98), eps=1e-06, weight_decay=0.) | |
model_config = dict( | |
batch_size=BATCH_SIZE, | |
learning_rate=LEARNING_RATE, | |
clip_grad_norm=GRAD_CLIP_NORM, | |
ga_steps=ACCUM_STEPS, | |
model_name=MODEL_NAME, | |
save_every_n_steps=SAVE_EVERY_N_STEPS, | |
clip_output_file_name=CLIP_OUTPUT_FILE_NAME, | |
clip_final_output_file_name=CLIP_FINAL_OUTPUT_FILE_NAME, | |
wandb_name=args.wandb_name, | |
text_seq_len=TEXT_SEQ_LEN, | |
image_width=input_res, | |
truncate_captions=truncate_captions, | |
device=device, | |
) | |
run = wandb.init( | |
project=args.wandb_name, # 'clip_finetuning' by default | |
config=model_config, | |
) | |
def save_model(path): | |
save_obj = clip_model.state_dict() | |
torch.save(save_obj, path) | |
if WARMUP_STEPS > 0: | |
print("Warmup steps:", WARMUP_STEPS) | |
save_model(f'./{CLIP_OUTPUT_FILE_NAME}') | |
# training | |
print(f"Training started ...loaded. ") | |
steps = 0 | |
t = time.time() # Get initial time. | |
for epoch in range(0, EPOCHS): | |
print(f"Epoch 0 ") | |
try: | |
# for i, (texts, images) in enumerate(dl): | |
for i, item in enumerate(dl): | |
texts = item["text_tokens"] | |
images = item["image_tensor"] | |
if i % 10 == 0: | |
t = time.time() | |
texts, images = map(lambda t: t.cuda(), (texts, images)) | |
logits_per_image, logits_per_text = clip_model(images, texts) | |
labels = torch.arange(BATCH_SIZE, device=device) | |
text_loss = F.cross_entropy(logits_per_image, labels) | |
image_loss = F.cross_entropy(logits_per_text, labels) / 2 | |
loss = text_loss + image_loss | |
loss.backward() | |
opt.step() | |
opt.zero_grad() | |
log = {} | |
lr = opt.param_groups[0]['lr'] | |
# Warm up learning rate | |
if lr < 1e-6: | |
lr = 1e-6 * (steps / WARMUP_STEPS) | |
for param_group in opt.param_groups: | |
param_group['lr'] = lr | |
print(f"Warmup step {steps}/{WARMUP_STEPS}") | |
print(f"Learning rate: {lr}") | |
if i % 10 == 0: | |
print(f'epoch - {epoch},', f'step - {i},', f'loss - {loss.item()}', f'text_loss - {text_loss.item()}', | |
f'image_loss - {image_loss.item()}') | |
log = { | |
**log, | |
'epoch': epoch, | |
'iter': i, | |
'loss': loss.item(), | |
'text_loss': text_loss.item(), | |
'image_loss': image_loss.item(), | |
'lr': lr | |
} | |
if i % 10 == 9: | |
sample_per_sec = BATCH_SIZE * 10 / (time.time() - t) | |
log["sample_per_sec"] = sample_per_sec | |
print(epoch, i, f'sample_per_sec - {sample_per_sec}') | |
if i % SAVE_EVERY_N_STEPS == 0: | |
save_model(f'./{CLIP_OUTPUT_FILE_NAME}') | |
steps += 1 | |
wandb.log(log) | |
# save trained model to wandb as an artifact every epoch's end | |
model_artifact = wandb.Artifact( | |
'finetuned-clip', type='model', metadata=dict(model_config)) | |
run.log_artifact(model_artifact) | |
except KeyError as e: | |
print(e) | |
break | |
save_model(f'./{CLIP_FINAL_OUTPUT_FILE_NAME}') | |
model_artifact = wandb.Artifact( | |
'finetuned-clip', type='model', metadata=dict(model_config)) | |
run.log_artifact(model_artifact) | |
wandb.finish() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment