Skip to content

Instantly share code, notes, and snippets.

@jamessdixon
Created January 24, 2025 05:36
Show Gist options
  • Save jamessdixon/e3fbb1db18b90dec00fa26e258be920e to your computer and use it in GitHub Desktop.
Save jamessdixon/e3fbb1db18b90dec00fa26e258be920e to your computer and use it in GitHub Desktop.
classifier_reinforcement
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\dixonjames\\Documents\\trained\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"import torch\n",
"from PIL import Image\n",
"from transformers import CLIPModel, CLIPProcessor"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Feedback collected: A passport\n"
]
}
],
"source": [
"model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
"\n",
"feedback_dataset = []\n",
"\n",
"def maker(image):\n",
" descriptions = [\"A driver's license\", \"A passport\", \"A student ID\"]\n",
" inputs = processor(text=descriptions, images=image, return_tensors=\"pt\", padding=True)\n",
" outputs = model(**inputs)\n",
" logits_per_image = outputs.logits_per_image \n",
" probs = logits_per_image.softmax(dim=1) \n",
" predicted_class = descriptions[probs.argmax()]\n",
" confidence = probs.max().item()\n",
"\n",
" if confidence > 0.5:\n",
" return predicted_class\n",
" else:\n",
" return None\n",
"\n",
"def collect_feedback(image, correct_label):\n",
" feedback_dataset.append((image, correct_label))\n",
" print(f\"Feedback collected: {correct_label}\")\n",
"\n",
"\n",
"input_folder = 'data_new'\n",
"input_filename = '0.png'\n",
"image_path = os.path.join(input_folder, input_filename)\n",
"image = Image.open(image_path)\n",
"prediction = maker(image)\n",
"\n",
"if prediction is None or prediction != \"A passport\": \n",
" collect_feedback(image, \"A passport\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from transformers import CLIPProcessor, CLIPModel\n",
"from PIL import Image\n",
"\n",
"LABELS = [\"A driver's license\", \"A passport\", \"A student ID\"]\n",
"label_to_index = {label: idx for idx, label in enumerate(LABELS)}\n",
"\n",
"class FeedbackDataset(Dataset):\n",
" def __init__(self, dataset, processor):\n",
" self.dataset = dataset\n",
" self.processor = processor\n",
"\n",
" def __len__(self):\n",
" return len(self.dataset)\n",
"\n",
" def __getitem__(self, idx):\n",
" image, label = self.dataset[idx]\n",
" inputs = self.processor(images=image, return_tensors=\"pt\")\n",
" inputs = {key: val.squeeze(0) for key, val in inputs.items()} # Remove batch dim\n",
" return inputs, label_to_index[label]\n",
" \n",
"fine_tune_dataset = FeedbackDataset(feedback_dataset, processor)\n",
"train_loader = DataLoader(fine_tune_dataset, batch_size=4, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 3, 224, 224])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\dixonjames\\AppData\\Local\\Temp\\ipykernel_27320\\1890607759.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" loss = criterion(logits, torch.tensor(labels))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 loss: 1.2374294996261597\n",
"torch.Size([1, 3, 224, 224])\n",
"Epoch 2 loss: 1.0383309125900269\n",
"torch.Size([1, 3, 224, 224])\n",
"Epoch 3 loss: 0.8594766855239868\n",
"torch.Size([1, 3, 224, 224])\n",
"Epoch 4 loss: 0.7024854421615601\n",
"torch.Size([1, 3, 224, 224])\n",
"Epoch 5 loss: 0.5679658055305481\n",
"Fine-tuning complete!\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"class SimpleClassifier(nn.Module):\n",
" def __init__(self, embedding_dim, num_classes):\n",
" super(SimpleClassifier, self).__init__()\n",
" self.fc = nn.Linear(embedding_dim, num_classes) # Fully connected layer\n",
"\n",
" def forward(self, x):\n",
" return self.fc(x)\n",
"\n",
"embedding_dim = model.visual_projection.out_features \n",
"num_classes = len(LABELS)\n",
"\n",
"classifier = SimpleClassifier(embedding_dim, num_classes)\n",
"optimizer = optim.Adam(classifier.parameters(), lr=0.001)\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"for epoch in range(5):\n",
" for inputs, labels in train_loader:\n",
" with torch.no_grad():\n",
" print(inputs[\"pixel_values\"].shape)\n",
" image_features = model.get_image_features(pixel_values=inputs[\"pixel_values\"])\n",
" \n",
"\n",
" logits = classifier(image_features)\n",
" loss = criterion(logits, torch.tensor(labels)) \n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print(f\"Epoch {epoch+1} loss: {loss.item()}\")\n",
"\n",
"print(\"Fine-tuning complete!\")\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def maker(image):\n",
" descriptions = LABELS # [\"A driver's license\", \"A passport\", \"A student ID\"]\n",
" inputs = processor(images=image, return_tensors=\"pt\")\n",
"\n",
" with torch.no_grad():\n",
" image_features = model.get_image_features(**inputs)\n",
"\n",
" logits = classifier(image_features)\n",
" predicted_index = logits.argmax().item()\n",
" predicted_class = descriptions[predicted_index]\n",
"\n",
" return predicted_class"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A passport\n"
]
}
],
"source": [
"input_folder = 'data_new'\n",
"input_filename = '0.png'\n",
"image_path = os.path.join(input_folder, input_filename)\n",
"image = Image.open(image_path)\n",
"prediction = maker(image)\n",
"print(prediction)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment