Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save FredrikNoren/6914e524cc9daba7a7b9c62804f2921c to your computer and use it in GitHub Desktop.
Save FredrikNoren/6914e524cc9daba7a7b9c62804f2921c to your computer and use it in GitHub Desktop.
import io
import json
from typing import Any, cast
import requests
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from transformers import TrainingArguments
from trl import SFTTrainer, SFTConfig
from peft.tuners.lora.config import LoraConfig
from datasets import IterableDataset, Features
import datasets
from PIL import Image
import numpy as np
HF_TOKEN = "..."
def load_image(url):
response = requests.get(url)
image = Image.open(io.BytesIO(response.content))
return image
def image_from_bytes(image_bytes):
return Image.open(io.BytesIO(image_bytes))
def main():
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", token=HF_TOKEN
)
model.config.use_cache = False # Disable caching for training
processor = AutoProcessor.from_pretrained(model_id, padding_side="right", token=HF_TOKEN)
processor.tokenizer.pad_token = processor.tokenizer.eos_token # Use eos token as pad token
processor.tokenizer.padding_side = "right"
USE_ITERABLE_DATASET = False
messages_obj = [
{
"role": "user",
"content": [
{"type": "image", },
{"type": "image", }
]
},
{
"role": "assistant",
"content": [{"type": "text", "text": "duck" }]
}
]
if USE_ITERABLE_DATASET:
def train_iterable_gen():
yield {
"messages": json.dumps(messages_obj)
}
train_ds = IterableDataset.from_generator(
train_iterable_gen,
features=Features({
'messages': datasets.Value(dtype='string', id=None)
} )
)
else:
train_ds = [
{
"messages": json.dumps(messages_obj)
}
]
image = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg").resize((896, 896))
def collate_fn(examples):
print("collate_fn examples", examples)
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(json.loads(example["messages"]), tokenize=False, add_generation_prompt=False) for example in examples]
images = [[image.convert("RGB"), image.convert("RGB")]]
print("collate_fn texts", texts)
print("collate_fn images", images)
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
print("collate_fn pixel_values", batch["pixel_values"].shape)
print("collate_fn input_ids", batch["input_ids"].shape)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.image_token_id] = -100
batch["labels"] = labels
return batch
# Set up LoRA configuration for causal language modeling
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM"
)
# Define training arguments
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=1,
learning_rate=2e-4,
logging_steps=1,
save_steps=25,
report_to="tensorboard",
group_by_length=False,
remove_unused_columns=False,
dataset_kwargs = {"skip_prepare_dataset": True},
gradient_checkpointing_kwargs = dict(use_reentrant=False),
max_steps=1
)
# Create the SFTTrainer with LoRA parameters
trainer = SFTTrainer(
model=model,
train_dataset=cast(Any, train_ds),
peft_config=lora_config,
args=training_args,
data_collator=collate_fn,
processing_class=processor.tokenizer,
)
print("Training model...")
trainer.train()
print("Training complete.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment