Skip to content

Instantly share code, notes, and snippets.

@BohemianHacks
Created January 20, 2025 04:00
Show Gist options
  • Save BohemianHacks/6e59abb25871a0991efd3de1942ec82d to your computer and use it in GitHub Desktop.
Save BohemianHacks/6e59abb25871a0991efd3de1942ec82d to your computer and use it in GitHub Desktop.
Lung Segmentation Visualization

Use:

# Single image visualization
fig = visualize_lung_segmentation(xray_image, lung_mask)
plt.show()

# Multiple image comparison
figs = plot_segmentation_comparison(xray_images[:3], lung_masks[:3])
plt.show()

# Get statistics
stats = calculate_lung_statistics(lung_mask)
print(f"Lung area percentage: {stats['lung_percentage']:.2f}%")

Functions:

  1. visualize_lung_segmentation: Creates a three-panel visualization showing:

    • Original X-ray
    • Binary segmentation mask
    • Overlay of the mask on the original image with customizable transparency
  2. plot_segmentation_comparison: Displays multiple examples side by side to show consistency across different cases

  3. calculate_lung_statistics: Computes useful metrics about the segmentation:

    • Total image pixels
    • Lung area pixels
    • Percentage of image occupied by lungs
    • Bounding box coordinates
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

def visualize_lung_segmentation(image, mask, alpha=0.3, figsize=(12, 5)):
    """
    Visualize X-ray image with segmentation mask overlay
    
    Parameters:
    -----------
    image : numpy.ndarray
        The original X-ray image (grayscale)
    mask : numpy.ndarray
        Binary segmentation mask for lungs
    alpha : float
        Transparency of the overlay (0-1)
    figsize : tuple
        Size of the figure (width, height)
        
    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure containing the visualization
    """
    # Create a custom colormap for the mask (blue overlay)
    colors = [(0, 0, 0, 0), (0, 0, 1, 0.5)]
    lung_cmap = LinearSegmentedColormap.from_list('lung_cmap', colors)
    
    # Create figure with subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)
    
    # Original image
    ax1.imshow(image, cmap='gray')
    ax1.set_title('Original X-ray')
    ax1.axis('off')
    
    # Segmentation mask
    ax2.imshow(mask, cmap='binary')
    ax2.set_title('Lung Segmentation Mask')
    ax2.axis('off')
    
    # Overlay
    ax3.imshow(image, cmap='gray')
    ax3.imshow(mask, cmap=lung_cmap, alpha=alpha)
    ax3.set_title('Overlay Visualization')
    ax3.axis('off')
    
    plt.tight_layout()
    return fig

def plot_segmentation_comparison(original_images, masks, num_samples=3, figsize=(15, 5)):
    """
    Plot multiple examples of segmentation results side by side
    
    Parameters:
    -----------
    original_images : list or numpy.ndarray
        List of original X-ray images
    masks : list or numpy.ndarray
        List of corresponding segmentation masks
    num_samples : int
        Number of examples to show
    figsize : tuple
        Size of the figure (width, height)
    """
    fig, axes = plt.subplots(num_samples, 3, figsize=figsize)
    
    for i in range(num_samples):
        # Original
        axes[i, 0].imshow(original_images[i], cmap='gray')
        axes[i, 0].axis('off')
        if i == 0:
            axes[i, 0].set_title('Original')
            
        # Mask
        axes[i, 1].imshow(masks[i], cmap='binary')
        axes[i, 1].axis('off')
        if i == 0:
            axes[i, 1].set_title('Segmentation')
            
        # Overlay
        axes[i, 2].imshow(original_images[i], cmap='gray')
        axes[i, 2].imshow(masks[i], cmap='binary', alpha=0.3)
        axes[i, 2].axis('off')
        if i == 0:
            axes[i, 2].set_title('Overlay')
    
    plt.tight_layout()
    return fig

def calculate_lung_statistics(mask):
    """
    Calculate basic statistics about the segmented lung regions
    
    Parameters:
    -----------
    mask : numpy.ndarray
        Binary segmentation mask
        
    Returns:
    --------
    dict
        Dictionary containing calculated statistics
    """
    total_pixels = mask.size
    lung_pixels = np.sum(mask)
    lung_percentage = (lung_pixels / total_pixels) * 100
    
    stats = {
        'total_pixels': total_pixels,
        'lung_pixels': lung_pixels,
        'lung_percentage': lung_percentage,
        'bbox': get_bounding_box(mask)
    }
    return stats

def get_bounding_box(mask):
    """Get the bounding box coordinates for the lungs"""
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]
    return (xmin, ymin, xmax, ymax)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment