Created
August 16, 2024 07:53
-
-
Save briva/2c32bafcce3480b9ffee9eaf12bb58d1 to your computer and use it in GitHub Desktop.
Basic remove backgound on video
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import numpy as np | |
import cv2 | |
from sam2.build_sam import build_sam2_video_predictor | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
# Set up the environment and model | |
home = os.getcwd() | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
checkpoint = "sam2_hiera_large.pt" | |
model_cfg = "sam2_hiera_l.yaml" | |
predictor = build_sam2_video_predictor(model_cfg, checkpoint) | |
# Convert video to frames | |
input_video = "input/source.mp4" | |
frames_dir = "frames" | |
os.makedirs(frames_dir, exist_ok=True) | |
os.system(f"ffmpeg -i {input_video} -q:v 2 -start_number 0 {frames_dir}/%05d.jpg") | |
# Initialize the model | |
video_dir = "./frames" | |
frame_names = [p for p in os.listdir(video_dir) if p.endswith(('.jpg', '.jpeg', '.JPG', '.JPEG'))] | |
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) | |
inference_state = predictor.init_state(video_path=video_dir) | |
# Get dimensions of the first frame to calculate center | |
first_frame = cv2.imread(os.path.join(video_dir, frame_names[0])) | |
height, width, _ = first_frame.shape | |
center_x, center_y = width // 2, height // 2 | |
# Add initial point for segmentation (center of the frame) | |
ann_frame_idx = 0 | |
ann_obj_id = 1 | |
points = np.array([[center_x, center_y]], dtype=np.float32) | |
labels = np.array([1], np.int32) | |
_, out_obj_ids, out_mask_logits = predictor.add_new_points( | |
inference_state=inference_state, | |
frame_idx=ann_frame_idx, | |
obj_id=ann_obj_id, | |
points=points, | |
labels=labels, | |
) | |
# Propagate masks to all frames | |
video_segments = {} | |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): | |
video_segments[out_frame_idx] = { | |
out_obj_id: out_mask_logits[i].cpu().numpy() | |
for i, out_obj_id in enumerate(out_obj_ids) | |
} | |
def remove_background(frame, mask, bg_color=(0, 255, 0)): # Green background | |
mask = mask.squeeze() | |
if mask.dtype == bool: | |
mask = mask.astype(np.uint8) * 255 | |
else: | |
mask = (mask > 0).astype(np.uint8) * 255 | |
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) | |
# Create a background filled with bg_color | |
bg = np.full(frame.shape, bg_color, dtype=np.uint8) | |
# Combine the frame and background based on the mask | |
result = np.where(mask[:, :, None] == 255, frame, bg) | |
return result | |
# Process video and save output | |
output_video_path = 'output_video.mp4' | |
frame_rate = 30 | |
# Use H.264 codec | |
fourcc = cv2.VideoWriter_fourcc(*'avc1') | |
out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height), isColor=True) | |
if not out.isOpened(): | |
logging.error("VideoWriter failed to open") | |
else: | |
logging.info("VideoWriter opened successfully") | |
frame_count = 0 | |
for out_frame_idx in range(len(frame_names)): | |
frame_path = os.path.join(video_dir, frame_names[out_frame_idx]) | |
frame = cv2.imread(frame_path) | |
if frame is None: | |
logging.error(f"Failed to read frame: {frame_path}") | |
continue | |
for out_obj_id, out_mask in video_segments[out_frame_idx].items(): | |
frame_with_bg_removed = remove_background(frame, out_mask) | |
if out.isOpened(): | |
out.write(frame_with_bg_removed) | |
frame_count += 1 | |
else: | |
logging.error("VideoWriter is not opened") | |
break | |
out.release() | |
logging.info(f"Processed {frame_count} frames") | |
logging.info(f"Background removed video saved as {output_video_path}") | |
# Verify the output file | |
if os.path.exists(output_video_path): | |
file_size = os.path.getsize(output_video_path) | |
logging.info(f"Output file size: {file_size} bytes") | |
else: | |
logging.error("Output file was not created") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment