Skip to content

Instantly share code, notes, and snippets.

@tvercaut
Last active February 9, 2026 23:27
Show Gist options
  • Select an option

  • Save tvercaut/5dcfd7130c80bf9e87e3aa84da76a1e3 to your computer and use it in GitHub Desktop.

Select an option

Save tvercaut/5dcfd7130c80bf9e87e3aa84da76a1e3 to your computer and use it in GitHub Desktop.
malvar_he_cutler_demosaic_pytorch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyNLS14UWkHFg3AynYBE4FUn",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/tvercaut/5dcfd7130c80bf9e87e3aa84da76a1e3/malvar_he_cutler_demosaic_pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Nc0gZGJjEh0u"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"import time\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Check CUDA availability\n",
"use_cuda = torch.cuda.is_available()\n",
"torchdevice = torch.device('cuda' if use_cuda else 'cpu')"
]
},
{
"cell_type": "code",
"source": [
"# ============================================================================\n",
"# BASIC IMPLEMENTATION\n",
"# ============================================================================\n",
"def malvar_he_cutler_demosaic_original(bayer_image, pattern='RGGB'):\n",
" \"\"\"Original implementation with masks and kernel cloning.\"\"\"\n",
" if bayer_image.dim() == 2:\n",
" bayer_image = bayer_image.unsqueeze(0).unsqueeze(0)\n",
" squeeze_output = True\n",
" elif bayer_image.dim() == 3:\n",
" bayer_image = bayer_image.unsqueeze(0)\n",
" squeeze_output = True\n",
" else:\n",
" squeeze_output = False\n",
"\n",
" device = bayer_image.device\n",
" dtype = bayer_image.dtype\n",
"\n",
" # Define the Malvar-He-Cutler filters\n",
" G_at_R = torch.tensor([\n",
" [0, 0, -1, 0, 0],\n",
" [0, 0, 2, 0, 0],\n",
" [-1, 2, 4, 2, -1],\n",
" [0, 0, 2, 0, 0],\n",
" [0, 0, -1, 0, 0]\n",
" ], dtype=dtype, device=device) / 8.0\n",
"\n",
" G_at_B = G_at_R.clone()\n",
"\n",
" R_at_G_RRow = torch.tensor([\n",
" [0, 0, 0.5, 0, 0],\n",
" [0, -1, 0, -1, 0],\n",
" [-1, 4, 5, 4, -1],\n",
" [0, -1, 0, -1, 0],\n",
" [0, 0, 0.5, 0, 0]\n",
" ], dtype=dtype, device=device) / 8.0\n",
"\n",
" R_at_G_BRow = torch.tensor([\n",
" [0, 0, -1, 0, 0],\n",
" [0, -1, 4, -1, 0],\n",
" [0.5, 0, 5, 0, 0.5],\n",
" [0, -1, 4, -1, 0],\n",
" [0, 0, -1, 0, 0]\n",
" ], dtype=dtype, device=device) / 8.0\n",
"\n",
" R_at_B = torch.tensor([\n",
" [0, 0, -1.5, 0, 0],\n",
" [0, 2, 0, 2, 0],\n",
" [-1.5, 0, 6, 0, -1.5],\n",
" [0, 2, 0, 2, 0],\n",
" [0, 0, -1.5, 0, 0]\n",
" ], dtype=dtype, device=device) / 8.0\n",
"\n",
" B_at_R = R_at_B.clone()\n",
" B_at_G_RRow = R_at_G_BRow.clone()\n",
" B_at_G_BRow = R_at_G_RRow.clone()\n",
"\n",
" B, C, H, W = bayer_image.shape\n",
"\n",
" pattern_dict = {\n",
" 'RGGB': {'R': (0, 0), 'G1': (0, 1), 'G2': (1, 0), 'B': (1, 1)},\n",
" 'BGGR': {'B': (0, 0), 'G1': (0, 1), 'G2': (1, 0), 'R': (1, 1)},\n",
" 'GRBG': {'G1': (0, 0), 'R': (0, 1), 'B': (1, 0), 'G2': (1, 1)},\n",
" 'GBRG': {'G1': (0, 0), 'B': (0, 1), 'R': (1, 0), 'G2': (1, 1)},\n",
" }\n",
"\n",
" positions = pattern_dict[pattern]\n",
"\n",
" mask_R = torch.zeros((H, W), dtype=dtype, device=device)\n",
" mask_G = torch.zeros((H, W), dtype=dtype, device=device)\n",
" mask_B = torch.zeros((H, W), dtype=dtype, device=device)\n",
"\n",
" r_y, r_x = positions['R']\n",
" mask_R[r_y::2, r_x::2] = 1\n",
"\n",
" b_y, b_x = positions['B']\n",
" mask_B[b_y::2, b_x::2] = 1\n",
"\n",
" g1_y, g1_x = positions['G1']\n",
" g2_y, g2_x = positions['G2']\n",
" mask_G[g1_y::2, g1_x::2] = 1\n",
" mask_G[g2_y::2, g2_x::2] = 1\n",
"\n",
" mask_G_RRow = torch.zeros((H, W), dtype=dtype, device=device)\n",
" mask_G_BRow = torch.zeros((H, W), dtype=dtype, device=device)\n",
" mask_G_RRow[r_y::2, 1-r_x::2] = 1\n",
" mask_G_BRow[b_y::2, 1-b_x::2] = 1\n",
"\n",
" mask_R = mask_R.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n",
" mask_G = mask_G.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n",
" mask_B = mask_B.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n",
" mask_G_RRow = mask_G_RRow.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n",
" mask_G_BRow = mask_G_BRow.unsqueeze(0).unsqueeze(0).expand(B, 1, H, W)\n",
"\n",
" #bayer_image_padded = F.pad(bayer_image, pad=(2, 2, 2, 2), mode='reflect')\n",
" bayer_image_padded = F.pixel_shuffle(F.pad(\n",
" F.pixel_unshuffle(bayer_image, 2), pad=(1, 1, 1, 1), mode='reflect'), 2)\n",
"\n",
" def apply_filter(kernel):\n",
" kernel = kernel.unsqueeze(0).unsqueeze(0)\n",
" return F.conv2d(bayer_image_padded, kernel, padding=0)\n",
"\n",
" G = bayer_image * mask_G\n",
" G = G + apply_filter(G_at_R) * mask_R\n",
" G = G + apply_filter(G_at_B) * mask_B\n",
"\n",
" R = bayer_image * mask_R\n",
" R = R + apply_filter(R_at_G_RRow) * mask_G_RRow\n",
" R = R + apply_filter(R_at_G_BRow) * mask_G_BRow\n",
" R = R + apply_filter(R_at_B) * mask_B\n",
"\n",
" B_channel = bayer_image * mask_B\n",
" B_channel = B_channel + apply_filter(B_at_G_RRow) * mask_G_RRow\n",
" B_channel = B_channel + apply_filter(B_at_G_BRow) * mask_G_BRow\n",
" B_channel = B_channel + apply_filter(B_at_R) * mask_R\n",
"\n",
" rgb = torch.cat([R, G, B_channel], dim=1)\n",
" # TODO - cleanup the clamping logic\n",
" rgb = torch.clamp(rgb, 0, 1) if rgb.max() <= 1 else torch.clamp(rgb, 0, 255)\n",
"\n",
" if squeeze_output:\n",
" rgb = rgb.squeeze(0)\n",
"\n",
" return rgb"
],
"metadata": {
"id": "iR0UgfAEFD6F"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# ============================================================================\n",
"# OPTIMIZED IMPLEMENTATION\n",
"# ============================================================================\n",
"def malvar_he_cutler_demosaic_optimized(bayer_image: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" Malvar–He–Cutler demosaicing using\n",
" PixelUnshuffle + explicit folded kernels + PixelShuffle.\n",
"\n",
" Args:\n",
" bayer_image: (B, 1, H, W), RGGB\n",
"\n",
" Returns:\n",
" rgb: (B, 3, H, W)\n",
" \"\"\"\n",
" assert bayer_image.ndim == 4 and bayer_image.shape[1] == 1\n",
" B, _, H, W = bayer_image.shape\n",
" device, dtype = bayer_image.device, bayer_image.dtype\n",
"\n",
" # ------------------------------------------------------------\n",
" # 1. Pixel unshuffle\n",
" # ------------------------------------------------------------\n",
" x = F.pixel_unshuffle(bayer_image, 2) # (B,4,H/2,W/2)\n",
"\n",
" # ------------------------------------------------------------\n",
" # 2. Explicit folded 3×3 kernels (8,4,3,3)\n",
" # ------------------------------------------------------------\n",
" K = torch.zeros((8, 4, 3, 3), device=device, dtype=dtype)\n",
" # TODO: Check if dtype is floating point and otherwise use only integers\n",
" # All kernels should be muliplied by 16 and the output image divided by 16\n",
"\n",
" # TODO: Reorder kernels to match the pixel shuffle operation to the extent possible\n",
" # ------------------------------------------------------------\n",
" # 0. Green at R (OK)\n",
" # ------------------------------------------------------------\n",
" K[0] = torch.tensor([\n",
" # R channel\n",
" [[0, -0.125, 0 ],\n",
" [-0.125, 0.5, -0.125],\n",
" [0, -0.125, 0 ]],\n",
"\n",
" # G1 channel\n",
" [[0, 0, 0],\n",
" [0.25, 0.25, 0],\n",
" [0, 0, 0]],\n",
"\n",
" # G2 channel\n",
" [[0, 0.25, 0],\n",
" [0, 0.25, 0],\n",
" [0, 0, 0]],\n",
"\n",
" # B channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
" ], device=device, dtype=dtype)\n",
"\n",
"\n",
" # ------------------------------------------------------------\n",
" # 1. Blue at R (OK)\n",
" # ------------------------------------------------------------\n",
" K[1] = torch.tensor([\n",
" # R channel\n",
" [[0, -0.1875, 0],\n",
" [-0.1875, 0.75, -0.1875],\n",
" [0, -0.1875, 0]],\n",
"\n",
" # G1 channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" # G2 channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" # B channel\n",
" [[0.25, 0.25, 0],\n",
" [0.25, 0.25, 0],\n",
" [0, 0, 0]],\n",
" ], device=device, dtype=dtype)\n",
"\n",
" # ------------------------------------------------------------\n",
" # 2. Red at G1 (OK)\n",
" # ------------------------------------------------------------\n",
" K[2] = torch.tensor([\n",
" # R channel\n",
" [[0, 0, 0],\n",
" [0, 0.5, 0.5],\n",
" [0, 0, 0]],\n",
"\n",
" # G1 channel\n",
" [[0, 0.0625, 0],\n",
" [-0.125, 0.625, -0.125],\n",
" [0, 0.0625, 0]],\n",
"\n",
" # G2 channel\n",
" [[0, -0.125, -0.125],\n",
" [0, -0.125, -0.125],\n",
" [0, 0, 0]],\n",
"\n",
" # B channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
" ], device=device, dtype=dtype)\n",
"\n",
" # ------------------------------------------------------------\n",
" # 3. Blue at G1 (OK)\n",
" # ------------------------------------------------------------\n",
" K[3] = torch.tensor([\n",
" # R channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" # G1 channel\n",
" [[0, -0.125, 0],\n",
" [0.0625, 0.625, 0.0625],\n",
" [0, -0.125, 0]],\n",
"\n",
" # G2 channel\n",
" [[0, -0.125, -0.125],\n",
" [0, -0.125, -0.125],\n",
" [0, 0, 0]],\n",
"\n",
" # B channel\n",
" [[0, 0.5, 0],\n",
" [0, 0.5, 0],\n",
" [0, 0, 0]],\n",
" ], device=device, dtype=dtype)\n",
"\n",
" # ------------------------------------------------------------\n",
" # 4. Red at G2 (OK)\n",
" # ------------------------------------------------------------\n",
" K[4] = torch.tensor([\n",
" # R channel\n",
" [[0, 0, 0],\n",
" [0, 0.5, 0],\n",
" [0, 0.5, 0]],\n",
"\n",
" # G1 channel\n",
" [[0, 0, 0],\n",
" [-0.125, -0.125, 0],\n",
" [-0.125, -0.125, 0]],\n",
"\n",
" # G2 channel\n",
" [[0, -0.125, 0],\n",
" [0.0625, 0.625, 0.0625],\n",
" [0, -0.125, 0]],\n",
"\n",
" # B channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
" ], device=device, dtype=dtype)\n",
"\n",
" # ------------------------------------------------------------\n",
" # 5. Blue at G2 (OK)\n",
" # ------------------------------------------------------------\n",
" K[5] = torch.tensor([\n",
" # R channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" # G1 channel\n",
" [[0, 0, 0],\n",
" [-0.125, -0.125, 0],\n",
" [-0.125, -0.125, 0]],\n",
"\n",
" # G2 channel\n",
" [[0, 0.0625, 0],\n",
" [-0.125, 0.625, -0.125],\n",
" [0, 0.0625, 0]],\n",
"\n",
" # B channel\n",
" [[0, 0, 0],\n",
" [0.5, 0.5, 0],\n",
" [0, 0, 0]],\n",
" ], device=device, dtype=dtype)\n",
"\n",
" # ------------------------------------------------------------\n",
" # 6. Green at B (OK)\n",
" # ------------------------------------------------------------\n",
" K[6] = torch.tensor([\n",
" # R channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" # G1 channel\n",
" [[0, 0, 0],\n",
" [0, 0.25, 0],\n",
" [0, 0.25, 0]],\n",
"\n",
" # G2 channel\n",
" [[0, 0, 0],\n",
" [0, 0.25, 0.25],\n",
" [0, 0, 0]],\n",
"\n",
" # B channel\n",
" [[0, -0.125, 0 ],\n",
" [-0.125, 0.5, -0.125],\n",
" [0, -0.125, 0 ]],\n",
" ], device=device, dtype=dtype)\n",
"\n",
" # ------------------------------------------------------------\n",
" # 7. Red at B (OK)\n",
" # ------------------------------------------------------------\n",
" K[7] = torch.tensor([\n",
" # R channel\n",
" [[0, 0, 0],\n",
" [0, 0.25, 0.25],\n",
" [0, 0.25, 0.25]],\n",
"\n",
" # G1 channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" # G2 channel\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" # B channel\n",
" [[0, -0.1875, 0],\n",
" [-0.1875, 0.75, -0.1875],\n",
" [0, -0.1875, 0]],\n",
" ], device=device, dtype=dtype)\n",
"\n",
" #print(torch.sum(K, dim=[1,2,3]))\n",
"\n",
" # ------------------------------------------------------------\n",
" # 3. Convolution\n",
" # ------------------------------------------------------------\n",
" padded_x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')\n",
" interp = F.conv2d(padded_x, K, padding=0)\n",
"\n",
" # ------------------------------------------------------------\n",
" # 4. Concatenate known + interpolated\n",
" # ------------------------------------------------------------\n",
" fused = torch.cat([x, interp], dim=1) # (B,12,H/2,W/2)\n",
"\n",
" permute_idx = [\n",
" # Red output group\n",
" 0, # R\n",
" 6, # R@G1\n",
" 8, # R@G2\n",
" 11, # R@B\n",
"\n",
" # Green output group\n",
" 4, # G@R\n",
" 1, # G1\n",
" 2, # G2\n",
" 10, # G@B\n",
"\n",
" # Blue output group\n",
" 5, # B@R\n",
" 7, # B@G1\n",
" 9, # B@G2\n",
" 3, # B\n",
" ]\n",
"\n",
" fused_reordered = fused[:, permute_idx, :, :]\n",
"\n",
" # ------------------------------------------------------------\n",
" # 5. Pixel shuffle → RGB\n",
" # ------------------------------------------------------------\n",
" rgb = F.pixel_shuffle(fused_reordered, 2)\n",
"\n",
" # TODO - cleanup the clamping logic\n",
" rgb = torch.clamp(rgb, 0, 1) if rgb.max() <= 1 else torch.clamp(rgb, 0, 255)\n",
"\n",
" return rgb"
],
"metadata": {
"id": "2VF-9nOQHzg_"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def simulate_bayer_rggb(rgb_image_tensor: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" Converts an RGB image tensor into a single-channel Bayer pattern (RGGB).\n",
"\n",
" Args:\n",
" rgb_image_tensor (torch.Tensor): Input RGB image tensor of shape [B, 3, H, W]\n",
" with values normalized between 0 and 1.\n",
"\n",
" Returns:\n",
" torch.Tensor: Single-channel Bayer pattern image tensor of shape [B, 1, H, W].\n",
" \"\"\"\n",
" B, C, H, W = rgb_image_tensor.shape\n",
" device = rgb_image_tensor.device\n",
" dtype = rgb_image_tensor.dtype\n",
"\n",
" # Initialize an empty single-channel tensor for the Bayer pattern\n",
" bayer_image = torch.zeros(B, 1, H, W, device=device, dtype=dtype)\n",
"\n",
" # Populate the Bayer pattern (RGGB)\n",
" # R at (0,0)\n",
" bayer_image[:, 0, 0::2, 0::2] = rgb_image_tensor[:, 0:1, 0::2, 0::2]\n",
" # G1 at (0,1)\n",
" bayer_image[:, 0, 0::2, 1::2] = rgb_image_tensor[:, 1:2, 0::2, 1::2]\n",
" # G2 at (1,0)\n",
" bayer_image[:, 0, 1::2, 0::2] = rgb_image_tensor[:, 1:2, 1::2, 0::2]\n",
" # B at (1,1)\n",
" bayer_image[:, 0, 1::2, 1::2] = rgb_image_tensor[:, 2:3, 1::2, 1::2]\n",
"\n",
" return bayer_image"
],
"metadata": {
"id": "g3Xulbo6CQTX"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from skimage import data as skimdata\n",
"npimage = skimdata.cat()\n",
"\n",
"# Make sure the size is even\n",
"npimage = npimage[:npimage.shape[0]-(npimage.shape[0]%2),:npimage.shape[1]-(npimage.shape[1]%2),:]\n",
"print(f\"npimage.shape={npimage.shape}\")\n",
"\n",
"batchedimage = torch.from_numpy(npimage).permute(2, 0, 1)[None,...].to(torch.float32)\n",
"\n",
"bayertensor = simulate_bayer_rggb(batchedimage)\n",
"print(f\"bayertensor.shape={bayertensor.shape}\")\n",
"\n",
"demos_bim = malvar_he_cutler_demosaic_optimized(bayertensor)\n",
"\n",
"fig, axes = plt.subplots(1, 3, figsize=(25, 5))\n",
"\n",
"# Original RGB Image\n",
"axes[0].imshow(npimage)\n",
"axes[0].set_title('Original RGB Image')\n",
"axes[0].axis('off')\n",
"\n",
"# Simulated Bayer Pattern\n",
"# For single channel image, use a grayscale colormap\n",
"axes[1].imshow(bayertensor.squeeze().cpu().numpy(), cmap='gray')\n",
"axes[1].set_title('Simulated Bayer Pattern (RGGB)')\n",
"axes[1].axis('off')\n",
"\n",
"# Demosaiced\n",
"axes[2].imshow(demos_bim.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8))\n",
"axes[2].set_title('Demosaiced (MHC)')\n",
"axes[2].axis('off')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 484
},
"id": "nfSIu0Zx_e5h",
"outputId": "00362f6c-a47a-44d5-817d-0c7bbce0feb0"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"npimage.shape=(300, 450, 3)\n",
"bayertensor.shape=torch.Size([1, 1, 300, 450])\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(np.float64(-0.5), np.float64(449.5), np.float64(299.5), np.float64(-0.5))"
]
},
"metadata": {},
"execution_count": 5
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 2500x500 with 3 Axes>"
],
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment