Skip to content

Instantly share code, notes, and snippets.

@shreyasgite
Created February 22, 2025 15:35
Show Gist options
  • Save shreyasgite/a26ec8a94613df02cae8e81b8efa8347 to your computer and use it in GitHub Desktop.
Save shreyasgite/a26ec8a94613df02cae8e81b8efa8347 to your computer and use it in GitHub Desktop.
Helper functions for Image Inpainting for Robotics datasets.
import modal
from diffusers.utils import load_image
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from PIL import ImageDraw
import cv2
from tqdm.notebook import tqdm
from typing import List, Optional, Dict, Union
def load_video_frames(
video_path: str,
resize: Optional[tuple] = None,
show_progress: bool = True
) -> Dict:
"""
Load video frames from path and convert to PIL Images.
Args:
video_path: Path to video file
resize: Optional (width, height) tuple to resize frames
show_progress: Whether to show progress bar
Returns:
Dict containing:
- frames: List of PIL Images
- fps: Video FPS
- total_frames: Number of frames
- duration: Video duration in seconds
"""
# Open video file
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Could not open video file: {video_path}")
# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps
frames = []
iterator = range(total_frames)
if show_progress:
iterator = tqdm(iterator, desc="Loading frames")
for _ in iterator:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert to PIL Image
image = Image.fromarray(frame)
# Resize if specified
if resize:
image = image.resize(resize, Image.Resampling.LANCZOS)
frames.append(image)
cap.release()
return {
"frames": frames,
"fps": fps,
"total_frames": total_frames,
"duration": duration
}
def display_views(top_im, front_im):
"""
Display top and front views of the scene
Args:
top_im: Top view image
front_im: Front view image
"""
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.imshow(top_im)
plt.axis('off')
plt.title("Top View")
plt.subplot(1, 2, 2)
plt.imshow(front_im)
plt.axis('off')
plt.title("Front View")
plt.tight_layout()
plt.show()
def create_inpaint_mask(
im: Union[Image.Image, np.ndarray],
direction: str,
) -> tuple[Image.Image, Image.Image]:
"""
Create a mask for inpainting a corner region of the image
Args:
im: Input image (PIL Image or numpy array)
direction: Corner to mask - 'top-left', 'top-right', 'bottom-left', 'bottom-right'
Returns:
tuple: (mask, overlaid_image) where overlaid_image shows the masked region in white
"""
# Convert numpy array to PIL if needed
if isinstance(im, np.ndarray):
im = Image.fromarray(im)
width, height = im.size
corner_width = width // 4
corner_height = height // 4
# Create mask (black background with white corner)
mask = Image.new('RGB', (width, height), 'black')
draw = ImageDraw.Draw(mask)
# Define corner coordinates based on direction
if direction == 'top-left':
coords = [0, 0, corner_width, corner_height]
elif direction == 'top-right':
coords = [width - corner_width, 0, width, corner_height]
elif direction == 'bottom-left':
coords = [0, height - corner_height, corner_width, height]
elif direction == 'bottom-right':
coords = [width - corner_width, height - corner_height, width, height]
else:
raise ValueError("Direction must be one of: 'top-left', 'top-right', 'bottom-left', 'bottom-right'")
# Draw white rectangle for the corner in mask
draw.rectangle(coords, fill='white')
# Create overlaid image by copying original and drawing white rectangle
overlaid_im = im.copy()
draw = ImageDraw.Draw(overlaid_im)
draw.rectangle(coords, fill='white')
return mask, overlaid_im
def create_outpaint_mask(
im: Image.Image,
direction: str,
plot: bool = True
) -> tuple[Image.Image, Image.Image, np.ndarray]:
"""
Create mask for outpainting by extending image in specified direction
Args:
im: Input PIL Image (keeps original dimensions)
direction: One of ['left-top', 'left', 'top', 'right-top',
'right', 'right-bottom', 'bottom', 'left-bottom']
plot: Whether to display visualization
Returns:
tuple: (modified_image, mask, overlay)
- modified_image: Original image cropped and extended with white space
- mask: White in new area, black in original area
- overlay: Visualization with translucent overlay
"""
# Get original dimensions
width, height = im.size
# Calculate crop/extend sizes (1/10 of image)
shift_x = width // 10
shift_y = height // 10
# Create new white image with original dimensions
new_im = Image.new('RGB', (width, height), 'white')
# Create base mask (black)
mask = Image.new('RGB', (width, height), 'black')
# Process based on direction
if 'left' in direction:
paste_x = shift_x
elif 'right' in direction:
paste_x = -shift_x
else:
paste_x = 0
if 'top' in direction:
paste_y = shift_y
elif 'bottom' in direction:
paste_y = -shift_y
else:
paste_y = 0
# Crop original image
crop_box = (
max(-paste_x, 0),
max(-paste_y, 0),
width - max(paste_x, 0),
height - max(paste_y, 0)
)
cropped_im = im.crop(crop_box)
# Paste cropped image into new white image
paste_box = (
max(paste_x, 0),
max(paste_y, 0),
width - max(-paste_x, 0),
height - max(-paste_y, 0)
)
new_im.paste(cropped_im, paste_box)
# Create mask (white in new area)
white_box = []
if 'left' in direction:
white_box.append((0, 0, shift_x, height))
if 'right' in direction:
white_box.append((width-shift_x, 0, width, height))
if 'top' in direction:
white_box.append((0, 0, width, shift_y))
if 'bottom' in direction:
white_box.append((0, height-shift_y, width, shift_y))
# Fill mask with white in new areas
mask_draw = ImageDraw.Draw(mask)
for box in white_box:
mask_draw.rectangle(box, fill='white')
return new_im, mask
def apply_filled_region(
target_im: Image.Image,
filled_im: Image.Image,
direction: str,
blur_width: int = 8, # Increased from 4 to 8
sigma: float = 3.0 # Increased from 1.0 to 2.0
) -> Image.Image:
"""
Paste filled region with stronger Gaussian blur at boundary
Args:
target_im: Image to modify
filled_im: Image with filled region
direction: 'left', 'right', 'top', or 'bottom'
blur_width: Width of blur region on each side of boundary
sigma: Strength of the Gaussian blur
"""
width, height = target_im.size
shift_x = width // 10
shift_y = height // 10
# Do the basic paste first
result_im = target_im.copy()
if direction == 'left':
filled_region = filled_im.crop((0, 0, shift_x, height))
result_im.paste(filled_region, (0, 0))
# Apply stronger Gaussian blur at boundary
arr = np.array(result_im)
boundary_region = arr[:, shift_x-blur_width:shift_x+blur_width]
arr[:, shift_x-blur_width:shift_x+blur_width] = cv2.GaussianBlur(
boundary_region,
(blur_width*2-1, blur_width*2-1), # Kernel size must be odd
sigma
)
result_im = Image.fromarray(arr)
elif direction == 'right':
filled_region = filled_im.crop((width-shift_x, 0, width, height))
result_im.paste(filled_region, (width-shift_x, 0))
arr = np.array(result_im)
boundary_region = arr[:, (width-shift_x-blur_width):(width-shift_x+blur_width)]
arr[:, (width-shift_x-blur_width):(width-shift_x+blur_width)] = cv2.GaussianBlur(
boundary_region,
(blur_width*2-1, blur_width*2-1),
sigma
)
result_im = Image.fromarray(arr)
elif direction == 'top':
filled_region = filled_im.crop((0, 0, width, shift_y))
result_im.paste(filled_region, (0, 0))
arr = np.array(result_im)
boundary_region = arr[shift_y-blur_width:shift_y+blur_width, :]
arr[shift_y-blur_width:shift_y+blur_width, :] = cv2.GaussianBlur(
boundary_region,
(blur_width*2-1, blur_width*2-1),
sigma
)
result_im = Image.fromarray(arr)
elif direction == 'bottom':
filled_region = filled_im.crop((0, height-shift_y, width, height))
result_im.paste(filled_region, (0, height-shift_y))
arr = np.array(result_im)
boundary_region = arr[(height-shift_y-blur_width):(height-shift_y+blur_width), :]
arr[(height-shift_y-blur_width):(height-shift_y+blur_width), :] = cv2.GaussianBlur(
boundary_region,
(blur_width*2-1, blur_width*2-1),
sigma
)
result_im = Image.fromarray(arr)
return result_im
def save_processed_video(frames: List[Image.Image], output_path: str, fps: int = 30):
"""Helper function to save processed frames as video"""
first_frame = frames[0]
height, width = np.array(first_frame).shape[:2]
writer = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*'mp4v'),
fps,
(width, height)
)
for frame in tqdm(frames, desc="Saving video"):
# Convert PIL to CV2 format
cv_frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
writer.write(cv_frame)
writer.release()
def process_video_with_single_inpaint(
inpaint_flux: modal.Function,
video_input: Union[str, Dict],
direction: str,
prompt: str = "",
blend_width: int = 30,
output_path: Optional[str] = None,
plot_progress: bool = False
) -> List[Image.Image]:
"""
Process video by inpainting first frame and blending with rest of the frames
Args:
video_input: Either path to video file or frames data dictionary
direction: Direction for outpainting ('left', 'right', etc)
blend_width: Width of the blending region
output_path: Optional path to save processed video
plot_progress: Whether to show visualization
"""
# Load frames if video path provided
if isinstance(video_input, str):
frames_data = load_video_frames(video_input)
frames = frames_data['frames']
else:
frames = video_input['frames']
total_frames = len(frames)
processed_frames = []
# Process first frame with inpainting
print("Processing first frame with inpainting...")
first_frame = frames[0]
new_im, mask = create_outpaint_mask(first_frame, direction, plot=False)
# Inpaint the first frame
inpainted_frame = inpaint_flux.remote(
new_im,
mask,
prompt=prompt,
neg_prompt="artifacts, blurry, distorted"
)
# Save the inpainted first frame
processed_frames.append(inpainted_frame)
# Blend all remaining frames
print(f"Blending remaining {total_frames-1} frames...")
for idx in tqdm(range(1, total_frames)):
blended_frame = apply_filled_region(
frames[idx],
inpainted_frame,
direction,
blur_width=blend_width,
)
processed_frames.append(blended_frame)
# Show progress every 100 frames if requested
if plot_progress and idx % 100 == 0:
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
axes[0,0].imshow(first_frame)
axes[0,0].set_title("Original First Frame")
axes[0,0].axis('off')
axes[0,1].imshow(inpainted_frame)
axes[0,1].set_title("Inpainted First Frame")
axes[0,1].axis('off')
axes[1,0].imshow(frames[idx])
axes[1,0].set_title(f"Original Frame {idx}")
axes[1,0].axis('off')
axes[1,1].imshow(blended_frame)
axes[1,1].set_title(f"Blended Frame {idx}")
axes[1,1].axis('off')
plt.tight_layout()
plt.show()
# Save as video if output path provided
if output_path:
print("\nSaving processed video...")
save_processed_video(processed_frames, output_path, fps=30)
print(f"Video saved to: {output_path}")
return processed_frames
def process_video_with_periodic_inpaint(
inpaint_flux: modal.Function,
video_input: Union[str, Dict],
direction: str,
frames_per_inpaint: int = 100,
prompt:str="",
blend_width: int = 30,
output_path: Optional[str] = None,
plot_progress: bool = False
) -> List[Image.Image]:
"""
Process video frames with periodic inpainting and continuous blending
Args:
video_input: Either path to video file or frames data dictionary
direction: Direction for outpainting ('left', 'right', etc)
frames_per_inpaint: Number of frames to blend between inpainting
blend_width: Width of the blending region
output_path: Optional path to save processed video
plot_progress: Whether to show processing visualization
"""
# Load frames if video path provided
if isinstance(video_input, str):
frames_data = load_video_frames(video_input)
frames = frames_data['frames']
else:
frames = video_input['frames']
total_frames = len(frames)
processed_frames = []
# Process frames in batches
for batch_start in tqdm(range(0, total_frames, frames_per_inpaint), desc="Processing batches"):
# Get the key frame for this batch
key_frame = frames[batch_start]
# Create mask and get inpainted result for key frame
print(f"\nProcessing key frame {batch_start}")
new_im, mask = create_outpaint_mask(key_frame, direction, plot=False) # Changed here
# Inpaint the key frame using Modal function
print("Inpainting key frame...")
inpainted_frame = inpaint_flux.remote(
new_im,
mask,
prompt=prompt,
neg_prompt="artifacts, blurry, distorted"
)
# Save the inpainted key frame
processed_frames.append(inpainted_frame)
# Process subsequent frames up to next key frame or end
batch_end = min(batch_start + frames_per_inpaint, total_frames)
print(f"Blending frames {batch_start+1} to {batch_end-1}")
for idx in range(batch_start + 1, batch_end):
# Blend current frame using previous inpainted frame
blended_frame = apply_filled_region(
frames[idx],
inpainted_frame,
direction,
blur_width=blend_width,
)
processed_frames.append(blended_frame)
if plot_progress and batch_start % (frames_per_inpaint * 2) == 0:
# Show progress visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
axes[0,0].imshow(key_frame)
axes[0,0].set_title(f"Original Key Frame {batch_start}")
axes[0,0].axis('off')
axes[0,1].imshow(inpainted_frame)
axes[0,1].set_title("Inpainted Key Frame")
axes[0,1].axis('off')
if len(processed_frames) > 1:
axes[1,0].imshow(processed_frames[-2])
axes[1,0].set_title("Previous Blended Frame")
axes[1,0].axis('off')
axes[1,1].imshow(processed_frames[-1])
axes[1,1].set_title("Latest Blended Frame")
axes[1,1].axis('off')
plt.tight_layout()
plt.show()
# Optionally save as video
if output_path:
print("\nSaving processed video...")
save_processed_video(processed_frames, output_path, fps=30)
print(f"Video saved to: {output_path}")
return processed_frames
def create_top_view(
im: Union[Image.Image, np.ndarray],
rotation_angle: float = 180.0,
perspective_strength: float = 0.3
) -> Image.Image:
"""
Rotate image and apply perspective transform to simulate top view
Args:
im: Input image (PIL Image or numpy array)
rotation_angle: Rotation angle in degrees
perspective_strength: Strength of perspective transform (0 to 1)
Returns:
Image.Image: Transformed image
"""
# Convert numpy array to PIL if needed
if isinstance(im, np.ndarray):
im = Image.fromarray(im)
# Get image dimensions
width, height = im.size
# First rotate the image
rotated_im = im.rotate(rotation_angle, expand=True, resample=Image.BICUBIC)
# Calculate perspective transform points
# Source points (corners of original image)
src_points = np.float32([
[0, 0], # top-left
[width, 0], # top-right
[width, height], # bottom-right
[0, height] # bottom-left
])
# Calculate destination points for perspective transform
perspective_shift = int(height * perspective_strength)
dst_points = np.float32([
[perspective_shift, perspective_shift], # top-left moved down and right
[width - perspective_shift, perspective_shift], # top-right moved down and left
[width, height], # bottom-right stays
[0, height] # bottom-left stays
])
# Calculate perspective transform matrix
matrix = cv2.getPerspectiveTransform(src_points, dst_points)
# Apply perspective transform
rotated_arr = np.array(rotated_im)
result_arr = cv2.warpPerspective(
rotated_arr,
matrix,
(width, height),
flags=cv2.INTER_LINEAR
)
# Convert back to PIL Image
result_im = Image.fromarray(result_arr)
return result_im
def paste_wrapped_image(
target_im: Image.Image,
wrapped_im: Image.Image,
mask: Image.Image,
direction: str
) -> tuple[Image.Image, Image.Image]:
"""
Paste wrapped trapezoid image into target image's masked region
Args:
target_im: Target image with masked region
wrapped_im: Perspective-transformed image (with black background)
mask: Original mask (white region indicates where to paste)
direction: One of 'top-left', 'top-right', 'bottom-left', 'bottom-right'
Returns:
tuple: (modified_target, new_mask)
"""
# Convert to PIL if needed
if isinstance(target_im, np.ndarray):
target_im = Image.fromarray(target_im)
if isinstance(wrapped_im, np.ndarray):
wrapped_im = Image.fromarray(wrapped_im)
if isinstance(mask, np.ndarray):
mask = Image.fromarray(mask)
# Extract non-black regions from wrapped image
wrapped_arr = np.array(wrapped_im)
non_black = np.any(wrapped_arr > 10, axis=2) # Threshold to handle compression artifacts
coords = np.where(non_black)
min_y, max_y = coords[0].min(), coords[0].max()
min_x, max_x = coords[1].min(), coords[1].max()
# Create RGBA version of wrapped image
wrapped_rgba = Image.new('RGBA', wrapped_im.size, (0,0,0,0))
wrapped_rgba.paste(wrapped_im)
wrapped_data = np.array(wrapped_rgba)
wrapped_data[~non_black] = [0,0,0,0] # Make black pixels transparent
trapezoid = Image.fromarray(wrapped_data)
# Find mask coordinates (white region)
mask_arr = np.array(mask)
white_pixels = np.all(mask_arr == 255, axis=2)
mask_coords = np.where(white_pixels)
mask_min_y, mask_max_y = mask_coords[0].min(), mask_coords[0].max()
mask_min_x, mask_max_x = mask_coords[1].min(), mask_coords[1].max()
# Create result image with white masked region
result_im = target_im.copy()
draw = ImageDraw.Draw(result_im)
draw.rectangle([mask_min_x, mask_min_y, mask_max_x, mask_max_y], fill='white')
# Calculate paste coordinates based on direction
if 'top' in direction:
paste_y = mask_min_y # Align to top
if 'left' in direction:
paste_x = mask_min_x # Align to left
else: # top-right
paste_x = mask_max_x - (max_x - min_x) # Align to right
else: # bottom
paste_y = mask_max_y - (max_y - min_y) # Align to bottom
if 'left' in direction:
paste_x = mask_min_x # Align to left
else: # bottom-right
paste_x = mask_max_x - (max_x - min_x) # Align to right
# Paste trapezoid image
result_im.paste(trapezoid, (paste_x - min_x, paste_y - min_y), trapezoid)
# Create new mask for the remaining white space
new_mask = Image.new('RGB', target_im.size, 'black')
draw = ImageDraw.Draw(new_mask)
# First fill the entire original masked region
draw.rectangle([mask_min_x, mask_min_y, mask_max_x, mask_max_y], fill='white')
# Create mask array
mask_arr = np.array(new_mask)
# Get the region where the trapezoid was pasted
paste_height = max_y - min_y + 1
paste_width = max_x - min_x + 1
# First make the entire pasted region black in the mask
paste_region = np.zeros_like(mask_arr[:,:,0], dtype=bool)
paste_region[paste_y:paste_y+paste_height, paste_x:paste_x+paste_width] = True
mask_arr[paste_region] = [0,0,0]
# Then make the black pixels from wrapped image white in the mask
black_pixels = ~np.any(wrapped_arr > 10, axis=2) # Find black pixels
black_region = black_pixels[min_y:max_y+1, min_x:max_x+1] # Get region within bounds
# Apply black pixels as white to the mask
paste_mask = np.zeros_like(mask_arr[:,:,0], dtype=bool)
paste_mask[paste_y:paste_y+paste_height, paste_x:paste_x+paste_width] = black_region
mask_arr[paste_mask] = [255,255,255]
new_mask = Image.fromarray(mask_arr)
return result_im, new_mask
def extract_masked_region(
im: Union[Image.Image, np.ndarray],
mask: Image.Image
) -> Image.Image:
"""
Extract/crop the region from image where mask is white
Args:
im: Input image (PIL Image or numpy array)
mask: Mask image (white pixels indicate region to extract)
Returns:
Image.Image: Cropped region from original image
"""
# Convert numpy array to PIL if needed
if isinstance(im, np.ndarray):
im = Image.fromarray(im)
# Convert mask to numpy array
mask_arr = np.array(mask)
# Find white pixels (assuming RGB mask where white is [255,255,255])
white_pixels = np.all(mask_arr == 255, axis=2)
white_coords = np.where(white_pixels)
# Get bounding box of white region
min_y, max_y = white_coords[0].min(), white_coords[0].max()
min_x, max_x = white_coords[1].min(), white_coords[1].max()
# Crop original image to this region
cropped_region = im.crop((min_x, min_y, max_x + 1, max_y + 1))
return cropped_region
def paste_masked_region(
target_im: Image.Image,
original_im: Image.Image,
mask: Image.Image
) -> Image.Image:
"""
Paste content from original image into target image where mask is white
Args:
target_im: Target image to paste into
original_im: Source image to copy content from
mask: Mask where white pixels indicate where to paste
Returns:
Image.Image: Modified target image
"""
# Convert to PIL if needed
if isinstance(target_im, np.ndarray):
target_im = Image.fromarray(target_im)
if isinstance(original_im, np.ndarray):
original_im = Image.fromarray(original_im)
if isinstance(mask, np.ndarray):
mask = Image.fromarray(mask)
# Create result image
result_im = target_im.copy()
# Convert mask to RGBA for alpha compositing
mask_rgba = mask.convert('RGBA')
mask_data = np.array(mask_rgba)
# Make black regions transparent in mask
mask_data[np.all(mask_data[:,:,:3] == 0, axis=2)] = [0,0,0,0]
mask_data[np.all(mask_data[:,:,:3] == 255, axis=2)] = [255,255,255,255]
alpha_mask = Image.fromarray(mask_data)
# Create RGBA version of original image
original_rgba = Image.new('RGBA', original_im.size)
original_rgba.paste(original_im)
# Paste original content using mask as alpha
result_im.paste(original_rgba, (0,0), alpha_mask)
return result_im
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment