Skip to content

Instantly share code, notes, and snippets.

@jc4p
Created March 9, 2025 05:19
Show Gist options
  • Save jc4p/1d27032bc452d92fcfd6ab4911c1c928 to your computer and use it in GitHub Desktop.
Save jc4p/1d27032bc452d92fcfd6ab4911c1c928 to your computer and use it in GitHub Desktop.
from sentence_transformers import SentenceTransformer
import time
import duckdb
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import os
MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
BATCH_SIZE = 1942
DEVICE = "cuda"
# Determine number of threads based on CPU cores
NUM_THREADS = min(32, os.cpu_count() + 4)
# DuckDB memory limit (70% of system RAM)
MEMORY_LIMIT = int(526 * 0.7) # Use 70% of your 526GB RAM
def batch_insert_embeddings(con, batch_data):
"""Insert a batch of embeddings into the DuckDB table"""
query = "INSERT INTO temp_embeddings VALUES (?, ?)"
con.executemany(query, batch_data)
def main():
# Configure DuckDB to use more memory
con = duckdb.connect()
con.execute(f"SET memory_limit='{MEMORY_LIMIT}GB'")
# Enable multithreading in DuckDB
con.execute(f"SET threads={NUM_THREADS}")
# Load model
print("Loading model...")
model = SentenceTransformer(MODEL_NAME, device=DEVICE)
model.max_seq_length = 512
# Process in streaming fashion to avoid memory pressure
print("Setting up streaming data processing...")
# Create schema for the embeddings table
con.execute("""
CREATE TABLE temp_embeddings (
Hash VARCHAR,
embedding DOUBLE[]
)
""")
# Use a streaming approach for the initial data fetch
# First, count the rows to give us progress information
total_rows = con.execute("SELECT COUNT(*) FROM read_parquet('casts.parquet')").fetchone()[0]
print(f"Total rows in parquet file: {total_rows}")
# Process in chunks of 100,000 rows
chunk_size = 100000
total_chunks = (total_rows + chunk_size - 1) // chunk_size
valid_texts_count = 0
for chunk_idx in range(total_chunks):
start_row = chunk_idx * chunk_size
# Fetch chunk of data
print(f"Processing chunk {chunk_idx+1}/{total_chunks} (rows {start_row} to {start_row+chunk_size-1})...")
query = f"""
SELECT *
FROM read_parquet('casts.parquet')
LIMIT {chunk_size} OFFSET {start_row}
"""
result = con.execute(query)
column_names = [desc[0] for desc in result.description]
chunk_data = [dict(zip(column_names, row)) for row in result.fetchall()]
# Extract text and hash for embedding
cast_texts = []
valid_hashes = []
for item in chunk_data:
if item.get('Text') and isinstance(item['Text'], str) and item['Text'].strip():
cast_texts.append(item['Text'])
valid_hashes.append(item['Hash'])
chunk_valid_count = len(cast_texts)
valid_texts_count += chunk_valid_count
print(f"Found {chunk_valid_count} valid texts in this chunk")
if chunk_valid_count == 0:
continue
# Generate embeddings for this chunk
print(f"Generating embeddings for chunk {chunk_idx+1}...")
embeddings = model.encode(
cast_texts,
batch_size=BATCH_SIZE,
device=DEVICE,
convert_to_numpy=True,
show_progress_bar=True,
output_value='sentence_embedding'
)
# Prepare batch data for insertion
batch_data = [(hash_val, emb.tolist()) for hash_val, emb in zip(valid_hashes, embeddings)]
# Insert in parallel batches to speed up DB operations
print(f"Inserting {len(batch_data)} embeddings into database...")
# Split into smaller batches for insertion
db_batch_size = 10000
insertion_batches = [batch_data[i:i+db_batch_size] for i in range(0, len(batch_data), db_batch_size)]
with ThreadPoolExecutor(max_workers=min(8, len(insertion_batches))) as executor:
executor.map(lambda batch: batch_insert_embeddings(con, batch), insertion_batches)
print(f"Completed chunk {chunk_idx+1}/{total_chunks}")
print(f"Total valid texts processed: {valid_texts_count}")
# Create final dataset with embeddings
print("Creating final table with original data and embeddings...")
# First determine embedding dimension
embedding_dim_query = """
SELECT ARRAY_LENGTH(embedding)
FROM temp_embeddings
LIMIT 1
"""
embedding_dim = con.execute(embedding_dim_query).fetchone()[0]
# Join with embeddings based on Hash
con.execute("""
CREATE TABLE casts_with_embeddings AS
SELECT o.*, e.embedding
FROM read_parquet('casts.parquet') o
LEFT JOIN temp_embeddings e ON o.Hash = e.Hash
""")
# Save to parquet with compression
print("Saving to parquet file...")
con.execute("""
COPY casts_with_embeddings TO 'casts_with_embeddings.parquet'
(FORMAT 'parquet', COMPRESSION 'ZSTD')
""")
# Get count of final records
total_count = con.execute("SELECT COUNT(*) FROM casts_with_embeddings").fetchone()[0]
with_embedding_count = con.execute("SELECT COUNT(*) FROM casts_with_embeddings WHERE embedding IS NOT NULL").fetchone()[0]
print(f"Saved {total_count} casts to parquet file")
print(f"Of these, {with_embedding_count} have embeddings attached")
print(f"Embedding coverage: {(with_embedding_count/total_count)*100:.2f}%")
# Clean up
con.execute("DROP TABLE IF EXISTS temp_embeddings")
con.execute("DROP TABLE IF EXISTS casts_with_embeddings")
con.close()
if __name__ == "__main__":
start_time = time.time()
main()
print(f"Total execution time: {time.time() - start_time:.2f} seconds")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment