-
-
Save calebrob6/4d7fc311045c4f9015e401100e34ed38 to your computer and use it in GitHub Desktop.
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "6a5b65df", | |
| "metadata": {}, | |
| "source": [ | |
| "# LEVIR-CD+ change detection example notebook\n", | |
| "\n", | |
| "We start off by installing torchgeo. If you are running this on Colab, then you will need to restart your runtime after this step." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "4627b902", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "!pip install torchgeo" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "475f3715", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "\n", | |
| "import torchgeo\n", | |
| "from torchgeo.datasets import LEVIRCDPlus\n", | |
| "from torchgeo.datasets.utils import unbind_samples\n", | |
| "from torchgeo.trainers import SemanticSegmentationTask\n", | |
| "from torchgeo.datamodules.utils import dataset_split\n", | |
| "\n", | |
| "import lightning.pytorch as pl\n", | |
| "from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n", | |
| "from lightning.pytorch import Trainer, seed_everything\n", | |
| "from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger\n", | |
| "from lightning.pytorch import LightningDataModule\n", | |
| "\n", | |
| "import torch\n", | |
| "from torch.utils.data import DataLoader\n", | |
| "import kornia.augmentation as K\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import torchvision\n", | |
| "from torchvision.transforms import Compose\n", | |
| "from tqdm import tqdm\n", | |
| "\n", | |
| "from sklearn.metrics import precision_score, recall_score" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "2ae75c6f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "('0.5.1', '2.1.3', '2.0.1+cu117')" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torchgeo.__version__, pl.__version__, torch.__version__" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "daedd8ce", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "torch.cuda.is_available()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "0b012728", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# some experiment parameters\n", | |
| "\n", | |
| "experiment_name = \"experiment_test\"\n", | |
| "experiment_dir = f\"results/{experiment_name}\"\n", | |
| "os.makedirs(experiment_dir, exist_ok=True)\n", | |
| "\n", | |
| "batch_size = 8\n", | |
| "learning_rate = 0.0001\n", | |
| "gpu_id = 0\n", | |
| "device = torch.device(f\"cuda:{gpu_id}\")\n", | |
| "num_dataloader_workers = 2\n", | |
| "patch_size = 256\n", | |
| "val_split_pct = 0.1 # how much of our training set to hold out as a validation set" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "ca211445", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Files already downloaded and verified\n", | |
| "Files already downloaded and verified\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(637, 348)" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Download the dataset and see how many images are in the train and test splits\n", | |
| "\n", | |
| "train_dataset = LEVIRCDPlus(root=\"data/LEVIRCDPlus\", split=\"train\", download=True, checksum=True)\n", | |
| "test_dataset = LEVIRCDPlus(root=\"data/LEVIRCDPlus\", split=\"test\", download=True, checksum=True)\n", | |
| "len(train_dataset), len(test_dataset)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "8d7e6981", | |
| "metadata": {}, | |
| "source": [ | |
| "## Excersise 1\n", | |
| "\n", | |
| "Plot some examples from the `train_dataset` (note: torchgeo will help you out here)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "8127d129", | |
| "metadata": {}, | |
| "source": [ | |
| "## Define a PyTorch Lightning module and datamodule\n", | |
| "\n", | |
| "PyTorch Lightning organizes the steps required for training deep learning models in `LightningModules`, and organizes the dataset handling to creating dataloaders in `LightningDataModules`. TorchGeo provides pre-built LightningDataModules for a handful of datasets, and pre-built \"trainers\" (i.e. LightningModules) for a variety of different types of tasks.\n", | |
| "\n", | |
| "For this tutorial, we will lightly extend TorchGeo's `SemanticSegmentationTask` (just to add some custom plotting code) and create a new LightningDataModule for the LEVIR-CD+ dataset." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "26f62ac5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class CustomSemanticSegmentationTask(SemanticSegmentationTask):\n", | |
| " \n", | |
| " def plot(self, sample):\n", | |
| " image1 = sample[\"image\"][:3]\n", | |
| " image2 = sample[\"image\"][3:]\n", | |
| " mask = sample[\"mask\"]\n", | |
| " prediction = sample[\"prediction\"]\n", | |
| "\n", | |
| " fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(4 * 5, 5))\n", | |
| " axs[0].imshow(image1.permute(1, 2, 0))\n", | |
| " axs[0].axis(\"off\")\n", | |
| " axs[1].imshow(image2.permute(1, 2, 0))\n", | |
| " axs[1].axis(\"off\")\n", | |
| " axs[2].imshow(mask)\n", | |
| " axs[2].axis(\"off\")\n", | |
| " axs[3].imshow(prediction)\n", | |
| " axs[3].axis(\"off\")\n", | |
| "\n", | |
| " axs[0].set_title(\"Image 1\")\n", | |
| " axs[1].set_title(\"Image 2\")\n", | |
| " axs[2].set_title(\"Mask\")\n", | |
| " axs[3].set_title(\"Prediction\")\n", | |
| "\n", | |
| " plt.tight_layout()\n", | |
| " \n", | |
| " return fig\n", | |
| "\n", | |
| " # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function\n", | |
| " def training_step(self, *args, **kwargs):\n", | |
| " batch = args[0]\n", | |
| " batch_idx = args[1]\n", | |
| " \n", | |
| " x = batch[\"image\"]\n", | |
| " y = batch[\"mask\"]\n", | |
| " y_hat = self.forward(x)\n", | |
| " y_hat_hard = y_hat.argmax(dim=1)\n", | |
| "\n", | |
| " loss = self.criterion(y_hat, y)\n", | |
| "\n", | |
| " self.log(\"train_loss\", loss, on_step=True, on_epoch=False)\n", | |
| " self.train_metrics(y_hat_hard, y)\n", | |
| "\n", | |
| " if batch_idx < 10:\n", | |
| " batch[\"prediction\"] = y_hat_hard\n", | |
| " for key in [\"image\", \"mask\", \"prediction\"]:\n", | |
| " batch[key] = batch[key].cpu()\n", | |
| " sample = unbind_samples(batch)[0]\n", | |
| " fig = self.plot(sample)\n", | |
| " summary_writer = self.logger.experiment\n", | |
| " summary_writer.add_figure(\n", | |
| " f\"image/train/{batch_idx}\", fig, global_step=self.global_step\n", | |
| " )\n", | |
| " plt.close()\n", | |
| " \n", | |
| " return loss\n", | |
| " \n", | |
| " # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function\n", | |
| " def validation_step(self, *args, **kwargs):\n", | |
| " batch = args[0]\n", | |
| " batch_idx = args[1]\n", | |
| " x = batch[\"image\"]\n", | |
| " y = batch[\"mask\"]\n", | |
| " y_hat = self.forward(x)\n", | |
| " y_hat_hard = y_hat.argmax(dim=1)\n", | |
| "\n", | |
| " loss = self.criterion(y_hat, y)\n", | |
| "\n", | |
| " self.log(\"val_loss\", loss, on_step=False, on_epoch=True)\n", | |
| " self.val_metrics(y_hat_hard, y)\n", | |
| "\n", | |
| " if batch_idx < 10:\n", | |
| " batch[\"prediction\"] = y_hat_hard\n", | |
| " for key in [\"image\", \"mask\", \"prediction\"]:\n", | |
| " batch[key] = batch[key].cpu()\n", | |
| " sample = unbind_samples(batch)[0]\n", | |
| " fig = self.plot(sample)\n", | |
| " summary_writer = self.logger.experiment\n", | |
| " summary_writer.add_figure(\n", | |
| " f\"image/val/{batch_idx}\", fig, global_step=self.global_step\n", | |
| " )\n", | |
| " plt.close()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "f420887f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class LEVIRCDPlusDataModule(pl.LightningDataModule):\n", | |
| "\n", | |
| " def __init__(\n", | |
| " self,\n", | |
| " batch_size=32,\n", | |
| " num_workers=0,\n", | |
| " val_split_pct=0.2,\n", | |
| " patch_size=(256, 256),\n", | |
| " **kwargs,\n", | |
| " ):\n", | |
| " super().__init__()\n", | |
| " self.batch_size = batch_size\n", | |
| " self.num_workers = num_workers\n", | |
| " self.val_split_pct = val_split_pct\n", | |
| " self.patch_size = patch_size\n", | |
| " self.kwargs = kwargs\n", | |
| "\n", | |
| " def on_after_batch_transfer(\n", | |
| " self, batch, batch_idx\n", | |
| " ):\n", | |
| " if (\n", | |
| " hasattr(self, \"trainer\")\n", | |
| " and self.trainer is not None\n", | |
| " and hasattr(self.trainer, \"training\")\n", | |
| " and self.trainer.training\n", | |
| " ):\n", | |
| " # Kornia expects masks to be floats with a channel dimension\n", | |
| " x = batch[\"image\"]\n", | |
| " y = batch[\"mask\"].float().unsqueeze(1)\n", | |
| "\n", | |
| " train_augmentations = K.AugmentationSequential(\n", | |
| " K.RandomRotation(p=0.5, degrees=90),\n", | |
| " K.RandomHorizontalFlip(p=0.5),\n", | |
| " K.RandomVerticalFlip(p=0.5),\n", | |
| " K.RandomCrop(self.patch_size),\n", | |
| " K.RandomSharpness(p=0.5),\n", | |
| " data_keys=[\"input\", \"mask\"],\n", | |
| " )\n", | |
| " x, y = train_augmentations(x, y)\n", | |
| "\n", | |
| " # torchmetrics expects masks to be longs without a channel dimension\n", | |
| " batch[\"image\"] = x\n", | |
| " batch[\"mask\"] = y.squeeze(1).long()\n", | |
| "\n", | |
| " return batch\n", | |
| " \n", | |
| " def preprocess(self, sample):\n", | |
| " sample[\"image\"] = (sample[\"image\"] / 255.0).float()\n", | |
| " sample[\"image\"] = torch.flatten(sample[\"image\"], 0, 1)\n", | |
| " sample[\"mask\"] = sample[\"mask\"].long()\n", | |
| " return sample\n", | |
| "\n", | |
| " def prepare_data(self):\n", | |
| " LEVIRCDPlus(split=\"train\", **self.kwargs)\n", | |
| "\n", | |
| " def setup(self, stage=None):\n", | |
| " train_transforms = Compose([self.preprocess])\n", | |
| " test_transforms = Compose([self.preprocess])\n", | |
| "\n", | |
| " train_dataset = LEVIRCDPlus(\n", | |
| " split=\"train\", transforms=train_transforms, **self.kwargs\n", | |
| " )\n", | |
| "\n", | |
| " if self.val_split_pct > 0.0:\n", | |
| " self.train_dataset, self.val_dataset, _ = dataset_split(\n", | |
| " train_dataset, val_pct=self.val_split_pct, test_pct=0.0\n", | |
| " )\n", | |
| " else:\n", | |
| " self.train_dataset = train_dataset\n", | |
| " self.val_dataset = train_dataset\n", | |
| "\n", | |
| " self.test_dataset = LEVIRCDPlus(\n", | |
| " split=\"test\", transforms=test_transforms, **self.kwargs\n", | |
| " )\n", | |
| "\n", | |
| " def train_dataloader(self):\n", | |
| " return DataLoader(\n", | |
| " self.train_dataset,\n", | |
| " batch_size=self.batch_size,\n", | |
| " num_workers=self.num_workers,\n", | |
| " shuffle=True,\n", | |
| " )\n", | |
| "\n", | |
| " def val_dataloader(self):\n", | |
| " return DataLoader(\n", | |
| " self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False\n", | |
| " )\n", | |
| "\n", | |
| " def test_dataloader(self):\n", | |
| " return DataLoader(\n", | |
| " self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d221e5db", | |
| "metadata": {}, | |
| "source": [ | |
| "## Setting up a training run" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "97a5ff80", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "datamodule = LEVIRCDPlusDataModule(\n", | |
| " root=\"data/LEVIRCDPlus\",\n", | |
| " batch_size=batch_size,\n", | |
| " num_workers=num_dataloader_workers,\n", | |
| " val_split_pct=val_split_pct,\n", | |
| " patch_size=(patch_size, patch_size),\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "82b472f5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "task = CustomSemanticSegmentationTask(\n", | |
| " model=\"unet\",\n", | |
| " backbone=\"resnet18\",\n", | |
| " weights=True,\n", | |
| " in_channels=6,\n", | |
| " num_classes=2,\n", | |
| " loss=\"ce\",\n", | |
| " ignore_index=None,\n", | |
| " lr=learning_rate,\n", | |
| " patience=10\n", | |
| ")\n", | |
| "\n", | |
| "checkpoint_callback = ModelCheckpoint(\n", | |
| " monitor=\"val_loss\",\n", | |
| " dirpath=experiment_dir,\n", | |
| " save_top_k=1,\n", | |
| " save_last=True,\n", | |
| ")\n", | |
| "\n", | |
| "early_stopping_callback = EarlyStopping(\n", | |
| " monitor=\"val_loss\",\n", | |
| " min_delta=0.00,\n", | |
| " patience=10,\n", | |
| ")\n", | |
| "\n", | |
| "tb_logger = TensorBoardLogger(\n", | |
| " save_dir=\"logs/\",\n", | |
| " name=experiment_name\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "e54642fd", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "%load_ext tensorboard" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "94fe9c6d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "%tensorboard --logdir logs/" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "6fc5259c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "GPU available: True (cuda), used: True\n", | |
| "TPU available: False, using: 0 TPU cores\n", | |
| "IPU available: False, using: 0 IPUs\n", | |
| "HPU available: False, using: 0 HPUs\n", | |
| "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", | |
| "\n", | |
| " | Name | Type | Params\n", | |
| "---------------------------------------------------\n", | |
| "0 | model | Unet | 14.3 M\n", | |
| "1 | loss | CrossEntropyLoss | 0 \n", | |
| "2 | train_metrics | MetricCollection | 0 \n", | |
| "3 | val_metrics | MetricCollection | 0 \n", | |
| "4 | test_metrics | MetricCollection | 0 \n", | |
| "---------------------------------------------------\n", | |
| "14.3 M Trainable params\n", | |
| "0 Non-trainable params\n", | |
| "14.3 M Total params\n", | |
| "57.351 Total estimated model params size (MB)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Sanity Checking: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/calebrobinson/.conda/envs/test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (18) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", | |
| " rank_zero_warn(\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "a15abdbc468b44d0b1a43a18e285ae95", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Training: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "f2c8514f2d4a41e09b3bc88af1b40887", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Validation: 0it [00:00, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "trainer = pl.Trainer(\n", | |
| " callbacks=[checkpoint_callback, early_stopping_callback],\n", | |
| " logger=[tb_logger],\n", | |
| " default_root_dir=experiment_dir,\n", | |
| " min_epochs=10,\n", | |
| " max_epochs=200,\n", | |
| " accelerator='gpu',\n", | |
| " devices=[gpu_id]\n", | |
| ")\n", | |
| "\n", | |
| "_ = trainer.fit(model=task, datamodule=datamodule)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "2cfacd81", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "trainer.test(model=task, datamodule=datamodule)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "346e4afe", | |
| "metadata": {}, | |
| "source": [ | |
| "## Custom test step to compute the precision, recall, and F1 metrics" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "b61db9fb", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Example of how to load a trained task from a checkpoint file\n", | |
| "# task = CustomSemanticSegmentationTask.load_from_checkpoint(\"results/...\")\n", | |
| "# datamodule.setup(\"test\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "c9b7a93c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model = task.model.to(device).eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "0e545e06", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 44/44 [00:21<00:00, 2.04it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "y_preds = []\n", | |
| "y_trues = []\n", | |
| "for batch in tqdm(datamodule.test_dataloader()):\n", | |
| " images = batch[\"image\"].to(device)\n", | |
| " y_trues.append(batch[\"mask\"].numpy().ravel()[::500])\n", | |
| " with torch.inference_mode():\n", | |
| " y_pred = model(images).argmax(dim=1).cpu().numpy().ravel()[::500]\n", | |
| " y_preds.append(y_pred)\n", | |
| "\n", | |
| "y_preds = np.concatenate(y_preds)\n", | |
| "y_trues = np.concatenate(y_trues)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "8b5a6975", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "precision = precision_score(y_trues, y_preds)\n", | |
| "recall = recall_score(y_trues, y_preds)\n", | |
| "f1 = 2 * (precision * recall) / (precision + recall)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "bf25b1d4", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(0.7234695667426767, 0.5552638664512655, 0.6283037550460812)" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "precision, recall, f1" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "python3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Then on running trainer.fit you will get
TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `CustomSemanticSegmentationTask`
It is necessary to use the new lightning format for imports:
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch import LightningDataModuleYou will next get:
AttributeError: 'CustomSemanticSegmentationTask' object has no attribute 'loss'
It is necessary in the custom trainer to use:
loss: Tensor = self.criterion(y_hat, y)| Name | Type | Params
0 | criterion | CrossEntropyLoss | 0
1 | train_metrics | MetricCollection | 0
2 | val_metrics | MetricCollection | 0
3 | test_metrics | MetricCollection | 0
4 | model | Unet | 14.3 M
14.3 M Trainable params
0 Non-trainable params
14.3 M Total params
57.351 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
It remains at this stage, what could be the problem?
Hey @mustafaemre2 -- are you running on the GPU?
Updated with @robmarkcole's fixes (and ensured that the notebook runs end-to-end) for torchgeo 0.5.1 (thanks Robin!)
Yes, I used your codes exactly
My GPU's RTX 3060 laptop @calebrob6
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set torch.set_float32_matmul_precision('medium' | 'high') which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
i get this message @calebrob6
Download the dataset and see how many images are in the train and test splits
train_dataset = LEVIRCDPlus(root="data/LEVIRCDPlus", split="train", download=True, checksum=True)
test_dataset = LEVIRCDPlus(root="data/LEVIRCDPlus", split="test", download=True, checksum=True)
len(train_dataset), len(test_dataset)
Its give error :-
RuntimeError: The MD5 checksum of the download file data/LEVIRCDPlus/LEVIR-CD+.zip does not match the one on record.Please delete the file and try again. If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues.
i get this msg @calebrob6
@nadeem-git-coder Have you found any solution for that?
@ProtikBose I have downloaded the dataset mannually and use it .
Have you encountered the error?
What the error
Running with torchgeo 0.5.0 will give:
The necessary update: