Created
April 4, 2025 01:00
-
-
Save caner-cetin/d95ab398db79be30e6421d8d63a98e83 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Version 4.11: Revised Remap Logic for Rotation Effect | |
import cv2 | |
import numpy as np | |
import argparse | |
import os | |
import torch | |
from transformers import SamModel, SamProcessor | |
from PIL import Image | |
import traceback | |
import math | |
from typing import Any, TypeAlias | |
# --- Type Aliases (Same) --- | |
NpFloat32Array: TypeAlias = np.ndarray[Any, np.dtype[np.float32]] | |
NpUInt8Array: TypeAlias = np.ndarray[Any, np.dtype[np.uint8]] | |
PointList = list[tuple[float, float]] | |
IntPointPrompt = list[list[list[int]]] | |
# --- Rich Integration (Same) --- | |
try: | |
from rich.console import Console | |
from rich.traceback import install as install_rich_traceback | |
install_rich_traceback(show_locals=True); console = Console(); print_ = console.print | |
except ImportError: print("Rich library not found..."); print_ = print | |
# --- End Rich Integration --- | |
# <<< get_sam_mask function definition (Same as V4.6) >>> | |
def get_sam_mask(image_pil: Image.Image, points: IntPointPrompt, device: str) -> NpUInt8Array | None: | |
try: | |
model_name: str = "facebook/sam-vit-base"; print_(f"Loading SAM model '[cyan]{model_name}[/cyan]'...") | |
model: SamModel = SamModel.from_pretrained(model_name).to(device); print_(f"Loading SAM processor for '[cyan]{model_name}[/cyan]'...") | |
processor: SamProcessor = SamProcessor.from_pretrained(model_name); print_("[green]SAM model and processor loaded.[/green]") | |
print_("Processing image for SAM..."); inputs: dict = processor(image_pil, input_points=points, return_tensors="pt").to(device); print_("Image processed.") | |
print_("Running SAM inference..."); | |
with torch.no_grad(): outputs: Any = model(**inputs) | |
print_("SAM inference complete.") | |
print_("Post-processing SAM masks..."); masks_output: list[torch.Tensor] = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) | |
if not masks_output: print_("[red]Error:[/red] SAM post-processing returned empty list."); return None | |
masks_for_image: torch.Tensor = masks_output[0]; print_(f"Shape of masks_for_image (tensor for image 0): {masks_for_image.shape}") | |
scores: torch.Tensor | None = None; num_scores: int = 0 | |
if outputs.iou_scores is not None: | |
scores = outputs.iou_scores.cpu().squeeze(0) | |
if scores.ndim == 1: num_scores = scores.numel() | |
else: scores = scores.flatten(); num_scores = scores.numel() | |
print_(f"Shape of scores tensor ({scores.ndim}D): {scores.shape}, Values: {scores}") | |
else: print_("[yellow]Warning:[/yellow] No IOU scores found.") | |
best_mask_idx: int = 0; best_score_str: str = "N/A"; potential_num_masks: int = 0 | |
if masks_for_image.ndim >= 3: potential_num_masks = masks_for_image.shape[1] | |
selected_mask_tensor: torch.Tensor | None = None | |
if scores is not None and num_scores == potential_num_masks and num_scores > 0: | |
best_mask_idx = torch.argmax(scores).item(); best_score_str = f"{scores[best_mask_idx]:.4f}" | |
print_(f"Scores seem to match mask dim 1 ({num_scores}). Selected idx: {best_mask_idx}, score: {best_score_str}") | |
if masks_for_image.ndim == 4 and masks_for_image.shape[0] == 1: selected_mask_tensor = masks_for_image[0, best_mask_idx] | |
else: print_(f"[yellow]Warning:[/yellow] Mask shape {masks_for_image.shape} unexpected. Falling back."); selected_mask_tensor = None | |
else: print_(f"[yellow]Warning:[/yellow] Score/mask count mismatch (Scores:{num_scores}, MaskDim1:{potential_num_masks}). Defaulting.") | |
if selected_mask_tensor is None: | |
print_("Applying fallback mask selection logic...") | |
if masks_for_image.ndim == 4: selected_mask_tensor = masks_for_image[0, 0] | |
elif masks_for_image.ndim == 3: selected_mask_tensor = masks_for_image[0] | |
else: print_(f"[red]Error:[/red] Cannot select mask from shape {masks_for_image.shape}."); return None | |
print_(f"Shape of selected_mask_tensor (fallback): {selected_mask_tensor.shape}") | |
mask_squeezed: torch.Tensor = selected_mask_tensor.squeeze(); print_(f"Shape after squeeze: {mask_squeezed.shape}") | |
mask_np: NpUInt8Array | None = None | |
if len(mask_squeezed.shape) == 2: | |
print_("Mask is 2D."); mask_np = (mask_squeezed.cpu().numpy() > 0.5).astype(np.uint8) * 255 | |
else: print_(f"[red]Error:[/red] Mask unexpected shape: {mask_squeezed.shape}."); return None | |
print_(f"Final mask_np shape: {mask_np.shape}"); print_(f"SAM generated mask with score: {best_score_str}"); return mask_np | |
except Exception as e: print_(f"[bold red]Error during SAM:[/bold red]"); console.print_exception(show_locals=True); return None | |
# <<< End of get_sam_mask >>> | |
# <<< blend_overlay function (Same as V4.3) >>> | |
def blend_overlay(background: NpFloat32Array, foreground: NpFloat32Array, mask: NpFloat32Array) -> NpFloat32Array: | |
bg_norm: NpFloat32Array = background; fg_norm: NpFloat32Array = foreground | |
dark_bg: NpUInt8Array = bg_norm < 0.5; light_bg: NpUInt8Array = ~dark_bg | |
overlay_dark: NpFloat32Array = 2.0 * fg_norm * bg_norm | |
overlay_light: NpFloat32Array = 1.0 - 2.0 * (1.0 - fg_norm) * (1.0 - bg_norm) | |
combined: NpFloat32Array = np.where(mask > 1e-6, np.where(dark_bg, overlay_dark, overlay_light), 0.0) | |
result: NpFloat32Array = combined + bg_norm * (1.0 - mask) | |
return np.clip(result, 0.0, 1.0) | |
# <<< End of blend_overlay >>> | |
# <<< overlay_logo_with_sam_remap function - Revised Rotation Logic >>> | |
def overlay_logo_with_sam_remap( | |
logo_path: str, | |
base_path: str, | |
output_path: str, | |
sam_prompt_factor_x: float = 0.5, | |
sam_prompt_factor_y: float = 0.53, | |
scale_factor_rel_mask: float = 0.65, | |
perspective_factor: float = 0.15, | |
top_curve_factor: float = 0.08, | |
rotation_factor: float = 0.30, # Default slightly higher | |
vertical_center_offset: float = 0.01, | |
edge_blur_sigma: float = 0.7, | |
blend_mode: str = "overlay", | |
debug_save: bool = False, | |
device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
**kwargs # Accept unused args | |
) -> None: | |
""" | |
Overlays a logo using SAM guidance and cv2.remap with improved rotation simulation. | |
""" | |
try: | |
# [ Keep initial setup, SAM call, BBox finding, Logo loading the same ] | |
print_(f"Using OpenCV version: [cyan]{cv2.__version__}[/cyan]") | |
base_img_cv: NpUInt8Array | None = cv2.imread(base_path); | |
if base_img_cv is None: raise FileNotFoundError(f"Base image not found: {base_path}") | |
base_h, base_w = base_img_cv.shape[:2] | |
base_img_pil: Image.Image | |
try: print_(f"Loading base image '[cyan]{base_path}[/cyan]' with PIL..."); base_img_pil = Image.open(base_path).convert("RGB") | |
except Exception as e: print_(f"[yellow]Warning:[/yellow] PIL load failed ({e})."); try: img_rgb = cv2.cvtColor(base_img_cv, cv2.COLOR_BGR2RGB); base_img_pil = Image.fromarray(img_rgb) | |
except Exception as ce: print_(f"[red]Error:[/red] Failed loading base image. Error: {ce}"); return | |
prompt_x = int(base_w * sam_prompt_factor_x); prompt_y = int(base_h * sam_prompt_factor_y); input_points: IntPointPrompt = [[[prompt_x, prompt_y]]] | |
print_(f"Running SAM (Prompt Anchor: {prompt_x}, {prompt_y})..."); sam_mask: NpUInt8Array | None = get_sam_mask(base_img_pil, input_points, device) | |
if sam_mask is None: print_("[red]SAM failed.[/red]"); return | |
if debug_save: cv2.imwrite("debug_sam_mask.png", sam_mask); print_("Saved debug_sam_mask.png") | |
if sam_mask.dtype != np.uint8: sam_mask = sam_mask.astype(np.uint8) | |
if len(sam_mask.shape) != 2: sam_mask = np.squeeze(sam_mask) | |
if len(sam_mask.shape) != 2: print_("[red]Error:[/red] SAM mask not 2D."); return | |
sam_mask_processed: NpUInt8Array = sam_mask.copy() | |
print_("Finding contours..."); contours: list[NpUInt8Array]; _; contours, _ = cv2.findContours(sam_mask_processed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
mask_w : int = base_w | |
if not contours: print_("[yellow]Warning:[/yellow] No contours found."); | |
else: largest_contour = max(contours, key=cv2.contourArea); mask_x, mask_y, mask_w, mask_h = cv2.boundingRect(largest_contour); | |
if mask_w <= 1 or mask_h <= 1: print_("[yellow]Warning:[/yellow] SAM mask bbox too small."); mask_w = base_w | |
else: print_(f"SAM mask bbox (for scaling): x={mask_x}, y={mask_y}, w={mask_w}, h={mask_h}") | |
print_(f"Loading logo '[cyan]{logo_path}[/cyan]'..."); logo_img: NpUInt8Array | None = cv2.imread(logo_path, cv2.IMREAD_UNCHANGED) | |
if logo_img is None: try: print_("[yellow]Warning:[/yellow] OpenCV failed load logo."); logo_pil = Image.open(logo_path).convert("RGBA"); logo_img = cv2.cvtColor(np.array(logo_pil), cv2.COLOR_RGBA2BGRA) | |
except Exception as e: raise FileNotFoundError(f"Logo load fail: {e}") | |
if logo_img is None: raise FileNotFoundError(f"Failed loading logo {logo_path}") | |
logo_h_orig, logo_w_orig = logo_img.shape[:2] | |
if len(logo_img.shape) < 3 or logo_img.shape[2] == 1: logo_img = cv2.cvtColor(logo_img, cv2.COLOR_GRAY2BGRA) | |
elif logo_img.shape[2] == 3: logo_img = cv2.cvtColor(logo_img, cv2.COLOR_BGR2BGRA) | |
elif logo_img.shape[2] != 4: raise ValueError(f"Logo unsupported channels: {logo_img.shape[2]}") | |
logo_h, logo_w = logo_img.shape[:2]; logo_bgr: NpUInt8Array = logo_img[:, :, :3]; logo_alpha: NpUInt8Array = logo_img[:, :, 3] | |
print_(f"Logo loaded BGRA. Dims: {logo_w}x{logo_h}") | |
effective_scale_base_width = mask_w; target_logo_w = int(effective_scale_base_width * scale_factor_rel_mask); target_logo_w = max(1, target_logo_w) | |
ratio = target_logo_w / logo_w if logo_w > 0 else 0; target_logo_h = int(logo_h * ratio); target_logo_h = max(1, target_logo_h) | |
center_x = float(prompt_x); center_y = float(prompt_y); print_(f"Target logo dims: {target_logo_w}x{target_logo_h}. ANCHOR (Prompt): ({center_x:.0f}, {center_y:.0f})") | |
# --- Define 4 Destination Corner Points --- | |
print_("Defining 4 destination corner points...") | |
center_y_adjusted = center_y + (target_logo_h * vertical_center_offset) | |
top_width_reduction = target_logo_w * perspective_factor; top_w = max(1, target_logo_w - top_width_reduction); bottom_w = float(target_logo_w) | |
y_curve_offset = target_logo_h * top_curve_factor | |
tl_x = center_x - top_w / 2; tl_y = center_y_adjusted - target_logo_h / 2 + y_curve_offset | |
tr_x = center_x + top_w / 2; tr_y = center_y_adjusted - target_logo_h / 2 + y_curve_offset | |
br_x = center_x + bottom_w / 2; br_y = center_y_adjusted + target_logo_h / 2 | |
bl_x = center_x - bottom_w / 2; bl_y = center_y_adjusted + target_logo_h / 2 | |
dst_pts: PointList = [(tl_x, tl_y), (tr_x, tr_y), (br_x, br_y), (bl_x, bl_y)] | |
dst_pts_np: NpFloat32Array = np.array(dst_pts, dtype=np.float32) | |
src_pts: PointList = [(0, 0), (logo_w - 1, 0), (logo_w - 1, logo_h - 1), (0, logo_h - 1)] | |
src_pts_np: NpFloat32Array = np.array(src_pts, dtype=np.float32) | |
print_(f"Destination Points (4 corners):\n{dst_pts_np}") | |
# --- Calculate Inverse Perspective Matrix --- | |
print_("Calculating inverse perspective matrix...") | |
inv_matrix: NpFloat32Array = cv2.getPerspectiveTransform(dst_pts_np, src_pts_np) | |
# --- Create Remap Grids --- | |
x_coords = dst_pts_np[:, 0]; y_coords = dst_pts_np[:, 1] | |
x_min, x_max = int(np.min(x_coords)), int(np.max(x_coords)) + 1 | |
y_min, y_max = int(np.min(y_coords)), int(np.max(y_coords)) + 1 | |
x_min, y_min = max(0, x_min), max(0, y_min); x_max, y_max = min(base_w, x_max), min(base_h, y_max) | |
print_(f"Remap grid bounding box: x=[{x_min}, {x_max}), y=[{y_min}, {y_max})") | |
if x_min >= x_max or y_min >= y_max: print_("[red]Error:[/red] Degenerate bounding box for warp."); return | |
grid_x_dst, grid_y_dst = np.meshgrid(np.arange(x_min, x_max), np.arange(y_min, y_max)) | |
dst_coords_flat = np.vstack([grid_x_dst.ravel(), grid_y_dst.ravel()]).T | |
dst_coords_hom = np.hstack([dst_coords_flat, np.ones((dst_coords_flat.shape[0], 1))]) | |
# --- Apply Inverse Perspective --- | |
print_("Applying inverse perspective transform..."); src_coords_hom = inv_matrix @ dst_coords_hom.T | |
src_coords_flat = (src_coords_hom[:2] / src_coords_hom[2]).T | |
# --- Apply Non-linear Rotation Adjustment --- # REVISED LOGIC V4.11 | |
print_(f"Applying rotation adjustment (factor={rotation_factor})...") | |
centerline_x_dst = (tl_x + tr_x + bl_x + br_x) / 4.0 # Apparent horizontal center on destination | |
# Normalized distance from center (-1 left edge, 0 center, +1 right edge) | |
dist_x_norm = (dst_coords_flat[:, 0] - centerline_x_dst) / (target_logo_w / 2.0) | |
dist_x_norm = np.clip(dist_x_norm, -1.0, 1.0) # Ensure it stays within bounds | |
src_x = src_coords_flat[:, 0]; src_y = src_coords_flat[:, 1]; src_x_center = logo_w / 2.0 | |
# Calculate scale factor: 1 on left/center, decreases linearly to (1-rotation_factor) on right edge | |
scale = np.where(dist_x_norm > 0, 1.0 - rotation_factor * dist_x_norm, 1.0) | |
scale = np.clip(scale, 0.01, 1.0) # Prevent extreme compression or inversion | |
# Apply scale to the deviation from the source center X | |
src_deviation_x = src_x - src_x_center | |
adjusted_src_x = src_x_center + src_deviation_x * scale | |
# --- End Revised Logic --- | |
# --- Build Final Remap Maps --- | |
map_x: NpFloat32Array = adjusted_src_x.reshape(grid_x_dst.shape).astype(np.float32) | |
map_y: NpFloat32Array = src_y.reshape(grid_y_dst.shape).astype(np.float32) | |
# --- Apply cv2.remap --- | |
print_(f"Applying cv2.remap using calculated maps...") | |
warped_bgr_remap = np.zeros_like(base_img_cv); warped_alpha_remap = np.zeros((base_h, base_w), dtype=np.uint8) | |
warped_bgr_remap[y_min:y_max, x_min:x_max] = cv2.remap(logo_bgr, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0)) | |
warped_alpha_remap[y_min:y_max, x_min:x_max] = cv2.remap(logo_alpha, map_x, map_y, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) | |
print_("Remapping complete."); warped_bgr = warped_bgr_remap; warped_alpha = warped_alpha_remap | |
if debug_save: cv2.imwrite("debug_warped_alpha.png", warped_alpha); print_("Saved debug_warped_alpha.png") | |
# --- Final Mask Creation, Blending, Saving (Same as before) --- | |
print_("Creating final alpha mask..."); warped_alpha_binary: NpUInt8Array | |
_, warped_alpha_binary = cv2.threshold(warped_alpha, 10, 255, cv2.THRESH_BINARY) | |
final_alpha: NpUInt8Array = cv2.bitwise_and(warped_alpha_binary, sam_mask_processed) | |
if cv2.countNonZero(final_alpha) == 0: print_("[bold yellow]Warning:[/bold yellow] Final alpha mask is empty."); | |
if debug_save: cv2.imwrite("debug_final_alpha_empty.png", final_alpha) if cv2.countNonZero(final_alpha) == 0 else cv2.imwrite("debug_final_alpha.png", final_alpha); print_("Saved debug_final_alpha image") | |
if edge_blur_sigma > 0: | |
ksize = int(6 * edge_blur_sigma) // 2 * 2 + 1 | |
if ksize > 0: final_alpha = cv2.GaussianBlur(final_alpha, (ksize, ksize), edge_blur_sigma); print_(f"Applied edge blur sigma={edge_blur_sigma}") | |
else: print_("[yellow]Warning:[/yellow] Blur sigma too small.") | |
#if debug_save and cv2.countNonZero(final_alpha) > 0: cv2.imwrite("debug_final_alpha_blurred.png", final_alpha); print_("Saved debug_final_alpha_blurred image") # Optional: Save after blur too | |
mask_normalized = final_alpha.astype(np.float32) / 255.0 | |
mask_normalized_3c = cv2.merge([mask_normalized, mask_normalized, mask_normalized]) | |
bgr_normalized = warped_bgr.astype(np.float32) / 255.0; base_normalized = base_img_cv.astype(np.float32) / 255.0 | |
print_(f"Blending images using '[cyan]{blend_mode}[/cyan]' mode...") | |
result_normalized: NpFloat32Array | |
if blend_mode.lower() == 'overlay': result_normalized = blend_overlay(base_normalized, bgr_normalized, mask_normalized_3c) | |
else: result_normalized = bgr_normalized * mask_normalized_3c + base_normalized * (1.0 - mask_normalized_3c) | |
result_img: NpUInt8Array = (result_normalized * 255.0).astype(np.uint8); print_("Blending complete.") | |
cv2.imwrite(output_path, result_img) | |
print_(f"[bold green]Successfully warped (Remap) and overlaid logo. Saved to:[/bold green] '[cyan]{output_path}[/cyan]'") | |
except FileNotFoundError as e: print_(f"[red]Error:[/red] {e}") | |
except ImportError as e: print_(f"[red]Error:[/red] Missing libraries? {e}") | |
except cv2.error as e: print_(f"[red]OpenCV Error:[/red] {e}"); console.print_exception(show_locals=True) | |
except Exception as e: print_(f"[bold red]An unexpected error occurred:[/bold red]"); console.print_exception(show_locals=True) | |
# <<< Main execution block - Calls remap function >>> | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Overlay a warped logo using SAM guidance, cv2.remap, and blending.") | |
# [ Arguments are the same as V4.10 ] | |
parser.add_argument("logo_image", help="Path to the logo image file") | |
parser.add_argument("base_image", help="Path to the base image file") | |
parser.add_argument("output_image", help="Path to save the output image") | |
parser.add_argument("--spx", type=float, default=0.5, help="SAM prompt X factor (0-1)") | |
parser.add_argument("--spy", type=float, default=0.53, help="SAM prompt Y factor (0-1, primary placement)") | |
parser.add_argument("--scale_mask", type=float, default=0.65, help="Logo width relative to mask/estimated width") | |
parser.add_argument("--perspective", type=float, default=0.15, help="Perspective horizontal pinch factor for corners") | |
parser.add_argument("--top_curve", type=float, default=0.08, help="Top edge vertical curve factor") | |
parser.add_argument("--rotation_factor", type=float, default=0.30, help="Simulated rotation/compression factor") # Adjusted default | |
parser.add_argument("--v_offset", type=float, default=0.01, help="Vertical center offset factor relative to LOGO HEIGHT") | |
parser.add_argument("--blur", type=float, default=0.7, help="Edge blur sigma (0=none)") | |
parser.add_argument("--blend", default='overlay', choices=['alpha', 'overlay'], help="Blending mode") | |
parser.add_argument("--debug_save", action='store_true', help="Save intermediate debug images") | |
args: argparse.Namespace = parser.parse_args() | |
device_str: str = 'cuda' if torch.cuda.is_available() else 'cpu' | |
if device_str == 'cpu': print_("[yellow]Warning:[/yellow] CUDA not available, SAM on CPU (slow).") | |
overlay_logo_with_sam_remap( # Call the REMAP function | |
args.logo_image, args.base_image, args.output_image, | |
sam_prompt_factor_x=args.spx, sam_prompt_factor_y=args.spy, | |
scale_factor_rel_mask=args.scale_mask, perspective_factor=args.perspective, | |
vertical_center_offset=args.v_offset, edge_blur_sigma=args.blur, | |
blend_mode=args.blend, top_curve_factor=args.top_curve, | |
rotation_factor=args.rotation_factor, | |
debug_save=args.debug_save, | |
device=device_str | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment