Skip to content

Instantly share code, notes, and snippets.

@caner-cetin
Created April 4, 2025 01:00
Show Gist options
  • Save caner-cetin/d95ab398db79be30e6421d8d63a98e83 to your computer and use it in GitHub Desktop.
Save caner-cetin/d95ab398db79be30e6421d8d63a98e83 to your computer and use it in GitHub Desktop.
# Version 4.11: Revised Remap Logic for Rotation Effect
import cv2
import numpy as np
import argparse
import os
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import traceback
import math
from typing import Any, TypeAlias
# --- Type Aliases (Same) ---
NpFloat32Array: TypeAlias = np.ndarray[Any, np.dtype[np.float32]]
NpUInt8Array: TypeAlias = np.ndarray[Any, np.dtype[np.uint8]]
PointList = list[tuple[float, float]]
IntPointPrompt = list[list[list[int]]]
# --- Rich Integration (Same) ---
try:
from rich.console import Console
from rich.traceback import install as install_rich_traceback
install_rich_traceback(show_locals=True); console = Console(); print_ = console.print
except ImportError: print("Rich library not found..."); print_ = print
# --- End Rich Integration ---
# <<< get_sam_mask function definition (Same as V4.6) >>>
def get_sam_mask(image_pil: Image.Image, points: IntPointPrompt, device: str) -> NpUInt8Array | None:
try:
model_name: str = "facebook/sam-vit-base"; print_(f"Loading SAM model '[cyan]{model_name}[/cyan]'...")
model: SamModel = SamModel.from_pretrained(model_name).to(device); print_(f"Loading SAM processor for '[cyan]{model_name}[/cyan]'...")
processor: SamProcessor = SamProcessor.from_pretrained(model_name); print_("[green]SAM model and processor loaded.[/green]")
print_("Processing image for SAM..."); inputs: dict = processor(image_pil, input_points=points, return_tensors="pt").to(device); print_("Image processed.")
print_("Running SAM inference...");
with torch.no_grad(): outputs: Any = model(**inputs)
print_("SAM inference complete.")
print_("Post-processing SAM masks..."); masks_output: list[torch.Tensor] = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
if not masks_output: print_("[red]Error:[/red] SAM post-processing returned empty list."); return None
masks_for_image: torch.Tensor = masks_output[0]; print_(f"Shape of masks_for_image (tensor for image 0): {masks_for_image.shape}")
scores: torch.Tensor | None = None; num_scores: int = 0
if outputs.iou_scores is not None:
scores = outputs.iou_scores.cpu().squeeze(0)
if scores.ndim == 1: num_scores = scores.numel()
else: scores = scores.flatten(); num_scores = scores.numel()
print_(f"Shape of scores tensor ({scores.ndim}D): {scores.shape}, Values: {scores}")
else: print_("[yellow]Warning:[/yellow] No IOU scores found.")
best_mask_idx: int = 0; best_score_str: str = "N/A"; potential_num_masks: int = 0
if masks_for_image.ndim >= 3: potential_num_masks = masks_for_image.shape[1]
selected_mask_tensor: torch.Tensor | None = None
if scores is not None and num_scores == potential_num_masks and num_scores > 0:
best_mask_idx = torch.argmax(scores).item(); best_score_str = f"{scores[best_mask_idx]:.4f}"
print_(f"Scores seem to match mask dim 1 ({num_scores}). Selected idx: {best_mask_idx}, score: {best_score_str}")
if masks_for_image.ndim == 4 and masks_for_image.shape[0] == 1: selected_mask_tensor = masks_for_image[0, best_mask_idx]
else: print_(f"[yellow]Warning:[/yellow] Mask shape {masks_for_image.shape} unexpected. Falling back."); selected_mask_tensor = None
else: print_(f"[yellow]Warning:[/yellow] Score/mask count mismatch (Scores:{num_scores}, MaskDim1:{potential_num_masks}). Defaulting.")
if selected_mask_tensor is None:
print_("Applying fallback mask selection logic...")
if masks_for_image.ndim == 4: selected_mask_tensor = masks_for_image[0, 0]
elif masks_for_image.ndim == 3: selected_mask_tensor = masks_for_image[0]
else: print_(f"[red]Error:[/red] Cannot select mask from shape {masks_for_image.shape}."); return None
print_(f"Shape of selected_mask_tensor (fallback): {selected_mask_tensor.shape}")
mask_squeezed: torch.Tensor = selected_mask_tensor.squeeze(); print_(f"Shape after squeeze: {mask_squeezed.shape}")
mask_np: NpUInt8Array | None = None
if len(mask_squeezed.shape) == 2:
print_("Mask is 2D."); mask_np = (mask_squeezed.cpu().numpy() > 0.5).astype(np.uint8) * 255
else: print_(f"[red]Error:[/red] Mask unexpected shape: {mask_squeezed.shape}."); return None
print_(f"Final mask_np shape: {mask_np.shape}"); print_(f"SAM generated mask with score: {best_score_str}"); return mask_np
except Exception as e: print_(f"[bold red]Error during SAM:[/bold red]"); console.print_exception(show_locals=True); return None
# <<< End of get_sam_mask >>>
# <<< blend_overlay function (Same as V4.3) >>>
def blend_overlay(background: NpFloat32Array, foreground: NpFloat32Array, mask: NpFloat32Array) -> NpFloat32Array:
bg_norm: NpFloat32Array = background; fg_norm: NpFloat32Array = foreground
dark_bg: NpUInt8Array = bg_norm < 0.5; light_bg: NpUInt8Array = ~dark_bg
overlay_dark: NpFloat32Array = 2.0 * fg_norm * bg_norm
overlay_light: NpFloat32Array = 1.0 - 2.0 * (1.0 - fg_norm) * (1.0 - bg_norm)
combined: NpFloat32Array = np.where(mask > 1e-6, np.where(dark_bg, overlay_dark, overlay_light), 0.0)
result: NpFloat32Array = combined + bg_norm * (1.0 - mask)
return np.clip(result, 0.0, 1.0)
# <<< End of blend_overlay >>>
# <<< overlay_logo_with_sam_remap function - Revised Rotation Logic >>>
def overlay_logo_with_sam_remap(
logo_path: str,
base_path: str,
output_path: str,
sam_prompt_factor_x: float = 0.5,
sam_prompt_factor_y: float = 0.53,
scale_factor_rel_mask: float = 0.65,
perspective_factor: float = 0.15,
top_curve_factor: float = 0.08,
rotation_factor: float = 0.30, # Default slightly higher
vertical_center_offset: float = 0.01,
edge_blur_sigma: float = 0.7,
blend_mode: str = "overlay",
debug_save: bool = False,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
**kwargs # Accept unused args
) -> None:
"""
Overlays a logo using SAM guidance and cv2.remap with improved rotation simulation.
"""
try:
# [ Keep initial setup, SAM call, BBox finding, Logo loading the same ]
print_(f"Using OpenCV version: [cyan]{cv2.__version__}[/cyan]")
base_img_cv: NpUInt8Array | None = cv2.imread(base_path);
if base_img_cv is None: raise FileNotFoundError(f"Base image not found: {base_path}")
base_h, base_w = base_img_cv.shape[:2]
base_img_pil: Image.Image
try: print_(f"Loading base image '[cyan]{base_path}[/cyan]' with PIL..."); base_img_pil = Image.open(base_path).convert("RGB")
except Exception as e: print_(f"[yellow]Warning:[/yellow] PIL load failed ({e})."); try: img_rgb = cv2.cvtColor(base_img_cv, cv2.COLOR_BGR2RGB); base_img_pil = Image.fromarray(img_rgb)
except Exception as ce: print_(f"[red]Error:[/red] Failed loading base image. Error: {ce}"); return
prompt_x = int(base_w * sam_prompt_factor_x); prompt_y = int(base_h * sam_prompt_factor_y); input_points: IntPointPrompt = [[[prompt_x, prompt_y]]]
print_(f"Running SAM (Prompt Anchor: {prompt_x}, {prompt_y})..."); sam_mask: NpUInt8Array | None = get_sam_mask(base_img_pil, input_points, device)
if sam_mask is None: print_("[red]SAM failed.[/red]"); return
if debug_save: cv2.imwrite("debug_sam_mask.png", sam_mask); print_("Saved debug_sam_mask.png")
if sam_mask.dtype != np.uint8: sam_mask = sam_mask.astype(np.uint8)
if len(sam_mask.shape) != 2: sam_mask = np.squeeze(sam_mask)
if len(sam_mask.shape) != 2: print_("[red]Error:[/red] SAM mask not 2D."); return
sam_mask_processed: NpUInt8Array = sam_mask.copy()
print_("Finding contours..."); contours: list[NpUInt8Array]; _; contours, _ = cv2.findContours(sam_mask_processed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
mask_w : int = base_w
if not contours: print_("[yellow]Warning:[/yellow] No contours found.");
else: largest_contour = max(contours, key=cv2.contourArea); mask_x, mask_y, mask_w, mask_h = cv2.boundingRect(largest_contour);
if mask_w <= 1 or mask_h <= 1: print_("[yellow]Warning:[/yellow] SAM mask bbox too small."); mask_w = base_w
else: print_(f"SAM mask bbox (for scaling): x={mask_x}, y={mask_y}, w={mask_w}, h={mask_h}")
print_(f"Loading logo '[cyan]{logo_path}[/cyan]'..."); logo_img: NpUInt8Array | None = cv2.imread(logo_path, cv2.IMREAD_UNCHANGED)
if logo_img is None: try: print_("[yellow]Warning:[/yellow] OpenCV failed load logo."); logo_pil = Image.open(logo_path).convert("RGBA"); logo_img = cv2.cvtColor(np.array(logo_pil), cv2.COLOR_RGBA2BGRA)
except Exception as e: raise FileNotFoundError(f"Logo load fail: {e}")
if logo_img is None: raise FileNotFoundError(f"Failed loading logo {logo_path}")
logo_h_orig, logo_w_orig = logo_img.shape[:2]
if len(logo_img.shape) < 3 or logo_img.shape[2] == 1: logo_img = cv2.cvtColor(logo_img, cv2.COLOR_GRAY2BGRA)
elif logo_img.shape[2] == 3: logo_img = cv2.cvtColor(logo_img, cv2.COLOR_BGR2BGRA)
elif logo_img.shape[2] != 4: raise ValueError(f"Logo unsupported channels: {logo_img.shape[2]}")
logo_h, logo_w = logo_img.shape[:2]; logo_bgr: NpUInt8Array = logo_img[:, :, :3]; logo_alpha: NpUInt8Array = logo_img[:, :, 3]
print_(f"Logo loaded BGRA. Dims: {logo_w}x{logo_h}")
effective_scale_base_width = mask_w; target_logo_w = int(effective_scale_base_width * scale_factor_rel_mask); target_logo_w = max(1, target_logo_w)
ratio = target_logo_w / logo_w if logo_w > 0 else 0; target_logo_h = int(logo_h * ratio); target_logo_h = max(1, target_logo_h)
center_x = float(prompt_x); center_y = float(prompt_y); print_(f"Target logo dims: {target_logo_w}x{target_logo_h}. ANCHOR (Prompt): ({center_x:.0f}, {center_y:.0f})")
# --- Define 4 Destination Corner Points ---
print_("Defining 4 destination corner points...")
center_y_adjusted = center_y + (target_logo_h * vertical_center_offset)
top_width_reduction = target_logo_w * perspective_factor; top_w = max(1, target_logo_w - top_width_reduction); bottom_w = float(target_logo_w)
y_curve_offset = target_logo_h * top_curve_factor
tl_x = center_x - top_w / 2; tl_y = center_y_adjusted - target_logo_h / 2 + y_curve_offset
tr_x = center_x + top_w / 2; tr_y = center_y_adjusted - target_logo_h / 2 + y_curve_offset
br_x = center_x + bottom_w / 2; br_y = center_y_adjusted + target_logo_h / 2
bl_x = center_x - bottom_w / 2; bl_y = center_y_adjusted + target_logo_h / 2
dst_pts: PointList = [(tl_x, tl_y), (tr_x, tr_y), (br_x, br_y), (bl_x, bl_y)]
dst_pts_np: NpFloat32Array = np.array(dst_pts, dtype=np.float32)
src_pts: PointList = [(0, 0), (logo_w - 1, 0), (logo_w - 1, logo_h - 1), (0, logo_h - 1)]
src_pts_np: NpFloat32Array = np.array(src_pts, dtype=np.float32)
print_(f"Destination Points (4 corners):\n{dst_pts_np}")
# --- Calculate Inverse Perspective Matrix ---
print_("Calculating inverse perspective matrix...")
inv_matrix: NpFloat32Array = cv2.getPerspectiveTransform(dst_pts_np, src_pts_np)
# --- Create Remap Grids ---
x_coords = dst_pts_np[:, 0]; y_coords = dst_pts_np[:, 1]
x_min, x_max = int(np.min(x_coords)), int(np.max(x_coords)) + 1
y_min, y_max = int(np.min(y_coords)), int(np.max(y_coords)) + 1
x_min, y_min = max(0, x_min), max(0, y_min); x_max, y_max = min(base_w, x_max), min(base_h, y_max)
print_(f"Remap grid bounding box: x=[{x_min}, {x_max}), y=[{y_min}, {y_max})")
if x_min >= x_max or y_min >= y_max: print_("[red]Error:[/red] Degenerate bounding box for warp."); return
grid_x_dst, grid_y_dst = np.meshgrid(np.arange(x_min, x_max), np.arange(y_min, y_max))
dst_coords_flat = np.vstack([grid_x_dst.ravel(), grid_y_dst.ravel()]).T
dst_coords_hom = np.hstack([dst_coords_flat, np.ones((dst_coords_flat.shape[0], 1))])
# --- Apply Inverse Perspective ---
print_("Applying inverse perspective transform..."); src_coords_hom = inv_matrix @ dst_coords_hom.T
src_coords_flat = (src_coords_hom[:2] / src_coords_hom[2]).T
# --- Apply Non-linear Rotation Adjustment --- # REVISED LOGIC V4.11
print_(f"Applying rotation adjustment (factor={rotation_factor})...")
centerline_x_dst = (tl_x + tr_x + bl_x + br_x) / 4.0 # Apparent horizontal center on destination
# Normalized distance from center (-1 left edge, 0 center, +1 right edge)
dist_x_norm = (dst_coords_flat[:, 0] - centerline_x_dst) / (target_logo_w / 2.0)
dist_x_norm = np.clip(dist_x_norm, -1.0, 1.0) # Ensure it stays within bounds
src_x = src_coords_flat[:, 0]; src_y = src_coords_flat[:, 1]; src_x_center = logo_w / 2.0
# Calculate scale factor: 1 on left/center, decreases linearly to (1-rotation_factor) on right edge
scale = np.where(dist_x_norm > 0, 1.0 - rotation_factor * dist_x_norm, 1.0)
scale = np.clip(scale, 0.01, 1.0) # Prevent extreme compression or inversion
# Apply scale to the deviation from the source center X
src_deviation_x = src_x - src_x_center
adjusted_src_x = src_x_center + src_deviation_x * scale
# --- End Revised Logic ---
# --- Build Final Remap Maps ---
map_x: NpFloat32Array = adjusted_src_x.reshape(grid_x_dst.shape).astype(np.float32)
map_y: NpFloat32Array = src_y.reshape(grid_y_dst.shape).astype(np.float32)
# --- Apply cv2.remap ---
print_(f"Applying cv2.remap using calculated maps...")
warped_bgr_remap = np.zeros_like(base_img_cv); warped_alpha_remap = np.zeros((base_h, base_w), dtype=np.uint8)
warped_bgr_remap[y_min:y_max, x_min:x_max] = cv2.remap(logo_bgr, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
warped_alpha_remap[y_min:y_max, x_min:x_max] = cv2.remap(logo_alpha, map_x, map_y, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
print_("Remapping complete."); warped_bgr = warped_bgr_remap; warped_alpha = warped_alpha_remap
if debug_save: cv2.imwrite("debug_warped_alpha.png", warped_alpha); print_("Saved debug_warped_alpha.png")
# --- Final Mask Creation, Blending, Saving (Same as before) ---
print_("Creating final alpha mask..."); warped_alpha_binary: NpUInt8Array
_, warped_alpha_binary = cv2.threshold(warped_alpha, 10, 255, cv2.THRESH_BINARY)
final_alpha: NpUInt8Array = cv2.bitwise_and(warped_alpha_binary, sam_mask_processed)
if cv2.countNonZero(final_alpha) == 0: print_("[bold yellow]Warning:[/bold yellow] Final alpha mask is empty.");
if debug_save: cv2.imwrite("debug_final_alpha_empty.png", final_alpha) if cv2.countNonZero(final_alpha) == 0 else cv2.imwrite("debug_final_alpha.png", final_alpha); print_("Saved debug_final_alpha image")
if edge_blur_sigma > 0:
ksize = int(6 * edge_blur_sigma) // 2 * 2 + 1
if ksize > 0: final_alpha = cv2.GaussianBlur(final_alpha, (ksize, ksize), edge_blur_sigma); print_(f"Applied edge blur sigma={edge_blur_sigma}")
else: print_("[yellow]Warning:[/yellow] Blur sigma too small.")
#if debug_save and cv2.countNonZero(final_alpha) > 0: cv2.imwrite("debug_final_alpha_blurred.png", final_alpha); print_("Saved debug_final_alpha_blurred image") # Optional: Save after blur too
mask_normalized = final_alpha.astype(np.float32) / 255.0
mask_normalized_3c = cv2.merge([mask_normalized, mask_normalized, mask_normalized])
bgr_normalized = warped_bgr.astype(np.float32) / 255.0; base_normalized = base_img_cv.astype(np.float32) / 255.0
print_(f"Blending images using '[cyan]{blend_mode}[/cyan]' mode...")
result_normalized: NpFloat32Array
if blend_mode.lower() == 'overlay': result_normalized = blend_overlay(base_normalized, bgr_normalized, mask_normalized_3c)
else: result_normalized = bgr_normalized * mask_normalized_3c + base_normalized * (1.0 - mask_normalized_3c)
result_img: NpUInt8Array = (result_normalized * 255.0).astype(np.uint8); print_("Blending complete.")
cv2.imwrite(output_path, result_img)
print_(f"[bold green]Successfully warped (Remap) and overlaid logo. Saved to:[/bold green] '[cyan]{output_path}[/cyan]'")
except FileNotFoundError as e: print_(f"[red]Error:[/red] {e}")
except ImportError as e: print_(f"[red]Error:[/red] Missing libraries? {e}")
except cv2.error as e: print_(f"[red]OpenCV Error:[/red] {e}"); console.print_exception(show_locals=True)
except Exception as e: print_(f"[bold red]An unexpected error occurred:[/bold red]"); console.print_exception(show_locals=True)
# <<< Main execution block - Calls remap function >>>
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Overlay a warped logo using SAM guidance, cv2.remap, and blending.")
# [ Arguments are the same as V4.10 ]
parser.add_argument("logo_image", help="Path to the logo image file")
parser.add_argument("base_image", help="Path to the base image file")
parser.add_argument("output_image", help="Path to save the output image")
parser.add_argument("--spx", type=float, default=0.5, help="SAM prompt X factor (0-1)")
parser.add_argument("--spy", type=float, default=0.53, help="SAM prompt Y factor (0-1, primary placement)")
parser.add_argument("--scale_mask", type=float, default=0.65, help="Logo width relative to mask/estimated width")
parser.add_argument("--perspective", type=float, default=0.15, help="Perspective horizontal pinch factor for corners")
parser.add_argument("--top_curve", type=float, default=0.08, help="Top edge vertical curve factor")
parser.add_argument("--rotation_factor", type=float, default=0.30, help="Simulated rotation/compression factor") # Adjusted default
parser.add_argument("--v_offset", type=float, default=0.01, help="Vertical center offset factor relative to LOGO HEIGHT")
parser.add_argument("--blur", type=float, default=0.7, help="Edge blur sigma (0=none)")
parser.add_argument("--blend", default='overlay', choices=['alpha', 'overlay'], help="Blending mode")
parser.add_argument("--debug_save", action='store_true', help="Save intermediate debug images")
args: argparse.Namespace = parser.parse_args()
device_str: str = 'cuda' if torch.cuda.is_available() else 'cpu'
if device_str == 'cpu': print_("[yellow]Warning:[/yellow] CUDA not available, SAM on CPU (slow).")
overlay_logo_with_sam_remap( # Call the REMAP function
args.logo_image, args.base_image, args.output_image,
sam_prompt_factor_x=args.spx, sam_prompt_factor_y=args.spy,
scale_factor_rel_mask=args.scale_mask, perspective_factor=args.perspective,
vertical_center_offset=args.v_offset, edge_blur_sigma=args.blur,
blend_mode=args.blend, top_curve_factor=args.top_curve,
rotation_factor=args.rotation_factor,
debug_save=args.debug_save,
device=device_str
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment