Skip to content

Instantly share code, notes, and snippets.

@justinledwards
Last active August 28, 2025 19:35
Show Gist options
  • Select an option

  • Save justinledwards/d6bd4caeaf620bb0d7528604e324449b to your computer and use it in GitHub Desktop.

Select an option

Save justinledwards/d6bd4caeaf620bb0d7528604e324449b to your computer and use it in GitHub Desktop.
white noise analysis
#!/usr/bin/env python3
"""
audio_whiteness_batch_v2.py
Batch detector that flags files which are not white noise, with better
handling of decode errors and blank regions. It
1) tries robust ffmpeg decode options
2) removes long silent or near silent spans before PSD analysis
3) requires a minimum usable coverage ratio
Default scan: *.trm *.m4a *.mp3 in the current directory.
Usage
python3 audio_whiteness_batch_v2.py
python3 audio_whiteness_batch_v2.py --recursive
python3 audio_whiteness_batch_v2.py --threshold 0.70 --ratio 0.80 --min-coverage 0.60
python3 audio_whiteness_batch_v2.py --list-only
Notes
- Temp WAVs are deleted by default
- Set --keep to retain them
"""
from __future__ import annotations
import argparse
import json
import math
import os
import re
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
from scipy.io import wavfile
from scipy.signal import welch
# ---------------- ffmpeg helpers ----------------
def check_tool(name: str) -> None:
try:
subprocess.run([name, "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except Exception:
print(f"Error: {name} not found on PATH. Please install ffmpeg suite.", file=sys.stderr)
sys.exit(1)
def ffprobe_audio_info(input_path: Path) -> Dict:
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "a:0",
"-show_entries", "stream=index,channels,sample_rate,channel_layout",
"-of", "json", str(input_path),
]
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
raise RuntimeError(f"ffprobe failed on {input_path}")
data = json.loads(res.stdout or "{}")
streams = data.get("streams", [])
if not streams:
raise RuntimeError(f"No audio stream found in {input_path}")
return streams[0]
DECODER_PATTERNS = [
re.compile(r"Error submitting a packet", re.I),
re.compile(r"Invalid data", re.I),
re.compile(r"scalefactor bands", re.I),
re.compile(r"corrupt", re.I),
]
def decode_multichannel_wav_robust(inpath: Path, out_mc_wav: Path, stream_index: int) -> Dict:
"""Decode a:stream_index to PCM WAV with retries and capture warnings.
Returns a dict with stderr text and a simple error count.
"""
tries = [
[
"ffmpeg", "-hide_banner", "-loglevel", "warning", "-y",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-c:a", "pcm_s16le",
str(out_mc_wav),
],
[
"ffmpeg", "-hide_banner", "-loglevel", "warning", "-y",
"-err_detect", "ignore_err", "-fflags", "+discardcorrupt",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-c:a", "pcm_s16le",
str(out_mc_wav),
],
[
"ffmpeg", "-hide_banner", "-loglevel", "warning", "-y",
"-analyzeduration", "100M", "-probesize", "100M",
"-err_detect", "ignore_err", "-fflags", "+discardcorrupt",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-c:a", "pcm_s16le",
str(out_mc_wav),
],
]
last_stderr = ""
for idx, cmd in enumerate(tries, 1):
res = subprocess.run(cmd, capture_output=True, text=True)
last_stderr = (res.stderr or "")
if res.returncode == 0 and out_mc_wav.exists() and out_mc_wav.stat().st_size > 44:
break
if idx == len(tries):
raise RuntimeError(f"ffmpeg failed to decode {inpath}\n{last_stderr.strip()}")
err_count = 0
for pat in DECODER_PATTERNS:
err_count += len(pat.findall(last_stderr))
return {"stderr": last_stderr, "err_count": err_count}
# ---------------- analysis helpers ----------------
def spectral_flatness_vector(psd: np.ndarray, eps: float = 1e-20) -> float:
psd_clip = np.maximum(psd, eps)
gmean = np.exp(np.mean(np.log(psd_clip)))
amean = float(np.mean(psd_clip))
return float(gmean / amean) if amean > 0 else 0.0
def find_passband(freqs: np.ndarray, Pxx: np.ndarray) -> Tuple[float, float, np.ndarray]:
fmin_guard = max(50.0, freqs[1])
fmax_guard = 0.95 * freqs.max()
power = Pxx.copy(); power[~np.isfinite(power)] = 0.0
c = np.cumsum(power)
if c[-1] <= 0:
mask = (freqs >= fmin_guard) & (freqs <= fmax_guard)
return fmin_guard, fmax_guard, mask
c_norm = c / c[-1]
low_idx = int(np.searchsorted(c_norm, 0.05))
high_idx = int(np.searchsorted(c_norm, 0.95))
f_lo = max(fmin_guard, float(freqs[max(1, low_idx)]))
f_hi = min(fmax_guard, float(freqs[min(len(freqs)-1, high_idx)]))
if f_hi <= f_lo:
f_lo, f_hi = fmin_guard, fmax_guard
mask = (freqs >= f_lo) & (freqs <= f_hi)
return f_lo, f_hi, mask
def band_evenness(freqs: np.ndarray, Pxx: np.ndarray, f_lo: float, f_hi: float, k: int = 12) -> float:
f_edges = np.geomspace(max(1.0, f_lo), max(f_lo * 1.0001, f_hi), k + 1)
band_vals = []
for i in range(k):
m = (freqs >= f_edges[i]) & (freqs < f_edges[i+1])
if not np.any(m):
continue
band_vals.append(float(np.mean(Pxx[m])))
if len(band_vals) < 3:
return 0.0
arr = np.array(band_vals)
arr[arr <= 0] = np.min(arr[arr > 0]) if np.any(arr > 0) else 1e-20
arr /= np.mean(arr)
std = float(np.std(arr))
return float(1.0 / (1.0 + 3.0 * std))
def slope_in_band(freqs: np.ndarray, Pxx: np.ndarray, mask: np.ndarray) -> float:
f = freqs[mask]
p = Pxx[mask]
m = (f > 0) & np.isfinite(p) & (p > 0)
if np.count_nonzero(m) < 16:
return 0.0
logf = np.log10(f[m])
logp = np.log10(p[m])
slope, _ = np.polyfit(logf, logp, 1)
return float(slope)
def whiteness_probability(sf: float, slope: float, even: float) -> float:
slope_score = math.exp(-abs(slope) / 0.25)
sf_score = max(0.0, min(1.0, sf))
even_score = max(0.0, min(1.0, even))
return float(0.4 * sf_score + 0.3 * slope_score + 0.3 * even_score)
# ---------------- channel analysis ----------------
def remove_silence(x: np.ndarray, sr: int, win_s: float = 0.05, thr_rel: float = 0.2, min_seg_s: float = 0.10) -> Tuple[np.ndarray, float]:
"""Remove near silent spans and return the concatenated active audio plus coverage.
thr_rel is relative to the 75th percentile of a smoothed envelope.
"""
if x.size == 0:
return x, 0.0
if x.dtype.kind in ("i", "u"):
if x.dtype == np.int16:
xf = x.astype(np.float32) / 32768.0
else:
max_abs = float(np.max(np.abs(x)) or 1)
xf = x.astype(np.float32) / max_abs
else:
xf = x.astype(np.float32)
L = max(8, int(sr * win_s))
kernel = np.ones(L, dtype=np.float32) / L
env = np.convolve(np.abs(xf), kernel, mode="same")
base = np.percentile(env, 75)
thr = max(1e-4, thr_rel * base)
mask = env > thr
# collect contiguous active segments of at least min_seg_s
min_len = max(8, int(sr * min_seg_s))
active_segments = []
i = 0
N = len(mask)
while i < N:
if mask[i]:
j = i
while j < N and mask[j]:
j += 1
if (j - i) >= min_len:
active_segments.append(xf[i:j])
i = j
else:
i += 1
if not active_segments:
return np.array([], dtype=np.float32), 0.0
x_active = np.concatenate(active_segments)
coverage = float(len(x_active) / len(xf))
return x_active, coverage
def analyze_channel_vector(x: np.ndarray, sr: int) -> Dict:
x_active, coverage = remove_silence(x, sr)
if x_active.size < sr // 2: # require at least half a second of usable audio
return {
"sample_rate": sr,
"coverage": coverage,
"flat_in": 0.0,
"slope_in": 0.0,
"even_in": 0.0,
"white_prob": 0.0,
"label": "insufficient audio",
}
# Welch PSD on active audio only
nperseg = max(2048, min(16384, 1 << int(np.floor(np.log2(sr)))))
noverlap = nperseg // 2
freqs, Pxx = welch(
x_active, fs=sr, window="hann", nperseg=nperseg, noverlap=noverlap,
detrend="constant", scaling="density", average="mean"
)
f_lo, f_hi, mask = find_passband(freqs, Pxx)
sf = spectral_flatness_vector(Pxx[mask])
slope = slope_in_band(freqs, Pxx, mask)
even = band_evenness(freqs, Pxx, f_lo, f_hi)
prob = whiteness_probability(sf, slope, even)
return {
"sample_rate": sr,
"coverage": coverage,
"flat_in": sf,
"slope_in": slope,
"even_in": even,
"white_prob": prob,
"label": "likely white" if prob > 0.7 and abs(slope) < 0.3 else "uncertain",
}
# ---------------- per file ----------------
def analyze_file(inpath: Path, tmpdir: Path) -> Dict:
info = ffprobe_audio_info(inpath)
stream_index = int(info.get("index", 0))
base = inpath.stem
mc_wav = tmpdir / f"{base}_all.wav"
decinfo = decode_multichannel_wav_robust(inpath, mc_wav, stream_index)
sr, data = wavfile.read(str(mc_wav))
if data.ndim == 1:
data = data[:, None]
ch_rows = []
for i in range(data.shape[1]):
m = analyze_channel_vector(data[:, i], sr)
ch_rows.append(m)
probs = np.array([r["white_prob"] for r in ch_rows])
labels = [r["label"] for r in ch_rows]
covers = np.array([r["coverage"] for r in ch_rows])
frac_likely = float(np.mean([1.0 if lab == "likely white" else 0.0 for lab in labels]))
median_prob = float(np.median(probs))
median_cov = float(np.median(covers))
return {
"file": str(inpath.name),
"channels": data.shape[1],
"median_prob": median_prob,
"frac_likely": frac_likely,
"median_coverage": median_cov,
"per_channel": ch_rows,
"tmp_mc_wav": str(mc_wav),
"decode_errs": decinfo.get("err_count", 0),
}
# ---------------- batch ----------------
def find_candidate_files(root: Path, exts: List[str], recursive: bool) -> List[Path]:
files = []
patterns = [f"**/*.{e}" if recursive else f"*.{e}" for e in exts]
for pat in patterns:
files.extend(root.glob(pat))
files = [p for p in files if not p.name.endswith("_all.wav")]
return sorted(set(files))
def main():
ap = argparse.ArgumentParser(description="Batch white noise detector with decode retries and silence skipping")
ap.add_argument("--exts", nargs="*", default=["trm", "m4a", "mp3"], help="File extensions to scan")
ap.add_argument("--threshold", type=float, default=0.70, help="Median probability threshold for white classification")
ap.add_argument("--ratio", type=float, default=0.80, help="Minimum fraction of channels labeled likely white")
ap.add_argument("--min-coverage", type=float, default=0.60, help="Require at least this median coverage of usable audio")
ap.add_argument("--recursive", action="store_true", help="Recurse into subdirectories")
ap.add_argument("--outdir", type=str, default=".whiteness_tmp", help="Temp working directory for decoded WAVs")
ap.add_argument("--keep", action="store_true", help="Keep intermediate WAVs")
ap.add_argument("--list-only", action="store_true", help="Only print the list of non white files")
args = ap.parse_args()
check_tool("ffprobe"); check_tool("ffmpeg")
root = Path.cwd()
tmpdir = Path(args.outdir)
tmpdir.mkdir(parents=True, exist_ok=True)
files = find_candidate_files(root, [e.lower() for e in args.exts], args.recursive)
if not files:
print("No files found.")
return
non_white: List[str] = []
if not args.list_only:
print("\nScanning files\n")
print("file ch median_prob frac_likely med_cov errs result")
print("-" * 96)
for f in files:
try:
res = analyze_file(f, tmpdir)
is_white = (
res["median_prob"] >= args.threshold
and res["frac_likely"] >= args.ratio
and res["median_coverage"] >= args.min_coverage
)
result = "white" if is_white else "not white"
if not args.list_only:
print(f"{f.name[:38]:<38} {res['channels']:>2} {res['median_prob']:.3f} {res['frac_likely']:.2f} {res['median_coverage']:.2f} {res['decode_errs']:>3} {result}")
if not is_white:
non_white.append(f.name)
except Exception as e:
if not args.list_only:
print(f"{f.name[:38]:<38} -- -- -- -- -- error: {e}")
non_white.append(f.name)
finally:
if not args.keep:
try:
for p in tmpdir.glob(f"{f.stem}_all.wav"):
p.unlink()
except Exception:
pass
if non_white:
print("\nNot white noise files:")
for n in non_white:
print(f" {n}")
else:
print("\nAll scanned files look like white noise under current thresholds.")
if not args.keep:
try:
if not any(Path(args.outdir).iterdir()):
Path(args.outdir).rmdir()
except Exception:
pass
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""
audio_whiteness_batch_v6_pipe.py
Streaming, low-RAM batch white-noise detector
- Pipes raw PCM from ffmpeg. No WAV headers, no giant arrays
- Parallel, one file per worker
- Time based classification with smoothing to ignore tiny blips
- Reports coarse non-white intervals
- Summary shows percent of files all white and total white minutes
Key tuning flags
--white-time 0.95 time ratio threshold to call a file all white
--min-nonwhite-sec 1.0 fill non-white gaps shorter than this
--min-white-sec 0.5 remove white blips shorter than this
--warmup-sec 2.0 duration used to learn passband
Usage
python3 audio_whiteness_batch_v6_pipe.py
python3 audio_whiteness_batch_v6_pipe.py --workers 4 --recursive
python3 audio_whiteness_batch_v6_pipe.py --white-time 0.95 --min-nonwhite-sec 1.5 --min-white-sec 0.5
Requirements: Python 3.9+, ffmpeg, ffprobe, numpy
"""
from __future__ import annotations
import argparse
import concurrent.futures as cf
import json
import math
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Sequence, Tuple
import numpy as np
# --------------- tools ---------------
def check_tool(name: str) -> None:
try:
subprocess.run([name, "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except Exception:
print(f"Error: {name} not found on PATH. Please install ffmpeg suite.", file=sys.stderr)
sys.exit(1)
def ffprobe_info(input_path: Path) -> Dict:
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "a:0",
"-show_entries", "stream=index,channels,sample_rate,channel_layout",
"-of", "json", str(input_path),
]
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
raise RuntimeError(f"ffprobe failed on {input_path}")
data = json.loads(res.stdout or "{}")
streams = data.get("streams", [])
if not streams:
raise RuntimeError(f"No audio stream found in {input_path}")
return streams[0]
def open_ffmpeg_pcm_pipe(inpath: Path, stream_index: int, sr: int, ch: int) -> subprocess.Popen:
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-err_detect", "ignore_err", "-fflags", "+discardcorrupt",
"-analyzeduration", "100M", "-probesize", "100M",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-ac", str(ch), "-ar", str(sr),
"-f", "s16le", "-acodec", "pcm_s16le", "pipe:1",
]
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, bufsize=1024*1024)
# --------------- whiteness metrics ---------------
def spectral_flatness_vector(psd: np.ndarray, eps: float = 1e-20) -> float:
psd_clip = np.maximum(psd, eps)
gmean = np.exp(np.mean(np.log(psd_clip)))
amean = float(np.mean(psd_clip))
return float(gmean / amean) if amean > 0 else 0.0
def find_passband_from_psd(Pxx: np.ndarray, freqs: np.ndarray) -> Tuple[float, float, np.ndarray]:
fmin_guard = max(50.0, float(freqs[1]))
fmax_guard = 0.95 * float(freqs.max())
power = Pxx.copy(); power[~np.isfinite(power)] = 0.0
c = np.cumsum(power)
if c[-1] <= 0:
mask = (freqs >= fmin_guard) & (freqs <= fmax_guard)
return fmin_guard, fmax_guard, mask
c_norm = c / c[-1]
low_idx = int(np.searchsorted(c_norm, 0.05))
high_idx = int(np.searchsorted(c_norm, 0.95))
f_lo = max(fmin_guard, float(freqs[max(1, low_idx)]))
f_hi = min(fmax_guard, float(freqs[min(len(freqs)-1, high_idx)]))
if f_hi <= f_lo:
f_lo, f_hi = fmin_guard, fmax_guard
mask = (freqs >= f_lo) & (freqs <= f_hi)
return f_lo, f_hi, mask
def band_evenness(freqs: np.ndarray, Pxx: np.ndarray, f_lo: float, f_hi: float, k: int = 12) -> float:
f_edges = np.geomspace(max(1.0, f_lo), max(f_lo * 1.0001, f_hi), k + 1)
vals = []
for i in range(k):
m = (freqs >= f_edges[i]) & (freqs < f_edges[i+1])
if not np.any(m):
continue
vals.append(float(np.mean(Pxx[m])))
if len(vals) < 3:
return 0.0
arr = np.array(vals)
arr[arr <= 0] = np.min(arr[arr > 0]) if np.any(arr > 0) else 1e-20
arr /= np.mean(arr)
std = float(np.std(arr))
return float(1.0 / (1.0 + 3.0 * std))
def slope_in_band(freqs: np.ndarray, Pxx: np.ndarray, mask: np.ndarray) -> float:
f = freqs[mask]
p = Pxx[mask]
m = (f > 0) & np.isfinite(p) & (p > 0)
if np.count_nonzero(m) < 16:
return 0.0
logf = np.log10(f[m])
logp = np.log10(p[m])
slope, _ = np.polyfit(logf, logp, 1)
return float(slope)
def whiteness_probability(sf: float, slope: float, even: float) -> float:
slope_score = math.exp(-abs(slope) / 0.25)
sf_score = max(0.0, min(1.0, sf))
even_score = max(0.0, min(1.0, even))
return float(0.4 * sf_score + 0.3 * slope_score + 0.3 * even_score)
# --------------- FFT helpers ---------------
def rfft_psd_window(x: np.ndarray, win: np.ndarray, sr: int) -> Tuple[np.ndarray, np.ndarray]:
X = np.fft.rfft(x * win)
Pxx = (np.abs(X) ** 2) / (np.sum(win**2) * sr)
freqs = np.fft.rfftfreq(len(x), d=1.0 / sr)
return freqs.astype(np.float32), Pxx.astype(np.float32)
# --------------- smoothing helpers ---------------
def remove_short_runs(flags: np.ndarray, value: int, min_len: int) -> None:
n = flags.size
i = 0
while i < n:
if flags[i] == value:
j = i
while j < n and flags[j] == value:
j += 1
if (j - i) < min_len:
flags[i:j] = 1 - value
i = j
else:
i += 1
def flags_to_intervals(flags: Sequence[bool], hop: int, win: int, sr: int) -> Tuple[List[Tuple[float, float]], List[Tuple[float, float]]]:
whites: List[Tuple[float, float]] = []
nonwhites: List[Tuple[float, float]] = []
if not flags:
return whites, nonwhites
cur_val = flags[0]
cur_start = 0
for i in range(1, len(flags)):
if flags[i] != cur_val:
start_t = cur_start * hop / sr
end_t = (i * hop + win) / sr
if cur_val:
whites.append((start_t, end_t))
else:
nonwhites.append((start_t, end_t))
cur_val = flags[i]
cur_start = i
start_t = cur_start * hop / sr
end_t = ((len(flags) - 1) * hop + win) / sr
if cur_val:
whites.append((start_t, end_t))
else:
nonwhites.append((start_t, end_t))
return whites, nonwhites
def intervals_total(intervals: List[Tuple[float, float]]) -> float:
if not intervals:
return 0.0
ints = sorted(intervals)
total = 0.0
cur_s, cur_e = ints[0]
for s, e in ints[1:]:
if s <= cur_e:
cur_e = max(cur_e, e)
else:
total += max(0.0, cur_e - cur_s)
cur_s, cur_e = s, e
total += max(0.0, cur_e - cur_s)
return total
# --------------- streaming engine ---------------
def segment_streaming_from_pipe(inpath: Path, stream_index: int, sr: int, ch: int, threshold: float, ratio: float, win_s: float, hop_s: float, warmup_sec: float, min_white_sec: float, min_nonwhite_sec: float) -> Dict:
proc = open_ffmpeg_pcm_pipe(inpath, stream_index, sr, ch)
assert proc.stdout is not None
win = max(256, int(sr * win_s))
hop = max(1, int(sr * hop_s))
bytes_per_sample = 2
bytes_per_frame = bytes_per_sample * ch
read_size_bytes = hop * bytes_per_frame
buf = np.zeros((win, ch), dtype=np.float32)
filled = 0
hann = np.hanning(win).astype(np.float32)
# learn passband over warmup windows using up to 4 channels
warm_windows = max(8, int(round(warmup_sec / max(1e-6, hop_s))))
psd_accum = None
freqs_ref = None
f_lo = 50.0
f_hi = sr * 0.45
mask = None
flags: List[bool] = []
total_frames = 0
while True:
raw = proc.stdout.read(read_size_bytes)
if not raw:
break
n_samples = len(raw) // bytes_per_sample
n_frames = n_samples // ch
if n_frames == 0:
continue
total_frames += n_frames
arr = np.frombuffer(raw[: n_frames * bytes_per_frame], dtype=np.int16)
arr = arr.reshape((-1, ch)).astype(np.float32) / 32768.0
step = arr.shape[0]
if step >= win:
buf[...] = arr[-win:]
filled = win
else:
if filled + step <= win:
buf[filled:filled+step, :] = arr
filled += step
else:
overflow = filled + step - win
buf[:win-overflow, :] = buf[overflow:, :]
buf[win-step:, :] = arr
filled = win
if filled < win:
continue
# warmup passband
if warm_windows > 0:
Psum = None
for c in range(min(ch, 4)):
fr, P = rfft_psd_window(buf[:, c], hann, sr)
if Psum is None:
Psum = P.astype(np.float64)
else:
Psum += P
freqs_ref = fr
if psd_accum is None:
psd_accum = Psum
else:
psd_accum += Psum
warm_windows -= 1
if warm_windows == 0:
Pavg = psd_accum / max(1, int(round(warmup_sec / max(1e-6, hop_s))))
f_lo, f_hi, mask = find_passband_from_psd(Pavg, freqs_ref)
if mask is None and freqs_ref is not None:
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
# per window whiteness across channels
ch_likely = 0
for c in range(ch):
xw = buf[:, c]
if np.max(np.abs(xw)) < 1e-6:
continue
fr, P = rfft_psd_window(xw, hann, sr)
m = mask if mask is not None else (fr >= f_lo) & (fr <= f_hi)
sf = spectral_flatness_vector(P[m])
slope = slope_in_band(fr, P, m)
even = band_evenness(fr, P, f_lo, f_hi)
prob = whiteness_probability(sf, slope, even)
if prob >= threshold and abs(slope) < 0.3:
ch_likely += 1
frac = ch_likely / float(ch)
flags.append(frac >= ratio)
proc.stdout.close()
proc.wait()
# smoothing of boolean flags
f_arr = np.array(flags, dtype=np.int8)
if f_arr.size:
min_white_frames = max(1, int(math.ceil(min_white_sec / hop_s)))
min_nonwhite_frames = max(1, int(math.ceil(min_nonwhite_sec / hop_s)))
remove_short_runs(f_arr, value=0, min_len=min_nonwhite_frames)
remove_short_runs(f_arr, value=1, min_len=min_white_frames)
whites, nonwhites = flags_to_intervals(list(f_arr.astype(bool)), hop, win, sr)
total_dur = total_frames / float(sr)
white_dur = intervals_total(whites)
return {
"duration_sec": total_dur,
"white_ratio_time": white_dur / max(1e-9, total_dur),
"white_intervals": whites,
"nonwhite_intervals": nonwhites,
}
# --------------- worker and batch ---------------
def analyze_file_worker(args_tuple) -> Dict:
inpath_str, threshold, ratio, win_s, hop_s, warmup_sec, min_white_sec, min_nonwhite_sec = args_tuple
inpath = Path(inpath_str)
info = ffprobe_info(inpath)
stream_index = int(info.get("index", 0))
sr = int(info.get("sample_rate", 0)) if info.get("sample_rate") else 48000
ch = int(info.get("channels", 1))
res = segment_streaming_from_pipe(inpath, stream_index, sr, ch, threshold, ratio, win_s, hop_s, warmup_sec, min_white_sec, min_nonwhite_sec)
res.update({"file": inpath.name, "channels": ch})
return res
def find_candidate_files(root: Path, exts: List[str], recursive: bool) -> List[Path]:
files = []
patterns = [f"**/*.{e}" if recursive else f"*.{e}" for e in exts]
for pat in patterns:
files.extend(root.glob(pat))
return sorted(set(files))
def fmt_hms(t: float) -> str:
if t < 0:
t = 0
h = int(t // 3600)
m = int((t % 3600) // 60)
s = t - 60 * (t // 60)
if h:
return f"{h:d}:{m:02d}:{s:06.3f}"
return f"{m:d}:{s:06.3f}"
def main():
ap = argparse.ArgumentParser(description="Streaming parallel white noise detector with smoothing")
ap.add_argument("--workers", type=int, default=2, help="Number of parallel workers")
ap.add_argument("--exts", nargs="*", default=["trm", "m4a", "mp3"], help="File extensions to scan")
ap.add_argument("--threshold", type=float, default=0.70, help="Per window per channel whiteness probability threshold")
ap.add_argument("--ratio", type=float, default=0.80, help="Fraction of channels that must be white per window")
ap.add_argument("--white-time", type=float, default=0.95, help="Time ratio threshold to call a file all white")
ap.add_argument("--warmup-sec", type=float, default=2.0, help="Warmup duration to learn passband")
ap.add_argument("--min-nonwhite-sec", type=float, default=1.0, help="Fill non-white gaps shorter than this")
ap.add_argument("--min-white-sec", type=float, default=0.5, help="Remove white blips shorter than this")
ap.add_argument("--win-sec", type=float, default=1.0, help="Window length for timeline in seconds")
ap.add_argument("--hop-sec", type=float, default=0.5, help="Hop size for timeline in seconds")
ap.add_argument("--recursive", action="store_true", help="Recurse into subdirectories")
ap.add_argument("--max-intervals", type=int, default=6, help="Print up to this many intervals per partial file")
args = ap.parse_args()
check_tool("ffprobe"); check_tool("ffmpeg")
root = Path.cwd()
files = find_candidate_files(root, [e.lower() for e in args.exts], args.recursive)
total = len(files)
if not files:
print("No files found.")
return
print("\nScanning files (streaming with smoothing)\n")
print("file ch dur(min) white% by time result")
print("-" * 88)
planned = [
(str(p), float(args.threshold), float(args.ratio), float(args.win_sec), float(args.hop_sec), float(args.warmup_sec), float(args.min_white_sec), float(args.min_nonwhite_sec))
for p in files
]
processed = 0
white_files = 0
total_white_seconds = 0.0
total_seconds = 0.0
non_white_files: List[str] = []
partial_intervals: Dict[str, List[Tuple[float, float]]] = {}
with cf.ProcessPoolExecutor(max_workers=max(1, args.workers)) as ex:
fut_to_path = {ex.submit(analyze_file_worker, job): files[i] for i, job in enumerate(planned)}
for fut in cf.as_completed(fut_to_path):
processed += 1
p = fut_to_path[fut]
try:
res = fut.result()
except Exception as e:
print(f"{p.name[:38]:<38} -- -- -- error: {e}")
non_white_files.append(p.name)
continue
seconds = float(res["duration_sec"])
total_seconds += seconds
white_seconds = seconds * float(res["white_ratio_time"])
total_white_seconds += white_seconds
is_white = res["white_ratio_time"] >= args.white_time
if is_white:
white_files += 1
else:
non_white_files.append(res["file"])
if res.get("nonwhite_intervals"):
partial_intervals[res["file"]] = res["nonwhite_intervals"]
print(
f"{res['file'][:38]:<38} {res['channels']:>2} "
f"{seconds/60.0:8.2f} {100.0*res['white_ratio_time']:6.1f}% "
f"{'white' if is_white else 'not white'}"
)
print(f"progress: {processed}/{total} files ({100.0*processed/total:.1f}%) done")
percent_files_white = 100.0 * white_files / max(1, total)
minutes_white = total_white_seconds / 60.0
minutes_total = total_seconds / 60.0
print("\nSummary")
print(f"Files all white: {white_files}/{total} ({percent_files_white:.1f}%)")
print(f"White audio time: {minutes_white:.2f} min out of {minutes_total:.2f} min total")
if non_white_files:
print("\nNot white noise files:")
for n in non_white_files:
print(f" {n}")
if partial_intervals:
print("\nNon white intervals (coarse) for partial files")
for fname, ivals in partial_intervals.items():
print(f"- {fname}")
for i, (s, e) in enumerate(ivals[:args.max_intervals], 1):
print(f" {i:2d}. {fmt_hms(s)} to {fmt_hms(e)}")
if len(ivals) > args.max_intervals:
print(f" ... {len(ivals)-args.max_intervals} more")
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""
audio_whiteness_batch_v6c_pipe.py
Streaming low RAM white noise detector
- Same fast pipeline as v6b
- Accepts tuner style flags: flat-weight, slope-weight, even-weight, file-threshold, file-ratio
- Keeps passband guards that matched your good runs
Usage example
python3 audio_whiteness_batch_v6c_pipe.py \
--flat-weight 0.616 --slope-weight 0.358 --even-weight 0.026 \
--file-threshold 0.912 --file-ratio 0.825 \
--threshold 0.636 --ratio 0.789 --white-time 0.941 \
--min-nonwhite-sec 1.164 --min-white-sec 0.455 --workers 4
"""
from __future__ import annotations
import argparse
import concurrent.futures as cf
import json
import math
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Sequence, Tuple
import numpy as np
# --------------- tools ---------------
def check_tool(name: str) -> None:
try:
subprocess.run([name, "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except Exception:
print(f"Error: {name} not found on PATH. Please install ffmpeg suite.", file=sys.stderr)
sys.exit(1)
def ffprobe_info(input_path: Path) -> Dict:
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "a:0",
"-show_entries", "stream=index,channels,sample_rate,channel_layout",
"-of", "json", str(input_path),
]
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
raise RuntimeError(f"ffprobe failed on {input_path}")
data = json.loads(res.stdout or "{}")
streams = data.get("streams", [])
if not streams:
raise RuntimeError(f"No audio stream found in {input_path}")
return streams[0]
def open_ffmpeg_pcm_pipe(inpath: Path, stream_index: int, sr: int, ch: int) -> subprocess.Popen:
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-err_detect", "ignore_err", "-fflags", "+discardcorrupt",
"-analyzeduration", "100M", "-probesize", "100M",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-ac", str(ch), "-ar", str(sr),
"-f", "s16le", "-acodec", "pcm_s16le", "pipe:1",
]
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, bufsize=1024*1024)
# --------------- metrics ---------------
def spectral_flatness_vector(psd: np.ndarray, eps: float = 1e-20) -> float:
psd_clip = np.maximum(psd, eps)
gmean = np.exp(np.mean(np.log(psd_clip)))
amean = float(np.mean(psd_clip))
return float(gmean / amean) if amean > 0 else 0.0
def find_passband_from_psd(Pxx: np.ndarray, freqs: np.ndarray) -> Tuple[float, float, np.ndarray]:
fmin_guard = max(100.0, float(freqs[1]))
fmax_guard = 0.95 * float(freqs.max())
power = Pxx.copy(); power[~np.isfinite(power)] = 0.0
c = np.cumsum(power)
if c[-1] <= 0:
mask = (freqs >= fmin_guard) & (freqs <= fmax_guard)
return fmin_guard, fmax_guard, mask
c_norm = c / c[-1]
low_idx = int(np.searchsorted(c_norm, 0.05))
high_idx = int(np.searchsorted(c_norm, 0.95))
f_lo = max(fmin_guard, float(freqs[max(1, low_idx)]))
f_hi = min(fmax_guard, float(freqs[min(len(freqs)-1, high_idx)]))
if f_hi <= f_lo:
f_lo, f_hi = fmin_guard, fmax_guard
mask = (freqs >= f_lo) & (freqs <= f_hi)
return f_lo, f_hi, mask
def band_evenness(freqs: np.ndarray, Pxx: np.ndarray, f_lo: float, f_hi: float, k: int = 12) -> float:
f_edges = np.geomspace(max(1.0, f_lo), max(f_lo * 1.0001, f_hi), k + 1)
vals = []
for i in range(k):
m = (freqs >= f_edges[i]) & (freqs < f_edges[i+1])
if not np.any(m):
continue
vals.append(float(np.mean(Pxx[m])))
if len(vals) < 3:
return 0.0
arr = np.array(vals)
arr[arr <= 0] = np.min(arr[arr > 0]) if np.any(arr > 0) else 1e-20
arr /= np.mean(arr)
std = float(np.std(arr))
return float(1.0 / (1.0 + 3.0 * std))
def slope_in_band(freqs: np.ndarray, Pxx: np.ndarray, mask: np.ndarray) -> float:
f = freqs[mask]
p = Pxx[mask]
m = (f > 0) & np.isfinite(p) & (p > 0)
if np.count_nonzero(m) < 16:
return 0.0
logf = np.log10(f[m])
logp = np.log10(p[m])
slope, _ = np.polyfit(logf, logp, 1)
return float(slope)
# --------------- FFT helper ---------------
def rfft_psd_window(x: np.ndarray, win: np.ndarray, sr: int) -> Tuple[np.ndarray, np.ndarray]:
X = np.fft.rfft(x * win)
Pxx = (np.abs(X) ** 2) / (np.sum(win**2) * sr)
freqs = np.fft.rfftfreq(len(x), d=1.0 / sr)
return freqs.astype(np.float32), Pxx.astype(np.float32)
# --------------- smoothing helpers ---------------
def remove_short_runs(flags: np.ndarray, value: int, min_len: int) -> None:
n = flags.size
i = 0
while i < n:
if flags[i] == value:
j = i
while j < n and flags[j] == value:
j += 1
if (j - i) < min_len:
flags[i:j] = 1 - value
i = j
else:
i += 1
def flags_to_intervals(flags, hop: int, win: int, sr: int):
whites = []
nonwhites = []
if len(flags) == 0:
return whites, nonwhites
cur_val = flags[0]
cur_start = 0
for i in range(1, len(flags)):
if flags[i] != cur_val:
start_t = cur_start * hop / sr
end_t = (i * hop + win) / sr
if cur_val:
whites.append((start_t, end_t))
else:
nonwhites.append((start_t, end_t))
cur_val = flags[i]
cur_start = i
start_t = cur_start * hop / sr
end_t = ((len(flags) - 1) * hop + win) / sr
if cur_val:
whites.append((start_t, end_t))
else:
nonwhites.append((start_t, end_t))
return whites, nonwhites
def intervals_total(intervals):
if not intervals:
return 0.0
ints = sorted(intervals)
total = 0.0
cur_s, cur_e = ints[0]
for s, e in ints[1:]:
if s <= cur_e:
cur_e = max(cur_e, e)
else:
total += max(0.0, cur_e - cur_s)
cur_s, cur_e = s, e
total += max(0.0, cur_e - cur_s)
return total
# --------------- streaming engine ---------------
def segment_streaming_from_pipe(
inpath: Path,
stream_index: int,
sr: int,
ch: int,
threshold: float,
ratio: float,
win_s: float,
hop_s: float,
warmup_sec: float,
min_white_sec: float,
min_nonwhite_sec: float,
flat_w: float,
slope_w: float,
even_w: float,
) -> Dict:
proc = open_ffmpeg_pcm_pipe(inpath, stream_index, sr, ch)
assert proc.stdout is not None
win = max(256, int(sr * win_s))
hop = max(1, int(sr * hop_s))
bytes_per_sample = 2
bytes_per_frame = bytes_per_sample * ch
read_size_bytes = hop * bytes_per_frame
buf = np.zeros((win, ch), dtype=np.float32)
filled = 0
hann = np.hanning(win).astype(np.float32)
# passband guards that worked well for your data
warm_windows = max(8, int(round(warmup_sec / max(1e-6, hop_s))))
psd_accum = None
freqs_ref = None
f_lo = 350.0
f_hi = sr * 0.65
mask = None
flags: List[bool] = []
total_frames = 0
# file level running averages per channel
sum_prob = np.zeros(ch, dtype=np.float64)
count_prob = np.zeros(ch, dtype=np.int64)
while True:
raw = proc.stdout.read(read_size_bytes)
if not raw:
break
n_samples = len(raw) // bytes_per_sample
n_frames = n_samples // ch
if n_frames == 0:
continue
total_frames += n_frames
arr = np.frombuffer(raw[: n_frames * bytes_per_frame], dtype=np.int16)
arr = arr.reshape((-1, ch)).astype(np.float32) / 32768.0
step = arr.shape[0]
if step >= win:
buf[...] = arr[-win:]
filled = win
else:
if filled + step <= win:
buf[filled:filled+step, :] = arr
filled += step
else:
overflow = filled + step - win
buf[:win-overflow, :] = buf[overflow:, :]
buf[win-step:, :] = arr
filled = win
if filled < win:
continue
if warm_windows > 0:
Psum = None
for c in range(min(ch, 4)):
fr, P = rfft_psd_window(buf[:, c], hann, sr)
if Psum is None:
Psum = P.astype(np.float64)
else:
Psum += P
freqs_ref = fr
if psd_accum is None:
psd_accum = Psum
else:
psd_accum += Psum
warm_windows -= 1
if warm_windows == 0:
Pavg = psd_accum / max(1, int(round(warmup_sec / max(1e-6, hop_s))))
f_lo, f_hi, mask = find_passband_from_psd(Pavg, freqs_ref)
if mask is None and freqs_ref is not None:
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
ch_likely = 0
for c in range(ch):
xw = buf[:, c]
if np.max(np.abs(xw)) < 1e-6:
continue
fr, P = rfft_psd_window(xw, hann, sr)
m = mask if mask is not None else (fr >= f_lo) & (fr <= f_hi)
sf = spectral_flatness_vector(P[m])
slope = slope_in_band(fr, P, m)
even = band_evenness(fr, P, f_lo, f_hi)
slope_score = math.exp(-abs(slope) / 0.25)
prob = float(flat_w * max(0.0, min(1.0, sf)) + slope_w * slope_score + even_w * max(0.0, min(1.0, even)))
sum_prob[c] += prob
count_prob[c] += 1
if prob >= threshold and abs(slope) < 0.3:
ch_likely += 1
frac = ch_likely / float(ch)
flags.append(frac >= ratio)
proc.stdout.close()
proc.wait()
f_arr = np.array(flags, dtype=np.int8)
if f_arr.size:
min_white_frames = max(1, int(math.ceil(min_white_sec / hop_s)))
min_nonwhite_frames = max(1, int(math.ceil(min_nonwhite_sec / hop_s)))
remove_short_runs(f_arr, value=0, min_len=min_nonwhite_frames)
remove_short_runs(f_arr, value=1, min_len=min_white_frames)
whites, nonwhites = flags_to_intervals(list(f_arr.astype(bool)), hop, win, sr)
total_dur = total_frames / float(sr)
white_dur = intervals_total(whites)
with np.errstate(divide='ignore', invalid='ignore'):
avg_prob_per_ch = np.where(count_prob > 0, sum_prob / np.maximum(1, count_prob), 0.0)
return {
"duration_sec": total_dur,
"white_ratio_time": white_dur / max(1e-9, total_dur),
"white_intervals": whites,
"nonwhite_intervals": nonwhites,
"avg_prob_per_ch": avg_prob_per_ch.tolist(),
"channels": int(ch),
}
# --------------- worker and batch ---------------
def analyze_file_worker(args_tuple) -> Dict:
(
inpath_str,
threshold,
ratio,
win_sec,
hop_sec,
warmup_sec,
min_white_sec,
min_nonwhite_sec,
flat_w,
slope_w,
even_w,
file_threshold,
file_ratio,
white_time_cut,
) = args_tuple
inpath = Path(inpath_str)
info = ffprobe_info(inpath)
stream_index = int(info.get("index", 0))
sr = int(info.get("sample_rate", 0)) if info.get("sample_rate") else 48000
ch = int(info.get("channels", 1))
res = segment_streaming_from_pipe(
inpath,
stream_index,
sr,
ch,
threshold,
ratio,
win_sec,
hop_sec,
warmup_sec,
min_white_sec,
min_nonwhite_sec,
flat_w,
slope_w,
even_w,
)
avg = np.array(res.get("avg_prob_per_ch", []), dtype=np.float64)
file_median = float(np.median(avg)) if avg.size else 0.0
file_frac_likely = float(np.mean(avg >= file_threshold)) if avg.size else 0.0
white_by_time = res["white_ratio_time"] >= white_time_cut
white_by_file = (file_median >= file_threshold) and (file_frac_likely >= file_ratio)
res.update({
"file": inpath.name,
"file_median_prob": file_median,
"file_frac_likely": file_frac_likely,
"white_final": bool(white_by_time or white_by_file),
})
if res["white_final"]:
dur = res["duration_sec"]
res["white_ratio_time"] = max(res["white_ratio_time"], white_time_cut)
res["white_intervals"] = [(0.0, dur)]
res["nonwhite_intervals"] = []
return res
def find_candidate_files(root: Path, exts: List[str], recursive: bool) -> List[Path]:
files = []
patterns = [f"**/*.{e}" if recursive else f"*.{e}" for e in exts]
for pat in patterns:
files.extend(root.glob(pat))
return sorted(set(files))
def fmt_hms(t: float) -> str:
if t < 0:
t = 0
h = int(t // 3600)
m = int((t % 3600) // 60)
s = t - 60 * (t // 60)
if h:
return f"{h:d}:{m:02d}:{s:06.3f}"
return f"{m:d}:{s:06.3f}"
def main():
ap = argparse.ArgumentParser(description="Streaming parallel white noise detector with file level override and weights")
ap.add_argument("--workers", type=int, default=2, help="Number of parallel workers")
ap.add_argument("--exts", nargs="*", default=["trm", "m4a", "mp3"], help="File extensions to scan")
# window classification
ap.add_argument("--threshold", type=float, default=0.70, help="Per window per channel whiteness threshold")
ap.add_argument("--ratio", type=float, default=0.80, help="Fraction of channels that must be white per window")
ap.add_argument("--win-sec", type=float, default=1.0, help="Window length in seconds")
ap.add_argument("--hop-sec", type=float, default=0.5, help="Hop in seconds")
# file level override
ap.add_argument("--file-threshold", type=float, default=0.88, help="Median of channel avg prob across time to call white")
ap.add_argument("--file-ratio", type=float, default=0.80, help="Fraction of channels whose avg prob >= file-threshold")
# smoothing and passband
ap.add_argument("--white-time", type=float, default=0.95, help="Time ratio to call file white when no override")
ap.add_argument("--min-nonwhite-sec", type=float, default=1.0, help="Fill non white gaps shorter than this")
ap.add_argument("--min-white-sec", type=float, default=0.5, help="Remove white blips shorter than this")
ap.add_argument("--warmup-sec", type=float, default=2.0, help="Warmup duration to learn passband")
# weights
ap.add_argument("--flat-weight", type=float, default=0.65, help="Weight of spectral flatness in probability")
ap.add_argument("--slope-weight", type=float, default=0.35, help="Weight of slope score in probability")
ap.add_argument("--even-weight", type=float, default=0.00, help="Weight of evenness in probability")
ap.add_argument("--recursive", action="store_true", help="Recurse into subdirectories")
ap.add_argument("--max-intervals", type=int, default=6, help="Print up to this many intervals per partial file")
ap.add_argument("--print-file-scores", action="store_true", help="Print file level median prob and fraction")
args = ap.parse_args()
check_tool("ffprobe"); check_tool("ffmpeg")
root = Path.cwd()
files = find_candidate_files(root, [e.lower() for e in args.exts], args.recursive)
total = len(files)
if not files:
print("No files found.")
return
print("\nScanning files (streaming, file override)\n")
header = "file ch dur(min) white% file_med file_frac result" if args.print_file_scores else "file ch dur(min) white% by time result"
print(header)
print("-" * max(88, len(header)))
planned = [
(
str(p),
float(args.threshold),
float(args.ratio),
float(args.win_sec),
float(args.hop_sec),
float(args.warmup_sec),
float(args.min_white_sec),
float(args.min_nonwhite_sec),
float(args.flat_weight),
float(args.slope_weight),
float(args.even_weight),
float(args.file_threshold),
float(args.file_ratio),
float(args.white_time),
)
for p in files
]
processed = 0
white_files = 0
total_white_seconds = 0.0
total_seconds = 0.0
non_white_files: List[str] = []
partial_intervals: Dict[str, List[Tuple[float, float]]] = {}
with cf.ProcessPoolExecutor(max_workers=max(1, args.workers)) as ex:
fut_to_path = {ex.submit(analyze_file_worker, job): files[i] for i, job in enumerate(planned)}
for fut in cf.as_completed(fut_to_path):
processed += 1
p = fut_to_path[fut]
try:
res = fut.result()
except Exception as e:
print(f"{p.name[:38]:<38} -- -- -- error: {e}")
non_white_files.append(p.name)
continue
seconds = float(res["duration_sec"])
total_seconds += seconds
white_seconds = seconds * float(res["white_ratio_time"])
total_white_seconds += white_seconds
is_white = bool(res.get("white_final", False))
if is_white:
white_files += 1
else:
non_white_files.append(res["file"])
if res.get("nonwhite_intervals"):
partial_intervals[res["file"]] = res["nonwhite_intervals"]
if args.print_file_scores:
print(
f"{res['file'][:38]:<38} {res['channels']:>2} "
f"{seconds/60.0:8.2f} {100.0*res['white_ratio_time']:6.1f}% "
f"{res['file_median_prob']:6.3f} {res['file_frac_likely']:6.2f} "
f"{'white' if is_white else 'not white'}"
)
else:
print(
f"{res['file'][:38]:<38} {res['channels']:>2} "
f"{seconds/60.0:8.2f} {100.0*res['white_ratio_time']:6.1f}% "
f"{'white' if is_white else 'not white'}"
)
print(f"progress: {processed}/{total} files ({100.0*processed/total:.1f}%) done")
percent_files_white = 100.0 * white_files / max(1, total)
minutes_white = total_white_seconds / 60.0
minutes_total = total_seconds / 60.0
print("\nSummary")
print(f"Files all white: {white_files}/{total} ({percent_files_white:.1f}%)")
print(f"White audio time: {minutes_white:.2f} min out of {minutes_total:.2f} min total")
if non_white_files:
print("\nNot white noise files:")
for n in non_white_files:
print(f" {n}")
if partial_intervals:
print("\nNon white intervals (coarse) for partial files")
for fname, ivals in partial_intervals.items():
print(f"- {fname}")
for i, (s, e) in enumerate(ivals[:args.max_intervals], 1):
print(f" {i:2d}. {fmt_hms(s)} to {fmt_hms(e)}")
if len(ivals) > args.max_intervals:
print(f" ... {len(ivals)-args.max_intervals} more")
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""
audio_whiteness_tui_tuner.py
Interactive terminal UI for labeling files and auto-tuning parameters for the
streaming white noise detector. Designed to match your v2 judgments while
keeping the v6b speed and tiny RAM.
What it does
- Lets you browse files in the current directory and assign labels:
[1] not white, [2] partial white, [3] all white
- Extracts per-window features once per labeled file and caches them to .npz
so tuning does not re-decode audio
- Runs a fast tuner that searches parameter space in smart increments and
reports a recommended set with validation metrics
Requirements
- Python 3.9+
- ffmpeg + ffprobe on PATH
- numpy
- Standard curses module (ships with Python)
Quick start
1) Label a handful of files that you are confident about
$ python3 audio_whiteness_tui_tuner.py
- Up/Down to move
- 1 = not white, 2 = partial, 3 = all white
- e = extract features for current file
- E = extract features for all labeled files
- t = run tuner on labeled set
- s = save labels to disk
- q = quit
2) Copy the printed recommended flags into your batch script
Notes
- Features compute with win-sec = 1.0 and hop-sec = 0.5 by default
- Passband guards match your good runs: f_lo about 350 to 450 Hz at 24 kHz,
f_hi about 0.65 of Nyquist which is near 7.8 kHz at 24 kHz
- Tuner first does a coarse random search, then a local refinement around the
best candidates. This is gradient free and robust.
"""
from __future__ import annotations
import curses
import json
import math
import os
import random
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
# -------------------------- config --------------------------
CACHE_DIR = Path(".aw_cache")
CACHE_DIR.mkdir(exist_ok=True)
DEFAULT_EXTS = ["trm", "m4a", "mp3"]
WIN_SEC = 1.0
HOP_SEC = 0.5
# Analyzer defaults that we will tune
DEFAULT_PARAMS = {
"threshold": 0.70,
"ratio": 0.80,
"white_time": 0.95,
"min_nonwhite_sec": 1.0,
"min_white_sec": 0.5,
"file_threshold": 0.88,
"file_ratio": 0.80,
"flat_weight": 0.65,
"slope_weight": 0.35,
"even_weight": 0.00, # leave 0 by default to ignore notches
}
# -------------------------- helpers --------------------------
def run_cmd(cmd: Sequence[str]) -> Tuple[int, str, str]:
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
out, err = p.communicate()
return p.returncode, out, err
def check_tools() -> Optional[str]:
for tool in ["ffmpeg", "ffprobe"]:
code, _, _ = run_cmd([tool, "-version"])
if code != 0:
return f"{tool} not found on PATH"
return None
def ffprobe_info(path: Path) -> Tuple[int, int, int]:
code, out, _ = run_cmd([
"ffprobe", "-v", "error",
"-select_streams", "a:0",
"-show_entries", "stream=index,channels,sample_rate",
"-of", "json", str(path),
])
if code != 0:
raise RuntimeError(f"ffprobe failed for {path.name}")
data = json.loads(out or "{}")
st = data.get("streams", [{}])[0]
idx = int(st.get("index", 0))
ch = int(st.get("channels", 1))
sr = int(st.get("sample_rate", 48000))
return idx, ch, sr
def open_ffmpeg_pcm_pipe(inpath: Path, stream_index: int, sr: int, ch: int) -> subprocess.Popen:
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-err_detect", "ignore_err", "-fflags", "+discardcorrupt",
"-analyzeduration", "100M", "-probesize", "100M",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-ac", str(ch), "-ar", str(sr),
"-f", "s16le", "-acodec", "pcm_s16le", "pipe:1",
]
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, bufsize=1024*1024)
# -------------------------- DSP --------------------------
def rfft_psd_window(x: np.ndarray, win: np.ndarray, sr: int) -> Tuple[np.ndarray, np.ndarray]:
X = np.fft.rfft(x * win)
Pxx = (np.abs(X) ** 2) / (np.sum(win**2) * sr)
freqs = np.fft.rfftfreq(len(x), d=1.0 / sr)
return freqs.astype(np.float32), Pxx.astype(np.float32)
def find_passband_from_psd(Pxx: np.ndarray, freqs: np.ndarray) -> Tuple[float, float, np.ndarray]:
fmin_guard = max(100.0, float(freqs[1]))
fmax_guard = 0.95 * float(freqs.max())
power = Pxx.copy(); power[~np.isfinite(power)] = 0.0
c = np.cumsum(power)
if c[-1] <= 0:
mask = (freqs >= fmin_guard) & (freqs <= fmax_guard)
return fmin_guard, fmax_guard, mask
c_norm = c / c[-1]
low_idx = int(np.searchsorted(c_norm, 0.05))
high_idx = int(np.searchsorted(c_norm, 0.95))
f_lo = max(fmin_guard, float(freqs[max(1, low_idx)]))
f_hi = min(fmax_guard, float(freqs[min(len(freqs)-1, high_idx)]))
if f_hi <= f_lo:
f_lo, f_hi = fmin_guard, fmax_guard
mask = (freqs >= f_lo) & (freqs <= f_hi)
return f_lo, f_hi, mask
# probability components
def spectral_flatness_vector(psd: np.ndarray, eps: float = 1e-20) -> float:
psd_clip = np.maximum(psd, eps)
gmean = np.exp(np.mean(np.log(psd_clip)))
amean = float(np.mean(psd_clip))
return float(gmean / amean) if amean > 0 else 0.0
def slope_in_band(freqs: np.ndarray, Pxx: np.ndarray, mask: np.ndarray) -> float:
f = freqs[mask]
p = Pxx[mask]
m = (f > 0) & np.isfinite(p) & (p > 0)
if np.count_nonzero(m) < 16:
return 0.0
logf = np.log10(f[m])
logp = np.log10(p[m])
slope, _ = np.polyfit(logf, logp, 1)
return float(slope)
def evenness(freqs: np.ndarray, Pxx: np.ndarray, f_lo: float, f_hi: float, k: int = 12) -> float:
f_edges = np.geomspace(max(1.0, f_lo), max(f_lo * 1.0001, f_hi), k + 1)
vals = []
for i in range(k):
m = (freqs >= f_edges[i]) & (freqs < f_edges[i+1])
if not np.any(m):
continue
vals.append(float(np.mean(Pxx[m])))
if len(vals) < 3:
return 0.0
arr = np.array(vals)
arr[arr <= 0] = np.min(arr[arr > 0]) if np.any(arr > 0) else 1e-20
arr /= np.mean(arr)
std = float(np.std(arr))
return float(1.0 / (1.0 + 3.0 * std))
# -------------------------- feature extraction --------------------------
def extract_features(path: Path) -> Path:
idx, ch, sr = ffprobe_info(path)
win = max(256, int(sr * WIN_SEC))
hop = max(1, int(sr * HOP_SEC))
hann = np.hanning(win).astype(np.float32)
proc = open_ffmpeg_pcm_pipe(path, idx, sr, ch)
assert proc.stdout is not None
bytes_per_sample = 2
bytes_per_frame = ch * bytes_per_sample
read_bytes = hop * bytes_per_frame
buf = np.zeros((win, ch), dtype=np.float32)
filled = 0
warm_windows = max(8, int(round(2.0 / max(1e-6, HOP_SEC))))
psd_accum = None
freqs_ref = None
f_lo = 350.0
f_hi = sr * 0.65
mask = None
# store per-window features
sf_list = [] # shape T x C
slope_list = []
even_list = []
total_frames = 0
while True:
raw = proc.stdout.read(read_bytes)
if not raw:
break
n_samples = len(raw) // bytes_per_sample
n_frames = n_samples // ch
if n_frames == 0:
continue
total_frames += n_frames
arr = np.frombuffer(raw[: n_frames * bytes_per_frame], dtype=np.int16)
arr = arr.reshape((-1, ch)).astype(np.float32) / 32768.0
step = arr.shape[0]
if step >= win:
buf[...] = arr[-win:]
filled = win
else:
if filled + step <= win:
buf[filled:filled+step, :] = arr
filled += step
else:
overflow = filled + step - win
buf[:win-overflow, :] = buf[overflow:, :]
buf[win-step:, :] = arr
filled = win
if filled < win:
continue
if warm_windows > 0:
Psum = None
for c in range(min(ch, 4)):
fr, P = rfft_psd_window(buf[:, c], hann, sr)
if Psum is None:
Psum = P.astype(np.float64)
else:
Psum += P
freqs_ref = fr
if psd_accum is None:
psd_accum = Psum
else:
psd_accum += Psum
warm_windows -= 1
if warm_windows == 0:
Pavg = psd_accum / max(1, int(round(2.0 / max(1e-6, HOP_SEC))))
f_lo, f_hi, mask = find_passband_from_psd(Pavg, freqs_ref)
if mask is None and freqs_ref is not None:
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
sf_row = np.zeros(ch, dtype=np.float32)
slope_row = np.zeros(ch, dtype=np.float32)
even_row = np.zeros(ch, dtype=np.float32)
for c in range(ch):
xw = buf[:, c]
if np.max(np.abs(xw)) < 1e-6:
continue
fr, P = rfft_psd_window(xw, hann, sr)
m = mask if mask is not None else (fr >= f_lo) & (fr <= f_hi)
sf_row[c] = spectral_flatness_vector(P[m])
slope_row[c] = slope_in_band(fr, P, m)
even_row[c] = evenness(fr, P, f_lo, f_hi)
sf_list.append(sf_row)
slope_list.append(slope_row)
even_list.append(even_row)
proc.stdout.close(); proc.wait()
T = len(sf_list)
if T == 0:
raise RuntimeError(f"no audio windows found in {path.name}")
sf = np.stack(sf_list, axis=0) # T x C
slope = np.stack(slope_list, axis=0)
even = np.stack(even_list, axis=0)
out_path = CACHE_DIR / f"{path.name}.npz"
np.savez_compressed(
out_path,
sf=sf.astype(np.float32),
slope=slope.astype(np.float32),
even=even.astype(np.float32),
sr=np.int32(sr),
ch=np.int16(ch),
win=np.int32(win),
hop=np.int32(hop),
f_lo=np.float32(f_lo),
f_hi=np.float32(f_hi),
)
return out_path
# -------------------------- inference on cached features --------------------------
def remove_short_runs(flags: np.ndarray, value: int, min_len: int) -> None:
n = flags.size
i = 0
while i < n:
if flags[i] == value:
j = i
while j < n and flags[j] == value:
j += 1
if (j - i) < min_len:
flags[i:j] = 1 - value
i = j
else:
i += 1
def intervals_total(intervals: List[Tuple[float, float]]) -> float:
if not intervals:
return 0.0
ints = sorted(intervals)
total = 0.0
cur_s, cur_e = ints[0]
for s, e in ints[1:]:
if s <= cur_e:
cur_e = max(cur_e, e)
else:
total += max(0.0, cur_e - cur_s)
cur_s, cur_e = s, e
total += max(0.0, cur_e - cur_s)
return total
def flags_to_intervals(flags: Sequence[bool], hop: int, win: int, sr: int) -> Tuple[List[Tuple[float, float]], List[Tuple[float, float]]]:
whites: List[Tuple[float, float]] = []
nonwhites: List[Tuple[float, float]] = []
if not flags:
return whites, nonwhites
cur_val = flags[0]
cur_start = 0
for i in range(1, len(flags)):
if flags[i] != cur_val:
start_t = cur_start * hop / sr
end_t = (i * hop + win) / sr
if cur_val:
whites.append((start_t, end_t))
else:
nonwhites.append((start_t, end_t))
cur_val = flags[i]
cur_start = i
start_t = cur_start * hop / sr
end_t = ((len(flags) - 1) * hop + win) / sr
if cur_val:
whites.append((start_t, end_t))
else:
nonwhites.append((start_t, end_t))
return whites, nonwhites
def eval_on_cache(npz_path: Path, params: Dict[str, float]) -> Dict[str, float]:
d = np.load(npz_path)
sf = d['sf'] # T x C
slope = d['slope']
even = d['even']
sr = int(d['sr'])
ch = int(d['ch'])
win = int(d['win'])
hop = int(d['hop'])
# time base
T = sf.shape[0]
# components
slope_score = np.exp(-np.abs(slope) / 0.25)
prob = params['flat_weight'] * np.clip(sf, 0, 1) + \
params['slope_weight'] * np.clip(slope_score, 0, 1) + \
params['even_weight'] * np.clip(even, 0, 1)
# per window decision across channels
ch_likely = (prob >= params['threshold']).astype(np.int8)
frac = ch_likely.mean(axis=1)
flags = (frac >= params['ratio']).astype(np.int8)
# smoothing
min_white_frames = max(1, int(math.ceil(params['min_white_sec'] / (hop / sr))))
min_nonwhite_frames = max(1, int(math.ceil(params['min_nonwhite_sec'] / (hop / sr))))
f_arr = flags.copy()
remove_short_runs(f_arr, value=0, min_len=min_nonwhite_frames)
remove_short_runs(f_arr, value=1, min_len=min_white_frames)
whites, nonwhites = flags_to_intervals(list(f_arr.astype(bool)), hop, win, sr)
total_dur = (T * hop + win) / sr
white_dur = intervals_total(whites)
white_ratio_time = white_dur / max(1e-9, total_dur)
# file level override using average channel probability over time
avg_prob_per_ch = prob.mean(axis=0)
file_median = float(np.median(avg_prob_per_ch))
file_frac = float(np.mean(avg_prob_per_ch >= params['file_threshold']))
white_by_time = white_ratio_time >= params['white_time']
white_by_file = (file_median >= params['file_threshold']) and (file_frac >= params['file_ratio'])
white_final = white_by_time or white_by_file
return {
'white_ratio_time': white_ratio_time,
'file_median_prob': file_median,
'file_frac_likely': file_frac,
'white_final': bool(white_final),
'duration_sec': total_dur,
}
# -------------------------- tuning --------------------------
LABELS = {
1: 'not_white',
2: 'partial',
3: 'all_white',
}
@dataclass
class LabeledItem:
path: Path
label: int # 1, 2, 3
def score_prediction(pred: Dict[str, float], label: int) -> float:
# lower is better
y_hat = pred['white_final']
wr = pred['white_ratio_time']
loss = 0.0
if label == 3: # all white
loss += 0.0 if y_hat else 1.0
loss += (1.0 - wr)
elif label == 1: # not white
loss += 1.0 if y_hat else 0.0
loss += wr
else: # partial
# prefer mid white ratios and not white_final
loss += 0.5 * (wr - 0.5) ** 2 * 4.0
loss += 0.5 if y_hat else 0.0
return loss
PARAM_SPACE = {
'threshold': (0.60, 0.80),
'ratio': (0.70, 0.90),
'white_time': (0.92, 0.98),
'min_nonwhite_sec': (0.5, 2.0),
'min_white_sec': (0.2, 1.0),
'file_threshold': (0.80, 0.92),
'file_ratio': (0.70, 0.90),
'flat_weight': (0.55, 0.75),
'slope_weight': (0.25, 0.45),
'even_weight': (0.00, 0.10),
}
def random_params(base: Dict[str, float], scale: float = 1.0) -> Dict[str, float]:
p = dict(base)
for k, (lo, hi) in PARAM_SPACE.items():
span = (hi - lo)
center = p.get(k, (lo + hi) / 2)
jitter = (random.random() - 0.5) * span * 0.5 * scale
val = max(lo, min(hi, center + jitter))
p[k] = float(val)
# re-normalize weights
s = p['flat_weight'] + p['slope_weight'] + p['even_weight']
if s <= 0:
p['flat_weight'], p['slope_weight'], p['even_weight'] = 0.65, 0.35, 0.0
else:
p['flat_weight'] /= s
p['slope_weight'] /= s
p['even_weight'] /= s
return p
def tune_parameters(items: List[LabeledItem], caches: Dict[Path, Path], base: Dict[str, float]) -> Tuple[Dict[str, float], float]:
# coarse random search
best = None
best_loss = 1e9
trials = []
for _ in range(60):
params = random_params(base, scale=1.0)
loss = 0.0
for it in items:
pred = eval_on_cache(caches[it.path], params)
loss += score_prediction(pred, it.label)
avg_loss = loss / max(1, len(items))
trials.append((avg_loss, params))
if avg_loss < best_loss:
best_loss, best = avg_loss, params
# local refinement around top 5
top = sorted(trials, key=lambda x: x[0])[:5]
for _, seed in top:
for _ in range(40):
params = random_params(seed, scale=0.35)
loss = 0.0
for it in items:
pred = eval_on_cache(caches[it.path], params)
loss += score_prediction(pred, it.label)
avg_loss = loss / max(1, len(items))
if avg_loss < best_loss:
best_loss, best = avg_loss, params
return best, best_loss
# -------------------------- TUI --------------------------
HELP = "Up/Down move 1=not 2=partial 3=all e=extract E=extract all t=tune r=run eval s=save q=quit"
LABEL_STR = {0: "", 1: "not", 2: "part", 3: "white"}
def draw_screen(stdscr, files: List[Path], labels: Dict[str, int], pos: int, msg: str) -> None:
stdscr.clear()
h, w = stdscr.getmaxyx()
stdscr.addstr(0, 0, "audio_whiteness_tui_tuner")
stdscr.addstr(1, 0, HELP[:w-1])
stdscr.addstr(2, 0, "".ljust(w-1, '-'))
start = max(0, min(pos - (h - 6) // 2, max(0, len(files) - (h - 6))))
end = min(len(files), start + (h - 6))
stdscr.addstr(3, 0, f"{'file':38} label cached")
stdscr.addstr(4, 0, "".ljust(w-1, '-'))
row = 5
for i in range(start, end):
f = files[i]
lab = LABEL_STR.get(labels.get(f.name, 0), "")
cached = "yes" if (CACHE_DIR / f"{f.name}.npz").exists() else "no"
line = f"{f.name[:38]:<38} {lab:>5} {cached:>6}"
if i == pos:
stdscr.attron(curses.A_REVERSE)
stdscr.addstr(row, 0, line[:w-1])
stdscr.attroff(curses.A_REVERSE)
else:
stdscr.addstr(row, 0, line[:w-1])
row += 1
stdscr.addstr(h-2, 0, "".ljust(w-1, '-'))
stdscr.addstr(h-1, 0, msg[:w-1])
stdscr.refresh()
def tui_main(stdscr):
curses.curs_set(0)
err = check_tools()
if err:
stdscr.addstr(0, 0, err)
stdscr.getch()
return
files: List[Path] = []
for ext in DEFAULT_EXTS:
files.extend(Path.cwd().glob(f"*.{ext}"))
files = sorted(files)
labels_path = Path("aw_labels.json")
labels: Dict[str, int] = {}
if labels_path.exists():
try:
labels.update(json.loads(labels_path.read_text()))
except Exception:
pass
pos = 0
msg = "Select files and label them. Press E to extract features for labeled files."
draw_screen(stdscr, files, labels, pos, msg)
while True:
ch = stdscr.getch()
if ch in (ord('q'), 27):
break
elif ch in (curses.KEY_DOWN, ord('j')):
pos = min(pos + 1, len(files) - 1)
elif ch in (curses.KEY_UP, ord('k')):
pos = max(pos - 1, 0)
elif ch in (ord('1'), ord('2'), ord('3')):
labels[files[pos].name] = int(chr(ch))
elif ch in (ord('s'),):
labels_path.write_text(json.dumps(labels, indent=2))
msg = f"saved labels to {labels_path.name}"
elif ch in (ord('e'),):
f = files[pos]
try:
path = extract_features(f)
msg = f"cached features: {path.name}"
except Exception as e:
msg = f"extract failed: {e}"
elif ch in (ord('E'),):
done = 0
for f in files:
if labels.get(f.name, 0) == 0:
continue
try:
extract_features(f)
done += 1
msg = f"cached {done} files"
draw_screen(stdscr, files, labels, pos, msg)
except Exception as e:
msg = f"extract failed for {f.name}: {e}"
draw_screen(stdscr, files, labels, pos, msg)
msg = f"extraction complete for {done} files"
elif ch in (ord('r'),):
# run eval for current params
labeled_items = [LabeledItem(Path(name), lab) for name, lab in labels.items() if lab in (1,2,3)]
if not labeled_items:
msg = "no labeled files"
else:
# ensure caches
caches = {}
missing = 0
for it in labeled_items:
npz = CACHE_DIR / f"{it.path.name}.npz"
if not npz.exists():
missing += 1
else:
caches[it.path] = npz
if missing:
msg = f"missing {missing} caches - run E"
else:
# evaluate
total = 0.0
acc = 0
for it in labeled_items:
pred = eval_on_cache(caches[it.path], DEFAULT_PARAMS)
loss = score_prediction(pred, it.label)
total += loss
y_hat = pred['white_final']
y = (it.label == 3)
acc += 1 if (y_hat == y) else 0
msg = f"eval - avg_loss {total/len(labeled_items):.3f} acc {acc}/{len(labeled_items)}"
elif ch in (ord('t'),):
# tune using caches
labeled_items = [LabeledItem(Path(name), lab) for name, lab in labels.items() if lab in (1,2,3)]
if not labeled_items:
msg = "no labeled files"
else:
caches = {}
for it in labeled_items:
npz = CACHE_DIR / f"{it.path.name}.npz"
if not npz.exists():
msg = f"missing cache for {it.path.name} - run E"
draw_screen(stdscr, files, labels, pos, msg)
break
caches[it.path] = npz
else:
best, best_loss = tune_parameters(labeled_items, caches, DEFAULT_PARAMS)
msg = "tuner done - press q to quit and copy params from stdout"
# print to stdout outside curses
curses.endwin()
print("\nRecommended parameters\n")
for k in sorted(best.keys()):
print(f"--{k.replace('_','-')} {best[k]:.3f}")
print(f"\navg_loss {best_loss:.4f} on {len(labeled_items)} files")
print("\nPaste these flags into audio_whiteness_batch_v6b_pipe.py")
return
draw_screen(stdscr, files, labels, pos, msg)
def main():
if sys.platform == 'win32':
print('curses UI is not supported on Windows')
sys.exit(1)
curses.wrapper(tui_main)
if __name__ == '__main__':
main()
#!/usr/bin/env python3
"""
white2whisper_chain_v2.py
Two pass pipeline
1) Fast parallel scan to find non white intervals per file using your v6c logic
2) Build long "sessions" by joining nearby intervals within and across files
3) Process sessions sequentially to keep disk and CPU low
- Make one mono WAV per channel for the whole session
- Run WhisperX once per channel per session
- Merge channels with a simple CRDT style rule
- Optional diarization on a mono mix of the session
- Save JSON and SRT named with session start and end time
Improvements vs v1
- Joins intervals across files using timestamps in names like *_YYYYMMDD-HHMM_*
- Joins within file if gaps are small to avoid tiny clips
- Ensures sessions are at least a target length when possible
- Prints live progress with channel and a short preview of first words
- Whisper phase runs exactly one channel at a time. Scanning can still use N workers
Requirements: ffmpeg, ffprobe, numpy, whisperx
"""
from __future__ import annotations
import argparse
import concurrent.futures as cf
import datetime as dt
import json
import math
import os
import re
import shutil
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
# ---------------- ffmpeg helpers ----------------
def _run_ok(cmd: Sequence[str]) -> None:
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if p.returncode != 0:
raise RuntimeError("command failed: " + " ".join(cmd) + "\n" + (p.stderr or ""))
def _check_tool(name: str) -> None:
try:
subprocess.run([name, "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except Exception:
print(f"Error: {name} not found on PATH", file=sys.stderr)
sys.exit(1)
def ffprobe_info(input_path: Path) -> Dict:
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "a:0",
"-show_entries", "stream=index,channels,sample_rate,duration",
"-of", "json", str(input_path),
]
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
raise RuntimeError(f"ffprobe failed on {input_path}")
data = json.loads(res.stdout or "{}")
streams = data.get("streams", [])
if not streams:
raise RuntimeError(f"No audio stream found in {input_path}")
return streams[0]
# ---------------- DSP bits from v6c ----------------
def spectral_flatness_vector(psd: np.ndarray, eps: float = 1e-20) -> float:
psd_clip = np.maximum(psd, eps)
gmean = np.exp(np.mean(np.log(psd_clip)))
amean = float(np.mean(psd_clip))
return float(gmean / amean) if amean > 0 else 0.0
def rfft_psd_window(x: np.ndarray, win: np.ndarray, sr: int) -> Tuple[np.ndarray, np.ndarray]:
X = np.fft.rfft(x * win)
Pxx = (np.abs(X) ** 2) / (np.sum(win**2) * sr)
freqs = np.fft.rfftfreq(len(x), d=1.0 / sr)
return freqs.astype(np.float32), Pxx.astype(np.float32)
def band_evenness(freqs: np.ndarray, Pxx: np.ndarray, f_lo: float, f_hi: float, k: int = 12) -> float:
f_edges = np.geomspace(max(1.0, f_lo), max(f_lo * 1.0001, f_hi), k + 1)
vals = []
for i in range(k):
m = (freqs >= f_edges[i]) & (freqs < f_edges[i+1])
if not np.any(m):
continue
vals.append(float(np.mean(Pxx[m])))
if len(vals) < 3:
return 0.0
arr = np.array(vals)
arr[arr <= 0] = np.min(arr[arr > 0]) if np.any(arr > 0) else 1e-20
arr /= np.mean(arr)
std = float(np.std(arr))
return float(1.0 / (1.0 + 3.0 * std))
def slope_in_band(freqs: np.ndarray, Pxx: np.ndarray, mask: np.ndarray) -> float:
f = freqs[mask]
p = Pxx[mask]
m = (f > 0) & np.isfinite(p) & (p > 0)
if np.count_nonzero(m) < 16:
return 0.0
logf = np.log10(f[m])
logp = np.log10(p[m])
slope, _ = np.polyfit(logf, logp, 1)
return float(slope)
def open_ffmpeg_pcm_pipe(inpath: Path, stream_index: int, sr: int, ch: int) -> subprocess.Popen:
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-err_detect", "ignore_err", "-fflags", "+discardcorrupt",
"-analyzeduration", "100M", "-probesize", "100M",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-ac", str(ch), "-ar", str(sr),
"-f", "s16le", "-acodec", "pcm_s16le", "pipe:1",
]
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, bufsize=1024*1024)
def remove_short_runs(flags: np.ndarray, value: int, min_len: int) -> None:
n = flags.size
i = 0
while i < n:
if flags[i] == value:
j = i
while j < n and flags[j] == value:
j += 1
if (j - i) < min_len:
flags[i:j] = 1 - value
i = j
else:
i += 1
def flags_to_intervals(flags: Sequence[bool], hop: int, win: int, sr: int) -> Tuple[List[Tuple[float, float]], List[Tuple[float, float]]]:
whites: List[Tuple[float, float]] = []
nonwhites: List[Tuple[float, float]] = []
if not flags:
return whites, nonwhites
cur_val = flags[0]
cur_start = 0
for i in range(1, len(flags)):
if flags[i] != cur_val:
start_t = cur_start * hop / sr
end_t = (i * hop + win) / sr
if cur_val:
whites.append((start_t, end_t))
else:
nonwhites.append((start_t, end_t))
cur_val = flags[i]
cur_start = i
start_t = cur_start * hop / sr
end_t = ((len(flags) - 1) * hop + win) / sr
if cur_val:
whites.append((start_t, end_t))
else:
nonwhites.append((start_t, end_t))
return whites, nonwhites
def intervals_total(intervals: List[Tuple[float, float]]) -> float:
if not intervals:
return 0.0
ints = sorted(intervals)
total = 0.0
cur_s, cur_e = ints[0]
for s, e in ints[1:]:
if s <= cur_e:
cur_e = max(cur_e, e)
else:
total += max(0.0, cur_e - cur_s)
cur_s, cur_e = s, e
total += max(0.0, cur_e - cur_s)
return total
# ---------------- scanning engine ----------------
def segment_streaming_from_pipe(
inpath: Path,
stream_index: int,
sr: int,
ch: int,
threshold: float,
ratio: float,
win_s: float,
hop_s: float,
warmup_sec: float,
min_white_sec: float,
min_nonwhite_sec: float,
flat_w: float,
slope_w: float,
even_w: float,
) -> Dict:
proc = open_ffmpeg_pcm_pipe(inpath, stream_index, sr, ch)
assert proc.stdout is not None
win = max(256, int(sr * win_s))
hop = max(1, int(sr * hop_s))
bytes_per_sample = 2
bytes_per_frame = bytes_per_sample * ch
read_size_bytes = hop * bytes_per_frame
buf = np.zeros((win, ch), dtype=np.float32)
filled = 0
hann = np.hanning(win).astype(np.float32)
warm_windows = max(8, int(round(warmup_sec / max(1e-6, hop_s))))
psd_accum = None
freqs_ref = None
f_lo = 350.0
f_hi = sr * 0.65
mask = None
flags: List[bool] = []
total_frames = 0
while True:
raw = proc.stdout.read(read_size_bytes)
if not raw:
break
n_samples = len(raw) // bytes_per_sample
n_frames = n_samples // ch
if n_frames == 0:
continue
total_frames += n_frames
arr = np.frombuffer(raw[: n_frames * bytes_per_frame], dtype=np.int16)
arr = arr.reshape((-1, ch)).astype(np.float32) / 32768.0
step = arr.shape[0]
if step >= win:
buf[...] = arr[-win:]
filled = win
else:
if filled + step <= win:
buf[filled:filled+step, :] = arr
filled += step
else:
overflow = filled + step - win
buf[:win-overflow, :] = buf[overflow:, :]
buf[win-step:, :] = arr
filled = win
if filled < win:
continue
if warm_windows > 0:
Psum = None
for c in range(min(ch, 4)):
fr, P = rfft_psd_window(buf[:, c], hann, sr)
if Psum is None:
Psum = P.astype(np.float64)
else:
Psum += P
freqs_ref = fr
if psd_accum is None:
psd_accum = Psum
else:
psd_accum += Psum
warm_windows -= 1
if warm_windows == 0:
power = psd_accum / max(1, int(round(warmup_sec / max(1e-6, hop_s))))
c = np.cumsum(power)
if c[-1] > 0:
c /= c[-1]
low_idx = int(np.searchsorted(c, 0.05))
high_idx = int(np.searchsorted(c, 0.95))
f_lo = max(100.0, float(freqs_ref[max(1, low_idx)]))
f_hi = min(0.95 * float(freqs_ref.max()), float(freqs_ref[min(len(freqs_ref)-1, high_idx)]))
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
if mask is None and freqs_ref is not None:
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
ch_likely = 0
for c in range(ch):
xw = buf[:, c]
if np.max(np.abs(xw)) < 1e-6:
continue
fr, P = rfft_psd_window(xw, hann, sr)
m = mask if mask is not None else (fr >= f_lo) & (fr <= f_hi)
sf = spectral_flatness_vector(P[m])
slope = slope_in_band(fr, P, m)
even = band_evenness(fr, P, f_lo, f_hi)
slope_score = math.exp(-abs(slope) / 0.25)
prob = float(flat_w * max(0.0, min(1.0, sf)) + 0.0 * slope_score + 0.0 * even) # only used for gating counts
# only slope restriction for gating to avoid layout dependence
if prob >= threshold and abs(slope) < 0.3:
ch_likely += 1
frac = ch_likely / float(ch)
flags.append(frac >= ratio)
proc.stdout.close(); proc.wait()
f_arr = np.array(flags, dtype=np.int8)
if f_arr.size:
min_white_frames = max(1, int(math.ceil(min_white_sec / hop_s)))
min_nonwhite_frames = max(1, int(math.ceil(min_nonwhite_sec / hop_s)))
remove_short_runs(f_arr, value=0, min_len=min_nonwhite_frames)
remove_short_runs(f_arr, value=1, min_len=min_white_frames)
whites, nonwhites = flags_to_intervals(list(f_arr.astype(bool)), hop, win, sr)
total_dur = total_frames / float(sr)
white_dur = intervals_total(whites)
return {
"duration": total_dur,
"nonwhite_intervals": nonwhites,
}
# ---------------- session planning ----------------
_TS_RE = re.compile(r".*_(\d{8})-(\d{4})_", re.IGNORECASE)
def file_start_time(p: Path) -> Optional[dt.datetime]:
m = _TS_RE.match(p.name)
if not m:
return None
ymd, hm = m.group(1), m.group(2)
try:
return dt.datetime.strptime(ymd + hm, "%Y%m%d%H%M")
except Exception:
return None
@dataclass
class FileScan:
path: Path
start: Optional[dt.datetime]
duration: float
intervals: List[Tuple[float, float]] # seconds relative to file start
@dataclass
class SessionPiece:
file: Path
start: float
end: float
@dataclass
class Session:
idx: int
start_dt: Optional[dt.datetime]
end_dt: Optional[dt.datetime]
pieces: List[SessionPiece]
def merge_intervals(iv: List[Tuple[float, float]], join_gap: float) -> List[Tuple[float, float]]:
if not iv:
return []
iv = sorted(iv)
out = [list(iv[0])]
for s, e in iv[1:]:
if s <= out[-1][1] + join_gap:
out[-1][1] = max(out[-1][1], e)
else:
out.append([s, e])
return [(float(s), float(e)) for s, e in out]
def plan_sessions(scans: List[FileScan], join_gap: float, min_session: float) -> List[Session]:
# sort by start time then name
scans_sorted = sorted(scans, key=lambda x: (x.start or dt.datetime.min, x.path.name))
sessions: List[Session] = []
cur_pieces: List[SessionPiece] = []
cur_start: Optional[dt.datetime] = None
cur_end: Optional[dt.datetime] = None
def close_current(idx: int):
nonlocal cur_pieces, cur_start, cur_end
if not cur_pieces:
return None
sess = Session(idx=idx, start_dt=cur_start, end_dt=cur_end, pieces=cur_pieces)
cur_pieces = []
cur_start = None
cur_end = None
return sess
idx = 1
for fs in scans_sorted:
merged = merge_intervals(fs.intervals, join_gap)
if not merged:
continue
# build absolute times per interval
for s_rel, e_rel in merged:
s_abs = (fs.start or dt.datetime.min) + dt.timedelta(seconds=s_rel) if fs.start else None
e_abs = (fs.start or dt.datetime.min) + dt.timedelta(seconds=e_rel) if fs.start else None
if not cur_pieces:
cur_pieces = [SessionPiece(fs.path, s_rel, e_rel)]
cur_start = s_abs
cur_end = e_abs
else:
# if close to previous end in absolute time, join; else close session and start new
if (cur_end and e_abs and s_abs and (s_abs - cur_end).total_seconds() <= join_gap) or (not fs.start):
cur_pieces.append(SessionPiece(fs.path, s_rel, e_rel))
if e_abs and (not cur_end or e_abs > cur_end):
cur_end = e_abs
else:
sess = close_current(idx)
if sess:
# enforce minimum session length by merging backward if very short
total_len = sum(p.end - p.start for p in sess.pieces)
if total_len < min_session and sessions:
# merge with previous
sessions[-1].pieces.extend(sess.pieces)
sessions[-1].end_dt = sess.end_dt
else:
sessions.append(sess)
idx += 1
cur_pieces = [SessionPiece(fs.path, s_rel, e_rel)]
cur_start = s_abs
cur_end = e_abs
# finalize
last = close_current(idx)
if last:
total_len = sum(p.end - p.start for p in last.pieces)
if total_len < min_session and sessions:
sessions[-1].pieces.extend(last.pieces)
sessions[-1].end_dt = last.end_dt
else:
sessions.append(last)
return sessions
# ---------------- worker for parallel scan ----------------
def scan_one_worker(job):
"""Top level worker so it can be pickled by ProcessPool on macOS.
Args tuple:
(inpath_str, threshold, ratio, win_sec, hop_sec, warmup_sec,
min_white_sec, min_nonwhite_sec, flat_weight, slope_weight, even_weight,
join_gap_sec, overlap)
Returns a plain dict that main converts to FileScan.
"""
(inpath_str, threshold, ratio, win_sec, hop_sec, warmup_sec,
min_white_sec, min_nonwhite_sec, flat_weight, slope_weight, even_weight,
join_gap_sec, overlap) = job
p = Path(inpath_str)
info = ffprobe_info(p)
stream_index = int(info.get("index", 0)); sr = int(info.get("sample_rate", 48000)); ch = int(info.get("channels", 1))
res = segment_streaming_from_pipe(p, stream_index, sr, ch,
threshold, ratio, win_sec, hop_sec, warmup_sec,
min_white_sec, min_nonwhite_sec,
flat_weight, slope_weight, even_weight)
iv = res.get("nonwhite_intervals", [])
dur = float(res.get("duration", 0.0))
padded = []
for s, e in iv:
s2 = max(0.0, s - overlap); e2 = min(dur, e + overlap)
if e2 > s2:
padded.append((float(s2), float(e2)))
merged = merge_intervals(padded, float(join_gap_sec))
start_dt = file_start_time(p)
start_iso = start_dt.isoformat() if start_dt else ""
return {"path": str(p), "start": start_iso, "duration": dur, "intervals": merged}
# ---------------- audio extraction ----------------
def extract_session_channel_wav(session: Session, ch_idx: int, sr_out: Optional[int], out_wav: Path) -> None:
# Strategy: make uniform mono PCM parts then concat with concat demuxer
part_files: List[Path] = []
list_txt = out_wav.with_suffix('.list.txt')
out_wav.parent.mkdir(parents=True, exist_ok=True)
try:
for i, piece in enumerate(session.pieces, 1):
part = out_wav.with_suffix("").parent / f"{out_wav.stem}_part{i:03d}.wav"
part_files.append(part)
pan = f"pan=mono|c0=c{ch_idx}"
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-ss", f"{piece.start:.3f}", "-t", f"{max(0.0, piece.end-piece.start):.3f}",
"-i", str(session.pieces[i-1].file),
"-map", "0:a:0", "-vn", "-sn", "-af", pan,
"-ac", "1", "-acodec", "pcm_s16le",
]
if sr_out and sr_out > 0:
cmd += ["-ar", str(sr_out)]
cmd += [str(part)]
_run_ok(cmd)
with open(list_txt, 'w') as f:
for pf in part_files:
f.write("file '" + str(pf).replace("'", "'\\''") + "'\n")
_run_ok(["ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error", "-f", "concat", "-safe", "0", "-i", str(list_txt), "-c", "copy", str(out_wav)])
finally:
try:
os.remove(list_txt)
except Exception:
pass
for pf in part_files:
try:
os.remove(pf)
except Exception:
pass
# ---------------- whisper and merge ----------------
@dataclass
class Word:
start: float
end: float
text: str
score: float
ch: int
def run_whisperx_on_wav_modelreused(wav_path: Path, device: str, model, align_model, align_meta) -> List[Word]:
import whisperx
audio = whisperx.load_audio(str(wav_path))
res = model.transcribe(audio)
aligned = whisperx.align(res.get("segments", []), align_model, align_meta, audio, device, return_char_alignments=False)
words: List[Word] = []
for seg in aligned.get("segments", []):
for w in seg.get("words", []) or []:
s = float(w.get("start", seg.get("start", 0.0)) or 0.0)
e = float(w.get("end", seg.get("end", s)))
txt = str(w.get("word", "")).strip()
sc = float(w.get("score", 0.0))
if txt:
words.append(Word(s, e, txt, sc, ch=-1))
return words
def crdt_merge_words(words_by_channel: Dict[int, List[Word]]) -> List[Word]:
all_words: List[Word] = []
for ch, lst in words_by_channel.items():
for w in lst:
all_words.append(Word(w.start, w.end, w.text, w.score, ch))
all_words.sort(key=lambda w: (w.start, w.end))
merged: List[Word] = []
for w in all_words:
if not merged:
merged.append(w); continue
last = merged[-1]
overlap = min(last.end, w.end) - max(last.start, w.start)
if overlap > 0:
if w.score > last.score or (abs(w.score - last.score) < 1e-6 and (w.end - w.start) > (last.end - last.start)):
merged[-1] = w
else:
merged.append(w)
return merged
def assign_speakers_if_available(audio_path: Path, merged: List[Word], device: str, hf_token: Optional[str]) -> List[Dict]:
try:
import whisperx
if not hf_token:
hf_token = os.environ.get("HF_TOKEN")
diar = whisperx.diarize.DiarizationPipeline(use_auth_token=hf_token, device=device)
diar_segs = diar(str(audio_path))
out = []
for w in merged:
t = 0.5 * (w.start + w.end)
spk = None
for seg in diar_segs:
s = float(seg.get('start', 0)); e = float(seg.get('end', 0))
if s <= t <= e:
spk = seg.get('speaker'); break
out.append({"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": spk})
return out
except Exception as e:
print(" diarization skipped:", e)
return [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": None} for w in merged]
def write_srt(words: List[Dict], out_path: Path, max_chars: int = 60, max_gap: float = 0.6) -> None:
if not words:
out_path.write_text(""); return
entries = []
cur_text = []; cur_start = words[0]['start']; cur_end = words[0]['end']
for w in words:
if cur_text and (w['start'] - cur_end > max_gap or len(' '.join(cur_text) + ' ' + w['text']) > max_chars):
entries.append((cur_start, cur_end, ' '.join(cur_text)))
cur_text = [w['text']]; cur_start = w['start']; cur_end = w['end']
else:
cur_text.append(w['text']); cur_end = w['end']
if cur_text:
entries.append((cur_start, cur_end, ' '.join(cur_text)))
def fmt(t: float) -> str:
ms = int(round(t * 1000)); h = ms // 3600000; m = (ms % 3600000) // 60000; s = (ms % 60000) // 1000; ms2 = ms % 1000
return f"{h:02d}:{m:02d}:{s:02d},{ms2:03d}"
lines = []
for i, (s, e, txt) in enumerate(entries, 1):
lines.append(str(i)); lines.append(f"{fmt(s)} --> {fmt(e)}"); lines.append(txt.strip()); lines.append("")
out_path.write_text('\n'.join(lines))
# ---------------- main ----------------
def main():
ap = argparse.ArgumentParser(description="Two pass white to WhisperX with session join and CRDT merge")
ap.add_argument("--exts", nargs="*", default=["trm", "m4a", "mp3"], help="Extensions to process")
ap.add_argument("--recursive", action="store_true", help="Recurse into subdirectories")
# scan knobs (v6c style)
ap.add_argument("--threshold", type=float, default=0.70)
ap.add_argument("--ratio", type=float, default=0.80)
ap.add_argument("--white-time", type=float, default=0.95)
ap.add_argument("--file-threshold", type=float, default=0.88)
ap.add_argument("--file-ratio", type=float, default=0.80)
ap.add_argument("--flat-weight", type=float, default=0.65)
ap.add_argument("--slope-weight", type=float, default=0.35)
ap.add_argument("--even-weight", type=float, default=0.00)
ap.add_argument("--win-sec", type=float, default=1.0)
ap.add_argument("--hop-sec", type=float, default=0.5)
ap.add_argument("--warmup-sec", type=float, default=2.0)
ap.add_argument("--min-nonwhite-sec", type=float, default=1.5)
ap.add_argument("--min-white-sec", type=float, default=0.5)
ap.add_argument("--scan-workers", type=int, default=4, help="Parallel workers for scanning phase")
# session building
ap.add_argument("--join-gap-sec", type=float, default=8.0, help="Join intervals with gaps up to this many seconds, across and within files")
ap.add_argument("--min-session-sec", type=float, default=35.0, help="Try not to emit sessions shorter than this")
ap.add_argument("--overlap", type=float, default=0.25, help="Extra padding seconds that will already be covered in scanning stage")
# extraction and whisper
ap.add_argument("--sr", type=int, default=0, help="Resample WAVs to this rate. 0 keeps source rate")
ap.add_argument("--outdir", type=str, default="whisper_out")
ap.add_argument("--keep-session-wavs", action="store_true")
ap.add_argument("--whisper-model", type=str, default="large-v2")
ap.add_argument("--device", type=str, default="cpu")
ap.add_argument("--compute-type", type=str, default="int8")
ap.add_argument("--diarize", action="store_true")
ap.add_argument("--hf-token", type=str, default="")
args = ap.parse_args()
_check_tool("ffprobe"); _check_tool("ffmpeg")
# collect files
root = Path.cwd()
files: List[Path] = []
for ext in args.exts:
glob = "**/*." + ext if args.recursive else "*." + ext
files.extend(root.glob(glob))
files = sorted(set(files))
if not files:
print("No input files found"); return
# pass 1: scan in parallel
print("Scanning for non white intervals..."); sys.stdout.flush()
scans: List[FileScan] = []
with cf.ProcessPoolExecutor(max_workers=max(1, args.scan_workers)) as ex:
jobs = [
(str(p), float(args.threshold), float(args.ratio), float(args.win_sec), float(args.hop_sec),
float(args.warmup_sec), float(args.min_white_sec), float(args.min_nonwhite_sec),
float(args.flat_weight), float(args.slope_weight), float(args.even_weight),
float(args.join_gap_sec), float(args.overlap))
for p in files
]
futs = {ex.submit(scan_one_worker, job): files[i] for i, job in enumerate(jobs)}
done = 0
for fu in cf.as_completed(futs):
done += 1
try:
_r = fu.result()
_start = dt.datetime.fromisoformat(_r["start"]) if _r.get("start") else None
scans.append(FileScan(path=Path(_r["path"]), start=_start, duration=float(_r["duration"]), intervals=[(float(a), float(b)) for a, b in _r.get("intervals", [])]))
except Exception as e:
print(" scan failed for", futs[fu].name, "-", e)
if done % 8 == 0 or done == len(futs):
print(f" scanned {done}/{len(futs)} files"); sys.stdout.flush()
# filter out files that are all white
scans = [s for s in scans if s.intervals]
if not scans:
print("All files are white. Nothing to transcribe"); return
# plan sessions
sessions = plan_sessions(scans, join_gap=args.join_gap_sec, min_session=args.min_session_sec)
if not sessions:
print("No sessions found. Nothing to transcribe"); return
print(f"Planned {len(sessions)} sessions. Processing sequentially...")
outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
# prepare whisper models once
print("Loading WhisperX models..."); sys.stdout.flush()
import whisperx
model = whisperx.load_model(args.whisper_model, args.device, compute_type=args.compute_type)
model_a, meta = whisperx.load_align_model(language_code="en", device=args.device)
# process sessions
for i, sess in enumerate(sessions, 1):
# build nice name with time range
if sess.start_dt and sess.end_dt:
label = sess.start_dt.strftime("%Y%m%d-%H%M%S") + "_to_" + sess.end_dt.strftime("%H%M%S")
else:
first = sess.pieces[0]
label = Path(first.file).stem + f"_{int(first.start):06d}"
sess_dir = outdir / f"session_{i:03d}_{label}"; sess_dir.mkdir(parents=True, exist_ok=True)
print(f"\nSession {i}/{len(sessions)} {label}"); sys.stdout.flush()
# get channels and sr from first file
info0 = ffprobe_info(sess.pieces[0].file)
ch = int(info0.get("channels", 1)); sr0 = int(info0.get("sample_rate", 48000))
sr_out = args.sr if args.sr and args.sr > 0 else sr0
# extract per channel wav once for the full session
words_by_ch: Dict[int, List[Word]] = {}
for c in range(ch):
wav_path = sess_dir / f"session_{i:03d}_ch{c+1:02d}.wav"
print(f" ch {c+1:02d}/{ch:02d}: extracting...", end=" "); sys.stdout.flush()
try:
extract_session_channel_wav(sess, c, sr_out, wav_path)
print("ok. transcribing..."); sys.stdout.flush()
except Exception as e:
print("extract failed:", e); continue
try:
words = run_whisperx_on_wav_modelreused(wav_path, args.device, model, model_a, meta)
words_by_ch[c] = words
# quick preview
preview = " ".join([w.text for w in sorted(words, key=lambda w: w.start)[:8]])
print(f" preview: {preview[:80]}")
except Exception as e:
print(" whisper failed:", e)
if not args.keep_session_wavs:
try:
os.remove(wav_path)
except Exception:
pass
# merge
merged = crdt_merge_words(words_by_ch)
# diarize using first channel mix if requested
merged_dicts: List[Dict]
if args.diarize:
mixwav = sess_dir / f"session_{i:03d}_mix.wav"
try:
# mix first channel file from parts again but with pan average over all channels
# For speed we just reuse channel 0 wav if it still exists; else rebuild a quick mono
if not mixwav.exists():
extract_session_channel_wav(sess, 0, sr_out, mixwav)
merged_dicts = assign_speakers_if_available(mixwav, merged, args.device, args.hf_token or None)
except Exception as e:
print(" diarization error:", e)
merged_dicts = [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": None} for w in merged]
finally:
if not args.keep_session_wavs:
try: os.remove(mixwav)
except Exception: pass
else:
merged_dicts = [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": None} for w in merged]
# save outputs
out_json = sess_dir / f"{label}_merged.json"
out_srt = sess_dir / f"{label}_merged.srt"
out_json.write_text(json.dumps({
"session_label": label,
"files": [str(p.file) for p in sess.pieces],
"pieces": [{"file": str(p.file), "start": p.start, "end": p.end} for p in sess.pieces],
"words": merged_dicts,
}, indent=2))
write_srt(merged_dicts, out_srt)
print(f" wrote {out_json.name} and {out_srt.name}")
sys.stdout.flush()
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""
white2whisper_chain_v3.py
Two pass pipeline
1) Fast parallel scan to find non white intervals per file using v6c logic
2) Plan long sessions by joining nearby intervals within and across files
3) Process sessions sequentially
- Build one mono WAV per channel per session using safe concat list with absolute paths
- Run WhisperX once per channel, preview a few words for progress
- Merge channels with a simple CRDT style rule
- Optional diarization per session
Key fixes vs v2
- ProcessPool worker is top level and picklable
- Concat list uses absolute paths and correct quoting, no unterminated string
- Fallback to re-encode on concat if stream copy fails
- Better error messages when extraction fails
"""
from __future__ import annotations
import argparse
import concurrent.futures as cf
import datetime as dt
import json
import math
import os
import re
import shutil
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
# ---------------- utils ----------------
def _run_ok(cmd: Sequence[str]) -> None:
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if p.returncode != 0:
raise RuntimeError("command failed:\n" + " ".join(cmd) + "\n" + (p.stderr[-800:] if p.stderr else ""))
def _check_tool(name: str) -> None:
try:
subprocess.run([name, "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except Exception:
print(f"Error: {name} not found on PATH", file=sys.stderr)
sys.exit(1)
# ---------------- ffprobe ----------------
def ffprobe_info(input_path: Path) -> Dict:
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "a:0",
"-show_entries", "stream=index,channels,sample_rate,duration",
"-of", "json", str(input_path),
]
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
raise RuntimeError(f"ffprobe failed on {input_path}")
data = json.loads(res.stdout or "{}")
streams = data.get("streams", [])
if not streams:
raise RuntimeError(f"No audio stream found in {input_path}")
return streams[0]
# ---------------- DSP core ----------------
def spectral_flatness_vector(psd: np.ndarray, eps: float = 1e-20) -> float:
psd_clip = np.maximum(psd, eps)
gmean = np.exp(np.mean(np.log(psd_clip)))
amean = float(np.mean(psd_clip))
return float(gmean / amean) if amean > 0 else 0.0
def rfft_psd_window(x: np.ndarray, win: np.ndarray, sr: int):
X = np.fft.rfft(x * win)
Pxx = (np.abs(X) ** 2) / (np.sum(win**2) * sr)
freqs = np.fft.rfftfreq(len(x), d=1.0 / sr)
return freqs.astype(np.float32), Pxx.astype(np.float32)
def band_evenness(freqs: np.ndarray, Pxx: np.ndarray, f_lo: float, f_hi: float, k: int = 12) -> float:
f_edges = np.geomspace(max(1.0, f_lo), max(f_lo * 1.0001, f_hi), k + 1)
vals = []
for i in range(k):
m = (freqs >= f_edges[i]) & (freqs < f_edges[i+1])
if not np.any(m):
continue
vals.append(float(np.mean(Pxx[m])))
if len(vals) < 3:
return 0.0
arr = np.array(vals)
arr[arr <= 0] = np.min(arr[arr > 0]) if np.any(arr > 0) else 1e-20
arr /= np.mean(arr)
std = float(np.std(arr))
return float(1.0 / (1.0 + 3.0 * std))
def slope_in_band(freqs: np.ndarray, Pxx: np.ndarray, mask: np.ndarray) -> float:
f = freqs[mask]
p = Pxx[mask]
m = (f > 0) & np.isfinite(p) & (p > 0)
if np.count_nonzero(m) < 16:
return 0.0
logf = np.log10(f[m]); logp = np.log10(p[m])
slope, _ = np.polyfit(logf, logp, 1)
return float(slope)
def open_ffmpeg_pcm_pipe(inpath: Path, stream_index: int, sr: int, ch: int) -> subprocess.Popen:
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-err_detect", "ignore_err", "-fflags", "+discardcorrupt",
"-analyzeduration", "100M", "-probesize", "100M",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-ac", str(ch), "-ar", str(sr),
"-f", "s16le", "-acodec", "pcm_s16le", "pipe:1",
]
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, bufsize=1024*1024)
def remove_short_runs(flags: np.ndarray, value: int, min_len: int) -> None:
n = flags.size
i = 0
while i < n:
if flags[i] == value:
j = i
while j < n and flags[j] == value:
j += 1
if (j - i) < min_len:
flags[i:j] = 1 - value
i = j
else:
i += 1
def flags_to_intervals(flags: List[bool], hop: int, win: int, sr: int):
whites = []
nonwhites = []
if not flags:
return whites, nonwhites
cur_val = flags[0]
cur_start = 0
for i in range(1, len(flags)):
if flags[i] != cur_val:
start_t = cur_start * hop / sr
end_t = (i * hop + win) / sr
(whites if cur_val else nonwhites).append((start_t, end_t))
cur_val = flags[i]
cur_start = i
start_t = cur_start * hop / sr
end_t = ((len(flags) - 1) * hop + win) / sr
(whites if cur_val else nonwhites).append((start_t, end_t))
return whites, nonwhites
def intervals_total(intervals: List[Tuple[float, float]]) -> float:
if not intervals:
return 0.0
ints = sorted(intervals)
total = 0.0
cur_s, cur_e = ints[0]
for s, e in ints[1:]:
if s <= cur_e:
cur_e = max(cur_e, e)
else:
total += max(0.0, cur_e - cur_s)
cur_s, cur_e = s, e
total += max(0.0, cur_e - cur_s)
return total
# ---------------- scanning engine ----------------
def segment_streaming_from_pipe(
inpath: Path,
stream_index: int,
sr: int,
ch: int,
threshold: float,
ratio: float,
win_s: float,
hop_s: float,
warmup_sec: float,
min_white_sec: float,
min_nonwhite_sec: float,
flat_w: float,
slope_w: float,
even_w: float,
) -> Dict:
proc = open_ffmpeg_pcm_pipe(inpath, stream_index, sr, ch)
assert proc.stdout is not None
win = max(256, int(sr * win_s))
hop = max(1, int(sr * hop_s))
bytes_per_sample = 2
bytes_per_frame = bytes_per_sample * ch
read_size_bytes = hop * bytes_per_frame
buf = np.zeros((win, ch), dtype=np.float32)
filled = 0
hann = np.hanning(win).astype(np.float32)
warm_windows = max(8, int(round(warmup_sec / max(1e-6, hop_s))))
psd_accum = None
freqs_ref = None
f_lo = 350.0
f_hi = sr * 0.65
mask = None
flags: List[bool] = []
total_frames = 0
while True:
raw = proc.stdout.read(read_size_bytes)
if not raw:
break
n_samples = len(raw) // bytes_per_sample
n_frames = n_samples // ch
if n_frames == 0:
continue
total_frames += n_frames
arr = np.frombuffer(raw[: n_frames * bytes_per_frame], dtype=np.int16)
arr = arr.reshape((-1, ch)).astype(np.float32) / 32768.0
step = arr.shape[0]
if step >= win:
buf[...] = arr[-win:]
filled = win
else:
if filled + step <= win:
buf[filled:filled+step, :] = arr
filled += step
else:
overflow = filled + step - win
buf[:win-overflow, :] = buf[overflow:, :]
buf[win-step:, :] = arr
filled = win
if filled < win:
continue
if warm_windows > 0:
Psum = None
for c in range(min(ch, 4)):
fr, P = rfft_psd_window(buf[:, c], hann, sr)
if Psum is None:
Psum = P.astype(np.float64)
else:
Psum += P
freqs_ref = fr
if psd_accum is None:
psd_accum = Psum
else:
psd_accum += Psum
warm_windows -= 1
if warm_windows == 0:
power = psd_accum / max(1, int(round(warmup_sec / max(1e-6, hop_s))))
c = np.cumsum(power)
if c[-1] > 0:
c /= c[-1]
low_idx = int(np.searchsorted(c, 0.05))
high_idx = int(np.searchsorted(c, 0.95))
f_lo = max(100.0, float(freqs_ref[max(1, low_idx)]))
f_hi = min(0.95 * float(freqs_ref.max()), float(freqs_ref[min(len(freqs_ref)-1, high_idx)]))
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
if mask is None and freqs_ref is not None:
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
ch_likely = 0
for c in range(ch):
xw = buf[:, c]
if np.max(np.abs(xw)) < 1e-6:
continue
fr, P = rfft_psd_window(xw, hann, sr)
m = mask if mask is not None else (fr >= f_lo) & (fr <= f_hi)
sf = spectral_flatness_vector(P[m])
slope = slope_in_band(fr, P, m)
even = band_evenness(fr, P, f_lo, f_hi)
slope_score = math.exp(-abs(slope) / 0.25)
prob = float(flat_w * max(0.0, min(1.0, sf)) + slope_w * slope_score + even_w * max(0.0, min(1.0, even)))
if prob >= threshold and abs(slope) < 0.3:
ch_likely += 1
frac = ch_likely / float(ch)
flags.append(frac >= ratio)
proc.stdout.close(); proc.wait()
f_arr = np.array(flags, dtype=np.int8)
if f_arr.size:
min_white_frames = max(1, int(math.ceil(min_white_sec / hop_s)))
min_nonwhite_frames = max(1, int(math.ceil(min_nonwhite_sec / hop_s)))
remove_short_runs(f_arr, value=0, min_len=min_nonwhite_frames)
remove_short_runs(f_arr, value=1, min_len=min_white_frames)
whites, nonwhites = flags_to_intervals(list(f_arr.astype(bool)), hop, win, sr)
total_dur = total_frames / float(sr)
white_dur = intervals_total(whites)
return {"duration": total_dur, "nonwhite_intervals": nonwhites}
# ---------------- filenames and sessions ----------------
_TS_RE = re.compile(r".*_(\d{8})-(\d{4})_", re.IGNORECASE)
def file_start_time(p: Path) -> Optional[dt.datetime]:
m = _TS_RE.match(p.name)
if not m:
return None
ymd, hm = m.group(1), m.group(2)
try:
return dt.datetime.strptime(ymd + hm, "%Y%m%d%H%M")
except Exception:
return None
@dataclass
class FileScan:
path: Path
start: Optional[dt.datetime]
duration: float
intervals: List[Tuple[float, float]]
@dataclass
class SessionPiece:
file: Path
start: float
end: float
@dataclass
class Session:
idx: int
start_dt: Optional[dt.datetime]
end_dt: Optional[dt.datetime]
pieces: List[SessionPiece]
def merge_intervals(iv: List[Tuple[float, float]], join_gap: float) -> List[Tuple[float, float]]:
if not iv:
return []
iv = sorted(iv)
out = [list(iv[0])]
for s, e in iv[1:]:
if s <= out[-1][1] + join_gap:
out[-1][1] = max(out[-1][1], e)
else:
out.append([s, e])
return [(float(s), float(e)) for s, e in out]
def plan_sessions(scans: List[FileScan], join_gap: float, min_session: float) -> List[Session]:
scans_sorted = sorted(scans, key=lambda x: (x.start or dt.datetime.min, x.path.name))
sessions: List[Session] = []
cur_pieces: List[SessionPiece] = []
cur_start: Optional[dt.datetime] = None
cur_end: Optional[dt.datetime] = None
def close_current(idx: int):
nonlocal cur_pieces, cur_start, cur_end
if not cur_pieces:
return None
sess = Session(idx=idx, start_dt=cur_start, end_dt=cur_end, pieces=cur_pieces)
cur_pieces = []
cur_start = None
cur_end = None
return sess
idx = 1
for fs in scans_sorted:
merged = merge_intervals(fs.intervals, join_gap)
if not merged:
continue
for s_rel, e_rel in merged:
s_abs = (fs.start or dt.datetime.min) + dt.timedelta(seconds=s_rel) if fs.start else None
e_abs = (fs.start or dt.datetime.min) + dt.timedelta(seconds=e_rel) if fs.start else None
if not cur_pieces:
cur_pieces = [SessionPiece(fs.path, s_rel, e_rel)]
cur_start = s_abs
cur_end = e_abs
else:
if (cur_end and e_abs and s_abs and (s_abs - cur_end).total_seconds() <= join_gap) or (not fs.start):
cur_pieces.append(SessionPiece(fs.path, s_rel, e_rel))
if e_abs and (not cur_end or e_abs > cur_end):
cur_end = e_abs
else:
sess = close_current(idx)
if sess:
total_len = sum(p.end - p.start for p in sess.pieces)
if total_len < min_session and sessions:
sessions[-1].pieces.extend(sess.pieces)
sessions[-1].end_dt = sess.end_dt
else:
sessions.append(sess)
idx += 1
cur_pieces = [SessionPiece(fs.path, s_rel, e_rel)]
cur_start = s_abs
cur_end = e_abs
last = close_current(idx)
if last:
total_len = sum(p.end - p.start for p in last.pieces)
if total_len < min_session and sessions:
sessions[-1].pieces.extend(last.pieces)
sessions[-1].end_dt = last.end_dt
else:
sessions.append(last)
return sessions
# ---------------- worker for scan ----------------
def scan_one_worker(job):
inpath_str, threshold, ratio, win_sec, hop_sec, warmup_sec, min_white_sec, min_nonwhite_sec, flat_w, slope_w, even_w, join_gap_sec, overlap = job
p = Path(inpath_str)
info = ffprobe_info(p)
stream_index = int(info.get("index", 0)); sr = int(info.get("sample_rate", 48000)); ch = int(info.get("channels", 1))
res = segment_streaming_from_pipe(p, stream_index, sr, ch, threshold, ratio, win_sec, hop_sec, warmup_sec, min_white_sec, min_nonwhite_sec, flat_w, slope_w, even_w)
iv = res.get("nonwhite_intervals", [])
dur = float(res.get("duration", 0.0))
padded = []
for s, e in iv:
s2 = max(0.0, s - overlap); e2 = min(dur, e + overlap)
if e2 > s2:
padded.append((float(s2), float(e2)))
merged = merge_intervals(padded, float(join_gap_sec))
start_dt = file_start_time(p)
start_iso = start_dt.isoformat() if start_dt else ""
return {"path": str(p), "start": start_iso, "duration": dur, "intervals": merged}
# ---------------- audio extraction ----------------
def extract_session_channel_wav(session: Session, ch_idx: int, sr_out: Optional[int], out_wav: Path) -> None:
part_files: List[Path] = []
list_txt = out_wav.with_suffix('.list.txt')
out_wav.parent.mkdir(parents=True, exist_ok=True)
try:
for i, piece in enumerate(session.pieces, 1):
part = out_wav.with_suffix("").parent / f"{out_wav.stem}_part{i:03d}.wav"
part_files.append(part)
pan = f"pan=mono|c0=c{ch_idx}"
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-ss", f"{piece.start:.3f}", "-t", f"{max(0.0, piece.end-piece.start):.3f}",
"-i", str(piece.file),
"-map", "0:a:0", "-vn", "-sn", "-af", pan,
"-ac", "1", "-acodec", "pcm_s16le",
]
if sr_out and sr_out > 0:
cmd += ["-ar", str(sr_out)]
cmd += [str(part)]
_run_ok(cmd)
# write absolute paths so ffmpeg does not double prefix
with open(list_txt, 'w', encoding='utf-8') as f:
for pf in part_files:
abs_p = pf.resolve()
esc = str(abs_p).replace("'", "'\\''") # escape single quotes
line = "file '{}'\n".format(esc)
f.write(line)
# try a fast stream copy
try:
_run_ok(["ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-f", "concat", "-safe", "0", "-i", str(list_txt), "-c", "copy", str(out_wav)])
except Exception:
# fallback to re-encode if headers differ
cmd = ["ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-f", "concat", "-safe", "0", "-i", str(list_txt),
"-acodec", "pcm_s16le", "-ac", "1"]
if sr_out and sr_out > 0:
cmd += ["-ar", str(sr_out)]
cmd += [str(out_wav)]
_run_ok(cmd)
finally:
try:
os.remove(list_txt)
except Exception:
pass
for pf in part_files:
try:
os.remove(pf)
except Exception:
pass
# ---------------- whisper and merge ----------------
@dataclass
class Word:
start: float
end: float
text: str
score: float
ch: int
def run_whisperx_on_wav_modelreused(wav_path: Path, device: str, model, align_model, align_meta) -> List[Word]:
import whisperx
audio = whisperx.load_audio(str(wav_path))
res = model.transcribe(audio)
aligned = whisperx.align(res.get("segments", []), align_model, align_meta, audio, device, return_char_alignments=False)
words: List[Word] = []
for seg in aligned.get("segments", []):
for w in seg.get("words", []) or []:
s = float(w.get("start", seg.get("start", 0.0)) or 0.0)
e = float(w.get("end", seg.get("end", s)))
txt = str(w.get("word", "")).strip()
sc = float(w.get("score", 0.0))
if txt:
words.append(Word(s, e, txt, sc, ch=-1))
return words
def crdt_merge_words(words_by_channel: Dict[int, List[Word]]) -> List[Word]:
all_words: List[Word] = []
for ch, lst in words_by_channel.items():
for w in lst:
all_words.append(Word(w.start, w.end, w.text, w.score, ch))
all_words.sort(key=lambda w: (w.start, w.end))
merged: List[Word] = []
for w in all_words:
if not merged:
merged.append(w); continue
last = merged[-1]
overlap = min(last.end, w.end) - max(last.start, w.start)
if overlap > 0:
if w.score > last.score or (abs(w.score - last.score) < 1e-6 and (w.end - w.start) > (last.end - last.start)):
merged[-1] = w
else:
merged.append(w)
return merged
def assign_speakers_if_available(audio_path: Path, merged: List[Word], device: str, hf_token: Optional[str]) -> List[Dict]:
try:
import whisperx
if not hf_token:
hf_token = os.environ.get("HF_TOKEN")
diar = whisperx.diarize.DiarizationPipeline(use_auth_token=hf_token, device=device)
diar_segs = diar(str(audio_path))
out = []
for w in merged:
t = 0.5 * (w.start + w.end)
spk = None
for seg in diar_segs:
s = float(seg.get('start', 0)); e = float(seg.get('end', 0))
if s <= t <= e:
spk = seg.get('speaker'); break
out.append({"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": spk})
return out
except Exception as e:
print(" diarization skipped:", e)
return [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": None} for w in merged]
def write_srt(words: List[Dict], out_path: Path, max_chars: int = 60, max_gap: float = 0.6) -> None:
if not words:
out_path.write_text(""); return
entries = []
cur_text = []; cur_start = words[0]['start']; cur_end = words[0]['end']
for w in words:
if cur_text and (w['start'] - cur_end > max_gap or len(' '.join(cur_text) + ' ' + w['text']) > max_chars):
entries.append((cur_start, cur_end, ' '.join(cur_text)))
cur_text = [w['text']]; cur_start = w['start']; cur_end = w['end']
else:
cur_text.append(w['text']); cur_end = w['end']
if cur_text:
entries.append((cur_start, cur_end, ' '.join(cur_text)))
def fmt(t: float) -> str:
ms = int(round(t * 1000)); h = ms // 3600000; m = (ms % 3600000) // 60000; s = (ms % 60000) // 1000; ms2 = ms % 1000
return f"{h:02d}:{m:02d}:{s:02d},{ms2:03d}"
lines = []
for i, (s, e, txt) in enumerate(entries, 1):
lines.append(str(i)); lines.append(f"{fmt(s)} --> {fmt(e)}"); lines.append(txt.strip()); lines.append("")
out_path.write_text('\n'.join(lines))
# ---------------- main ----------------
def main():
ap = argparse.ArgumentParser(description="Two pass white to WhisperX with session join and CRDT merge")
ap.add_argument("--exts", nargs="*", default=["trm", "m4a", "mp3"], help="Extensions to process")
ap.add_argument("--recursive", action="store_true", help="Recurse into subdirectories")
# scan knobs
ap.add_argument("--threshold", type=float, default=0.70)
ap.add_argument("--ratio", type=float, default=0.80)
ap.add_argument("--white-time", type=float, default=0.95)
ap.add_argument("--file-threshold", type=float, default=0.88)
ap.add_argument("--file-ratio", type=float, default=0.80)
ap.add_argument("--flat-weight", type=float, default=0.65)
ap.add_argument("--slope-weight", type=float, default=0.35)
ap.add_argument("--even-weight", type=float, default=0.00)
ap.add_argument("--win-sec", type=float, default=1.0)
ap.add_argument("--hop-sec", type=float, default=0.5)
ap.add_argument("--warmup-sec", type=float, default=2.0)
ap.add_argument("--min-nonwhite-sec", type=float, default=1.5)
ap.add_argument("--min-white-sec", type=float, default=0.5)
ap.add_argument("--scan-workers", type=int, default=4, help="Parallel workers for scanning phase")
# sessions
ap.add_argument("--join-gap-sec", type=float, default=8.0, help="Join intervals up to this gap")
ap.add_argument("--min-session-sec", type=float, default=35.0, help="Try not to emit sessions shorter than this")
ap.add_argument("--overlap", type=float, default=0.25, help="Padding applied around intervals before join")
# extraction and whisper
ap.add_argument("--sr", type=int, default=0, help="Resample WAVs to this rate. 0 keeps source rate")
ap.add_argument("--outdir", type=str, default="whisper_out")
ap.add_argument("--keep-session-wavs", action="store_true")
ap.add_argument("--whisper-model", type=str, default="large-v2")
ap.add_argument("--device", type=str, default="cpu")
ap.add_argument("--compute-type", type=str, default="int8")
ap.add_argument("--diarize", action="store_true")
ap.add_argument("--hf-token", type=str, default="")
args = ap.parse_args()
_check_tool("ffprobe"); _check_tool("ffmpeg")
# gather files
root = Path.cwd()
files: List[Path] = []
for ext in args.exts:
glob = "**/*." + ext if args.recursive else "*." + ext
files.extend(root.glob(glob))
files = sorted(set(files))
if not files:
print("No input files found"); return
# pass 1: scan in parallel
print("Scanning for non white intervals..."); sys.stdout.flush()
jobs = [
(str(p), float(args.threshold), float(args.ratio), float(args.win_sec), float(args.hop_sec),
float(args.warmup_sec), float(args.min_white_sec), float(args.min_nonwhite_sec),
float(args.flat_weight), float(args.slope_weight), float(args.even_weight),
float(args.join_gap_sec), float(args.overlap))
for p in files
]
scans: List[FileScan] = []
with cf.ProcessPoolExecutor(max_workers=max(1, args.scan_workers)) as ex:
futs = {ex.submit(scan_one_worker, job): files[i] for i, job in enumerate(jobs)}
done = 0
for fu in cf.as_completed(futs):
done += 1
try:
_r = fu.result()
_start = dt.datetime.fromisoformat(_r["start"]) if _r.get("start") else None
scans.append(FileScan(path=Path(_r["path"]), start=_start, duration=float(_r["duration"]), intervals=[(float(a), float(b)) for a, b in _r.get("intervals", [])]))
except Exception as e:
print(" scan failed for", futs[fu].name, "-", e)
if done % 8 == 0 or done == len(futs):
print(f" scanned {done}/{len(futs)} files"); sys.stdout.flush()
scans = [s for s in scans if s.intervals]
if not scans:
print("All files are white. Nothing to transcribe"); return
sessions = plan_sessions(scans, join_gap=args.join_gap_sec, min_session=args.min_session_sec)
if not sessions:
print("No sessions found. Nothing to transcribe"); return
print(f"Planned {len(sessions)} sessions. Processing sequentially...")
outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
# load whisper once
print("Loading WhisperX models..."); sys.stdout.flush()
import whisperx
model = whisperx.load_model(args.whisper_model, args.device, compute_type=args.compute_type)
model_a, meta = whisperx.load_align_model(language_code="en", device=args.device)
# process sessions
for i, sess in enumerate(sessions, 1):
if sess.start_dt and sess.end_dt:
label = sess.start_dt.strftime("%Y%m%d-%H%M%S") + "_to_" + sess.end_dt.strftime("%H%M%S")
else:
first = sess.pieces[0]
label = Path(first.file).stem + f"_{int(first.start):06d}"
sess_dir = outdir / f"session_{i:03d}_{label}"; sess_dir.mkdir(parents=True, exist_ok=True)
print(f"\nSession {i}/{len(sessions)} {label}"); sys.stdout.flush()
info0 = ffprobe_info(sess.pieces[0].file)
ch = int(info0.get("channels", 1)); sr0 = int(info0.get("sample_rate", 48000))
sr_out = args.sr if args.sr and args.sr > 0 else sr0
words_by_ch: Dict[int, List[Word]] = {}
for c in range(ch):
wav_path = sess_dir / f"session_{i:03d}_ch{c+1:02d}.wav"
print(f" ch {c+1:02d}/{ch:02d}: extracting...", end=" "); sys.stdout.flush()
try:
extract_session_channel_wav(sess, c, sr_out, wav_path)
print("ok. transcribing...")
except Exception as e:
print("extract failed:", e)
continue
try:
words = run_whisperx_on_wav_modelreused(wav_path, args.device, model, model_a, meta)
words_by_ch[c] = words
preview = " ".join([w.text for w in sorted(words, key=lambda w: w.start)[:8]])
print(f" preview: {preview[:80]}")
except Exception as e:
print(" whisper failed:", e)
if not args.keep_session_wavs:
try:
os.remove(wav_path)
except Exception:
pass
merged = crdt_merge_words(words_by_ch)
if args.diarize:
mixwav = sess_dir / f"session_{i:03d}_mix.wav"
try:
extract_session_channel_wav(sess, 0, sr_out, mixwav)
merged_dicts = assign_speakers_if_available(mixwav, merged, args.device, args.hf_token or None)
except Exception as e:
print(" diarization error:", e)
merged_dicts = [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": None} for w in merged]
finally:
if not args.keep_session_wavs:
try: os.remove(mixwav)
except Exception: pass
else:
merged_dicts = [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": None} for w in merged]
out_json = sess_dir / f"{label}_merged.json"
out_srt = sess_dir / f"{label}_merged.srt"
out_json.write_text(json.dumps({
"session_label": label,
"files": [str(p.file) for p in sess.pieces],
"pieces": [{"file": str(p.file), "start": p.start, "end": p.end} for p in sess.pieces],
"words": merged_dicts,
}, indent=2))
write_srt(merged_dicts, out_srt)
print(f" wrote {out_json.name} and {out_srt.name}")
sys.stdout.flush()
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""
white2whisper_chain_v4.py
Two pass pipeline
1) Fast parallel scan to find non white intervals per file using v6c-style logic
2) Plan long sessions by joining nearby intervals within and across files
3) Process sessions sequentially
- Build one mono WAV per channel per session using concat lists with absolute paths
- Run WhisperX once per channel, print a heartbeat while it runs
- Merge channels with a simple CRDT-style rule
- Optional diarization per session
New in v4
- --resume skips sessions that already have final outputs
- ffmpeg calls are forced non-interactive: -y and -nostdin
- Clean stale outputs before starting each channel extract
- Heartbeat during Whisper so you can see progress
"""
from __future__ import annotations
import argparse
import concurrent.futures as cf
import datetime as dt
import json
import math
import os
import re
import subprocess
import sys
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
# ---------------- utils ----------------
def _run_ok(cmd: Sequence[str]) -> None:
# make ffmpeg non-interactive
if cmd and Path(cmd[0]).name == "ffmpeg":
cmd = [cmd[0], "-y", "-nostdin"] + list(cmd[1:])
p = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.DEVNULL,
text=True,
)
if p.returncode != 0:
raise RuntimeError(
"command failed:\n" + " ".join(cmd) + "\n" + (p.stderr[-2000:] if p.stderr else "")
)
def _check_tool(name: str) -> None:
try:
subprocess.run([name, "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except Exception:
print(f"Error: {name} not found on PATH", file=sys.stderr)
sys.exit(1)
# ---------------- ffprobe ----------------
def ffprobe_info(input_path: Path) -> Dict:
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "a:0",
"-show_entries", "stream=index,channels,sample_rate,duration",
"-of", "json", str(input_path),
]
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
raise RuntimeError(f"ffprobe failed on {input_path}")
data = json.loads(res.stdout or "{}")
streams = data.get("streams", [])
if not streams:
raise RuntimeError(f"No audio stream found in {input_path}")
return streams[0]
# ---------------- DSP core ----------------
def spectral_flatness_vector(psd: np.ndarray, eps: float = 1e-20) -> float:
psd_clip = np.maximum(psd, eps)
gmean = np.exp(np.mean(np.log(psd_clip)))
amean = float(np.mean(psd_clip))
return float(gmean / amean) if amean > 0 else 0.0
def rfft_psd_window(x: np.ndarray, win: np.ndarray, sr: int):
X = np.fft.rfft(x * win)
Pxx = (np.abs(X) ** 2) / (np.sum(win**2) * sr)
freqs = np.fft.rfftfreq(len(x), d=1.0 / sr)
return freqs.astype(np.float32), Pxx.astype(np.float32)
def band_evenness(freqs: np.ndarray, Pxx: np.ndarray, f_lo: float, f_hi: float, k: int = 12) -> float:
f_edges = np.geomspace(max(1.0, f_lo), max(f_lo * 1.0001, f_hi), k + 1)
vals = []
for i in range(k):
m = (freqs >= f_edges[i]) & (freqs < f_edges[i+1])
if not np.any(m):
continue
vals.append(float(np.mean(Pxx[m])))
if len(vals) < 3:
return 0.0
arr = np.array(vals)
arr[arr <= 0] = np.min(arr[arr > 0]) if np.any(arr > 0) else 1e-20
arr /= np.mean(arr)
std = float(np.std(arr))
return float(1.0 / (1.0 + 3.0 * std))
def slope_in_band(freqs: np.ndarray, Pxx: np.ndarray, mask: np.ndarray) -> float:
f = freqs[mask]
p = Pxx[mask]
m = (f > 0) & np.isfinite(p) & (p > 0)
if np.count_nonzero(m) < 16:
return 0.0
logf = np.log10(f[m]); logp = np.log10(p[m])
slope, _ = np.polyfit(logf, logp, 1)
return float(slope)
def open_ffmpeg_pcm_pipe(inpath: Path, stream_index: int, sr: int, ch: int) -> subprocess.Popen:
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error", "-nostdin",
"-err_detect", "ignore_err", "-fflags", "+discardcorrupt",
"-analyzeduration", "100M", "-probesize", "100M",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-ac", str(ch), "-ar", str(sr),
"-f", "s16le", "-acodec", "pcm_s16le", "pipe:1",
]
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, bufsize=1024*1024)
def remove_short_runs(flags: np.ndarray, value: int, min_len: int) -> None:
n = flags.size
i = 0
while i < n:
if flags[i] == value:
j = i
while j < n and flags[j] == value:
j += 1
if (j - i) < min_len:
flags[i:j] = 1 - value
i = j
else:
i += 1
def flags_to_intervals(flags: List[bool], hop: int, win: int, sr: int):
whites = []
nonwhites = []
if not flags:
return whites, nonwhites
cur_val = flags[0]
cur_start = 0
for i in range(1, len(flags)):
if flags[i] != cur_val:
start_t = cur_start * hop / sr
end_t = (i * hop + win) / sr
(whites if cur_val else nonwhites).append((start_t, end_t))
cur_val = flags[i]
cur_start = i
start_t = cur_start * hop / sr
end_t = ((len(flags) - 1) * hop + win) / sr
(whites if cur_val else nonwhites).append((start_t, end_t))
return whites, nonwhites
def intervals_total(intervals: List[Tuple[float, float]]) -> float:
if not intervals:
return 0.0
ints = sorted(intervals)
total = 0.0
cur_s, cur_e = ints[0]
for s, e in ints[1:]:
if s <= cur_e:
cur_e = max(cur_e, e)
else:
total += max(0.0, cur_e - cur_s)
cur_s, cur_e = s, e
total += max(0.0, cur_e - cur_s)
return total
# ---------------- scanning engine ----------------
def segment_streaming_from_pipe(
inpath: Path,
stream_index: int,
sr: int,
ch: int,
threshold: float,
ratio: float,
win_s: float,
hop_s: float,
warmup_sec: float,
min_white_sec: float,
min_nonwhite_sec: float,
flat_w: float,
slope_w: float,
even_w: float,
) -> Dict:
proc = open_ffmpeg_pcm_pipe(inpath, stream_index, sr, ch)
assert proc.stdout is not None
win = max(256, int(sr * win_s))
hop = max(1, int(sr * hop_s))
bytes_per_sample = 2
bytes_per_frame = bytes_per_sample * ch
read_size_bytes = hop * bytes_per_frame
buf = np.zeros((win, ch), dtype=np.float32)
filled = 0
hann = np.hanning(win).astype(np.float32)
warm_windows = max(8, int(round(warmup_sec / max(1e-6, hop_s))))
psd_accum = None
freqs_ref = None
f_lo = 350.0
f_hi = sr * 0.65
mask = None
flags: List[bool] = []
total_frames = 0
while True:
raw = proc.stdout.read(read_size_bytes)
if not raw:
break
n_samples = len(raw) // bytes_per_sample
n_frames = n_samples // ch
if n_frames == 0:
continue
total_frames += n_frames
arr = np.frombuffer(raw[: n_frames * bytes_per_frame], dtype=np.int16)
arr = arr.reshape((-1, ch)).astype(np.float32) / 32768.0
step = arr.shape[0]
if step >= win:
buf[...] = arr[-win:]
filled = win
else:
if filled + step <= win:
buf[filled:filled+step, :] = arr
filled += step
else:
overflow = filled + step - win
buf[:win-overflow, :] = buf[overflow:, :]
buf[win-step:, :] = arr
filled = win
if filled < win:
continue
if warm_windows > 0:
Psum = None
for c in range(min(ch, 4)):
fr, P = rfft_psd_window(buf[:, c], hann, sr)
if Psum is None:
Psum = P.astype(np.float64)
else:
Psum += P
freqs_ref = fr
if psd_accum is None:
psd_accum = Psum
else:
psd_accum += Psum
warm_windows -= 1
if warm_windows == 0:
power = psd_accum / max(1, int(round(warmup_sec / max(1e-6, hop_s))))
c = np.cumsum(power)
if c[-1] > 0:
c /= c[-1]
low_idx = int(np.searchsorted(c, 0.05))
high_idx = int(np.searchsorted(c, 0.95))
f_lo = max(100.0, float(freqs_ref[max(1, low_idx)]))
f_hi = min(0.95 * float(freqs_ref.max()), float(freqs_ref[min(len(freqs_ref)-1, high_idx)]))
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
if mask is None and freqs_ref is not None:
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
ch_likely = 0
for c in range(ch):
xw = buf[:, c]
if np.max(np.abs(xw)) < 1e-6:
continue
fr, P = rfft_psd_window(xw, hann, sr)
m = mask if mask is not None else (fr >= f_lo) & (fr <= f_hi)
sf = spectral_flatness_vector(P[m])
slope = slope_in_band(fr, P, m)
even = band_evenness(fr, P, f_lo, f_hi)
slope_score = math.exp(-abs(slope) / 0.25)
prob = float(flat_w * max(0.0, min(1.0, sf)) + 0.0 * slope_score + 0.0 * even)
if prob >= threshold and abs(slope) < 0.3:
ch_likely += 1
frac = ch_likely / float(ch)
flags.append(frac >= ratio)
proc.stdout.close(); proc.wait()
f_arr = np.array(flags, dtype=np.int8)
if f_arr.size:
min_white_frames = max(1, int(math.ceil(min_white_sec / hop_s)))
min_nonwhite_frames = max(1, int(math.ceil(min_nonwhite_sec / hop_s)))
remove_short_runs(f_arr, value=0, min_len=min_nonwhite_frames)
remove_short_runs(f_arr, value=1, min_len=min_white_frames)
whites, nonwhites = flags_to_intervals(list(f_arr.astype(bool)), hop, win, sr)
total_dur = total_frames / float(sr)
_ = intervals_total(whites) # white duration if you need it later
return {"duration": total_dur, "nonwhite_intervals": nonwhites}
# ---------------- filenames and sessions ----------------
_TS_RE = re.compile(r".*_(\d{8})-(\d{4})_", re.IGNORECASE)
def file_start_time(p: Path) -> Optional[dt.datetime]:
m = _TS_RE.match(p.name)
if not m:
return None
ymd, hm = m.group(1), m.group(2)
try:
return dt.datetime.strptime(ymd + hm, "%Y%m%d%H%M")
except Exception:
return None
@dataclass
class FileScan:
path: Path
start: Optional[dt.datetime]
duration: float
intervals: List[Tuple[float, float]]
@dataclass
class SessionPiece:
file: Path
start: float
end: float
@dataclass
class Session:
idx: int
start_dt: Optional[dt.datetime]
end_dt: Optional[dt.datetime]
pieces: List[SessionPiece]
def merge_intervals(iv: List[Tuple[float, float]], join_gap: float) -> List[Tuple[float, float]]:
if not iv:
return []
iv = sorted(iv)
out = [list(iv[0])]
for s, e in iv[1:]:
if s <= out[-1][1] + join_gap:
out[-1][1] = max(out[-1][1], e)
else:
out.append([s, e])
return [(float(s), float(e)) for s, e in out]
def plan_sessions(scans: List[FileScan], join_gap: float, min_session: float) -> List[Session]:
scans_sorted = sorted(scans, key=lambda x: (x.start or dt.datetime.min, x.path.name))
sessions: List[Session] = []
cur_pieces: List[SessionPiece] = []
cur_start: Optional[dt.datetime] = None
cur_end: Optional[dt.datetime] = None
def close_current(idx: int):
nonlocal cur_pieces, cur_start, cur_end
if not cur_pieces:
return None
sess = Session(idx=idx, start_dt=cur_start, end_dt=cur_end, pieces=cur_pieces)
cur_pieces = []
cur_start = None
cur_end = None
return sess
idx = 1
for fs in scans_sorted:
merged = merge_intervals(fs.intervals, join_gap)
if not merged:
continue
for s_rel, e_rel in merged:
s_abs = (fs.start or dt.datetime.min) + dt.timedelta(seconds=s_rel) if fs.start else None
e_abs = (fs.start or dt.datetime.min) + dt.timedelta(seconds=e_rel) if fs.start else None
if not cur_pieces:
cur_pieces = [SessionPiece(fs.path, s_rel, e_rel)]
cur_start = s_abs
cur_end = e_abs
else:
if (cur_end and e_abs and s_abs and (s_abs - cur_end).total_seconds() <= join_gap) or (not fs.start):
cur_pieces.append(SessionPiece(fs.path, s_rel, e_rel))
if e_abs and (not cur_end or e_abs > cur_end):
cur_end = e_abs
else:
sess = close_current(idx)
if sess:
total_len = sum(p.end - p.start for p in sess.pieces)
if total_len < min_session and sessions:
sessions[-1].pieces.extend(sess.pieces)
sessions[-1].end_dt = sess.end_dt
else:
sessions.append(sess)
idx += 1
cur_pieces = [SessionPiece(fs.path, s_rel, e_rel)]
cur_start = s_abs
cur_end = e_abs
last = close_current(idx)
if last:
total_len = sum(p.end - p.start for p in last.pieces)
if total_len < min_session and sessions:
sessions[-1].pieces.extend(last.pieces)
sessions[-1].end_dt = last.end_dt
else:
sessions.append(last)
return sessions
# ---------------- worker for scan ----------------
def scan_one_worker(job):
inpath_str, threshold, ratio, win_sec, hop_sec, warmup_sec, min_white_sec, min_nonwhite_sec, flat_w, slope_w, even_w, join_gap_sec, overlap = job
p = Path(inpath_str)
info = ffprobe_info(p)
stream_index = int(info.get("index", 0)); sr = int(info.get("sample_rate", 48000)); ch = int(info.get("channels", 1))
res = segment_streaming_from_pipe(p, stream_index, sr, ch, threshold, ratio, win_sec, hop_sec, warmup_sec, min_white_sec, min_nonwhite_sec, flat_w, slope_w, even_w)
iv = res.get("nonwhite_intervals", [])
dur = float(res.get("duration", 0.0))
padded = []
for s, e in iv:
s2 = max(0.0, s - overlap); e2 = min(dur, e + overlap)
if e2 > s2:
padded.append((float(s2), float(e2)))
merged = merge_intervals(padded, float(join_gap_sec))
start_dt = file_start_time(p)
start_iso = start_dt.isoformat() if start_dt else ""
return {"path": str(p), "start": start_iso, "duration": dur, "intervals": merged}
# ---------------- audio extraction ----------------
def extract_session_channel_wav(session: Session, ch_idx: int, sr_out: Optional[int], out_wav: Path) -> None:
part_files: List[Path] = []
list_txt = out_wav.with_suffix('.list.txt')
out_wav.parent.mkdir(parents=True, exist_ok=True)
# clean stale outputs to avoid prompts
try:
if out_wav.exists():
out_wav.unlink()
if list_txt.exists():
list_txt.unlink()
for old in out_wav.parent.glob(f"{out_wav.stem}_part*.wav"):
try:
old.unlink()
except Exception:
pass
except Exception:
pass
try:
for i, piece in enumerate(session.pieces, 1):
part = out_wav.with_suffix("").parent / f"{out_wav.stem}_part{i:03d}.wav"
part_files.append(part)
pan = f"pan=mono|c0=c{ch_idx}"
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-ss", f"{piece.start:.3f}", "-t", f"{max(0.0, piece.end-piece.start):.3f}",
"-i", str(piece.file),
"-map", "0:a:0", "-vn", "-sn", "-af", pan,
"-ac", "1", "-acodec", "pcm_s16le",
]
if sr_out and sr_out > 0:
cmd += ["-ar", str(sr_out)]
cmd += [str(part)]
_run_ok(cmd)
# write absolute paths for concat
with open(list_txt, 'w', encoding='utf-8') as f:
for pf in part_files:
abs_p = pf.resolve()
esc = str(abs_p).replace("'", "'\\''")
f.write("file '{}'\n".format(esc))
# try stream copy first
try:
_run_ok(["ffmpeg", "-f", "concat", "-safe", "0", "-i", str(list_txt), "-c", "copy", str(out_wav)])
except Exception:
cmd = ["ffmpeg", "-f", "concat", "-safe", "0", "-i", str(list_txt), "-acodec", "pcm_s16le", "-ac", "1"]
if sr_out and sr_out > 0:
cmd += ["-ar", str(sr_out)]
cmd += [str(out_wav)]
_run_ok(cmd)
finally:
try:
if list_txt.exists():
list_txt.unlink()
except Exception:
pass
for pf in part_files:
try:
if pf.exists():
pf.unlink()
except Exception:
pass
# ---------------- whisper and merge ----------------
@dataclass
class Word:
start: float
end: float
text: str
score: float
ch: int
class Heartbeat:
def __init__(self, label: str, interval: float = 15.0):
self.label = label
self.interval = interval
self._stop = threading.Event()
self._t = threading.Thread(target=self._run, daemon=True)
self._t0 = time.time()
def __enter__(self):
self._t.start(); return self
def __exit__(self, exc_type, exc, tb):
self._stop.set(); self._t.join(timeout=1.0)
def _run(self):
n = 1
while not self._stop.wait(self.interval):
dt_s = int(time.time() - self._t0)
print(f" working... {self.label} t+{dt_s}s")
sys.stdout.flush()
def run_whisperx_on_wav_modelreused(wav_path: Path, device: str, model, align_model, align_meta, hb_sec: float) -> List[Word]:
import whisperx
audio = whisperx.load_audio(str(wav_path))
with Heartbeat(f"whisper {wav_path.name}", hb_sec):
res = model.transcribe(audio)
aligned = whisperx.align(res.get("segments", []), align_model, align_meta, audio, device, return_char_alignments=False)
words: List[Word] = []
for seg in aligned.get("segments", []):
for w in seg.get("words", []) or []:
s = float(w.get("start", seg.get("start", 0.0)) or 0.0)
e = float(w.get("end", seg.get("end", s)))
txt = str(w.get("word", "")).strip()
sc = float(w.get("score", 0.0))
if txt:
words.append(Word(s, e, txt, sc, ch=-1))
return words
def crdt_merge_words(words_by_channel: Dict[int, List[Word]]) -> List[Word]:
all_words: List[Word] = []
for ch, lst in words_by_channel.items():
for w in lst:
all_words.append(Word(w.start, w.end, w.text, w.score, ch))
all_words.sort(key=lambda w: (w.start, w.end))
merged: List[Word] = []
for w in all_words:
if not merged:
merged.append(w); continue
last = merged[-1]
overlap = min(last.end, w.end) - max(last.start, w.start)
if overlap > 0:
if w.score > last.score or (abs(w.score - last.score) < 1e-6 and (w.end - w.start) > (last.end - last.start)):
merged[-1] = w
else:
merged.append(w)
return merged
def assign_speakers_if_available(audio_path: Path, merged: List[Word], device: str, hf_token: Optional[str]) -> List[Dict]:
try:
import whisperx
if not hf_token:
hf_token = os.environ.get("HF_TOKEN")
diar = whisperx.diarize.DiarizationPipeline(use_auth_token=hf_token, device=device)
diar_segs = diar(str(audio_path))
out = []
for w in merged:
t = 0.5 * (w.start + w.end)
spk = None
for seg in diar_segs:
s = float(seg.get('start', 0)); e = float(seg.get('end', 0))
if s <= t <= e:
spk = seg.get('speaker'); break
out.append({"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": spk})
return out
except Exception as e:
print(" diarization skipped:", e)
return [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": None} for w in merged]
def write_srt(words: List[Dict], out_path: Path, max_chars: int = 60, max_gap: float = 0.6) -> None:
if not words:
out_path.write_text(""); return
entries = []
cur_text = []; cur_start = words[0]['start']; cur_end = words[0]['end']
for w in words:
if cur_text and (w['start'] - cur_end > max_gap or len(' '.join(cur_text) + ' ' + w['text']) > max_chars):
entries.append((cur_start, cur_end, ' '.join(cur_text)))
cur_text = [w['text']]; cur_start = w['start']; cur_end = w['end']
else:
cur_text.append(w['text']); cur_end = w['end']
if cur_text:
entries.append((cur_start, cur_end, ' '.join(cur_text)))
def fmt(t: float) -> str:
ms = int(round(t * 1000)); h = ms // 3600000; m = (ms % 3600000) // 60000; s = (ms % 60000) // 1000; ms2 = ms % 1000
return f"{h:02d}:{m:02d}:{s:02d},{ms2:03d}"
lines = []
for i, (s, e, txt) in enumerate(entries, 1):
lines.append(str(i)); lines.append(f"{fmt(s)} --> {fmt(e)}"); lines.append(txt.strip()); lines.append("")
out_path.write_text('\n'.join(lines))
# ---------------- main ----------------
def main():
ap = argparse.ArgumentParser(description="Two pass white to WhisperX with session join and CRDT merge")
ap.add_argument("--exts", nargs="*", default=["trm", "m4a", "mp3"], help="Extensions to process")
ap.add_argument("--recursive", action="store_true", help="Recurse into subdirectories")
# scan knobs
ap.add_argument("--threshold", type=float, default=0.70)
ap.add_argument("--ratio", type=float, default=0.80)
ap.add_argument("--white-time", type=float, default=0.95)
ap.add_argument("--file-threshold", type=float, default=0.88)
ap.add_argument("--file-ratio", type=float, default=0.80)
ap.add_argument("--flat-weight", type=float, default=0.65)
ap.add_argument("--slope-weight", type=float, default=0.35)
ap.add_argument("--even-weight", type=float, default=0.00)
ap.add_argument("--win-sec", type=float, default=1.0)
ap.add_argument("--hop-sec", type=float, default=0.5)
ap.add_argument("--warmup-sec", type=float, default=2.0)
ap.add_argument("--min-nonwhite-sec", type=float, default=1.5)
ap.add_argument("--min-white-sec", type=float, default=0.5)
ap.add_argument("--scan-workers", type=int, default=4, help="Parallel workers for scanning phase")
# sessions
ap.add_argument("--join-gap-sec", type=float, default=8.0, help="Join intervals up to this gap")
ap.add_argument("--min-session-sec", type=float, default=35.0, help="Try not to emit sessions shorter than this")
ap.add_argument("--overlap", type=float, default=0.25, help="Padding applied around intervals before join")
# extraction and whisper
ap.add_argument("--sr", type=int, default=0, help="Resample WAVs to this rate. 0 keeps source rate")
ap.add_argument("--outdir", type=str, default="whisper_out")
ap.add_argument("--keep-session-wavs", action="store_true")
ap.add_argument("--whisper-model", type=str, default="large-v2")
ap.add_argument("--device", type=str, default="cpu")
ap.add_argument("--compute-type", type=str, default="int8")
ap.add_argument("--diarize", action="store_true")
ap.add_argument("--hf-token", type=str, default="")
# new
ap.add_argument("--resume", action="store_true", help="Skip sessions that already have final outputs")
ap.add_argument("--heartbeat-sec", type=float, default=15.0, help="Heartbeat print interval during Whisper")
args = ap.parse_args()
_check_tool("ffprobe"); _check_tool("ffmpeg")
# gather files
root = Path.cwd()
files: List[Path] = []
for ext in args.exts:
glob = "**/*." + ext if args.recursive else "*." + ext
files.extend(root.glob(glob))
files = sorted(set(files))
if not files:
print("No input files found"); return
# pass 1: scan in parallel
print("Scanning for non white intervals..."); sys.stdout.flush()
jobs = [
(str(p), float(args.threshold), float(args.ratio), float(args.win_sec), float(args.hop_sec),
float(args.warmup_sec), float(args.min_white_sec), float(args.min_nonwhite_sec),
float(args.flat_weight), float(args.slope_weight), float(args.even_weight),
float(args.join_gap_sec), float(args.overlap))
for p in files
]
scans: List[FileScan] = []
with cf.ProcessPoolExecutor(max_workers=max(1, args.scan_workers)) as ex:
futs = {ex.submit(scan_one_worker, job): files[i] for i, job in enumerate(jobs)}
done = 0
for fu in cf.as_completed(futs):
done += 1
try:
_r = fu.result()
_start = dt.datetime.fromisoformat(_r["start"]) if _r.get("start") else None
scans.append(FileScan(path=Path(_r["path"]), start=_start, duration=float(_r["duration"]), intervals=[(float(a), float(b)) for a, b in _r.get("intervals", [])]))
except Exception as e:
print(" scan failed for", futs[fu].name, "-", e)
if done % 8 == 0 or done == len(futs):
print(f" scanned {done}/{len(futs)} files"); sys.stdout.flush()
scans = [s for s in scans if s.intervals]
if not scans:
print("All files are white. Nothing to transcribe"); return
sessions = plan_sessions(scans, join_gap=args.join_gap_sec, min_session=args.min_session_sec)
if not sessions:
print("No sessions found. Nothing to transcribe"); return
print(f"Planned {len(sessions)} sessions. Processing sequentially...")
outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
# load whisper once
print("Loading WhisperX models...")
import whisperx
model = whisperx.load_model(args.whisper_model, args.device, compute_type=args.compute_type)
model_a, meta = whisperx.load_align_model(language_code="en", device=args.device)
# process sessions
for i, sess in enumerate(sessions, 1):
if sess.start_dt and sess.end_dt:
label = sess.start_dt.strftime("%Y%m%d-%H%M%S") + "_to_" + sess.end_dt.strftime("%H%M%S")
else:
first = sess.pieces[0]
label = Path(first.file).stem + f"_{int(first.start):06d}"
sess_dir = outdir / f"session_{i:03d}_{label}"; sess_dir.mkdir(parents=True, exist_ok=True)
out_json = sess_dir / f"{label}_merged.json"
out_srt = sess_dir / f"{label}_merged.srt"
if args.resume and out_json.exists() and out_srt.exists():
print(f"\nSession {i}/{len(sessions)} {label} - resume skip (outputs exist)")
continue
print(f"\nSession {i}/{len(sessions)} {label}")
info0 = ffprobe_info(sess.pieces[0].file)
ch = int(info0.get("channels", 1)); sr0 = int(info0.get("sample_rate", 48000))
sr_out = args.sr if args.sr and args.sr > 0 else sr0
words_by_ch: Dict[int, List[Word]] = {}
for c in range(ch):
wav_path = sess_dir / f"session_{i:03d}_ch{c+1:02d}.wav"
print(f" ch {c+1:02d}/{ch:02d}: extracting...", end=" ")
try:
extract_session_channel_wav(sess, c, sr_out, wav_path)
print("ok. transcribing...")
except Exception as e:
print("extract failed:", e)
continue
try:
words = run_whisperx_on_wav_modelreused(wav_path, args.device, model, model_a, meta, args.heartbeat_sec)
words_by_ch[c] = words
preview = " ".join([w.text for w in sorted(words, key=lambda w: w.start)[:8]])
print(f" preview: {preview[:80]}")
except Exception as e:
print(" whisper failed:", e)
if not args.keep_session_wavs:
try:
wav_path.unlink()
except Exception:
pass
merged = crdt_merge_words(words_by_ch)
if args.diarize:
mixwav = sess_dir / f"session_{i:03d}_mix.wav"
try:
extract_session_channel_wav(sess, 0, sr_out, mixwav)
merged_dicts = assign_speakers_if_available(mixwav, merged, args.device, os.environ.get("HF_TOKEN") or "")
except Exception as e:
print(" diarization error:", e)
merged_dicts = [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": None} for w in merged]
finally:
if not args.keep_session_wavs:
try: mixwav.unlink()
except Exception: pass
else:
merged_dicts = [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "speaker": None} for w in merged]
out_json.write_text(json.dumps({
"session_label": label,
"files": [str(p.file) for p in sess.pieces],
"pieces": [{"file": str(p.file), "start": p.start, "end": p.end} for p in sess.pieces],
"words": merged_dicts,
}, indent=2))
write_srt(merged_dicts, out_srt)
print(f" wrote {out_json.name} and {out_srt.name}")
sys.stdout.flush()
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""
white2whisper_chain_v5.py
Two pass pipeline
1) Fast parallel scan to find non white intervals per file
2) Plan long sessions by joining nearby intervals within and across files
3) Process sessions sequentially
- Build one mono WAV per channel per session using concat lists with absolute paths
- Run WhisperX once per channel, print a heartbeat while it runs
- Save per channel transcripts to their own files for later debugging
- Optional final merge of channels
New in v5
- Per channel artifacts saved inside each session directory for every channel:
chXX/words.json, chXX/words.srt, chXX/transcript.txt
- Stable and deterministic CRDT merge, with a flag to disable merge
- Same non interactive ffmpeg and cleanup hardening as v4
- Resume support
"""
from __future__ import annotations
import argparse
import concurrent.futures as cf
import datetime as dt
import json
import math
import os
import re
import subprocess
import sys
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
# ---------------- utils ----------------
def _run_ok(cmd: Sequence[str]) -> None:
if cmd and Path(cmd[0]).name == "ffmpeg":
cmd = [cmd[0], "-y", "-nostdin"] + list(cmd[1:])
p = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.DEVNULL,
text=True,
)
if p.returncode != 0:
raise RuntimeError(
"command failed:\n" + " ".join(cmd) + "\n" + (p.stderr[-2000:] if p.stderr else "")
)
def _check_tool(name: str) -> None:
try:
subprocess.run([name, "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except Exception:
print(f"Error: {name} not found on PATH", file=sys.stderr)
sys.exit(1)
# ---------------- ffprobe ----------------
def ffprobe_info(input_path: Path) -> Dict:
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "a:0",
"-show_entries", "stream=index,channels,sample_rate,duration",
"-of", "json", str(input_path),
]
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
raise RuntimeError(f"ffprobe failed on {input_path}")
data = json.loads(res.stdout or "{}")
streams = data.get("streams", [])
if not streams:
raise RuntimeError(f"No audio stream found in {input_path}")
return streams[0]
# ---------------- DSP core ----------------
def spectral_flatness_vector(psd: np.ndarray, eps: float = 1e-20) -> float:
psd_clip = np.maximum(psd, eps)
gmean = np.exp(np.mean(np.log(psd_clip)))
amean = float(np.mean(psd_clip))
return float(gmean / amean) if amean > 0 else 0.0
def rfft_psd_window(x: np.ndarray, win: np.ndarray, sr: int):
X = np.fft.rfft(x * win)
Pxx = (np.abs(X) ** 2) / (np.sum(win**2) * sr)
freqs = np.fft.rfftfreq(len(x), d=1.0 / sr)
return freqs.astype(np.float32), Pxx.astype(np.float32)
def band_evenness(freqs: np.ndarray, Pxx: np.ndarray, f_lo: float, f_hi: float, k: int = 12) -> float:
f_edges = np.geomspace(max(1.0, f_lo), max(f_lo * 1.0001, f_hi), k + 1)
vals = []
for i in range(k):
m = (freqs >= f_edges[i]) & (freqs < f_edges[i+1])
if not np.any(m):
continue
vals.append(float(np.mean(Pxx[m])))
if len(vals) < 3:
return 0.0
arr = np.array(vals)
arr[arr <= 0] = np.min(arr[arr > 0]) if np.any(arr > 0) else 1e-20
arr /= np.mean(arr)
std = float(np.std(arr))
return float(1.0 / (1.0 + 3.0 * std))
def slope_in_band(freqs: np.ndarray, Pxx: np.ndarray, mask: np.ndarray) -> float:
f = freqs[mask]
p = Pxx[mask]
m = (f > 0) & np.isfinite(p) & (p > 0)
if np.count_nonzero(m) < 16:
return 0.0
logf = np.log10(f[m]); logp = np.log10(p[m])
slope, _ = np.polyfit(logf, logp, 1)
return float(slope)
def open_ffmpeg_pcm_pipe(inpath: Path, stream_index: int, sr: int, ch: int) -> subprocess.Popen:
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error", "-nostdin",
"-err_detect", "ignore_err", "-fflags", "+discardcorrupt",
"-analyzeduration", "100M", "-probesize", "100M",
"-i", str(inpath),
"-map", f"0:a:{stream_index}", "-vn", "-sn",
"-ac", str(ch), "-ar", str(sr),
"-f", "s16le", "-acodec", "pcm_s16le", "pipe:1",
]
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, bufsize=1024*1024)
def remove_short_runs(flags: np.ndarray, value: int, min_len: int) -> None:
n = flags.size
i = 0
while i < n:
if flags[i] == value:
j = i
while j < n and flags[j] == value:
j += 1
if (j - i) < min_len:
flags[i:j] = 1 - value
i = j
else:
i += 1
def flags_to_intervals(flags: List[bool], hop: int, win: int, sr: int):
whites = []
nonwhites = []
if not flags:
return whites, nonwhites
cur_val = flags[0]
cur_start = 0
for i in range(1, len(flags)):
if flags[i] != cur_val:
start_t = cur_start * hop / sr
end_t = (i * hop + win) / sr
(whites if cur_val else nonwhites).append((start_t, end_t))
cur_val = flags[i]
cur_start = i
start_t = cur_start * hop / sr
end_t = ((len(flags) - 1) * hop + win) / sr
(whites if cur_val else nonwhites).append((start_t, end_t))
return whites, nonwhites
def intervals_total(intervals: List[Tuple[float, float]]) -> float:
if not intervals:
return 0.0
ints = sorted(intervals)
total = 0.0
cur_s, cur_e = ints[0]
for s, e in ints[1:]:
if s <= cur_e:
cur_e = max(cur_e, e)
else:
total += max(0.0, cur_e - cur_s)
cur_s, cur_e = s, e
total += max(0.0, cur_e - cur_s)
return total
# ---------------- scanning engine ----------------
def segment_streaming_from_pipe(
inpath: Path,
stream_index: int,
sr: int,
ch: int,
threshold: float,
ratio: float,
win_s: float,
hop_s: float,
warmup_sec: float,
min_white_sec: float,
min_nonwhite_sec: float,
flat_w: float,
slope_w: float,
even_w: float,
) -> Dict:
proc = open_ffmpeg_pcm_pipe(inpath, stream_index, sr, ch)
assert proc.stdout is not None
win = max(256, int(sr * win_s))
hop = max(1, int(sr * hop_s))
bytes_per_sample = 2
bytes_per_frame = bytes_per_sample * ch
read_size_bytes = hop * bytes_per_frame
buf = np.zeros((win, ch), dtype=np.float32)
filled = 0
hann = np.hanning(win).astype(np.float32)
warm_windows = max(8, int(round(warmup_sec / max(1e-6, hop_s))))
psd_accum = None
freqs_ref = None
f_lo = 350.0
f_hi = sr * 0.65
mask = None
flags: List[bool] = []
total_frames = 0
while True:
raw = proc.stdout.read(read_size_bytes)
if not raw:
break
n_samples = len(raw) // bytes_per_sample
n_frames = n_samples // ch
if n_frames == 0:
continue
total_frames += n_frames
arr = np.frombuffer(raw[: n_frames * bytes_per_frame], dtype=np.int16)
arr = arr.reshape((-1, ch)).astype(np.float32) / 32768.0
step = arr.shape[0]
if step >= win:
buf[...] = arr[-win:]
filled = win
else:
if filled + step <= win:
buf[filled:filled+step, :] = arr
filled += step
else:
overflow = filled + step - win
buf[:win-overflow, :] = buf[overflow:, :]
buf[win-step:, :] = arr
filled = win
if filled < win:
continue
if warm_windows > 0:
Psum = None
for c in range(min(ch, 4)):
fr, P = rfft_psd_window(buf[:, c], hann, sr)
if Psum is None:
Psum = P.astype(np.float64)
else:
Psum += P
freqs_ref = fr
if psd_accum is None:
psd_accum = Psum
else:
psd_accum += Psum
warm_windows -= 1
if warm_windows == 0:
power = psd_accum / max(1, int(round(warmup_sec / max(1e-6, hop_s))))
c = np.cumsum(power)
if c[-1] > 0:
c /= c[-1]
low_idx = int(np.searchsorted(c, 0.05))
high_idx = int(np.searchsorted(c, 0.95))
f_lo = max(100.0, float(freqs_ref[max(1, low_idx)]))
f_hi = min(0.95 * float(freqs_ref.max()), float(freqs_ref[min(len(freqs_ref)-1, high_idx)]))
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
if mask is None and freqs_ref is not None:
mask = (freqs_ref >= f_lo) & (freqs_ref <= f_hi)
ch_likely = 0
for c in range(ch):
xw = buf[:, c]
if np.max(np.abs(xw)) < 1e-6:
continue
fr, P = rfft_psd_window(xw, hann, sr)
m = mask if mask is not None else (fr >= f_lo) & (fr <= f_hi)
sf = spectral_flatness_vector(P[m])
slope = slope_in_band(fr, P, m)
even = band_evenness(fr, P, f_lo, f_hi)
slope_score = math.exp(-abs(slope) / 0.25)
prob = float(flat_w * max(0.0, min(1.0, sf)) + 0.0 * slope_score + 0.0 * even)
if prob >= threshold and abs(slope) < 0.3:
ch_likely += 1
frac = ch_likely / float(ch)
flags.append(frac >= ratio)
proc.stdout.close(); proc.wait()
f_arr = np.array(flags, dtype=np.int8)
if f_arr.size:
min_white_frames = max(1, int(math.ceil(min_white_sec / hop_s)))
min_nonwhite_frames = max(1, int(math.ceil(min_nonwhite_sec / hop_s)))
remove_short_runs(f_arr, value=0, min_len=min_nonwhite_frames)
remove_short_runs(f_arr, value=1, min_len=min_white_frames)
whites, nonwhites = flags_to_intervals(list(f_arr.astype(bool)), hop, win, sr)
total_dur = total_frames / float(sr)
return {"duration": total_dur, "nonwhite_intervals": nonwhites}
# ---------------- filenames and sessions ----------------
_TS_RE = re.compile(r".*_(\d{8})-(\d{4})_", re.IGNORECASE)
def file_start_time(p: Path) -> Optional[dt.datetime]:
m = _TS_RE.match(p.name)
if not m:
return None
ymd, hm = m.group(1), m.group(2)
try:
return dt.datetime.strptime(ymd + hm, "%Y%m%d%H%M")
except Exception:
return None
@dataclass
class FileScan:
path: Path
start: Optional[dt.datetime]
duration: float
intervals: List[Tuple[float, float]]
@dataclass
class SessionPiece:
file: Path
start: float
end: float
@dataclass
class Session:
idx: int
start_dt: Optional[dt.datetime]
end_dt: Optional[dt.datetime]
pieces: List[SessionPiece]
def merge_intervals(iv: List[Tuple[float, float]], join_gap: float) -> List[Tuple[float, float]]:
if not iv:
return []
iv = sorted(iv)
out = [list(iv[0])]
for s, e in iv[1:]:
if s <= out[-1][1] + join_gap:
out[-1][1] = max(out[-1][1], e)
else:
out.append([s, e])
return [(float(s), float(e)) for s, e in out]
def plan_sessions(scans: List[FileScan], join_gap: float, min_session: float) -> List[Session]:
scans_sorted = sorted(scans, key=lambda x: (x.start or dt.datetime.min, x.path.name))
sessions: List[Session] = []
cur_pieces: List[SessionPiece] = []
cur_start: Optional[dt.datetime] = None
cur_end: Optional[dt.datetime] = None
def close_current(idx: int):
nonlocal cur_pieces, cur_start, cur_end
if not cur_pieces:
return None
sess = Session(idx=idx, start_dt=cur_start, end_dt=cur_end, pieces=cur_pieces)
cur_pieces = []
cur_start = None
cur_end = None
return sess
idx = 1
for fs in scans_sorted:
merged = merge_intervals(fs.intervals, join_gap)
if not merged:
continue
for s_rel, e_rel in merged:
s_abs = (fs.start or dt.datetime.min) + dt.timedelta(seconds=s_rel) if fs.start else None
e_abs = (fs.start or dt.datetime.min) + dt.timedelta(seconds=e_rel) if fs.start else None
if not cur_pieces:
cur_pieces = [SessionPiece(fs.path, s_rel, e_rel)]
cur_start = s_abs
cur_end = e_abs
else:
if (cur_end and e_abs and s_abs and (s_abs - cur_end).total_seconds() <= join_gap) or (not fs.start):
cur_pieces.append(SessionPiece(fs.path, s_rel, e_rel))
if e_abs and (not cur_end or e_abs > cur_end):
cur_end = e_abs
else:
sess = close_current(idx)
if sess:
total_len = sum(p.end - p.start for p in sess.pieces)
if total_len < min_session and sessions:
sessions[-1].pieces.extend(sess.pieces)
sessions[-1].end_dt = sess.end_dt
else:
sessions.append(sess)
idx += 1
cur_pieces = [SessionPiece(fs.path, s_rel, e_rel)]
cur_start = s_abs
cur_end = e_abs
last = close_current(idx)
if last:
total_len = sum(p.end - p.start for p in last.pieces)
if total_len < min_session and sessions:
sessions[-1].pieces.extend(last.pieces)
sessions[-1].end_dt = last.end_dt
else:
sessions.append(last)
return sessions
# ---------------- worker for scan ----------------
def scan_one_worker(job):
inpath_str, threshold, ratio, win_sec, hop_sec, warmup_sec, min_white_sec, min_nonwhite_sec, flat_w, slope_w, even_w, join_gap_sec, overlap = job
p = Path(inpath_str)
info = ffprobe_info(p)
stream_index = int(info.get("index", 0)); sr = int(info.get("sample_rate", 48000)); ch = int(info.get("channels", 1))
res = segment_streaming_from_pipe(p, stream_index, sr, ch, threshold, ratio, win_sec, hop_sec, warmup_sec, min_white_sec, min_nonwhite_sec, flat_w, slope_w, even_w)
iv = res.get("nonwhite_intervals", [])
dur = float(res.get("duration", 0.0))
padded = []
for s, e in iv:
s2 = max(0.0, s - overlap); e2 = min(dur, e + overlap)
if e2 > s2:
padded.append((float(s2), float(e2)))
merged = merge_intervals(padded, float(join_gap_sec))
start_dt = file_start_time(p)
start_iso = start_dt.isoformat() if start_dt else ""
return {"path": str(p), "start": start_iso, "duration": dur, "intervals": merged}
# ---------------- audio extraction ----------------
def extract_session_channel_wav(session: Session, ch_idx: int, sr_out: Optional[int], out_wav: Path) -> None:
part_files: List[Path] = []
list_txt = out_wav.with_suffix('.list.txt')
out_wav.parent.mkdir(parents=True, exist_ok=True)
try:
if out_wav.exists():
out_wav.unlink()
if list_txt.exists():
list_txt.unlink()
for old in out_wav.parent.glob(f"{out_wav.stem}_part*.wav"):
try: old.unlink()
except Exception: pass
except Exception:
pass
try:
for i, piece in enumerate(session.pieces, 1):
part = out_wav.with_suffix("").parent / f"{out_wav.stem}_part{i:03d}.wav"
part_files.append(part)
pan = f"pan=mono|c0=c{ch_idx}"
cmd = [
"ffmpeg", "-hide_banner", "-nostats", "-loglevel", "error",
"-ss", f"{piece.start:.3f}", "-t", f"{max(0.0, piece.end-piece.start):.3f}",
"-i", str(piece.file),
"-map", "0:a:0", "-vn", "-sn", "-af", pan,
"-ac", "1", "-acodec", "pcm_s16le",
]
if sr_out and sr_out > 0:
cmd += ["-ar", str(sr_out)]
cmd += [str(part)]
_run_ok(cmd)
with open(list_txt, 'w', encoding='utf-8') as f:
for pf in part_files:
abs_p = pf.resolve()
esc = str(abs_p).replace("'", "'\\''")
f.write("file '{}'\n".format(esc))
try:
_run_ok(["ffmpeg", "-f", "concat", "-safe", "0", "-i", str(list_txt), "-c", "copy", str(out_wav)])
except Exception:
cmd = ["ffmpeg", "-f", "concat", "-safe", "0", "-i", str(list_txt), "-acodec", "pcm_s16le", "-ac", "1"]
if sr_out and sr_out > 0:
cmd += ["-ar", str(sr_out)]
cmd += [str(out_wav)]
_run_ok(cmd)
finally:
try:
if list_txt.exists():
list_txt.unlink()
except Exception:
pass
for pf in part_files:
try:
if pf.exists(): pf.unlink()
except Exception:
pass
# ---------------- whisper and merge ----------------
@dataclass
class Word:
start: float
end: float
text: str
score: float
ch: int
class Heartbeat:
def __init__(self, label: str, interval: float = 15.0):
self.label = label
self.interval = interval
self._stop = threading.Event()
self._t = threading.Thread(target=self._run, daemon=True)
self._t0 = time.time()
def __enter__(self):
self._t.start(); return self
def __exit__(self, exc_type, exc, tb):
self._stop.set(); self._t.join(timeout=1.0)
def _run(self):
while not self._stop.wait(self.interval):
dt_s = int(time.time() - self._t0)
print(f" working... {self.label} t+{dt_s}s"); sys.stdout.flush()
def run_whisperx_on_wav_modelreused(wav_path: Path, device: str, model, align_model, align_meta, hb_sec: float) -> List[Word]:
import whisperx
audio = whisperx.load_audio(str(wav_path))
with Heartbeat(f"whisper {wav_path.name}", hb_sec):
res = model.transcribe(audio)
aligned = whisperx.align(res.get("segments", []), align_model, align_meta, audio, device, return_char_alignments=False)
words: List[Word] = []
for seg in aligned.get("segments", []):
for w in seg.get("words", []) or []:
s = float(w.get("start", seg.get("start", 0.0)) or 0.0)
e = float(w.get("end", seg.get("end", s)))
txt = str(w.get("word", "")).strip()
sc = float(w.get("score", 0.0))
if txt:
words.append(Word(s, e, txt, sc, ch=-1))
return words
def crdt_merge_words(words_by_channel: Dict[int, List[Word]]) -> List[Word]:
all_words: List[Word] = []
for ch, lst in words_by_channel.items():
for w in lst:
all_words.append(Word(w.start, w.end, w.text, w.score, ch))
# stable sort to reduce jitter
all_words.sort(key=lambda w: (w.start, w.end, w.ch))
merged: List[Word] = []
for w in all_words:
if not merged:
merged.append(w); continue
last = merged[-1]
overlap = min(last.end, w.end) - max(last.start, w.start)
if overlap > 0:
# prefer higher score, break ties by longer duration then lower channel id
if (w.score > last.score) or (abs(w.score - last.score) < 1e-6 and (w.end - w.start) > (last.end - last.start)) or (abs(w.score - last.score) < 1e-6 and abs((w.end - w.start) - (last.end - last.start)) < 1e-6 and w.ch < last.ch):
merged[-1] = w
else:
merged.append(w)
return merged
def write_srt(words: List[Dict], out_path: Path, max_chars: int = 60, max_gap: float = 0.6) -> None:
if not words:
out_path.write_text(""); return
entries = []
cur_text = []; cur_start = words[0]['start']; cur_end = words[0]['end']
for w in words:
if cur_text and (w['start'] - cur_end > max_gap or len(' '.join(cur_text) + ' ' + w['text']) > max_chars):
entries.append((cur_start, cur_end, ' '.join(cur_text)))
cur_text = [w['text']]; cur_start = w['start']; cur_end = w['end']
else:
cur_text.append(w['text']); cur_end = w['end']
if cur_text:
entries.append((cur_start, cur_end, ' '.join(cur_text)))
def fmt(t: float) -> str:
ms = int(round(t * 1000)); h = ms // 3600000; m = (ms % 3600000) // 60000; s = (ms % 60000) // 1000; ms2 = ms % 1000
return f"{h:02d}:{m:02d}:{s:02d},{ms2:03d}"
lines = []
for i, (s, e, txt) in enumerate(entries, 1):
lines.append(str(i)); lines.append(f"{fmt(s)} --> {fmt(e)}"); lines.append(txt.strip()); lines.append("")
out_path.write_text('\n'.join(lines))
# ---------------- main ----------------
def main():
ap = argparse.ArgumentParser(description="Two pass white to WhisperX with session join and optional merge")
ap.add_argument("--exts", nargs="*", default=["trm", "m4a", "mp3"], help="Extensions to process")
ap.add_argument("--recursive", action="store_true", help="Recurse into subdirectories")
# scan knobs
ap.add_argument("--threshold", type=float, default=0.70)
ap.add_argument("--ratio", type=float, default=0.80)
ap.add_argument("--white-time", type=float, default=0.95)
ap.add_argument("--file-threshold", type=float, default=0.88)
ap.add_argument("--file-ratio", type=float, default=0.80)
ap.add_argument("--flat-weight", type=float, default=0.65)
ap.add_argument("--slope-weight", type=float, default=0.35)
ap.add_argument("--even-weight", type=float, default=0.00)
ap.add_argument("--win-sec", type=float, default=1.0)
ap.add_argument("--hop-sec", type=float, default=0.5)
ap.add_argument("--warmup-sec", type=float, default=2.0)
ap.add_argument("--min-nonwhite-sec", type=float, default=1.5)
ap.add_argument("--min-white-sec", type=float, default=0.5)
ap.add_argument("--scan-workers", type=int, default=4, help="Parallel workers for scanning phase")
# sessions
ap.add_argument("--join-gap-sec", type=float, default=8.0)
ap.add_argument("--min-session-sec", type=float, default=35.0)
ap.add_argument("--overlap", type=float, default=0.25)
# extraction and whisper
ap.add_argument("--sr", type=int, default=0)
ap.add_argument("--outdir", type=str, default="whisper_out")
ap.add_argument("--keep-session-wavs", action="store_true")
ap.add_argument("--whisper-model", type=str, default="large-v2")
ap.add_argument("--device", type=str, default="cpu")
ap.add_argument("--compute-type", type=str, default="int8")
ap.add_argument("--diarize", action="store_true")
ap.add_argument("--hf-token", type=str, default="")
# new options
ap.add_argument("--resume", action="store_true", help="Skip sessions that already have final outputs")
ap.add_argument("--heartbeat-sec", type=float, default=15.0, help="Heartbeat print interval during Whisper")
ap.add_argument("--merge", type=str, default="crdt", choices=["crdt", "none"], help="Final merge strategy")
args = ap.parse_args()
_check_tool("ffprobe"); _check_tool("ffmpeg")
# gather files
root = Path.cwd()
files: List[Path] = []
for ext in args.exts:
glob = "**/*." + ext if args.recursive else "*." + ext
files.extend(root.glob(glob))
files = sorted(set(files))
if not files:
print("No input files found"); return
# pass 1: scan in parallel
print("Scanning for non white intervals..."); sys.stdout.flush()
jobs = [
(str(p), float(args.threshold), float(args.ratio), float(args.win_sec), float(args.hop_sec),
float(args.warmup_sec), float(args.min_white_sec), float(args.min_nonwhite_sec),
float(args.flat_weight), float(args.slope_weight), float(args.even_weight),
float(args.join_gap_sec), float(args.overlap))
for p in files
]
scans: List[FileScan] = []
with cf.ProcessPoolExecutor(max_workers=max(1, args.scan_workers)) as ex:
futs = {ex.submit(scan_one_worker, job): files[i] for i, job in enumerate(jobs)}
done = 0
for fu in cf.as_completed(futs):
done += 1
try:
_r = fu.result()
_start = dt.datetime.fromisoformat(_r["start"]) if _r.get("start") else None
scans.append(FileScan(path=Path(_r["path"]), start=_start, duration=float(_r["duration"]), intervals=[(float(a), float(b)) for a, b in _r.get("intervals", [])]))
except Exception as e:
print(" scan failed for", futs[fu].name, "-", e)
if done % 8 == 0 or done == len(futs):
print(f" scanned {done}/{len(futs)} files"); sys.stdout.flush()
scans = [s for s in scans if s.intervals]
if not scans:
print("All files are white. Nothing to transcribe"); return
sessions = plan_sessions(scans, join_gap=args.join_gap_sec, min_session=args.min_session_sec)
if not sessions:
print("No sessions found. Nothing to transcribe"); return
print(f"Planned {len(sessions)} sessions. Processing sequentially...")
outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
# load whisper once
print("Loading WhisperX models...")
import whisperx
model = whisperx.load_model(args.whisper_model, args.device, compute_type=args.compute_type)
model_a, meta = whisperx.load_align_model(language_code="en", device=args.device)
# process sessions
for i, sess in enumerate(sessions, 1):
if sess.start_dt and sess.end_dt:
label = sess.start_dt.strftime("%Y%m%d-%H%M%S") + "_to_" + sess.end_dt.strftime("%H%M%S")
else:
first = sess.pieces[0]
label = Path(first.file).stem + f"_{int(first.start):06d}"
sess_dir = outdir / f"session_{i:03d}_{label}"; sess_dir.mkdir(parents=True, exist_ok=True)
out_json = sess_dir / f"{label}_merged.json"
out_srt = sess_dir / f"{label}_merged.srt"
if args.resume and out_json.exists() and out_srt.exists():
print(f"\nSession {i}/{len(sessions)} {label} - resume skip (outputs exist)")
continue
print(f"\nSession {i}/{len(sessions)} {label}")
info0 = ffprobe_info(sess.pieces[0].file)
ch = int(info0.get("channels", 1)); sr0 = int(info0.get("sample_rate", 48000))
sr_out = args.sr if args.sr and args.sr > 0 else sr0
words_by_ch: Dict[int, List[Word]] = {}
for c in range(ch):
wav_path = sess_dir / f"session_{i:03d}_ch{c+1:02d}.wav"
ch_dir = sess_dir / f"ch{c+1:02d}"; ch_dir.mkdir(exist_ok=True)
print(f" ch {c+1:02d}/{ch:02d}: extracting...", end=" ")
try:
extract_session_channel_wav(sess, c, sr_out, wav_path)
print("ok. transcribing...")
except Exception as e:
print("extract failed:", e)
continue
try:
words = run_whisperx_on_wav_modelreused(wav_path, args.device, model, model_a, meta, args.heartbeat_sec)
words_by_ch[c] = words
preview = " ".join([w.text for w in sorted(words, key=lambda w: w.start)[:8]])
print(f" preview: {preview[:80]}")
# save per channel artifacts
words_dicts = [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "channel": c+1} for w in sorted(words, key=lambda w: (w.start, w.end))]
(ch_dir / "words.json").write_text(json.dumps({
"session_label": label,
"channel": c+1,
"sr": sr_out,
"words": words_dicts,
}, indent=2))
# srt and text
write_srt(words_dicts, ch_dir / "words.srt")
(ch_dir / "transcript.txt").write_text(" ".join([w["text"] for w in words_dicts]))
except Exception as e:
print(" whisper failed:", e)
if not args.keep_session_wavs:
try: wav_path.unlink()
except Exception: pass
if args.merge == "none":
index = {
"session_label": label,
"files": [str(p.file) for p in sess.pieces],
"channels": [str((sess_dir / f"ch{c+1:02d}").resolve()) for c in range(ch)],
"note": "Per channel outputs saved. No merged transcript created in --merge none mode.",
}
(sess_dir / f"{label}_index.json").write_text(json.dumps(index, indent=2))
print(" merge skipped (--merge none). Per channel files are ready.")
continue
# CRDT merge
merged_words = crdt_merge_words(words_by_ch)
merged_dicts = [{"start": w.start, "end": w.end, "text": w.text, "score": w.score, "channel": w.ch + 1} for w in merged_words]
out_json.write_text(json.dumps({
"session_label": label,
"files": [str(p.file) for p in sess.pieces],
"pieces": [{"file": str(p.file), "start": p.start, "end": p.end} for p in sess.pieces],
"words": merged_dicts,
}, indent=2))
write_srt(merged_dicts, out_srt)
print(f" wrote {out_json.name} and {out_srt.name}")
sys.stdout.flush()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment