Created
August 16, 2024 18:45
-
-
Save Entreprenerdly/395060b709d797229c45c16412eb316f to your computer and use it in GitHub Desktop.
Entreprenerdly.com - Restoring Image Quality With AI using Real ESRGAN SwinIR and BSRGAN.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "0mDz75f2i7CM" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "\"\"\"\n", | |
| "This Jupyter notebook is a supplementary resource for the article\n", | |
| "'Restoring Image Quality With AI using Real-ESRGAN and SwinIR' on Entreprenerdly.com.\n", | |
| "It contains all the code snippets and examples discussed in the article,\n", | |
| "providing a hands-on approach to understanding the concepts and techniques presented.\n", | |
| "For a comprehensive understanding, please refer to the article at\n", | |
| "https://entreprenerdly.com/restoring-image-quality-with-ai-using-real-esrgan-and-swinir/\n", | |
| "\"\"\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "ICOfsSv5i5mM" | |
| }, | |
| "source": [ | |
| "<div style=\"text-align: center;\">\n", | |
| " <img src=\"https://entreprenerdly.com/wp-content/uploads/2024/03/logo-com.png\" width=\"50%\" alt=\"Fine-tuning BLIB2\">\n", | |
| "</div>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "dskuEwGcjQUm" | |
| }, | |
| "source": [ | |
| "# 1) Preparations\n", | |
| "\n", | |
| "Before start, make sure that you choose\n", | |
| "* Runtime Type = Python 3\n", | |
| "* Hardware Accelerator = GPU\n", | |
| "* Broswer != Firefox (cannot upload images in step 2)\n", | |
| "\n", | |
| "in the **Runtime** menu -> **Change runtime type**\n", | |
| "\n", | |
| "Then, we clone the repository, set up the envrironment, and download the pre-trained model." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "ugfa6HKSjYdZ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Clone realESRGAN\n", | |
| "!git clone https://github.com/xinntao/Real-ESRGAN.git\n", | |
| "%cd Real-ESRGAN\n", | |
| "# Set up the environment\n", | |
| "!pip install basicsr\n", | |
| "!pip install facexlib\n", | |
| "!pip install gfpgan\n", | |
| "!pip install -r requirements.txt\n", | |
| "!python setup.py develop\n", | |
| "\n", | |
| "# Clone BSRGAN\n", | |
| "!git clone https://github.com/cszn/BSRGAN.git\n", | |
| "\n", | |
| "!rm -r SwinIR\n", | |
| "# Clone SwinIR\n", | |
| "!git clone https://github.com/JingyunLiang/SwinIR.git\n", | |
| "!pip install timm\n", | |
| "\n", | |
| "# Download the pre-trained models\n", | |
| "!wget https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth -P BSRGAN/model_zoo\n", | |
| "!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models\n", | |
| "#!wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth -P experiments/pretrained_models\n", | |
| "!wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth -P experiments/pretrained_models" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 443 | |
| }, | |
| "id": "9WlpWEHcosui", | |
| "outputId": "765a09df-f5f3-4ec4-f28f-09462ffc0a30" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| " Note1: You can find an image on the web or download images from the RealSRSet (proposed in BSRGAN, ICCV2021) at https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/RealSRSet+5images.zip.\n", | |
| " Note2: You may need Chrome to enable file uploading!\n", | |
| " Note3: If out-of-memory, set test_patch_wise = True.\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <input type=\"file\" id=\"files-98f52cd3-f42a-482e-b96d-c99026fb9a6f\" name=\"files[]\" multiple disabled\n", | |
| " style=\"border:none\" />\n", | |
| " <output id=\"result-98f52cd3-f42a-482e-b96d-c99026fb9a6f\">\n", | |
| " Upload widget is only available when the cell has been executed in the\n", | |
| " current browser session. Please rerun this cell to enable.\n", | |
| " </output>\n", | |
| " <script>// Copyright 2017 Google LLC\n", | |
| "//\n", | |
| "// Licensed under the Apache License, Version 2.0 (the \"License\");\n", | |
| "// you may not use this file except in compliance with the License.\n", | |
| "// You may obtain a copy of the License at\n", | |
| "//\n", | |
| "// http://www.apache.org/licenses/LICENSE-2.0\n", | |
| "//\n", | |
| "// Unless required by applicable law or agreed to in writing, software\n", | |
| "// distributed under the License is distributed on an \"AS IS\" BASIS,\n", | |
| "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | |
| "// See the License for the specific language governing permissions and\n", | |
| "// limitations under the License.\n", | |
| "\n", | |
| "/**\n", | |
| " * @fileoverview Helpers for google.colab Python module.\n", | |
| " */\n", | |
| "(function(scope) {\n", | |
| "function span(text, styleAttributes = {}) {\n", | |
| " const element = document.createElement('span');\n", | |
| " element.textContent = text;\n", | |
| " for (const key of Object.keys(styleAttributes)) {\n", | |
| " element.style[key] = styleAttributes[key];\n", | |
| " }\n", | |
| " return element;\n", | |
| "}\n", | |
| "\n", | |
| "// Max number of bytes which will be uploaded at a time.\n", | |
| "const MAX_PAYLOAD_SIZE = 100 * 1024;\n", | |
| "\n", | |
| "function _uploadFiles(inputId, outputId) {\n", | |
| " const steps = uploadFilesStep(inputId, outputId);\n", | |
| " const outputElement = document.getElementById(outputId);\n", | |
| " // Cache steps on the outputElement to make it available for the next call\n", | |
| " // to uploadFilesContinue from Python.\n", | |
| " outputElement.steps = steps;\n", | |
| "\n", | |
| " return _uploadFilesContinue(outputId);\n", | |
| "}\n", | |
| "\n", | |
| "// This is roughly an async generator (not supported in the browser yet),\n", | |
| "// where there are multiple asynchronous steps and the Python side is going\n", | |
| "// to poll for completion of each step.\n", | |
| "// This uses a Promise to block the python side on completion of each step,\n", | |
| "// then passes the result of the previous step as the input to the next step.\n", | |
| "function _uploadFilesContinue(outputId) {\n", | |
| " const outputElement = document.getElementById(outputId);\n", | |
| " const steps = outputElement.steps;\n", | |
| "\n", | |
| " const next = steps.next(outputElement.lastPromiseValue);\n", | |
| " return Promise.resolve(next.value.promise).then((value) => {\n", | |
| " // Cache the last promise value to make it available to the next\n", | |
| " // step of the generator.\n", | |
| " outputElement.lastPromiseValue = value;\n", | |
| " return next.value.response;\n", | |
| " });\n", | |
| "}\n", | |
| "\n", | |
| "/**\n", | |
| " * Generator function which is called between each async step of the upload\n", | |
| " * process.\n", | |
| " * @param {string} inputId Element ID of the input file picker element.\n", | |
| " * @param {string} outputId Element ID of the output display.\n", | |
| " * @return {!Iterable<!Object>} Iterable of next steps.\n", | |
| " */\n", | |
| "function* uploadFilesStep(inputId, outputId) {\n", | |
| " const inputElement = document.getElementById(inputId);\n", | |
| " inputElement.disabled = false;\n", | |
| "\n", | |
| " const outputElement = document.getElementById(outputId);\n", | |
| " outputElement.innerHTML = '';\n", | |
| "\n", | |
| " const pickedPromise = new Promise((resolve) => {\n", | |
| " inputElement.addEventListener('change', (e) => {\n", | |
| " resolve(e.target.files);\n", | |
| " });\n", | |
| " });\n", | |
| "\n", | |
| " const cancel = document.createElement('button');\n", | |
| " inputElement.parentElement.appendChild(cancel);\n", | |
| " cancel.textContent = 'Cancel upload';\n", | |
| " const cancelPromise = new Promise((resolve) => {\n", | |
| " cancel.onclick = () => {\n", | |
| " resolve(null);\n", | |
| " };\n", | |
| " });\n", | |
| "\n", | |
| " // Wait for the user to pick the files.\n", | |
| " const files = yield {\n", | |
| " promise: Promise.race([pickedPromise, cancelPromise]),\n", | |
| " response: {\n", | |
| " action: 'starting',\n", | |
| " }\n", | |
| " };\n", | |
| "\n", | |
| " cancel.remove();\n", | |
| "\n", | |
| " // Disable the input element since further picks are not allowed.\n", | |
| " inputElement.disabled = true;\n", | |
| "\n", | |
| " if (!files) {\n", | |
| " return {\n", | |
| " response: {\n", | |
| " action: 'complete',\n", | |
| " }\n", | |
| " };\n", | |
| " }\n", | |
| "\n", | |
| " for (const file of files) {\n", | |
| " const li = document.createElement('li');\n", | |
| " li.append(span(file.name, {fontWeight: 'bold'}));\n", | |
| " li.append(span(\n", | |
| " `(${file.type || 'n/a'}) - ${file.size} bytes, ` +\n", | |
| " `last modified: ${\n", | |
| " file.lastModifiedDate ? file.lastModifiedDate.toLocaleDateString() :\n", | |
| " 'n/a'} - `));\n", | |
| " const percent = span('0% done');\n", | |
| " li.appendChild(percent);\n", | |
| "\n", | |
| " outputElement.appendChild(li);\n", | |
| "\n", | |
| " const fileDataPromise = new Promise((resolve) => {\n", | |
| " const reader = new FileReader();\n", | |
| " reader.onload = (e) => {\n", | |
| " resolve(e.target.result);\n", | |
| " };\n", | |
| " reader.readAsArrayBuffer(file);\n", | |
| " });\n", | |
| " // Wait for the data to be ready.\n", | |
| " let fileData = yield {\n", | |
| " promise: fileDataPromise,\n", | |
| " response: {\n", | |
| " action: 'continue',\n", | |
| " }\n", | |
| " };\n", | |
| "\n", | |
| " // Use a chunked sending to avoid message size limits. See b/62115660.\n", | |
| " let position = 0;\n", | |
| " do {\n", | |
| " const length = Math.min(fileData.byteLength - position, MAX_PAYLOAD_SIZE);\n", | |
| " const chunk = new Uint8Array(fileData, position, length);\n", | |
| " position += length;\n", | |
| "\n", | |
| " const base64 = btoa(String.fromCharCode.apply(null, chunk));\n", | |
| " yield {\n", | |
| " response: {\n", | |
| " action: 'append',\n", | |
| " file: file.name,\n", | |
| " data: base64,\n", | |
| " },\n", | |
| " };\n", | |
| "\n", | |
| " let percentDone = fileData.byteLength === 0 ?\n", | |
| " 100 :\n", | |
| " Math.round((position / fileData.byteLength) * 100);\n", | |
| " percent.textContent = `${percentDone}% done`;\n", | |
| "\n", | |
| " } while (position < fileData.byteLength);\n", | |
| " }\n", | |
| "\n", | |
| " // All done.\n", | |
| " yield {\n", | |
| " response: {\n", | |
| " action: 'complete',\n", | |
| " }\n", | |
| " };\n", | |
| "}\n", | |
| "\n", | |
| "scope.google = scope.google || {};\n", | |
| "scope.google.colab = scope.google.colab || {};\n", | |
| "scope.google.colab._files = {\n", | |
| " _uploadFiles,\n", | |
| " _uploadFilesContinue,\n", | |
| "};\n", | |
| "})(self);\n", | |
| "</script> " | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Saving building.png to building.png\n", | |
| "Saving dped_crop00061.png to dped_crop00061.png\n", | |
| "Saving foreman.png to foreman.png\n", | |
| "Saving frog.png to frog.png\n", | |
| "Saving oldphoto6.png to oldphoto6.png\n", | |
| "Saving OST_009.png to OST_009.png\n", | |
| "move building.png to BSRGAN/testsets/RealSRSet/building.png\n", | |
| "move dped_crop00061.png to BSRGAN/testsets/RealSRSet/dped_crop00061.png\n", | |
| "move foreman.png to BSRGAN/testsets/RealSRSet/foreman.png\n", | |
| "move frog.png to BSRGAN/testsets/RealSRSet/frog.png\n", | |
| "move oldphoto6.png to BSRGAN/testsets/RealSRSet/oldphoto6.png\n", | |
| "move OST_009.png to BSRGAN/testsets/RealSRSet/OST_009.png\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import os\n", | |
| "import glob\n", | |
| "from google.colab import files\n", | |
| "import shutil\n", | |
| "print(' Note1: You can find an image on the web or download images from the RealSRSet (proposed in BSRGAN, ICCV2021) at https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/RealSRSet+5images.zip.\\n Note2: You may need Chrome to enable file uploading!\\n Note3: If out-of-memory, set test_patch_wise = True.\\n')\n", | |
| "\n", | |
| "# test SwinIR by partioning the image into patches\n", | |
| "test_patch_wise = False\n", | |
| "\n", | |
| "# to be compatible with BSRGAN\n", | |
| "!rm -r BSRGAN/testsets/RealSRSet\n", | |
| "upload_folder = 'BSRGAN/testsets/RealSRSet'\n", | |
| "result_folder = 'results'\n", | |
| "\n", | |
| "if os.path.isdir(upload_folder):\n", | |
| " shutil.rmtree(upload_folder)\n", | |
| "if os.path.isdir(result_folder):\n", | |
| " shutil.rmtree(result_folder)\n", | |
| "os.mkdir(upload_folder)\n", | |
| "os.mkdir(result_folder)\n", | |
| "\n", | |
| "# upload images\n", | |
| "uploaded = files.upload()\n", | |
| "for filename in uploaded.keys():\n", | |
| " dst_path = os.path.join(upload_folder, filename)\n", | |
| " print(f'move {filename} to {dst_path}')\n", | |
| " shutil.move(filename, dst_path)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "vm8Kdss6Wamh" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# empty cache with torch\n", | |
| "import torch\n", | |
| "torch.cuda.empty_cache()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "ERBZIETTSHpp", | |
| "outputId": "d490cadc-2301-4424-ab58-41c1e4e1bd06" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "The file has been updated successfully.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import os\n", | |
| "\n", | |
| "# Use IPython magic command to get the package location\n", | |
| "file_paths = !pip show basicsr | grep \"Location\"\n", | |
| "if file_paths:\n", | |
| " file_path = os.path.join(file_paths[0].split(\": \")[1], \"basicsr/data/degradations.py\")\n", | |
| "\n", | |
| " # Check if the file exists\n", | |
| " if os.path.exists(file_path):\n", | |
| " # Open the file for reading\n", | |
| " with open(file_path, \"r\") as file:\n", | |
| " file_content = file.read()\n", | |
| "\n", | |
| " # Replace the problematic import statement\n", | |
| " new_content = file_content.replace(\n", | |
| " \"from torchvision.transforms.functional_tensor import rgb_to_grayscale\",\n", | |
| " \"from torchvision.transforms._functional_tensor import rgb_to_grayscale\"\n", | |
| " )\n", | |
| "\n", | |
| " # Open the file for writing and overwrite its content with the modified content\n", | |
| " with open(file_path, \"w\") as file:\n", | |
| " file.write(new_content)\n", | |
| "\n", | |
| " print(\"The file has been updated successfully.\")\n", | |
| " else:\n", | |
| " print(\"The specified file does not exist:\", file_path)\n", | |
| "else:\n", | |
| " print(\"Failed to find the installation location for 'basicsr'. Please check the package installation.\")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "iN5wA5yCS2Th", | |
| "outputId": "a6979650-f5d7-492a-e23f-b157ad53095a" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "/content/Real-ESRGAN/BSRGAN\n", | |
| "LogHandlers setup!\n", | |
| "24-05-15 03:45:52.967 : Model Name : BSRGAN\n", | |
| "24-05-15 03:45:52.983 : GPU ID : 0\n", | |
| "[3, 3, 64, 23, 32, 4]\n", | |
| "24-05-15 03:45:53.578 : Input Path : testsets/RealSRSet\n", | |
| "24-05-15 03:45:53.578 : Output Path : testsets/RealSRSet_results_x4\n", | |
| "24-05-15 03:45:53.578 : ---1 --> BSRGAN --> x4--> OST_009.png\n", | |
| "24-05-15 03:45:56.774 : ---2 --> BSRGAN --> x4--> building.png\n", | |
| "24-05-15 03:45:57.025 : ---3 --> BSRGAN --> x4--> dped_crop00061.png\n", | |
| "24-05-15 03:45:57.598 : ---4 --> BSRGAN --> x4--> foreman.png\n", | |
| "24-05-15 03:45:58.239 : ---5 --> BSRGAN --> x4--> frog.png\n", | |
| "24-05-15 03:45:58.796 : ---6 --> BSRGAN --> x4--> oldphoto6.png\n", | |
| "/content/Real-ESRGAN\n", | |
| "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", | |
| " warnings.warn(\n", | |
| "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", | |
| " warnings.warn(msg)\n", | |
| "Testing 0 OST_009\n", | |
| "Testing 1 building\n", | |
| "Testing 2 dped_crop00061\n", | |
| "Testing 3 foreman\n", | |
| "Testing 4 frog\n", | |
| "Testing 5 oldphoto6\n", | |
| "loading model from experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth\n", | |
| "/usr/local/lib/python3.10/dist-packages/torch/functional.py:507: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3549.)\n", | |
| " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", | |
| "Testing 0 OST_009 \n", | |
| "Testing 1 building \n", | |
| "Testing 2 dped_crop00061 \n", | |
| "Testing 3 foreman \n", | |
| "Testing 4 frog \n", | |
| "Testing 5 oldphoto6 \n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# BSRGAN\n", | |
| "!rm -r results\n", | |
| "if not test_patch_wise:\n", | |
| " %cd BSRGAN\n", | |
| " !python main_test_bsrgan.py\n", | |
| " %cd ..\n", | |
| " shutil.move('BSRGAN/testsets/RealSRSet_results_x4', 'results/BSRGAN')\n", | |
| "\n", | |
| "# realESRGAN\n", | |
| "if test_patch_wise:\n", | |
| " !python inference_realesrgan.py -n RealESRGAN_x4plus --input BSRGAN/testsets/RealSRSet -s 4 --output results/realESRGAN --tile 800 --face_enhance\n", | |
| "else:\n", | |
| " !python inference_realesrgan.py -n RealESRGAN_x4plus --input BSRGAN/testsets/RealSRSet -s 4 --output results/realESRGAN --face_enhance\n", | |
| "\n", | |
| "# SwinIR-Large\n", | |
| "if test_patch_wise:\n", | |
| " !python SwinIR/main_test_swinir.py --task real_sr --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth --folder_lq BSRGAN/testsets/RealSRSet --scale 4 --large_model --tile 640\n", | |
| "else:\n", | |
| " !python SwinIR/main_test_swinir.py --task real_sr --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth --folder_lq BSRGAN/testsets/RealSRSet --scale 4 --large_model\n", | |
| "shutil.move('results/swinir_real_sr_x4_large', 'results/SwinIR_large')\n", | |
| "for path in sorted(glob.glob(os.path.join('results/SwinIR_large', '*.png'))):\n", | |
| " os.rename(path, path.replace('SwinIR.png', 'SwinIR_large.png')) # here is a bug in Colab file downloading: no same-name files\n", | |
| "\n", | |
| "\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true, | |
| "base_uri": "https://localhost:8080/", | |
| "height": 1000 | |
| }, | |
| "id": "7sdfXV7tTXVs", | |
| "outputId": "5e7177b2-75dc-43f7-b39d-cdc2adbcc3e4" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Note: BSRGAN may be better at face restoration, but worse at building restoration because it uses different datasets in training.\n", | |
| "\n", | |
| "\n", | |
| "Note: BSRGAN may be better at face restoration, but worse at building restoration because it uses different datasets in training.\n", | |
| "\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment