Last active
February 9, 2026 23:27
-
-
Save tvercaut/5dcfd7130c80bf9e87e3aa84da76a1e3 to your computer and use it in GitHub Desktop.
malvar_he_cutler_demosaic_pytorch.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "authorship_tag": "ABX9TyNLS14UWkHFg3AynYBE4FUn", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/tvercaut/5dcfd7130c80bf9e87e3aa84da76a1e3/malvar_he_cutler_demosaic_pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "id": "Nc0gZGJjEh0u" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn.functional as F\n", | |
| "import time\n", | |
| "import numpy as np\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "\n", | |
| "# Check CUDA availability\n", | |
| "use_cuda = torch.cuda.is_available()\n", | |
| "torchdevice = torch.device('cuda' if use_cuda else 'cpu')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# ============================================================================\n", | |
| "# BASIC IMPLEMENTATION\n", | |
| "# ============================================================================\n", | |
| "def malvar_he_cutler_demosaic_original(bayer_image, pattern='RGGB'):\n", | |
| " \"\"\"Original implementation with masks and kernel cloning.\"\"\"\n", | |
| " if bayer_image.dim() == 2:\n", | |
| " bayer_image = bayer_image.unsqueeze(0).unsqueeze(0)\n", | |
| " squeeze_output = True\n", | |
| " elif bayer_image.dim() == 3:\n", | |
| " bayer_image = bayer_image.unsqueeze(0)\n", | |
| " squeeze_output = True\n", | |
| " else:\n", | |
| " squeeze_output = False\n", | |
| "\n", | |
| " device = bayer_image.device\n", | |
| " dtype = bayer_image.dtype\n", | |
| "\n", | |
| " # Define the Malvar-He-Cutler filters\n", | |
| " G_at_R = torch.tensor([\n", | |
| " [0, 0, -1, 0, 0],\n", | |
| " [0, 0, 2, 0, 0],\n", | |
| " [-1, 2, 4, 2, -1],\n", | |
| " [0, 0, 2, 0, 0],\n", | |
| " [0, 0, -1, 0, 0]\n", | |
| " ], dtype=dtype, device=device) / 8.0\n", | |
| "\n", | |
| " G_at_B = G_at_R.clone()\n", | |
| "\n", | |
| " R_at_G_RRow = torch.tensor([\n", | |
| " [0, 0, 0.5, 0, 0],\n", | |
| " [0, -1, 0, -1, 0],\n", | |
| " [-1, 4, 5, 4, -1],\n", | |
| " [0, -1, 0, -1, 0],\n", | |
| " [0, 0, 0.5, 0, 0]\n", | |
| " ], dtype=dtype, device=device) / 8.0\n", | |
| "\n", | |
| " R_at_G_BRow = torch.tensor([\n", | |
| " [0, 0, -1, 0, 0],\n", | |
| " [0, -1, 4, -1, 0],\n", | |
| " [0.5, 0, 5, 0, 0.5],\n", | |
| " [0, -1, 4, -1, 0],\n", | |
| " [0, 0, -1, 0, 0]\n", | |
| " ], dtype=dtype, device=device) / 8.0\n", | |
| "\n", | |
| " R_at_B = torch.tensor([\n", | |
| " [0, 0, -1.5, 0, 0],\n", | |
| " [0, 2, 0, 2, 0],\n", | |
| " [-1.5, 0, 6, 0, -1.5],\n", | |
| " [0, 2, 0, 2, 0],\n", | |
| " [0, 0, -1.5, 0, 0]\n", | |
| " ], dtype=dtype, device=device) / 8.0\n", | |
| "\n", | |
| " B_at_R = R_at_B.clone()\n", | |
| " B_at_G_RRow = R_at_G_BRow.clone()\n", | |
| " B_at_G_BRow = R_at_G_RRow.clone()\n", | |
| "\n", | |
| " B, C, H, W = bayer_image.shape\n", | |
| "\n", | |
| " pattern_dict = {\n", | |
| " 'RGGB': {'R': (0, 0), 'G1': (0, 1), 'G2': (1, 0), 'B': (1, 1)},\n", | |
| " 'BGGR': {'B': (0, 0), 'G1': (0, 1), 'G2': (1, 0), 'R': (1, 1)},\n", | |
| " 'GRBG': {'G1': (0, 0), 'R': (0, 1), 'B': (1, 0), 'G2': (1, 1)},\n", | |
| " 'GBRG': {'G1': (0, 0), 'B': (0, 1), 'R': (1, 0), 'G2': (1, 1)},\n", | |
| " }\n", | |
| "\n", | |
| " positions = pattern_dict[pattern]\n", | |
| "\n", | |
| " mask_R = torch.zeros((H, W), dtype=dtype, device=device)\n", | |
| " mask_G = torch.zeros((H, W), dtype=dtype, device=device)\n", | |
| " mask_B = torch.zeros((H, W), dtype=dtype, device=device)\n", | |
| "\n", | |
| " r_y, r_x = positions['R']\n", | |
| " mask_R[r_y::2, r_x::2] = 1\n", | |
| "\n", | |
| " b_y, b_x = positions['B']\n", | |
| " mask_B[b_y::2, b_x::2] = 1\n", | |
| "\n", | |
| " g1_y, g1_x = positions['G1']\n", | |
| " g2_y, g2_x = positions['G2']\n", | |
| " mask_G[g1_y::2, g1_x::2] = 1\n", | |
| " mask_G[g2_y::2, g2_x::2] = 1\n", | |
| "\n", | |
| " mask_G_RRow = torch.zeros((H, W), dtype=dtype, device=device)\n", | |
| " mask_G_BRow = torch.zeros((H, W), dtype=dtype, device=device)\n", | |
| " mask_G_RRow[r_y::2, 1-r_x::2] = 1\n", | |
| " mask_G_BRow[b_y::2, 1-b_x::2] = 1\n", | |
| "\n", | |
| " mask_R = mask_R.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n", | |
| " mask_G = mask_G.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n", | |
| " mask_B = mask_B.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n", | |
| " mask_G_RRow = mask_G_RRow.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n", | |
| " mask_G_BRow = mask_G_BRow.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n", | |
| "\n", | |
| " #bayer_image_padded = F.pad(bayer_image, pad=(2, 2, 2, 2), mode='reflect')\n", | |
| " bayer_image_padded = F.pixel_shuffle(F.pad(\n", | |
| " F.pixel_unshuffle(bayer_image, 2), pad=(1, 1, 1, 1), mode='reflect'), 2)\n", | |
| "\n", | |
| " def apply_filter(kernel):\n", | |
| " kernel = kernel.unsqueeze(0).unsqueeze(0)\n", | |
| " return F.conv2d(bayer_image_padded, kernel, padding=0)\n", | |
| "\n", | |
| " G = bayer_image * mask_G\n", | |
| " G = G + apply_filter(G_at_R) * mask_R\n", | |
| " G = G + apply_filter(G_at_B) * mask_B\n", | |
| "\n", | |
| " R = bayer_image * mask_R\n", | |
| " R = R + apply_filter(R_at_G_RRow) * mask_G_RRow\n", | |
| " R = R + apply_filter(R_at_G_BRow) * mask_G_BRow\n", | |
| " R = R + apply_filter(R_at_B) * mask_B\n", | |
| "\n", | |
| " B_channel = bayer_image * mask_B\n", | |
| " B_channel = B_channel + apply_filter(B_at_G_RRow) * mask_G_RRow\n", | |
| " B_channel = B_channel + apply_filter(B_at_G_BRow) * mask_G_BRow\n", | |
| " B_channel = B_channel + apply_filter(B_at_R) * mask_R\n", | |
| "\n", | |
| " rgb = torch.cat([R, G, B_channel], dim=1)\n", | |
| " # TODO - cleanup the clamping logic\n", | |
| " rgb = torch.clamp(rgb, 0, 1) if rgb.max() <= 1 else torch.clamp(rgb, 0, 255)\n", | |
| "\n", | |
| " if squeeze_output:\n", | |
| " rgb = rgb.squeeze(0)\n", | |
| "\n", | |
| " return rgb" | |
| ], | |
| "metadata": { | |
| "id": "iR0UgfAEFD6F" | |
| }, | |
| "execution_count": 2, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# ============================================================================\n", | |
| "# OPTIMIZED IMPLEMENTATION\n", | |
| "# ============================================================================\n", | |
| "def malvar_he_cutler_demosaic_optimized(bayer_image: torch.Tensor) -> torch.Tensor:\n", | |
| " \"\"\"\n", | |
| " Malvar–He–Cutler demosaicing using\n", | |
| " PixelUnshuffle + explicit folded kernels + PixelShuffle.\n", | |
| "\n", | |
| " Args:\n", | |
| " bayer_image: (B, 1, H, W), RGGB\n", | |
| "\n", | |
| " Returns:\n", | |
| " rgb: (B, 3, H, W)\n", | |
| " \"\"\"\n", | |
| " assert bayer_image.ndim == 4 and bayer_image.shape[1] == 1\n", | |
| " B, _, H, W = bayer_image.shape\n", | |
| " device, dtype = bayer_image.device, bayer_image.dtype\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 1. Pixel unshuffle\n", | |
| " # ------------------------------------------------------------\n", | |
| " x = F.pixel_unshuffle(bayer_image, 2) # (B,4,H/2,W/2)\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 2. Explicit folded 3×3 kernels (8,4,3,3)\n", | |
| " # ------------------------------------------------------------\n", | |
| " K = torch.zeros((8, 4, 3, 3), device=device, dtype=dtype)\n", | |
| " # TODO: Check if dtype is floating point and otherwise use only integers\n", | |
| " # All kernels should be muliplied by 16 and the output image divided by 16\n", | |
| "\n", | |
| " # TODO: Reorder kernels to match the pixel shuffle operation to the extent possible\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 0. Green at R (OK)\n", | |
| " # ------------------------------------------------------------\n", | |
| " K[0] = torch.tensor([\n", | |
| " # R channel\n", | |
| " [[0, -0.125, 0 ],\n", | |
| " [-0.125, 0.5, -0.125],\n", | |
| " [0, -0.125, 0 ]],\n", | |
| "\n", | |
| " # G1 channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0.25, 0.25, 0],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # G2 channel\n", | |
| " [[0, 0.25, 0],\n", | |
| " [0, 0.25, 0],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # B channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| " ], device=device, dtype=dtype)\n", | |
| "\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 1. Blue at R (OK)\n", | |
| " # ------------------------------------------------------------\n", | |
| " K[1] = torch.tensor([\n", | |
| " # R channel\n", | |
| " [[0, -0.1875, 0],\n", | |
| " [-0.1875, 0.75, -0.1875],\n", | |
| " [0, -0.1875, 0]],\n", | |
| "\n", | |
| " # G1 channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # G2 channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # B channel\n", | |
| " [[0.25, 0.25, 0],\n", | |
| " [0.25, 0.25, 0],\n", | |
| " [0, 0, 0]],\n", | |
| " ], device=device, dtype=dtype)\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 2. Red at G1 (OK)\n", | |
| " # ------------------------------------------------------------\n", | |
| " K[2] = torch.tensor([\n", | |
| " # R channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0.5, 0.5],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # G1 channel\n", | |
| " [[0, 0.0625, 0],\n", | |
| " [-0.125, 0.625, -0.125],\n", | |
| " [0, 0.0625, 0]],\n", | |
| "\n", | |
| " # G2 channel\n", | |
| " [[0, -0.125, -0.125],\n", | |
| " [0, -0.125, -0.125],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # B channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| " ], device=device, dtype=dtype)\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 3. Blue at G1 (OK)\n", | |
| " # ------------------------------------------------------------\n", | |
| " K[3] = torch.tensor([\n", | |
| " # R channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # G1 channel\n", | |
| " [[0, -0.125, 0],\n", | |
| " [0.0625, 0.625, 0.0625],\n", | |
| " [0, -0.125, 0]],\n", | |
| "\n", | |
| " # G2 channel\n", | |
| " [[0, -0.125, -0.125],\n", | |
| " [0, -0.125, -0.125],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # B channel\n", | |
| " [[0, 0.5, 0],\n", | |
| " [0, 0.5, 0],\n", | |
| " [0, 0, 0]],\n", | |
| " ], device=device, dtype=dtype)\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 4. Red at G2 (OK)\n", | |
| " # ------------------------------------------------------------\n", | |
| " K[4] = torch.tensor([\n", | |
| " # R channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0.5, 0],\n", | |
| " [0, 0.5, 0]],\n", | |
| "\n", | |
| " # G1 channel\n", | |
| " [[0, 0, 0],\n", | |
| " [-0.125, -0.125, 0],\n", | |
| " [-0.125, -0.125, 0]],\n", | |
| "\n", | |
| " # G2 channel\n", | |
| " [[0, -0.125, 0],\n", | |
| " [0.0625, 0.625, 0.0625],\n", | |
| " [0, -0.125, 0]],\n", | |
| "\n", | |
| " # B channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| " ], device=device, dtype=dtype)\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 5. Blue at G2 (OK)\n", | |
| " # ------------------------------------------------------------\n", | |
| " K[5] = torch.tensor([\n", | |
| " # R channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # G1 channel\n", | |
| " [[0, 0, 0],\n", | |
| " [-0.125, -0.125, 0],\n", | |
| " [-0.125, -0.125, 0]],\n", | |
| "\n", | |
| " # G2 channel\n", | |
| " [[0, 0.0625, 0],\n", | |
| " [-0.125, 0.625, -0.125],\n", | |
| " [0, 0.0625, 0]],\n", | |
| "\n", | |
| " # B channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0.5, 0.5, 0],\n", | |
| " [0, 0, 0]],\n", | |
| " ], device=device, dtype=dtype)\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 6. Green at B (OK)\n", | |
| " # ------------------------------------------------------------\n", | |
| " K[6] = torch.tensor([\n", | |
| " # R channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # G1 channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0.25, 0],\n", | |
| " [0, 0.25, 0]],\n", | |
| "\n", | |
| " # G2 channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0.25, 0.25],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # B channel\n", | |
| " [[0, -0.125, 0 ],\n", | |
| " [-0.125, 0.5, -0.125],\n", | |
| " [0, -0.125, 0 ]],\n", | |
| " ], device=device, dtype=dtype)\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 7. Red at B (OK)\n", | |
| " # ------------------------------------------------------------\n", | |
| " K[7] = torch.tensor([\n", | |
| " # R channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0.25, 0.25],\n", | |
| " [0, 0.25, 0.25]],\n", | |
| "\n", | |
| " # G1 channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # G2 channel\n", | |
| " [[0, 0, 0],\n", | |
| " [0, 0, 0],\n", | |
| " [0, 0, 0]],\n", | |
| "\n", | |
| " # B channel\n", | |
| " [[0, -0.1875, 0],\n", | |
| " [-0.1875, 0.75, -0.1875],\n", | |
| " [0, -0.1875, 0]],\n", | |
| " ], device=device, dtype=dtype)\n", | |
| "\n", | |
| " #print(torch.sum(K, dim=[1,2,3]))\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 3. Convolution\n", | |
| " # ------------------------------------------------------------\n", | |
| " padded_x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')\n", | |
| " interp = F.conv2d(padded_x, K, padding=0)\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 4. Concatenate known + interpolated\n", | |
| " # ------------------------------------------------------------\n", | |
| " fused = torch.cat([x, interp], dim=1) # (B,12,H/2,W/2)\n", | |
| "\n", | |
| " permute_idx = [\n", | |
| " # Red output group\n", | |
| " 0, # R\n", | |
| " 6, # R@G1\n", | |
| " 8, # R@G2\n", | |
| " 11, # R@B\n", | |
| "\n", | |
| " # Green output group\n", | |
| " 4, # G@R\n", | |
| " 1, # G1\n", | |
| " 2, # G2\n", | |
| " 10, # G@B\n", | |
| "\n", | |
| " # Blue output group\n", | |
| " 5, # B@R\n", | |
| " 7, # B@G1\n", | |
| " 9, # B@G2\n", | |
| " 3, # B\n", | |
| " ]\n", | |
| "\n", | |
| " fused_reordered = fused[:, permute_idx, :, :]\n", | |
| "\n", | |
| " # ------------------------------------------------------------\n", | |
| " # 5. Pixel shuffle → RGB\n", | |
| " # ------------------------------------------------------------\n", | |
| " rgb = F.pixel_shuffle(fused_reordered, 2)\n", | |
| "\n", | |
| " # TODO - cleanup the clamping logic\n", | |
| " rgb = torch.clamp(rgb, 0, 1) if rgb.max() <= 1 else torch.clamp(rgb, 0, 255)\n", | |
| "\n", | |
| " return rgb" | |
| ], | |
| "metadata": { | |
| "id": "2VF-9nOQHzg_" | |
| }, | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def simulate_bayer_rggb(rgb_image_tensor: torch.Tensor) -> torch.Tensor:\n", | |
| " \"\"\"\n", | |
| " Converts an RGB image tensor into a single-channel Bayer pattern (RGGB).\n", | |
| "\n", | |
| " Args:\n", | |
| " rgb_image_tensor (torch.Tensor): Input RGB image tensor of shape [B, 3, H, W]\n", | |
| " with values normalized between 0 and 1.\n", | |
| "\n", | |
| " Returns:\n", | |
| " torch.Tensor: Single-channel Bayer pattern image tensor of shape [B, 1, H, W].\n", | |
| " \"\"\"\n", | |
| " B, C, H, W = rgb_image_tensor.shape\n", | |
| " device = rgb_image_tensor.device\n", | |
| " dtype = rgb_image_tensor.dtype\n", | |
| "\n", | |
| " # Initialize an empty single-channel tensor for the Bayer pattern\n", | |
| " bayer_image = torch.zeros(B, 1, H, W, device=device, dtype=dtype)\n", | |
| "\n", | |
| " # Populate the Bayer pattern (RGGB)\n", | |
| " # R at (0,0)\n", | |
| " bayer_image[:, 0, 0::2, 0::2] = rgb_image_tensor[:, 0:1, 0::2, 0::2]\n", | |
| " # G1 at (0,1)\n", | |
| " bayer_image[:, 0, 0::2, 1::2] = rgb_image_tensor[:, 1:2, 0::2, 1::2]\n", | |
| " # G2 at (1,0)\n", | |
| " bayer_image[:, 0, 1::2, 0::2] = rgb_image_tensor[:, 1:2, 1::2, 0::2]\n", | |
| " # B at (1,1)\n", | |
| " bayer_image[:, 0, 1::2, 1::2] = rgb_image_tensor[:, 2:3, 1::2, 1::2]\n", | |
| "\n", | |
| " return bayer_image" | |
| ], | |
| "metadata": { | |
| "id": "g3Xulbo6CQTX" | |
| }, | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from skimage import data as skimdata\n", | |
| "npimage = skimdata.cat()\n", | |
| "\n", | |
| "# Make sure the size is even\n", | |
| "npimage = npimage[:npimage.shape[0]-(npimage.shape[0]%2),:npimage.shape[1]-(npimage.shape[1]%2),:]\n", | |
| "print(f\"npimage.shape={npimage.shape}\")\n", | |
| "\n", | |
| "batchedimage = torch.from_numpy(npimage).permute(2, 0, 1)[None,...].to(torch.float32)\n", | |
| "\n", | |
| "bayertensor = simulate_bayer_rggb(batchedimage)\n", | |
| "print(f\"bayertensor.shape={bayertensor.shape}\")\n", | |
| "\n", | |
| "demos_bim = malvar_he_cutler_demosaic_optimized(bayertensor)\n", | |
| "\n", | |
| "fig, axes = plt.subplots(1, 3, figsize=(25, 5))\n", | |
| "\n", | |
| "# Original RGB Image\n", | |
| "axes[0].imshow(npimage)\n", | |
| "axes[0].set_title('Original RGB Image')\n", | |
| "axes[0].axis('off')\n", | |
| "\n", | |
| "# Simulated Bayer Pattern\n", | |
| "# For single channel image, use a grayscale colormap\n", | |
| "axes[1].imshow(bayertensor.squeeze().cpu().numpy(), cmap='gray')\n", | |
| "axes[1].set_title('Simulated Bayer Pattern (RGGB)')\n", | |
| "axes[1].axis('off')\n", | |
| "\n", | |
| "# Demosaiced\n", | |
| "axes[2].imshow(demos_bim.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8))\n", | |
| "axes[2].set_title('Demosaiced (MHC)')\n", | |
| "axes[2].axis('off')" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 484 | |
| }, | |
| "id": "nfSIu0Zx_e5h", | |
| "outputId": "00362f6c-a47a-44d5-817d-0c7bbce0feb0" | |
| }, | |
| "execution_count": 5, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "npimage.shape=(300, 450, 3)\n", | |
| "bayertensor.shape=torch.Size([1, 1, 300, 450])\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "(np.float64(-0.5), np.float64(449.5), np.float64(299.5), np.float64(-0.5))" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 5 | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 2500x500 with 3 Axes>" | |
| ], |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment