Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save eek/dcba1605ab75f2c8c3c2c03ca9d266e8 to your computer and use it in GitHub Desktop.
Save eek/dcba1605ab75f2c8c3c2c03ca9d266e8 to your computer and use it in GitHub Desktop.
This script extracts frames from videos and generates descriptions using the Kimi-VL-A3B model. It takes the following arguments: video_path (required): Path to the input video file --max_frames (default=1): Maximum number of frames to extract --save_dir (default="./test-frames"): Directory to save extracted frames --prompt (default="Describe th…
import cv2
import argparse
import torch
import os # Added import
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
# Function to extract frames from video, save them, and return paths
def extract_frames(video_path, save_dir, target_fps=1, max_frames=1):
"""Extracts up to max_frames from a video file at target FPS, saves them, and returns their paths."""
frame_paths = []
# Create save directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file: {video_path}")
return frame_paths
video_fps = cap.get(cv2.CAP_PROP_FPS)
if video_fps <= 0:
print(f"Warning: Could not get video FPS for {video_path}. Assuming 30 FPS.")
video_fps = 30
frame_interval = max(1, int(round(video_fps / target_fps)))
frame_count = 0
saved_frame_count = 0
while cap.isOpened() and saved_frame_count < max_frames:
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
# Convert BGR (OpenCV) to RGB (PIL)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame_rgb)
# Save frame to disk
frame_filename = f"frame_{saved_frame_count:04d}.png"
frame_filepath = os.path.join(save_dir, frame_filename)
try:
img.save(frame_filepath)
frame_paths.append(frame_filepath)
saved_frame_count += 1
except Exception as save_e:
print(f"Error saving frame {frame_filepath}: {save_e}")
# Optionally continue to next frame or break
frame_count += 1
cap.release()
print(f"Extracted and saved {len(frame_paths)} frames to {save_dir} from {video_path} at ~{target_fps} FPS.")
except Exception as e:
print(f"An error occurred during frame extraction: {e}")
if 'cap' in locals() and cap.isOpened():
cap.release()
return frame_paths
# --- Argument Parsing ---
parser = argparse.ArgumentParser(description="Describe video frame(s) using Kimi-VL model.")
parser.add_argument("video_path", help="Path to the input video file.")
parser.add_argument("--max_frames", type=int, default=1, help="Maximum number of frames to extract and process.")
parser.add_argument("--save_dir", default="./test-frames", help="Directory to save extracted frames.")
parser.add_argument("--prompt", default="Describe this video", help="Text prompt for the model.")
parser.add_argument("--target_fps", type=float, default=1.0, help="Target frames per second for extraction.")
args = parser.parse_args()
# --- Frame Extraction & Saving ---
extracted_frame_paths = extract_frames(args.video_path, args.save_dir, target_fps=args.target_fps, max_frames=args.max_frames)
if not extracted_frame_paths:
print("No frames were extracted or saved. Exiting.")
exit()
# --- Model Loading ---
model_path = "moonshotai/Kimi-VL-A3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# --- Prepare Input for Model ---
# Load images from the saved paths
loaded_images = []
for frame_path in extracted_frame_paths:
if os.path.exists(frame_path):
loaded_images.append(Image.open(frame_path))
else:
print(f"Warning: Saved frame path not found: {frame_path}")
print(f'Loaded {len(loaded_images)} images for processing.')
# Construct the messages list
content = []
for frame_path in extracted_frame_paths:
content.append({"type": "image", "image": frame_path})
content.append({"type": "text", "text": args.prompt})
messages = [
{"role": "user", "content": content}
]
print('Messages:')
print(messages)
# Process text and image paths
text = processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
# --- Model Generation ---
try:
# Use the loaded images in the processor call
inputs = processor(
images=loaded_images[0] if len(loaded_images) == 1 else loaded_images,
text=text,
return_tensors="pt",
padding=True,
truncation=True
).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print("--- Model Response ---")
print(response)
except Exception as e:
print(f"An error occurred during model generation: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment