Created
March 10, 2025 01:39
-
-
Save jc4p/93c2887453f1852fc716dd364577162f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sentence_transformers import SentenceTransformer | |
import time | |
import duckdb | |
import numpy as np | |
from concurrent.futures import ThreadPoolExecutor | |
import os | |
import argparse | |
import glob | |
import sys | |
MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2' | |
BATCH_SIZE = 1600 | |
DEVICE = "cuda" | |
# Determine number of threads based on CPU cores | |
NUM_THREADS = min(32, os.cpu_count() + 4) | |
# DuckDB memory limit (70% of system RAM) | |
MEMORY_LIMIT = int(526 * 0.7) # Use 70% of your 526GB RAM | |
# Total number of instances to distribute work across | |
TOTAL_INSTANCES = 4 | |
def batch_insert_embeddings(con, batch_data): | |
"""Insert a batch of embeddings into the DuckDB table""" | |
query = "INSERT INTO temp_embeddings VALUES (?, ?)" | |
con.executemany(query, batch_data) | |
def main(): | |
parser = argparse.ArgumentParser(description='Distributed embedding generation for Farcaster casts') | |
parser.add_argument('--instance', type=int, default=0, choices=range(TOTAL_INSTANCES), | |
help=f'Instance ID (0-{TOTAL_INSTANCES-1})') | |
args = parser.parse_args() | |
# Set the instance ID from command line argument | |
instance_id = args.instance | |
print(f"Running as instance {instance_id} of {TOTAL_INSTANCES}") | |
# Create output directory | |
output_dir = f"instance_{instance_id}_output" | |
os.makedirs(output_dir, exist_ok=True) | |
# Use in-memory database like the original script | |
con = duckdb.connect() | |
con.execute(f"SET memory_limit='{MEMORY_LIMIT}GB'") | |
con.execute(f"SET threads={NUM_THREADS}") | |
# Load model | |
print("Loading model...") | |
model = SentenceTransformer(MODEL_NAME, device=DEVICE) | |
model.max_seq_length = 512 | |
# Create schema for the embeddings table | |
con.execute(""" | |
CREATE TABLE temp_embeddings ( | |
Hash VARCHAR, | |
embedding DOUBLE[] | |
) | |
""") | |
# Use a streaming approach for the initial data fetch | |
# First, count the rows to give us progress information | |
total_rows = con.execute("SELECT COUNT(*) FROM read_parquet('casts.parquet')").fetchone()[0] | |
print(f"Total rows in parquet file: {total_rows}") | |
# Process in chunks of 100,000 rows | |
chunk_size = 100000 | |
total_chunks = (total_rows + chunk_size - 1) // chunk_size | |
# Calculate chunk ranges for this instance | |
chunks_per_instance = total_chunks // TOTAL_INSTANCES | |
remainder = total_chunks % TOTAL_INSTANCES | |
# Distribute remainder chunks evenly | |
start_chunk = instance_id * chunks_per_instance | |
if instance_id < remainder: | |
start_chunk += instance_id | |
end_chunk = start_chunk + chunks_per_instance + 1 | |
else: | |
start_chunk += remainder | |
end_chunk = start_chunk + chunks_per_instance | |
# Adjust for zero-indexing | |
instance_chunks = list(range(start_chunk, end_chunk)) | |
print(f"Instance {instance_id} will process chunks {start_chunk} to {end_chunk-1}") | |
print(f"This represents approximately {len(instance_chunks)/total_chunks*100:.2f}% of the total workload") | |
valid_texts_count = 0 | |
# Process chunks assigned to this instance | |
for chunk_idx in instance_chunks: | |
start_row = chunk_idx * chunk_size | |
# Fetch chunk of data | |
print(f"Processing chunk {chunk_idx}/{total_chunks-1} (rows {start_row} to {start_row+chunk_size-1})...") | |
query = f""" | |
SELECT * | |
FROM read_parquet('casts.parquet') | |
LIMIT {chunk_size} OFFSET {start_row} | |
""" | |
result = con.execute(query) | |
column_names = [desc[0] for desc in result.description] | |
chunk_data = [dict(zip(column_names, row)) for row in result.fetchall()] | |
# Extract text and hash for embedding | |
cast_texts = [] | |
valid_hashes = [] | |
for item in chunk_data: | |
if item.get('Text') and isinstance(item['Text'], str) and item['Text'].strip(): | |
cast_texts.append(item['Text']) | |
valid_hashes.append(item['Hash']) | |
chunk_valid_count = len(cast_texts) | |
valid_texts_count += chunk_valid_count | |
print(f"Found {chunk_valid_count} valid texts in this chunk") | |
if chunk_valid_count == 0: | |
continue | |
# Generate embeddings for this chunk | |
print(f"Generating embeddings for chunk {chunk_idx}...") | |
embeddings = model.encode( | |
cast_texts, | |
batch_size=BATCH_SIZE, | |
device=DEVICE, | |
convert_to_numpy=True, | |
show_progress_bar=True, | |
output_value='sentence_embedding' | |
) | |
# Prepare batch data for insertion | |
batch_data = [(hash_val, emb.tolist()) for hash_val, emb in zip(valid_hashes, embeddings)] | |
# Insert in parallel batches to speed up DB operations | |
print(f"Inserting {len(batch_data)} embeddings into database...") | |
# Split into smaller batches for insertion | |
db_batch_size = 10000 | |
insertion_batches = [batch_data[i:i+db_batch_size] for i in range(0, len(batch_data), db_batch_size)] | |
with ThreadPoolExecutor(max_workers=min(8, len(insertion_batches))) as executor: | |
executor.map(lambda batch: batch_insert_embeddings(con, batch), insertion_batches) | |
print(f"Completed chunk {chunk_idx}/{total_chunks-1}") | |
# Save progress to a parquet file | |
intermediate_file = f"{output_dir}/embeddings_chunk_{chunk_idx}.parquet" | |
print(f"Saving progress to {intermediate_file}...") | |
con.execute(f""" | |
COPY temp_embeddings TO '{intermediate_file}' | |
(FORMAT 'parquet', COMPRESSION 'ZSTD') | |
""") | |
print(f"Total valid texts processed by instance {instance_id}: {valid_texts_count}") | |
# Save the final embeddings table for this instance | |
final_output = f"{output_dir}/instance_{instance_id}_embeddings.parquet" | |
con.execute(f""" | |
COPY temp_embeddings TO '{final_output}' | |
(FORMAT 'parquet', COMPRESSION 'ZSTD') | |
""") | |
print(f"Saved all embeddings to {final_output}") | |
print(f"To combine results, run the merge script after all instances complete") | |
# Clean up | |
con.execute("DROP TABLE IF EXISTS temp_embeddings") | |
con.close() | |
if __name__ == "__main__": | |
start_time = time.time() | |
main() | |
print(f"Total execution time: {time.time() - start_time:.2f} seconds") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment