Created
January 24, 2025 05:36
-
-
Save jamessdixon/e3fbb1db18b90dec00fa26e258be920e to your computer and use it in GitHub Desktop.
classifier_reinforcement
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"c:\\Users\\dixonjames\\Documents\\trained\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
" from .autonotebook import tqdm as notebook_tqdm\n" | |
] | |
} | |
], | |
"source": [ | |
"import os\n", | |
"import torch\n", | |
"from PIL import Image\n", | |
"from transformers import CLIPModel, CLIPProcessor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Feedback collected: A passport\n" | |
] | |
} | |
], | |
"source": [ | |
"model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", | |
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n", | |
"\n", | |
"feedback_dataset = []\n", | |
"\n", | |
"def maker(image):\n", | |
" descriptions = [\"A driver's license\", \"A passport\", \"A student ID\"]\n", | |
" inputs = processor(text=descriptions, images=image, return_tensors=\"pt\", padding=True)\n", | |
" outputs = model(**inputs)\n", | |
" logits_per_image = outputs.logits_per_image \n", | |
" probs = logits_per_image.softmax(dim=1) \n", | |
" predicted_class = descriptions[probs.argmax()]\n", | |
" confidence = probs.max().item()\n", | |
"\n", | |
" if confidence > 0.5:\n", | |
" return predicted_class\n", | |
" else:\n", | |
" return None\n", | |
"\n", | |
"def collect_feedback(image, correct_label):\n", | |
" feedback_dataset.append((image, correct_label))\n", | |
" print(f\"Feedback collected: {correct_label}\")\n", | |
"\n", | |
"\n", | |
"input_folder = 'data_new'\n", | |
"input_filename = '0.png'\n", | |
"image_path = os.path.join(input_folder, input_filename)\n", | |
"image = Image.open(image_path)\n", | |
"prediction = maker(image)\n", | |
"\n", | |
"if prediction is None or prediction != \"A passport\": \n", | |
" collect_feedback(image, \"A passport\")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch.utils.data import Dataset, DataLoader\n", | |
"from transformers import CLIPProcessor, CLIPModel\n", | |
"from PIL import Image\n", | |
"\n", | |
"LABELS = [\"A driver's license\", \"A passport\", \"A student ID\"]\n", | |
"label_to_index = {label: idx for idx, label in enumerate(LABELS)}\n", | |
"\n", | |
"class FeedbackDataset(Dataset):\n", | |
" def __init__(self, dataset, processor):\n", | |
" self.dataset = dataset\n", | |
" self.processor = processor\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.dataset)\n", | |
"\n", | |
" def __getitem__(self, idx):\n", | |
" image, label = self.dataset[idx]\n", | |
" inputs = self.processor(images=image, return_tensors=\"pt\")\n", | |
" inputs = {key: val.squeeze(0) for key, val in inputs.items()} # Remove batch dim\n", | |
" return inputs, label_to_index[label]\n", | |
" \n", | |
"fine_tune_dataset = FeedbackDataset(feedback_dataset, processor)\n", | |
"train_loader = DataLoader(fine_tune_dataset, batch_size=4, shuffle=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([1, 3, 224, 224])\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"C:\\Users\\dixonjames\\AppData\\Local\\Temp\\ipykernel_27320\\1890607759.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", | |
" loss = criterion(logits, torch.tensor(labels))\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1 loss: 1.2374294996261597\n", | |
"torch.Size([1, 3, 224, 224])\n", | |
"Epoch 2 loss: 1.0383309125900269\n", | |
"torch.Size([1, 3, 224, 224])\n", | |
"Epoch 3 loss: 0.8594766855239868\n", | |
"torch.Size([1, 3, 224, 224])\n", | |
"Epoch 4 loss: 0.7024854421615601\n", | |
"torch.Size([1, 3, 224, 224])\n", | |
"Epoch 5 loss: 0.5679658055305481\n", | |
"Fine-tuning complete!\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"\n", | |
"class SimpleClassifier(nn.Module):\n", | |
" def __init__(self, embedding_dim, num_classes):\n", | |
" super(SimpleClassifier, self).__init__()\n", | |
" self.fc = nn.Linear(embedding_dim, num_classes) # Fully connected layer\n", | |
"\n", | |
" def forward(self, x):\n", | |
" return self.fc(x)\n", | |
"\n", | |
"embedding_dim = model.visual_projection.out_features \n", | |
"num_classes = len(LABELS)\n", | |
"\n", | |
"classifier = SimpleClassifier(embedding_dim, num_classes)\n", | |
"optimizer = optim.Adam(classifier.parameters(), lr=0.001)\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"\n", | |
"for epoch in range(5):\n", | |
" for inputs, labels in train_loader:\n", | |
" with torch.no_grad():\n", | |
" print(inputs[\"pixel_values\"].shape)\n", | |
" image_features = model.get_image_features(pixel_values=inputs[\"pixel_values\"])\n", | |
" \n", | |
"\n", | |
" logits = classifier(image_features)\n", | |
" loss = criterion(logits, torch.tensor(labels)) \n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" print(f\"Epoch {epoch+1} loss: {loss.item()}\")\n", | |
"\n", | |
"print(\"Fine-tuning complete!\")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def maker(image):\n", | |
" descriptions = LABELS # [\"A driver's license\", \"A passport\", \"A student ID\"]\n", | |
" inputs = processor(images=image, return_tensors=\"pt\")\n", | |
"\n", | |
" with torch.no_grad():\n", | |
" image_features = model.get_image_features(**inputs)\n", | |
"\n", | |
" logits = classifier(image_features)\n", | |
" predicted_index = logits.argmax().item()\n", | |
" predicted_class = descriptions[predicted_index]\n", | |
"\n", | |
" return predicted_class" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"A passport\n" | |
] | |
} | |
], | |
"source": [ | |
"input_folder = 'data_new'\n", | |
"input_filename = '0.png'\n", | |
"image_path = os.path.join(input_folder, input_filename)\n", | |
"image = Image.open(image_path)\n", | |
"prediction = maker(image)\n", | |
"print(prediction)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment