Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save Entreprenerdly/395060b709d797229c45c16412eb316f to your computer and use it in GitHub Desktop.

Select an option

Save Entreprenerdly/395060b709d797229c45c16412eb316f to your computer and use it in GitHub Desktop.
Entreprenerdly.com - Restoring Image Quality With AI using Real ESRGAN SwinIR and BSRGAN.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0mDz75f2i7CM"
},
"outputs": [],
"source": [
"\"\"\"\n",
"This Jupyter notebook is a supplementary resource for the article\n",
"'Restoring Image Quality With AI using Real-ESRGAN and SwinIR' on Entreprenerdly.com.\n",
"It contains all the code snippets and examples discussed in the article,\n",
"providing a hands-on approach to understanding the concepts and techniques presented.\n",
"For a comprehensive understanding, please refer to the article at\n",
"https://entreprenerdly.com/restoring-image-quality-with-ai-using-real-esrgan-and-swinir/\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ICOfsSv5i5mM"
},
"source": [
"<div style=\"text-align: center;\">\n",
" <img src=\"https://entreprenerdly.com/wp-content/uploads/2024/03/logo-com.png\" width=\"50%\" alt=\"Fine-tuning BLIB2\">\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dskuEwGcjQUm"
},
"source": [
"# 1) Preparations\n",
"\n",
"Before start, make sure that you choose\n",
"* Runtime Type = Python 3\n",
"* Hardware Accelerator = GPU\n",
"* Broswer != Firefox (cannot upload images in step 2)\n",
"\n",
"in the **Runtime** menu -> **Change runtime type**\n",
"\n",
"Then, we clone the repository, set up the envrironment, and download the pre-trained model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ugfa6HKSjYdZ"
},
"outputs": [],
"source": [
"# Clone realESRGAN\n",
"!git clone https://github.com/xinntao/Real-ESRGAN.git\n",
"%cd Real-ESRGAN\n",
"# Set up the environment\n",
"!pip install basicsr\n",
"!pip install facexlib\n",
"!pip install gfpgan\n",
"!pip install -r requirements.txt\n",
"!python setup.py develop\n",
"\n",
"# Clone BSRGAN\n",
"!git clone https://github.com/cszn/BSRGAN.git\n",
"\n",
"!rm -r SwinIR\n",
"# Clone SwinIR\n",
"!git clone https://github.com/JingyunLiang/SwinIR.git\n",
"!pip install timm\n",
"\n",
"# Download the pre-trained models\n",
"!wget https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth -P BSRGAN/model_zoo\n",
"!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models\n",
"#!wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth -P experiments/pretrained_models\n",
"!wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth -P experiments/pretrained_models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 443
},
"id": "9WlpWEHcosui",
"outputId": "765a09df-f5f3-4ec4-f28f-09462ffc0a30"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Note1: You can find an image on the web or download images from the RealSRSet (proposed in BSRGAN, ICCV2021) at https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/RealSRSet+5images.zip.\n",
" Note2: You may need Chrome to enable file uploading!\n",
" Note3: If out-of-memory, set test_patch_wise = True.\n",
"\n"
]
},
{
"data": {
"text/html": [
"\n",
" <input type=\"file\" id=\"files-98f52cd3-f42a-482e-b96d-c99026fb9a6f\" name=\"files[]\" multiple disabled\n",
" style=\"border:none\" />\n",
" <output id=\"result-98f52cd3-f42a-482e-b96d-c99026fb9a6f\">\n",
" Upload widget is only available when the cell has been executed in the\n",
" current browser session. Please rerun this cell to enable.\n",
" </output>\n",
" <script>// Copyright 2017 Google LLC\n",
"//\n",
"// Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"// you may not use this file except in compliance with the License.\n",
"// You may obtain a copy of the License at\n",
"//\n",
"// http://www.apache.org/licenses/LICENSE-2.0\n",
"//\n",
"// Unless required by applicable law or agreed to in writing, software\n",
"// distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"// See the License for the specific language governing permissions and\n",
"// limitations under the License.\n",
"\n",
"/**\n",
" * @fileoverview Helpers for google.colab Python module.\n",
" */\n",
"(function(scope) {\n",
"function span(text, styleAttributes = {}) {\n",
" const element = document.createElement('span');\n",
" element.textContent = text;\n",
" for (const key of Object.keys(styleAttributes)) {\n",
" element.style[key] = styleAttributes[key];\n",
" }\n",
" return element;\n",
"}\n",
"\n",
"// Max number of bytes which will be uploaded at a time.\n",
"const MAX_PAYLOAD_SIZE = 100 * 1024;\n",
"\n",
"function _uploadFiles(inputId, outputId) {\n",
" const steps = uploadFilesStep(inputId, outputId);\n",
" const outputElement = document.getElementById(outputId);\n",
" // Cache steps on the outputElement to make it available for the next call\n",
" // to uploadFilesContinue from Python.\n",
" outputElement.steps = steps;\n",
"\n",
" return _uploadFilesContinue(outputId);\n",
"}\n",
"\n",
"// This is roughly an async generator (not supported in the browser yet),\n",
"// where there are multiple asynchronous steps and the Python side is going\n",
"// to poll for completion of each step.\n",
"// This uses a Promise to block the python side on completion of each step,\n",
"// then passes the result of the previous step as the input to the next step.\n",
"function _uploadFilesContinue(outputId) {\n",
" const outputElement = document.getElementById(outputId);\n",
" const steps = outputElement.steps;\n",
"\n",
" const next = steps.next(outputElement.lastPromiseValue);\n",
" return Promise.resolve(next.value.promise).then((value) => {\n",
" // Cache the last promise value to make it available to the next\n",
" // step of the generator.\n",
" outputElement.lastPromiseValue = value;\n",
" return next.value.response;\n",
" });\n",
"}\n",
"\n",
"/**\n",
" * Generator function which is called between each async step of the upload\n",
" * process.\n",
" * @param {string} inputId Element ID of the input file picker element.\n",
" * @param {string} outputId Element ID of the output display.\n",
" * @return {!Iterable<!Object>} Iterable of next steps.\n",
" */\n",
"function* uploadFilesStep(inputId, outputId) {\n",
" const inputElement = document.getElementById(inputId);\n",
" inputElement.disabled = false;\n",
"\n",
" const outputElement = document.getElementById(outputId);\n",
" outputElement.innerHTML = '';\n",
"\n",
" const pickedPromise = new Promise((resolve) => {\n",
" inputElement.addEventListener('change', (e) => {\n",
" resolve(e.target.files);\n",
" });\n",
" });\n",
"\n",
" const cancel = document.createElement('button');\n",
" inputElement.parentElement.appendChild(cancel);\n",
" cancel.textContent = 'Cancel upload';\n",
" const cancelPromise = new Promise((resolve) => {\n",
" cancel.onclick = () => {\n",
" resolve(null);\n",
" };\n",
" });\n",
"\n",
" // Wait for the user to pick the files.\n",
" const files = yield {\n",
" promise: Promise.race([pickedPromise, cancelPromise]),\n",
" response: {\n",
" action: 'starting',\n",
" }\n",
" };\n",
"\n",
" cancel.remove();\n",
"\n",
" // Disable the input element since further picks are not allowed.\n",
" inputElement.disabled = true;\n",
"\n",
" if (!files) {\n",
" return {\n",
" response: {\n",
" action: 'complete',\n",
" }\n",
" };\n",
" }\n",
"\n",
" for (const file of files) {\n",
" const li = document.createElement('li');\n",
" li.append(span(file.name, {fontWeight: 'bold'}));\n",
" li.append(span(\n",
" `(${file.type || 'n/a'}) - ${file.size} bytes, ` +\n",
" `last modified: ${\n",
" file.lastModifiedDate ? file.lastModifiedDate.toLocaleDateString() :\n",
" 'n/a'} - `));\n",
" const percent = span('0% done');\n",
" li.appendChild(percent);\n",
"\n",
" outputElement.appendChild(li);\n",
"\n",
" const fileDataPromise = new Promise((resolve) => {\n",
" const reader = new FileReader();\n",
" reader.onload = (e) => {\n",
" resolve(e.target.result);\n",
" };\n",
" reader.readAsArrayBuffer(file);\n",
" });\n",
" // Wait for the data to be ready.\n",
" let fileData = yield {\n",
" promise: fileDataPromise,\n",
" response: {\n",
" action: 'continue',\n",
" }\n",
" };\n",
"\n",
" // Use a chunked sending to avoid message size limits. See b/62115660.\n",
" let position = 0;\n",
" do {\n",
" const length = Math.min(fileData.byteLength - position, MAX_PAYLOAD_SIZE);\n",
" const chunk = new Uint8Array(fileData, position, length);\n",
" position += length;\n",
"\n",
" const base64 = btoa(String.fromCharCode.apply(null, chunk));\n",
" yield {\n",
" response: {\n",
" action: 'append',\n",
" file: file.name,\n",
" data: base64,\n",
" },\n",
" };\n",
"\n",
" let percentDone = fileData.byteLength === 0 ?\n",
" 100 :\n",
" Math.round((position / fileData.byteLength) * 100);\n",
" percent.textContent = `${percentDone}% done`;\n",
"\n",
" } while (position < fileData.byteLength);\n",
" }\n",
"\n",
" // All done.\n",
" yield {\n",
" response: {\n",
" action: 'complete',\n",
" }\n",
" };\n",
"}\n",
"\n",
"scope.google = scope.google || {};\n",
"scope.google.colab = scope.google.colab || {};\n",
"scope.google.colab._files = {\n",
" _uploadFiles,\n",
" _uploadFilesContinue,\n",
"};\n",
"})(self);\n",
"</script> "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving building.png to building.png\n",
"Saving dped_crop00061.png to dped_crop00061.png\n",
"Saving foreman.png to foreman.png\n",
"Saving frog.png to frog.png\n",
"Saving oldphoto6.png to oldphoto6.png\n",
"Saving OST_009.png to OST_009.png\n",
"move building.png to BSRGAN/testsets/RealSRSet/building.png\n",
"move dped_crop00061.png to BSRGAN/testsets/RealSRSet/dped_crop00061.png\n",
"move foreman.png to BSRGAN/testsets/RealSRSet/foreman.png\n",
"move frog.png to BSRGAN/testsets/RealSRSet/frog.png\n",
"move oldphoto6.png to BSRGAN/testsets/RealSRSet/oldphoto6.png\n",
"move OST_009.png to BSRGAN/testsets/RealSRSet/OST_009.png\n"
]
}
],
"source": [
"import os\n",
"import glob\n",
"from google.colab import files\n",
"import shutil\n",
"print(' Note1: You can find an image on the web or download images from the RealSRSet (proposed in BSRGAN, ICCV2021) at https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/RealSRSet+5images.zip.\\n Note2: You may need Chrome to enable file uploading!\\n Note3: If out-of-memory, set test_patch_wise = True.\\n')\n",
"\n",
"# test SwinIR by partioning the image into patches\n",
"test_patch_wise = False\n",
"\n",
"# to be compatible with BSRGAN\n",
"!rm -r BSRGAN/testsets/RealSRSet\n",
"upload_folder = 'BSRGAN/testsets/RealSRSet'\n",
"result_folder = 'results'\n",
"\n",
"if os.path.isdir(upload_folder):\n",
" shutil.rmtree(upload_folder)\n",
"if os.path.isdir(result_folder):\n",
" shutil.rmtree(result_folder)\n",
"os.mkdir(upload_folder)\n",
"os.mkdir(result_folder)\n",
"\n",
"# upload images\n",
"uploaded = files.upload()\n",
"for filename in uploaded.keys():\n",
" dst_path = os.path.join(upload_folder, filename)\n",
" print(f'move {filename} to {dst_path}')\n",
" shutil.move(filename, dst_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vm8Kdss6Wamh"
},
"outputs": [],
"source": [
"# empty cache with torch\n",
"import torch\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ERBZIETTSHpp",
"outputId": "d490cadc-2301-4424-ab58-41c1e4e1bd06"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The file has been updated successfully.\n"
]
}
],
"source": [
"import os\n",
"\n",
"# Use IPython magic command to get the package location\n",
"file_paths = !pip show basicsr | grep \"Location\"\n",
"if file_paths:\n",
" file_path = os.path.join(file_paths[0].split(\": \")[1], \"basicsr/data/degradations.py\")\n",
"\n",
" # Check if the file exists\n",
" if os.path.exists(file_path):\n",
" # Open the file for reading\n",
" with open(file_path, \"r\") as file:\n",
" file_content = file.read()\n",
"\n",
" # Replace the problematic import statement\n",
" new_content = file_content.replace(\n",
" \"from torchvision.transforms.functional_tensor import rgb_to_grayscale\",\n",
" \"from torchvision.transforms._functional_tensor import rgb_to_grayscale\"\n",
" )\n",
"\n",
" # Open the file for writing and overwrite its content with the modified content\n",
" with open(file_path, \"w\") as file:\n",
" file.write(new_content)\n",
"\n",
" print(\"The file has been updated successfully.\")\n",
" else:\n",
" print(\"The specified file does not exist:\", file_path)\n",
"else:\n",
" print(\"Failed to find the installation location for 'basicsr'. Please check the package installation.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iN5wA5yCS2Th",
"outputId": "a6979650-f5d7-492a-e23f-b157ad53095a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/content/Real-ESRGAN/BSRGAN\n",
"LogHandlers setup!\n",
"24-05-15 03:45:52.967 : Model Name : BSRGAN\n",
"24-05-15 03:45:52.983 : GPU ID : 0\n",
"[3, 3, 64, 23, 32, 4]\n",
"24-05-15 03:45:53.578 : Input Path : testsets/RealSRSet\n",
"24-05-15 03:45:53.578 : Output Path : testsets/RealSRSet_results_x4\n",
"24-05-15 03:45:53.578 : ---1 --> BSRGAN --> x4--> OST_009.png\n",
"24-05-15 03:45:56.774 : ---2 --> BSRGAN --> x4--> building.png\n",
"24-05-15 03:45:57.025 : ---3 --> BSRGAN --> x4--> dped_crop00061.png\n",
"24-05-15 03:45:57.598 : ---4 --> BSRGAN --> x4--> foreman.png\n",
"24-05-15 03:45:58.239 : ---5 --> BSRGAN --> x4--> frog.png\n",
"24-05-15 03:45:58.796 : ---6 --> BSRGAN --> x4--> oldphoto6.png\n",
"/content/Real-ESRGAN\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
" warnings.warn(msg)\n",
"Testing 0 OST_009\n",
"Testing 1 building\n",
"Testing 2 dped_crop00061\n",
"Testing 3 foreman\n",
"Testing 4 frog\n",
"Testing 5 oldphoto6\n",
"loading model from experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth\n",
"/usr/local/lib/python3.10/dist-packages/torch/functional.py:507: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3549.)\n",
" return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
"Testing 0 OST_009 \n",
"Testing 1 building \n",
"Testing 2 dped_crop00061 \n",
"Testing 3 foreman \n",
"Testing 4 frog \n",
"Testing 5 oldphoto6 \n"
]
}
],
"source": [
"# BSRGAN\n",
"!rm -r results\n",
"if not test_patch_wise:\n",
" %cd BSRGAN\n",
" !python main_test_bsrgan.py\n",
" %cd ..\n",
" shutil.move('BSRGAN/testsets/RealSRSet_results_x4', 'results/BSRGAN')\n",
"\n",
"# realESRGAN\n",
"if test_patch_wise:\n",
" !python inference_realesrgan.py -n RealESRGAN_x4plus --input BSRGAN/testsets/RealSRSet -s 4 --output results/realESRGAN --tile 800 --face_enhance\n",
"else:\n",
" !python inference_realesrgan.py -n RealESRGAN_x4plus --input BSRGAN/testsets/RealSRSet -s 4 --output results/realESRGAN --face_enhance\n",
"\n",
"# SwinIR-Large\n",
"if test_patch_wise:\n",
" !python SwinIR/main_test_swinir.py --task real_sr --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth --folder_lq BSRGAN/testsets/RealSRSet --scale 4 --large_model --tile 640\n",
"else:\n",
" !python SwinIR/main_test_swinir.py --task real_sr --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth --folder_lq BSRGAN/testsets/RealSRSet --scale 4 --large_model\n",
"shutil.move('results/swinir_real_sr_x4_large', 'results/SwinIR_large')\n",
"for path in sorted(glob.glob(os.path.join('results/SwinIR_large', '*.png'))):\n",
" os.rename(path, path.replace('SwinIR.png', 'SwinIR_large.png')) # here is a bug in Colab file downloading: no same-name files\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true,
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "7sdfXV7tTXVs",
"outputId": "5e7177b2-75dc-43f7-b39d-cdc2adbcc3e4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: BSRGAN may be better at face restoration, but worse at building restoration because it uses different datasets in training.\n",
"\n",
"\n",
"Note: BSRGAN may be better at face restoration, but worse at building restoration because it uses different datasets in training.\n",
"\n",
"\n"
]
},
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment