Skip to content

Instantly share code, notes, and snippets.

@jc4p
Created March 10, 2025 01:39
Show Gist options
  • Save jc4p/93c2887453f1852fc716dd364577162f to your computer and use it in GitHub Desktop.
Save jc4p/93c2887453f1852fc716dd364577162f to your computer and use it in GitHub Desktop.
from sentence_transformers import SentenceTransformer
import time
import duckdb
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import os
import argparse
import glob
import sys
MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
BATCH_SIZE = 1600
DEVICE = "cuda"
# Determine number of threads based on CPU cores
NUM_THREADS = min(32, os.cpu_count() + 4)
# DuckDB memory limit (70% of system RAM)
MEMORY_LIMIT = int(526 * 0.7) # Use 70% of your 526GB RAM
# Total number of instances to distribute work across
TOTAL_INSTANCES = 4
def batch_insert_embeddings(con, batch_data):
"""Insert a batch of embeddings into the DuckDB table"""
query = "INSERT INTO temp_embeddings VALUES (?, ?)"
con.executemany(query, batch_data)
def main():
parser = argparse.ArgumentParser(description='Distributed embedding generation for Farcaster casts')
parser.add_argument('--instance', type=int, default=0, choices=range(TOTAL_INSTANCES),
help=f'Instance ID (0-{TOTAL_INSTANCES-1})')
args = parser.parse_args()
# Set the instance ID from command line argument
instance_id = args.instance
print(f"Running as instance {instance_id} of {TOTAL_INSTANCES}")
# Create output directory
output_dir = f"instance_{instance_id}_output"
os.makedirs(output_dir, exist_ok=True)
# Use in-memory database like the original script
con = duckdb.connect()
con.execute(f"SET memory_limit='{MEMORY_LIMIT}GB'")
con.execute(f"SET threads={NUM_THREADS}")
# Load model
print("Loading model...")
model = SentenceTransformer(MODEL_NAME, device=DEVICE)
model.max_seq_length = 512
# Create schema for the embeddings table
con.execute("""
CREATE TABLE temp_embeddings (
Hash VARCHAR,
embedding DOUBLE[]
)
""")
# Use a streaming approach for the initial data fetch
# First, count the rows to give us progress information
total_rows = con.execute("SELECT COUNT(*) FROM read_parquet('casts.parquet')").fetchone()[0]
print(f"Total rows in parquet file: {total_rows}")
# Process in chunks of 100,000 rows
chunk_size = 100000
total_chunks = (total_rows + chunk_size - 1) // chunk_size
# Calculate chunk ranges for this instance
chunks_per_instance = total_chunks // TOTAL_INSTANCES
remainder = total_chunks % TOTAL_INSTANCES
# Distribute remainder chunks evenly
start_chunk = instance_id * chunks_per_instance
if instance_id < remainder:
start_chunk += instance_id
end_chunk = start_chunk + chunks_per_instance + 1
else:
start_chunk += remainder
end_chunk = start_chunk + chunks_per_instance
# Adjust for zero-indexing
instance_chunks = list(range(start_chunk, end_chunk))
print(f"Instance {instance_id} will process chunks {start_chunk} to {end_chunk-1}")
print(f"This represents approximately {len(instance_chunks)/total_chunks*100:.2f}% of the total workload")
valid_texts_count = 0
# Process chunks assigned to this instance
for chunk_idx in instance_chunks:
start_row = chunk_idx * chunk_size
# Fetch chunk of data
print(f"Processing chunk {chunk_idx}/{total_chunks-1} (rows {start_row} to {start_row+chunk_size-1})...")
query = f"""
SELECT *
FROM read_parquet('casts.parquet')
LIMIT {chunk_size} OFFSET {start_row}
"""
result = con.execute(query)
column_names = [desc[0] for desc in result.description]
chunk_data = [dict(zip(column_names, row)) for row in result.fetchall()]
# Extract text and hash for embedding
cast_texts = []
valid_hashes = []
for item in chunk_data:
if item.get('Text') and isinstance(item['Text'], str) and item['Text'].strip():
cast_texts.append(item['Text'])
valid_hashes.append(item['Hash'])
chunk_valid_count = len(cast_texts)
valid_texts_count += chunk_valid_count
print(f"Found {chunk_valid_count} valid texts in this chunk")
if chunk_valid_count == 0:
continue
# Generate embeddings for this chunk
print(f"Generating embeddings for chunk {chunk_idx}...")
embeddings = model.encode(
cast_texts,
batch_size=BATCH_SIZE,
device=DEVICE,
convert_to_numpy=True,
show_progress_bar=True,
output_value='sentence_embedding'
)
# Prepare batch data for insertion
batch_data = [(hash_val, emb.tolist()) for hash_val, emb in zip(valid_hashes, embeddings)]
# Insert in parallel batches to speed up DB operations
print(f"Inserting {len(batch_data)} embeddings into database...")
# Split into smaller batches for insertion
db_batch_size = 10000
insertion_batches = [batch_data[i:i+db_batch_size] for i in range(0, len(batch_data), db_batch_size)]
with ThreadPoolExecutor(max_workers=min(8, len(insertion_batches))) as executor:
executor.map(lambda batch: batch_insert_embeddings(con, batch), insertion_batches)
print(f"Completed chunk {chunk_idx}/{total_chunks-1}")
# Save progress to a parquet file
intermediate_file = f"{output_dir}/embeddings_chunk_{chunk_idx}.parquet"
print(f"Saving progress to {intermediate_file}...")
con.execute(f"""
COPY temp_embeddings TO '{intermediate_file}'
(FORMAT 'parquet', COMPRESSION 'ZSTD')
""")
print(f"Total valid texts processed by instance {instance_id}: {valid_texts_count}")
# Save the final embeddings table for this instance
final_output = f"{output_dir}/instance_{instance_id}_embeddings.parquet"
con.execute(f"""
COPY temp_embeddings TO '{final_output}'
(FORMAT 'parquet', COMPRESSION 'ZSTD')
""")
print(f"Saved all embeddings to {final_output}")
print(f"To combine results, run the merge script after all instances complete")
# Clean up
con.execute("DROP TABLE IF EXISTS temp_embeddings")
con.close()
if __name__ == "__main__":
start_time = time.time()
main()
print(f"Total execution time: {time.time() - start_time:.2f} seconds")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment