Skip to content

Instantly share code, notes, and snippets.

@alexpaden
Created March 10, 2025 02:59
Show Gist options
  • Save alexpaden/b99668307e6e16c18e5ce581c8d719b8 to your computer and use it in GitHub Desktop.
Save alexpaden/b99668307e6e16c18e5ce581c8d719b8 to your computer and use it in GitHub Desktop.
"""
Batch process embedding generation for casts table.
Uses async prefetching and parallel processing with TorchScript optimization and float16 precision on MPS.
Stores embeddings as int8 vectors for efficiency.
"""
# Configuration settings
BATCH_SIZE = 256 # Embedding batch size
BATCH_SIZE_ROWS = 200000 # Number of rows to process per instance
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
MAX_SEQ_LENGTH = 256 # Maximum sequence length for tokenization
CHUNK_SIZE = 20000 # Rows to update in one transaction chunk
import os
import warnings
import sys
import psutil
import time
from pathlib import Path
import subprocess
import asyncio
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass, field
import numpy as np
from dotenv import load_dotenv
import torch
import gc
import multiprocessing
from transformers import AutoTokenizer, AutoModel
from src.db.connect import db
# Load environment variables
load_dotenv()
warnings.filterwarnings("ignore", category=FutureWarning)
@dataclass
class InferenceMetrics:
"""Track detailed inference metrics."""
tokenization_time: float
model_forward_time: float
pooling_time: float
quantization_time: float
total_time: float
batch_size: int
texts_per_second: float
cosine_sim: float
# Configure warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", message=".*_is_quantized_training_enabled.*")
warnings.filterwarnings(
"ignore",
message=r".*loss_type=None.*Unrecognised.*", # partial text from the warning
category=UserWarning
)
def quantize_worker(queue, shared_dict, shared_event):
"""Worker process for quantizing embeddings."""
while True:
item = queue.get()
if item is None: # Poison pill
break
embeddings = item
# Quantize on CPU
emb_cpu = embeddings.float() # Already on CPU
global_max_abs = emb_cpu.abs().max()
global_scale = global_max_abs / 127.0
quantized = (emb_cpu / global_scale).round().clamp(-128, 127).to(torch.int8)
# Store results
shared_dict['quantized'] = quantized.numpy()
shared_dict['scale'] = global_scale.item()
shared_event.set()
class OptimizedEmbeddingModel:
def __init__(self, batch_size: int, quiet: bool = False):
self.device = torch.device("mps")
self.batch_size = batch_size
self.quiet = quiet
self.model_name = MODEL_NAME
if not quiet:
print(f"\nInitializing model with batch_size={batch_size} on MPS...")
self._initialize_model()
def _initialize_model(self):
print("\nInitializing model...")
init_start = time.time()
torch.set_num_threads(multiprocessing.cpu_count())
gc.collect()
# Loading tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Loading base model
base_model = AutoModel.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
# Moving model to MPS
base_model.to(self.device)
base_model.eval()
# Preparing sample inputs
sample_texts = ["This is a longer initialization text that will ensure adequate buffer sizes"] * 32
inputs = self.tokenizer(
sample_texts,
padding=True,
truncation=True,
max_length=MAX_SEQ_LENGTH,
return_tensors="pt"
).to(self.device)
# Tracing and optimizing model
with torch.inference_mode():
traced_model = torch.jit.trace(
base_model,
(inputs['input_ids'], inputs['attention_mask']),
strict=False
)
self.model = torch.jit.optimize_for_inference(traced_model)
# Initialize buffers
self.input_buffers = {
'input_ids': torch.zeros((self.batch_size, MAX_SEQ_LENGTH), dtype=torch.long, device=self.device),
'attention_mask': torch.zeros((self.batch_size, MAX_SEQ_LENGTH), dtype=torch.long, device=self.device)
}
# Warmup inference
with torch.inference_mode():
outputs = self.model(inputs['input_ids'], inputs['attention_mask'])
embeddings = self._mean_pooling(outputs, inputs['attention_mask'])
self.output_buffer = torch.zeros(
(self.batch_size, embeddings.shape[1]),
dtype=torch.float16,
device=self.device
)
# Cleanup
del base_model, traced_model, outputs, embeddings
gc.collect()
# Warmup encode
warmup_texts = ["Warm-up sentence"] * 8
_ = self.encode(warmup_texts)
print(f"✨ Model initialization completed in {time.time() - init_start:.1f}s")
def _mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output['last_hidden_state']
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
normalized = sum_embeddings / sum_mask
return torch.nn.functional.normalize(normalized, p=2, dim=1)
def encode(self, texts: List[str]) -> Tuple[np.ndarray, float, InferenceMetrics]:
"""Single-process inline quantization (no external worker)."""
start_time = time.time()
# 1. Tokenize
tokenize_start = time.time()
inputs = self.tokenizer(
texts, padding=True, truncation=True, max_length=MAX_SEQ_LENGTH, return_tensors="pt"
)
tokenize_time = time.time() - tokenize_start
# Move to MPS
for k, v in inputs.items():
if k in self.input_buffers:
self.input_buffers[k][:v.size(0), :v.size(1)] = v.to(self.device)
inputs[k] = self.input_buffers[k][:v.size(0), :v.size(1)]
# 2. Forward
with torch.inference_mode():
forward_start = time.time()
outputs = self.model(inputs['input_ids'], inputs['attention_mask'])
forward_time = time.time() - forward_start
# 3. Pooling
pool_start = time.time()
embeddings = self._mean_pooling(outputs, inputs['attention_mask'])
self.output_buffer[:embeddings.size(0)] = embeddings
float16_embeddings = self.output_buffer[:embeddings.size(0)]
pool_time = time.time() - pool_start
# 4. Quantize inline
quantize_start = time.time()
emb_cpu = float16_embeddings.detach().cpu().float()
global_max_abs = emb_cpu.abs().max()
global_scale = global_max_abs / 127.0
quantized_cpu = (emb_cpu / global_scale).round().clamp(-128, 127).to(torch.int8)
sim = float(
torch.nn.functional.cosine_similarity(
emb_cpu, (quantized_cpu.float() * global_scale), dim=1
).mean().item()
)
quantize_time = time.time() - quantize_start
# 5. Metrics
total_time = time.time() - start_time
metrics = InferenceMetrics(
tokenization_time=tokenize_time,
model_forward_time=forward_time,
pooling_time=pool_time,
quantization_time=quantize_time,
total_time=total_time,
batch_size=len(texts),
texts_per_second=len(texts)/total_time,
cosine_sim=sim
)
return quantized_cpu.numpy(), sim, metrics
@dataclass
class ProcessingStats:
"""Track processing statistics."""
start_time: float = field(default_factory=time.time)
total_processed: int = 0
total_remaining: int = 0
last_log_time: float = field(default_factory=time.time)
last_processed: int = 0
def get_recent_tps(self) -> float:
"""Calculate recent transactions per second."""
current_time = time.time()
time_since_last = current_time - self.last_log_time
if time_since_last > 0:
return (self.total_processed - self.last_processed) / time_since_last
return 0.0
def get_eta_minutes(self) -> float:
"""Calculate estimated time remaining in minutes."""
current_time = time.time()
total_time = current_time - self.start_time
if self.total_processed > 0 and total_time > 0:
overall_tps = self.total_processed / total_time
if overall_tps > 0:
return (self.total_remaining / overall_tps) / 60
return 0.0
def update(self, processed: int):
"""Update stats with newly processed rows."""
self.total_processed += processed
self.total_remaining = max(0, self.total_remaining - processed)
# Log every 5 seconds
current_time = time.time()
time_since_last = current_time - self.last_log_time
if time_since_last >= 5:
total_time = current_time - self.start_time
recent_tps = self.get_recent_tps()
overall_tps = self.total_processed / total_time if total_time > 0 else 0
timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
print(f"\nProgress update [{timestamp}]:")
print(f"Recent TPS: {recent_tps:,.1f}")
print(f"Overall TPS: {overall_tps:,.1f}")
print(f"Processed: {self.total_processed:,} rows")
print(f"Remaining: {self.total_remaining:,} rows")
if self.total_remaining > 0 and overall_tps > 0:
print(f"ETA: {self.get_eta_minutes():.1f} minutes")
self.last_log_time = current_time
self.last_processed = self.total_processed
def get_process_memory(pid):
"""Get memory usage for a process."""
try:
process = psutil.Process(pid)
return process.memory_info().rss / (1024 * 1024) # Convert to MB
except (psutil.NoSuchProcess, psutil.AccessDenied):
return 0
def get_gpu_memory():
"""Get MPS memory usage."""
try:
result = subprocess.run(['ps', '-o', 'rss=', '-p', str(os.getpid())],
capture_output=True, text=True)
return int(result.stdout.strip()) / 1024 # Convert KB to MB
except:
return 0
async def get_unprocessed_estimate(pool) -> int:
"""Get current count of unprocessed rows."""
async with pool.acquire() as conn:
await conn.execute("SET statement_timeout = '60s'")
result = await conn.fetchval("SELECT get_unprocessed_count()")
return int(result or 0)
@dataclass
class PrefetchMetrics:
"""Track prefetch performance metrics."""
fetch_start: float = 0.0
fetch_end: float = 0.0
queue_wait_start: float = 0.0
queue_depth: int = 0
rows_fetched: int = 0
@property
def fetch_time(self) -> float:
return self.fetch_end - self.fetch_start if self.fetch_end > 0 else 0.0
def log(self):
print(f"\nPrefetch Metrics:")
print(f" Fetch Time: {self.fetch_time:.1f}s")
print(f" Queue Depth: {self.queue_depth}")
print(f" Rows Fetched: {self.rows_fetched:,}")
async def fetch_next_batch(pool, limit: int = 100) -> Tuple[List[Dict[str, Any]], float]:
"""Fetch next batch of unprocessed casts."""
fetch_start = time.time()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute("SET statement_timeout = '60s'")
select_start = time.time()
rows = await conn.fetch("""
WITH selected AS (
SELECT id, text
FROM public.casts
WHERE embedding384 IS NULL
AND text IS NOT NULL
AND length(trim(text)) > 0
AND (embedding384_updated_at IS NULL OR embedding384_updated_at < NOW() - interval '1 hour')
ORDER BY RANDOM() -- Changed from ORDER BY id
LIMIT $1
FOR UPDATE SKIP LOCKED
)
UPDATE casts c
SET embedding384_updated_at = NOW()
FROM selected s
WHERE c.id = s.id
RETURNING c.id, c.text
""", limit)
select_time = time.time() - select_start
if not rows:
return [], time.time() - fetch_start
# Log detailed fetch timing
total_time = time.time() - fetch_start
print(f"\nFetch Timing:")
print(f" Query Time: {select_time:.3f}s")
print(f" Total Time: {total_time:.3f}s")
print(f" Rows: {len(rows):,}")
return [dict(row) for row in rows], total_time
class BatchProcessor:
"""Simple batch processor without prefetching."""
def __init__(self, pool, batch_size: int):
self.pool = pool
self.batch_size = batch_size
self.done = False
async def start(self):
"""No-op since we're not prefetching."""
pass
async def get_next_batch(self) -> Optional[Tuple[List[Dict[str, Any]], float]]:
"""Directly fetch next batch."""
if self.done:
return None
batch, fetch_time = await fetch_next_batch(
self.pool,
self.batch_size
)
if not batch:
remaining = await get_unprocessed_estimate(self.pool)
if remaining == 0:
self.done = True
return None
return batch, fetch_time
class EmbeddingProcessor:
"""Handles model initialization and inference."""
def __init__(self, batch_size: int, quiet: bool = True):
self.model = OptimizedEmbeddingModel(batch_size, quiet=quiet)
self.batch_size = batch_size
def process_batch(self, texts: List[str], ids: List[int]) -> Tuple[np.ndarray, List[int], List[InferenceMetrics]]:
"""Process a batch of texts."""
all_embeddings = []
all_batch_ids = []
all_metrics = []
total_batches = len(texts) // self.batch_size + (1 if len(texts) % self.batch_size else 0)
for i in range(0, len(texts), self.batch_size):
batch_texts = texts[i:i + self.batch_size]
batch_ids = ids[i:i + self.batch_size]
# Process batch and collect metrics
embeddings, sim, metrics = self.model.encode(batch_texts)
all_embeddings.append(embeddings)
all_batch_ids.extend(batch_ids)
all_metrics.append(metrics)
# Create progress bar and metrics summary
current_batch = i // self.batch_size + 1
progress = current_batch / total_batches
bar_length = 30
filled_length = int(bar_length * progress)
bar = '=' * filled_length + '-' * (bar_length - filled_length)
# Only print every 5% progress or at the end
if current_batch == total_batches or current_batch % max(1, total_batches // 20) == 0:
print(f"\rProcessing: [{bar}] {progress*100:.1f}% "
f"({current_batch}/{total_batches} batches) | "
f"Speed: {metrics.texts_per_second:.0f} texts/s | "
f"Batch: {len(batch_texts)} texts", end='')
if current_batch == total_batches:
# Final summary
avg_metrics = InferenceMetrics(
tokenization_time=np.mean([m.tokenization_time for m in all_metrics]),
model_forward_time=np.mean([m.model_forward_time for m in all_metrics]),
pooling_time=np.mean([m.pooling_time for m in all_metrics]),
quantization_time=np.mean([m.quantization_time for m in all_metrics]),
total_time=np.mean([m.total_time for m in all_metrics]),
batch_size=np.mean([m.batch_size for m in all_metrics]),
texts_per_second=np.mean([m.texts_per_second for m in all_metrics]),
cosine_sim=np.mean([m.cosine_sim for m in all_metrics])
)
print("\n\nBatch Processing Summary:")
print(f" Average Speed: {avg_metrics.texts_per_second:.0f} texts/s")
print(f" Times (avg): tokenize={avg_metrics.tokenization_time:.3f}s, "
f"forward={avg_metrics.model_forward_time:.3f}s, "
f"pool={avg_metrics.pooling_time:.3f}s, "
f"quantize={avg_metrics.quantization_time:.3f}s")
print(f" Total Processed: {len(texts):,} texts")
# Force garbage collection between batches
gc.collect()
torch.mps.empty_cache()
# Combine all batches
combined_embeddings = np.vstack(all_embeddings)
return combined_embeddings, all_batch_ids, all_metrics
@dataclass
class BatchSummary:
"""Track total time spent in each operation for a batch."""
batch_size: int = 0
fetch_time: float = 0.0
embedding_time: float = 0.0
tokenization_time: float = 0.0
forward_time: float = 0.0
pooling_time: float = 0.0
quantization_time: float = 0.0
db_update_time: float = 0.0
avg_cosine_sim: float = 0.0
start_time: float = field(default_factory=time.time)
def log_summary(self):
"""Log summary of all operations for this batch."""
total_time = self.fetch_time + self.embedding_time + self.db_update_time
print(f"\nBatch Summary ({self.batch_size:,} rows):")
print(f" Fetch Time: {self.fetch_time:.1f}s ({(self.fetch_time/total_time)*100:.1f}%)")
print(f" Embedding Time: {self.embedding_time:.1f}s ({(self.embedding_time/total_time)*100:.1f}%)")
print(f" → Tokenization: {self.tokenization_time:.1f}s ({(self.tokenization_time/self.embedding_time)*100:.1f}%)")
print(f" → Model Forward: {self.forward_time:.1f}s ({(self.forward_time/self.embedding_time)*100:.1f}%)")
print(f" → Pooling: {self.pooling_time:.1f}s ({(self.pooling_time/self.embedding_time)*100:.1f}%)")
print(f" → Quantization: {self.quantization_time:.1f}s ({(self.quantization_time/self.embedding_time)*100:.1f}%)")
print(f" DB Update Time: {self.db_update_time:.1f}s ({(self.db_update_time/total_time)*100:.1f}%)")
print(f" Total Time: {total_time:.1f}s")
print(f" Throughput: {self.batch_size/total_time:.1f} rows/sec")
print(f" Cosine Sim: {self.avg_cosine_sim:.4f}")
class PipelineLogger:
"""Consolidated logging for the embedding pipeline."""
def __init__(self):
self.start_time = time.time()
self.last_log_time = self.start_time
self.total_processed = 0
self.total_remaining = 0
self.batch_inference_metrics: List[InferenceMetrics] = []
def set_total_remaining(self, total: int):
self.total_remaining = total
def log_event(self, message: str, force: bool = False):
"""Log a timestamped event."""
now = time.time()
if force or now - self.last_log_time >= 5:
now_str = time.strftime("%Y-%m-%d %H:%M:%S")
elapsed = now - self.start_time
print(f"\n[{now_str} | +{elapsed:.1f}s] {message}")
self.last_log_time = now
def log_query_time(self, operation: str, duration: float):
"""Log database query timing."""
self.log_event(f"{operation} query time: {duration:.3f}s", force=True)
def log_batch_progress(self, current_batch: int, total_batches: int, metrics: InferenceMetrics):
"""Log progress after each sub-batch."""
self.batch_inference_metrics.append(metrics)
# Log every quarter batch
if total_batches > 0 and current_batch % max(1, total_batches // 4) == 0:
recent = self.batch_inference_metrics[-10:] # Last 10 metrics
self.log_event(
f"Processing {current_batch}/{total_batches} batches\n"
f" Recent metrics:\n"
f" Tokenization: {np.mean([m.tokenization_time for m in recent]):.3f}s\n"
f" Model Forward: {np.mean([m.model_forward_time for m in recent]):.3f}s\n"
f" Pooling: {np.mean([m.pooling_time for m in recent]):.3f}s\n"
f" Quantization: {np.mean([m.quantization_time for m in recent]):.3f}s\n"
f" Total: {np.mean([m.total_time for m in recent]):.3f}s\n"
f" Throughput: {np.mean([m.texts_per_second for m in recent]):.1f} texts/sec",
force=True
)
def update_processed(self, count: int):
"""Update progress counters and log if needed."""
self.total_processed += count
now = time.time()
if now - self.last_log_time >= 5:
elapsed = now - self.start_time
rows_per_sec = self.total_processed / elapsed if elapsed > 0 else 0
remaining = max(0, self.total_remaining - self.total_processed)
status = (
f"Processed: {self.total_processed:,} / {self.total_remaining:,} rows\n"
f"Elapsed: {elapsed/60:.1f} minutes, Overall TPS: {rows_per_sec:,.1f}"
)
if rows_per_sec > 0:
eta_min = (remaining / rows_per_sec) / 60
status += f"\nETA: {eta_min:.1f} minutes for {remaining:,} rows"
self.log_event(status)
@dataclass
class OperationTiming:
"""Track detailed timing of overlapping operations."""
operation_start: float = field(default_factory=time.time)
fetch_start: float = 0.0
fetch_end: float = 0.0
inference_start: float = 0.0
inference_end: float = 0.0
db_update_start: float = 0.0
db_update_end: float = 0.0
def log_overlap(self):
"""Log timing and overlap of operations."""
print("\nOperation Timing:")
fetch_time = self.fetch_end - self.fetch_start
inference_time = self.inference_end - self.inference_start
update_time = self.db_update_end - self.db_update_start
# Calculate overlaps
fetch_inference_gap = self.inference_start - self.fetch_end
inference_update_gap = self.db_update_start - self.inference_end
print(f" Fetch: {self.fetch_start:.1f}s → {self.fetch_end:.1f}s ({fetch_time:.1f}s)")
print(f" Inference: {self.inference_start:.1f}s → {self.inference_end:.1f}s ({inference_time:.1f}s)")
print(f" DB Update: {self.db_update_start:.1f}s → {self.db_update_end:.1f}s ({update_time:.1f}s)")
print(f" Gaps:")
print(f" Fetch→Inference: {fetch_inference_gap:.3f}s")
print(f" Inference→Update: {inference_update_gap:.3f}s")
async def update_embeddings(pool, batch_ids: List[int], embeddings: np.ndarray) -> Tuple[float, int]:
"""Async function to update embeddings in the database. Returns (time_taken, rows_updated)"""
start_time = time.time()
chunk_size = CHUNK_SIZE
total_updated = 0
for i in range(0, len(batch_ids), chunk_size):
chunk_ids = batch_ids[i:i + chunk_size]
chunk_embeddings = embeddings[i:i + chunk_size]
async with pool.acquire() as conn:
async with conn.transaction():
value_strings = [
f"({id_}, '[{','.join(map(str, emb))}]')"
for id_, emb in zip(chunk_ids, chunk_embeddings)
]
values_clause = ','.join(value_strings)
update_sql = f"""
UPDATE casts AS t
SET
embedding384 = v.embedding::vector,
embedding384_updated_at = NOW()
FROM (VALUES {values_clause}) AS v(id, embedding)
WHERE t.id = v.id
RETURNING t.id
"""
result = await conn.fetch(update_sql)
rows_updated = len(result)
total_updated += rows_updated
if rows_updated != len(chunk_ids):
print(f"\nWarning: Expected to update {len(chunk_ids)} rows but updated {rows_updated}")
return time.time() - start_time, total_updated
async def process_casts(pool, batch_size: int):
stats = ProcessingStats()
stats.total_remaining = await get_unprocessed_estimate(pool)
print(f"Starting processing of {stats.total_remaining:,} rows...")
if stats.total_remaining == 0:
print("\n✨ No rows to process!")
return
# Initialize processor and queues
processor = BatchProcessor(pool, BATCH_SIZE_ROWS)
await processor.start()
# Start DB update worker
db_queue = asyncio.Queue()
db_update_event = asyncio.Event()
update_worker = asyncio.create_task(db_update_worker(pool, db_queue, stats, db_update_event))
embedding_processor = EmbeddingProcessor(batch_size, quiet=True)
batch_stats = BatchStats()
processed_rows = 0
try:
while True:
batch_summary = BatchSummary()
op_timing = OperationTiming()
# Get batch with timing
op_timing.fetch_start = time.time() - op_timing.operation_start
fetch_start = time.time()
result = await processor.get_next_batch()
if result is None:
print("\n✨ Processing completed successfully - no more rows to process!")
break
batch, fetch_time = result
batch_summary.fetch_time = time.time() - fetch_start
op_timing.fetch_end = time.time() - op_timing.operation_start
if not batch:
print("\n✨ Processing completed successfully - no more rows to process!")
break
batch_summary.batch_size = len(batch)
texts = [cast['text'] for cast in batch]
ids = [cast['id'] for cast in batch]
# Process embeddings
op_timing.inference_start = time.time() - op_timing.operation_start
embed_start = time.time()
embeddings, batch_ids, metrics = embedding_processor.process_batch(texts, ids)
batch_summary.embedding_time = time.time() - embed_start
op_timing.inference_end = time.time() - op_timing.operation_start
# Aggregate inference metrics
batch_summary.tokenization_time = sum(m.tokenization_time for m in metrics)
batch_summary.forward_time = sum(m.model_forward_time for m in metrics)
batch_summary.pooling_time = sum(m.pooling_time for m in metrics)
batch_summary.quantization_time = sum(m.quantization_time for m in metrics)
batch_summary.avg_cosine_sim = np.mean([m.cosine_sim for m in metrics])
batch_stats.inference_metrics.extend(metrics)
# Queue DB update and wait for completion
db_update_event.clear()
op_timing.db_update_start = time.time() - op_timing.operation_start
db_start = time.time()
await db_queue.put((batch_ids, embeddings))
await db_update_event.wait()
batch_summary.db_update_time = time.time() - db_start
op_timing.db_update_end = time.time() - op_timing.operation_start
processed_rows += len(batch_ids)
batch_stats.update_processed(len(batch_ids))
# Log summaries
batch_summary.log_summary()
op_timing.log_overlap()
# Log quarter batch analysis
current_batch = processed_rows // batch_size
total_batches = stats.total_remaining // batch_size
batch_stats.log_quartile(current_batch, total_batches, processed_rows, stats.total_remaining)
finally:
# Wait for remaining DB updates and stop worker
await db_queue.put(None)
await update_worker
# Final stats
total_time = time.time() - batch_stats.start_time
print("\nProcessing complete!")
print(f"Total processed: {batch_stats.total_processed:,} rows")
print(f"Overall TPS: {batch_stats.total_processed/total_time:,.1f}")
print(f"Total time: {total_time/60:.1f} minutes")
async def db_update_worker(pool, queue: asyncio.Queue, stats: ProcessingStats, update_event: asyncio.Event):
"""Worker to handle async DB updates."""
last_log = time.time()
updates_since_log = 0
total_update_time = 0.0
while True:
item = await queue.get()
if item is None: # Poison pill
break
batch_ids, embeddings = item
start_time = time.time()
update_time, rows_updated = await update_embeddings(pool, batch_ids, embeddings)
total_update_time += update_time
updates_since_log += rows_updated
# Signal update completion
update_event.set()
# Log DB update stats every 5 seconds
now = time.time()
if now - last_log >= 5:
avg_time = total_update_time / (updates_since_log or 1)
print(f"\nDB Update Stats:")
print(f" Queue Size: {queue.qsize():,} batches")
print(f" Recent Updates: {updates_since_log:,} rows")
print(f" Avg Update Time: {avg_time:.3f}s per batch")
print(f" Total Update Time: {total_update_time:.1f}s")
print(f" Recent Batch Time: {update_time:.3f}s")
# Reset counters
last_log = now
updates_since_log = 0
total_update_time = 0
stats.update(rows_updated)
queue.task_done()
@dataclass
class BatchStats:
"""Track statistics for an entire batch."""
inference_metrics: List[InferenceMetrics] = field(default_factory=list)
db_select_time: float = 0
db_update_time: float = 0
total_time: float = 0
start_time: float = field(default_factory=time.time)
last_log_time: float = field(default_factory=time.time)
total_processed: int = 0
def log_quartile(self, current_batch: int, total_batches: int, processed: int, total: int):
"""Log statistics at quartile boundaries."""
# Skip logging if there are no batches or we're done processing
if total_batches <= 0:
return
if current_batch % (total_batches // 4) == 0:
metrics = self.inference_metrics[-10:] # Last 10 sub-batches
current_time = time.strftime('%Y-%m-%d %H:%M:%S')
elapsed = time.time() - self.start_time
# Calculate key metrics
progress = (processed / total) * 100 if total > 0 else 0
rows_per_sec = processed / elapsed if elapsed > 0 else 0
remaining_rows = total - processed
eta_minutes = (remaining_rows / rows_per_sec) / 60 if rows_per_sec > 0 else 0
print(f"\n{'='*80}")
print(f"Quarter-Batch Analysis ({current_batch}/{total_batches}) - {current_time}")
print(f"{'='*80}")
print("\nProgress Summary:")
print(f" Processed: {processed:,}/{total:,} rows ({progress:.1f}%)")
print(f" Elapsed Time: {elapsed/60:.1f} minutes")
print(f" Overall Speed: {rows_per_sec:.1f} rows/sec")
print(f" ETA: {eta_minutes:.1f} minutes")
print("\nRecent Performance (last 10 batches):")
print(f" Tokenization: {np.mean([m.tokenization_time for m in metrics]):.3f}s / {np.percentile([m.tokenization_time for m in metrics], 95):.3f}s")
print(f" Model Forward: {np.mean([m.model_forward_time for m in metrics]):.3f}s / {np.percentile([m.model_forward_time for m in metrics], 95):.3f}s")
print(f" Pooling: {np.mean([m.pooling_time for m in metrics]):.3f}s / {np.percentile([m.pooling_time for m in metrics], 95):.3f}s")
print(f" Quantization: {np.mean([m.quantization_time for m in metrics]):.3f}s / {np.percentile([m.quantization_time for m in metrics], 95):.3f}s")
print(f" Total: {np.mean([m.total_time for m in metrics]):.3f}s / {np.percentile([m.total_time for m in metrics], 95):.3f}s")
print("\nThroughput Analysis:")
print(f" Recent TPS: {np.mean([m.texts_per_second for m in metrics]):.1f} texts/sec")
print(f" Peak TPS: {np.max([m.texts_per_second for m in metrics]):.1f} texts/sec")
print(f" Current TPS: {rows_per_sec:.1f} rows/sec")
print("\nDatabase Performance:")
print(f" Select Time: {self.db_select_time:.3f}s")
print(f" Update Time: {self.db_update_time:.3f}s")
print("\nMemory Usage:")
print(f" Process: {get_process_memory(os.getpid()):.1f}MB")
print(f" GPU: {get_gpu_memory():.1f}MB")
print(f"{'='*80}\n")
self.last_log_time = time.time()
def update_processed(self, count: int):
"""Update the total processed count."""
self.total_processed += count
async def reset_stale_rows(pool) -> int:
"""Reset stale rows at startup."""
async with pool.acquire() as conn:
async with conn.transaction():
result = await conn.fetch("""
UPDATE casts
SET embedding384_updated_at = NULL
WHERE embedding384 IS NULL
AND embedding384_updated_at IS NOT NULL
AND embedding384_updated_at < NOW() - interval '5 minutes'
RETURNING id
""")
return len(result)
async def main():
"""Main processing orchestration."""
try:
start_time = time.time()
print("\nStarting casts processing...")
await db.initialize_pool()
await db.run_migrations(module_path=str(Path(__file__).parent))
# Reset stale rows before starting
reset_count = await reset_stale_rows(db.pool)
print(f"Reset {reset_count:,} stale rows")
# Get initial count of unprocessed rows
total_unprocessed = await get_unprocessed_estimate(db.pool)
print(f"Estimated unprocessed rows: {total_unprocessed:,}")
await process_casts(db.pool, BATCH_SIZE)
total_duration = time.time() - start_time
if total_unprocessed > 0:
print(f"\nProcessing complete!")
print(f"Total processed: {total_unprocessed:,} rows")
print(f"Overall TPS: {total_unprocessed/total_duration:.1f}")
print(f"Total time: {total_duration/60:.1f} minutes")
print(f"\nTotal processing time: {total_duration:.1f}s")
finally:
await db.close_pool()
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment