Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save ianturton/f0e393f7e27bcf9ee82a8defddcecacf to your computer and use it in GitHub Desktop.

Select an option

Save ianturton/f0e393f7e27bcf9ee82a8defddcecacf to your computer and use it in GitHub Desktop.
Short example of object detection training in TorchGeo.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d7ff0785",
"metadata": {},
"outputs": [],
"source": [
"import torchgeo"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "64d6ec01",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.6.0.dev0'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torchgeo.__version__"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d91b04cd",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchgeo.trainers import ObjectDetectionTask\n",
"from torchgeo.datasets import VHR10\n",
"from torch.utils.data import DataLoader\n",
"import lightning.pytorch as pl\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8a565672",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"loading annotations into memory...\n",
"Done (t=0.02s)\n",
"creating index...\n",
"index created!\n"
]
}
],
"source": [
"def preprocess(sample):\n",
" sample[\"image\"] = sample[\"image\"].float() / 255.0\n",
" return sample\n",
"\n",
"ds = VHR10(root='data/VHR10/', split='positive', transforms=preprocess, download=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4f4aab35",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"650"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(ds)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "abde065f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'image': tensor([[[0.3059, 0.3059, 0.3098, ..., 0.3765, 0.3686, 0.3647],\n",
" [0.3020, 0.3020, 0.3098, ..., 0.3804, 0.3725, 0.3686],\n",
" [0.2941, 0.2980, 0.3059, ..., 0.3804, 0.3686, 0.3569],\n",
" ...,\n",
" [0.4431, 0.4431, 0.4471, ..., 0.3373, 0.3333, 0.3333],\n",
" [0.4431, 0.4431, 0.4471, ..., 0.3373, 0.3333, 0.3294],\n",
" [0.4392, 0.4431, 0.4431, ..., 0.3412, 0.3333, 0.3294]],\n",
" \n",
" [[0.3490, 0.3490, 0.3529, ..., 0.4275, 0.4196, 0.4157],\n",
" [0.3451, 0.3451, 0.3529, ..., 0.4314, 0.4235, 0.4196],\n",
" [0.3333, 0.3373, 0.3451, ..., 0.4314, 0.4196, 0.4078],\n",
" ...,\n",
" [0.4941, 0.4941, 0.4980, ..., 0.3569, 0.3529, 0.3529],\n",
" [0.4941, 0.4941, 0.4980, ..., 0.3569, 0.3529, 0.3490],\n",
" [0.4902, 0.4941, 0.4941, ..., 0.3608, 0.3529, 0.3490]],\n",
" \n",
" [[0.1922, 0.1922, 0.1961, ..., 0.3098, 0.3098, 0.3059],\n",
" [0.1882, 0.1882, 0.1961, ..., 0.3137, 0.3137, 0.3098],\n",
" [0.1882, 0.1922, 0.2000, ..., 0.3137, 0.3098, 0.2980],\n",
" ...,\n",
" [0.4667, 0.4588, 0.4627, ..., 0.2314, 0.2275, 0.2275],\n",
" [0.4667, 0.4588, 0.4627, ..., 0.2314, 0.2275, 0.2235],\n",
" [0.4627, 0.4588, 0.4588, ..., 0.2353, 0.2275, 0.2235]]]),\n",
" 'labels': tensor([1]),\n",
" 'boxes': tensor([[563., 485., 629., 571.]]),\n",
" 'masks': tensor([[[0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" ...,\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]]], dtype=torch.uint8)}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds[0]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "47fe2f86",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([3, 808, 958]), torch.Size([3, 806, 950]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds[0][\"image\"].shape, ds[1][\"image\"].shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "38476acc",
"metadata": {},
"outputs": [
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment