Skip to content

Instantly share code, notes, and snippets.

@briva
Created August 16, 2024 07:53
Show Gist options
  • Save briva/2c32bafcce3480b9ffee9eaf12bb58d1 to your computer and use it in GitHub Desktop.
Save briva/2c32bafcce3480b9ffee9eaf12bb58d1 to your computer and use it in GitHub Desktop.
Basic remove backgound on video
import os
import torch
import numpy as np
import cv2
from sam2.build_sam import build_sam2_video_predictor
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
# Set up the environment and model
home = os.getcwd()
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
checkpoint = "sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
# Convert video to frames
input_video = "input/source.mp4"
frames_dir = "frames"
os.makedirs(frames_dir, exist_ok=True)
os.system(f"ffmpeg -i {input_video} -q:v 2 -start_number 0 {frames_dir}/%05d.jpg")
# Initialize the model
video_dir = "./frames"
frame_names = [p for p in os.listdir(video_dir) if p.endswith(('.jpg', '.jpeg', '.JPG', '.JPEG'))]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=video_dir)
# Get dimensions of the first frame to calculate center
first_frame = cv2.imread(os.path.join(video_dir, frame_names[0]))
height, width, _ = first_frame.shape
center_x, center_y = width // 2, height // 2
# Add initial point for segmentation (center of the frame)
ann_frame_idx = 0
ann_obj_id = 1
points = np.array([[center_x, center_y]], dtype=np.float32)
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# Propagate masks to all frames
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: out_mask_logits[i].cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
def remove_background(frame, mask, bg_color=(0, 255, 0)): # Green background
mask = mask.squeeze()
if mask.dtype == bool:
mask = mask.astype(np.uint8) * 255
else:
mask = (mask > 0).astype(np.uint8) * 255
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
# Create a background filled with bg_color
bg = np.full(frame.shape, bg_color, dtype=np.uint8)
# Combine the frame and background based on the mask
result = np.where(mask[:, :, None] == 255, frame, bg)
return result
# Process video and save output
output_video_path = 'output_video.mp4'
frame_rate = 30
# Use H.264 codec
fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height), isColor=True)
if not out.isOpened():
logging.error("VideoWriter failed to open")
else:
logging.info("VideoWriter opened successfully")
frame_count = 0
for out_frame_idx in range(len(frame_names)):
frame_path = os.path.join(video_dir, frame_names[out_frame_idx])
frame = cv2.imread(frame_path)
if frame is None:
logging.error(f"Failed to read frame: {frame_path}")
continue
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
frame_with_bg_removed = remove_background(frame, out_mask)
if out.isOpened():
out.write(frame_with_bg_removed)
frame_count += 1
else:
logging.error("VideoWriter is not opened")
break
out.release()
logging.info(f"Processed {frame_count} frames")
logging.info(f"Background removed video saved as {output_video_path}")
# Verify the output file
if os.path.exists(output_video_path):
file_size = os.path.getsize(output_video_path)
logging.info(f"Output file size: {file_size} bytes")
else:
logging.error("Output file was not created")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment