Created
February 8, 2025 01:42
-
-
Save bfollington/c6f2b3b245d0d512d0ee70ba64499b4a to your computer and use it in GitHub Desktop.
MLX Whisper + FFMpeg video caption script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# Function to check if a command exists | |
command_exists() { | |
command -v "$1" >/dev/null 2>&1 | |
} | |
# Function to install mlx_whisper using uv | |
install_mlx_whisper() { | |
if command_exists uv; then | |
echo "Installing mlx_whisper using uv..." | |
uv pip install mlx-whisper | |
else | |
echo "Error: uv is not installed." | |
echo "Please visit https://github.com/astral-sh/uv to install uv first" | |
echo "Then run this script again." | |
exit 1 | |
fi | |
} | |
# Function to check dependencies | |
check_dependencies() { | |
local missing_deps=false | |
if ! command_exists mlx_whisper; then | |
echo "mlx_whisper is not installed." | |
read -p "Would you like to install it now? (y/n) " -n 1 -r | |
echo | |
if [[ $REPLY =~ ^[Yy]$ ]]; then | |
install_mlx_whisper | |
else | |
missing_deps=true | |
fi | |
fi | |
if ! command_exists ffmpeg; then | |
echo "ffmpeg is not installed." | |
echo "Please install ffmpeg using your package manager:" | |
echo " macOS: brew install ffmpeg" | |
echo " Ubuntu/Debian: sudo apt install ffmpeg" | |
echo " Other: https://ffmpeg.org/download.html" | |
missing_deps=true | |
fi | |
if [ "$missing_deps" = true ]; then | |
exit 1 | |
fi | |
} | |
# Function to display usage information | |
usage() { | |
echo "Usage: $0 input_file [options]" | |
echo " input_file: Path to video file" | |
echo "Options:" | |
echo " -s, --speed SPEED Playback speed multiplier (default: 1.0)" | |
echo " -q, --quality QUAL Quality preset: lo, med, hi (default: med)" | |
echo " -h, --help Show this help message" | |
exit 1 | |
} | |
# Parse quality preset to FFmpeg parameters | |
get_quality_params() { | |
local quality=$1 | |
case $quality in | |
"lo") | |
echo "-c:v libx264 -crf 28 -preset faster" | |
;; | |
"med") | |
echo "-c:v libx264 -crf 23 -preset medium" | |
;; | |
"hi") | |
echo "-c:v libx264 -crf 18 -preset slower" | |
;; | |
*) | |
echo "Invalid quality preset: $quality" >&2 | |
exit 1 | |
;; | |
esac | |
} | |
# Check dependencies first | |
check_dependencies | |
# Default values | |
speed=1.0 | |
quality="med" | |
# Parse command line arguments | |
POSITIONAL_ARGS=() | |
while [[ $# -gt 0 ]]; do | |
case $1 in | |
-s|--speed) | |
speed="$2" | |
shift 2 | |
;; | |
-q|--quality) | |
quality="$2" | |
shift 2 | |
;; | |
-h|--help) | |
usage | |
;; | |
*) | |
POSITIONAL_ARGS+=("$1") | |
shift | |
;; | |
esac | |
done | |
set -- "${POSITIONAL_ARGS[@]}" | |
# Check minimum arguments | |
if [ "$#" -lt 1 ]; then | |
usage | |
fi | |
# Validate speed | |
if ! [[ $speed =~ ^[0-9]*\.?[0-9]+$ ]]; then | |
echo "Error: Speed must be a positive number" | |
exit 1 | |
fi | |
# Get the absolute paths | |
working_dir=$(pwd) | |
input_file=$(realpath "$1") | |
input_dir=$(dirname "$input_file") | |
filename=$(basename "$input_file") | |
filename_noext="${filename%.*}" | |
temp_dir="/tmp/video_processing_${filename_noext}" | |
output_file="${input_dir}/${filename_noext}_captioned.mp4" | |
srt_file="${input_dir}/${filename_noext}.srt" | |
raw_output="${temp_dir}/raw_whisper_output.txt" | |
# Get quality parameters | |
quality_params=$(get_quality_params "$quality") | |
# Log configuration | |
echo "Configuration:" | |
echo "Current directory: $working_dir" | |
echo "Input file: $input_file" | |
echo "Input directory: $input_dir" | |
echo "Filename: $filename" | |
echo "Filename without extension: $filename_noext" | |
echo "Speed: $speed" | |
echo "Quality preset: $quality" | |
echo "Output file: $output_file" | |
echo "SRT file: $srt_file" | |
echo "Temp directory: $temp_dir" | |
echo "" | |
# Create temporary directory | |
mkdir -p "$temp_dir" | |
# Function to clean up temporary files | |
cleanup() { | |
rm -rf "$temp_dir" | |
} | |
trap cleanup EXIT | |
cd "$input_dir" | |
echo "Step 1: Generating captions using MLX Whisper..." | |
# Run mlx_whisper and save output to SRT file | |
mlx_whisper "$input_file" \ | |
--language en \ | |
--word-timestamps True \ | |
--highlight-words True \ | |
--output-format srt > "$raw_output" | |
# Process the output using our Python script | |
echo "Processing script directory: ${working_dir}" | |
python3 "${working_dir}/format_srt.py" "$speed" < "$raw_output" > "$srt_file" | |
# Check if caption generation was successful | |
if [ ! -f "$srt_file" ] || [ ! -s "$srt_file" ]; then | |
echo "Error: Caption generation failed" | |
exit 1 | |
fi | |
echo "Step 2: Processing video with captions..." | |
# Get the source frame rate using ffprobe | |
source_fps=$(ffprobe -v 0 -select_streams v:0 -show_entries stream=r_frame_rate -of csv=p=0 "$input_file" | bc) | |
# Check if ffprobe was successful | |
if [ -z "$source_fps" ]; then | |
echo "Failed to get the frame rate from the input file." | |
exit 1 | |
fi | |
# Render video with burned-in captions | |
ffmpeg -i "$input_file" \ | |
-vf "setpts=PTS/${speed},subtitles=${srt_file}:force_style='Fontname=SF\\ Mono,FontSize=16,Alignment=2,BorderStyle=3,Outline=1,Shadow=0,MarginV=10'" \ | |
-r "${source_fps}" \ | |
$quality_params \ | |
-c:a aac -b:a 128k \ | |
-af "atempo=${speed}" \ | |
-movflags +faststart \ | |
"$output_file" | |
# Check if encoding was successful | |
if [ $? -eq 0 ]; then | |
echo "Processing completed successfully!" | |
echo "Output file: $output_file" | |
echo "Caption file: ${filename_noext}.srt" | |
else | |
echo "Error: Video processing failed" | |
exit 1 | |
fi |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys | |
import re | |
from datetime import timedelta | |
def format_timestamp(seconds): | |
td = timedelta(seconds=float(seconds)) | |
hours = int(td.total_seconds() // 3600) | |
minutes = int((td.total_seconds() % 3600) // 60) | |
seconds = td.total_seconds() % 60 | |
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',') | |
def split_text_into_lines(text, max_chars=42): | |
words = text.split() | |
lines = [] | |
current_line = [] | |
current_length = 0 | |
for word in words: | |
if current_length + len(word) + 1 > max_chars: | |
lines.append(' '.join(current_line)) | |
current_line = [word] | |
current_length = len(word) | |
else: | |
current_line.append(word) | |
current_length += len(word) + 1 | |
if current_line: | |
lines.append(' '.join(current_line)) | |
return lines | |
def process_whisper_output(input_str, speed=1.0): | |
lines = input_str.strip().split('\n') | |
segments = [] | |
current_segment = None | |
for line in lines: | |
if '[' in line and ']' in line: | |
try: | |
time_str, text = line.split(']', 1) | |
time_str = time_str.strip('[') | |
if '-->' not in time_str: | |
continue | |
start_str, end_str = time_str.split('-->') | |
def parse_time(t): | |
parts = t.strip().split(':') | |
if len(parts) == 2: # MM:SS format | |
minutes, seconds = parts | |
return float(minutes) * 60 + float(seconds) | |
else: # Just seconds | |
return float(parts[0]) | |
start_time = parse_time(start_str) / float(speed) | |
end_time = parse_time(end_str) / float(speed) | |
# Ensure minimum duration | |
if end_time - start_time < 1.0: | |
end_time = start_time + 1.0 | |
text = text.strip() | |
# Remove filler words | |
text = re.sub(r'\b(um|uh)\b', '', text).strip() | |
if text: # Only add segments with actual text | |
segments.append({ | |
'start': start_time, | |
'end': end_time, | |
'text': text | |
}) | |
except (ValueError, IndexError) as e: | |
print(f"Error parsing line: {line}", file=sys.stderr) | |
print(f"Error details: {e}", file=sys.stderr) | |
continue | |
# Generate SRT format with proper text splitting | |
srt_output = [] | |
for i, segment in enumerate(segments, 1): | |
# Split text into lines if it's too long | |
text_lines = split_text_into_lines(segment['text']) | |
srt_output.extend([ | |
str(i), | |
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}", | |
'\n'.join(text_lines), | |
'' | |
]) | |
return '\n'.join(srt_output) | |
if __name__ == '__main__': | |
speed = float(sys.argv[1]) if len(sys.argv) > 1 else 1.0 | |
input_text = sys.stdin.read() | |
print(process_whisper_output(input_text, speed)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment