Created
April 10, 2025 09:19
-
-
Save eek/dcba1605ab75f2c8c3c2c03ca9d266e8 to your computer and use it in GitHub Desktop.
This script extracts frames from videos and generates descriptions using the Kimi-VL-A3B model. It takes the following arguments: video_path (required): Path to the input video file --max_frames (default=1): Maximum number of frames to extract --save_dir (default="./test-frames"): Directory to save extracted frames --prompt (default="Describe th…
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import argparse | |
import torch | |
import os # Added import | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
# Function to extract frames from video, save them, and return paths | |
def extract_frames(video_path, save_dir, target_fps=1, max_frames=1): | |
"""Extracts up to max_frames from a video file at target FPS, saves them, and returns their paths.""" | |
frame_paths = [] | |
# Create save directory if it doesn't exist | |
os.makedirs(save_dir, exist_ok=True) | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print(f"Error: Could not open video file: {video_path}") | |
return frame_paths | |
video_fps = cap.get(cv2.CAP_PROP_FPS) | |
if video_fps <= 0: | |
print(f"Warning: Could not get video FPS for {video_path}. Assuming 30 FPS.") | |
video_fps = 30 | |
frame_interval = max(1, int(round(video_fps / target_fps))) | |
frame_count = 0 | |
saved_frame_count = 0 | |
while cap.isOpened() and saved_frame_count < max_frames: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_count % frame_interval == 0: | |
# Convert BGR (OpenCV) to RGB (PIL) | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
img = Image.fromarray(frame_rgb) | |
# Save frame to disk | |
frame_filename = f"frame_{saved_frame_count:04d}.png" | |
frame_filepath = os.path.join(save_dir, frame_filename) | |
try: | |
img.save(frame_filepath) | |
frame_paths.append(frame_filepath) | |
saved_frame_count += 1 | |
except Exception as save_e: | |
print(f"Error saving frame {frame_filepath}: {save_e}") | |
# Optionally continue to next frame or break | |
frame_count += 1 | |
cap.release() | |
print(f"Extracted and saved {len(frame_paths)} frames to {save_dir} from {video_path} at ~{target_fps} FPS.") | |
except Exception as e: | |
print(f"An error occurred during frame extraction: {e}") | |
if 'cap' in locals() and cap.isOpened(): | |
cap.release() | |
return frame_paths | |
# --- Argument Parsing --- | |
parser = argparse.ArgumentParser(description="Describe video frame(s) using Kimi-VL model.") | |
parser.add_argument("video_path", help="Path to the input video file.") | |
parser.add_argument("--max_frames", type=int, default=1, help="Maximum number of frames to extract and process.") | |
parser.add_argument("--save_dir", default="./test-frames", help="Directory to save extracted frames.") | |
parser.add_argument("--prompt", default="Describe this video", help="Text prompt for the model.") | |
parser.add_argument("--target_fps", type=float, default=1.0, help="Target frames per second for extraction.") | |
args = parser.parse_args() | |
# --- Frame Extraction & Saving --- | |
extracted_frame_paths = extract_frames(args.video_path, args.save_dir, target_fps=args.target_fps, max_frames=args.max_frames) | |
if not extracted_frame_paths: | |
print("No frames were extracted or saved. Exiting.") | |
exit() | |
# --- Model Loading --- | |
model_path = "moonshotai/Kimi-VL-A3B-Instruct" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
# --- Prepare Input for Model --- | |
# Load images from the saved paths | |
loaded_images = [] | |
for frame_path in extracted_frame_paths: | |
if os.path.exists(frame_path): | |
loaded_images.append(Image.open(frame_path)) | |
else: | |
print(f"Warning: Saved frame path not found: {frame_path}") | |
print(f'Loaded {len(loaded_images)} images for processing.') | |
# Construct the messages list | |
content = [] | |
for frame_path in extracted_frame_paths: | |
content.append({"type": "image", "image": frame_path}) | |
content.append({"type": "text", "text": args.prompt}) | |
messages = [ | |
{"role": "user", "content": content} | |
] | |
print('Messages:') | |
print(messages) | |
# Process text and image paths | |
text = processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") | |
# --- Model Generation --- | |
try: | |
# Use the loaded images in the processor call | |
inputs = processor( | |
images=loaded_images[0] if len(loaded_images) == 1 else loaded_images, | |
text=text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True | |
).to(model.device) | |
generated_ids = model.generate(**inputs, max_new_tokens=512) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
response = processor.batch_decode( | |
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
)[0] | |
print("--- Model Response ---") | |
print(response) | |
except Exception as e: | |
print(f"An error occurred during model generation: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment