Created
March 9, 2025 05:19
-
-
Save jc4p/1d27032bc452d92fcfd6ab4911c1c928 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sentence_transformers import SentenceTransformer | |
import time | |
import duckdb | |
import numpy as np | |
from concurrent.futures import ThreadPoolExecutor | |
import os | |
MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2' | |
BATCH_SIZE = 1942 | |
DEVICE = "cuda" | |
# Determine number of threads based on CPU cores | |
NUM_THREADS = min(32, os.cpu_count() + 4) | |
# DuckDB memory limit (70% of system RAM) | |
MEMORY_LIMIT = int(526 * 0.7) # Use 70% of your 526GB RAM | |
def batch_insert_embeddings(con, batch_data): | |
"""Insert a batch of embeddings into the DuckDB table""" | |
query = "INSERT INTO temp_embeddings VALUES (?, ?)" | |
con.executemany(query, batch_data) | |
def main(): | |
# Configure DuckDB to use more memory | |
con = duckdb.connect() | |
con.execute(f"SET memory_limit='{MEMORY_LIMIT}GB'") | |
# Enable multithreading in DuckDB | |
con.execute(f"SET threads={NUM_THREADS}") | |
# Load model | |
print("Loading model...") | |
model = SentenceTransformer(MODEL_NAME, device=DEVICE) | |
model.max_seq_length = 512 | |
# Process in streaming fashion to avoid memory pressure | |
print("Setting up streaming data processing...") | |
# Create schema for the embeddings table | |
con.execute(""" | |
CREATE TABLE temp_embeddings ( | |
Hash VARCHAR, | |
embedding DOUBLE[] | |
) | |
""") | |
# Use a streaming approach for the initial data fetch | |
# First, count the rows to give us progress information | |
total_rows = con.execute("SELECT COUNT(*) FROM read_parquet('casts.parquet')").fetchone()[0] | |
print(f"Total rows in parquet file: {total_rows}") | |
# Process in chunks of 100,000 rows | |
chunk_size = 100000 | |
total_chunks = (total_rows + chunk_size - 1) // chunk_size | |
valid_texts_count = 0 | |
for chunk_idx in range(total_chunks): | |
start_row = chunk_idx * chunk_size | |
# Fetch chunk of data | |
print(f"Processing chunk {chunk_idx+1}/{total_chunks} (rows {start_row} to {start_row+chunk_size-1})...") | |
query = f""" | |
SELECT * | |
FROM read_parquet('casts.parquet') | |
LIMIT {chunk_size} OFFSET {start_row} | |
""" | |
result = con.execute(query) | |
column_names = [desc[0] for desc in result.description] | |
chunk_data = [dict(zip(column_names, row)) for row in result.fetchall()] | |
# Extract text and hash for embedding | |
cast_texts = [] | |
valid_hashes = [] | |
for item in chunk_data: | |
if item.get('Text') and isinstance(item['Text'], str) and item['Text'].strip(): | |
cast_texts.append(item['Text']) | |
valid_hashes.append(item['Hash']) | |
chunk_valid_count = len(cast_texts) | |
valid_texts_count += chunk_valid_count | |
print(f"Found {chunk_valid_count} valid texts in this chunk") | |
if chunk_valid_count == 0: | |
continue | |
# Generate embeddings for this chunk | |
print(f"Generating embeddings for chunk {chunk_idx+1}...") | |
embeddings = model.encode( | |
cast_texts, | |
batch_size=BATCH_SIZE, | |
device=DEVICE, | |
convert_to_numpy=True, | |
show_progress_bar=True, | |
output_value='sentence_embedding' | |
) | |
# Prepare batch data for insertion | |
batch_data = [(hash_val, emb.tolist()) for hash_val, emb in zip(valid_hashes, embeddings)] | |
# Insert in parallel batches to speed up DB operations | |
print(f"Inserting {len(batch_data)} embeddings into database...") | |
# Split into smaller batches for insertion | |
db_batch_size = 10000 | |
insertion_batches = [batch_data[i:i+db_batch_size] for i in range(0, len(batch_data), db_batch_size)] | |
with ThreadPoolExecutor(max_workers=min(8, len(insertion_batches))) as executor: | |
executor.map(lambda batch: batch_insert_embeddings(con, batch), insertion_batches) | |
print(f"Completed chunk {chunk_idx+1}/{total_chunks}") | |
print(f"Total valid texts processed: {valid_texts_count}") | |
# Create final dataset with embeddings | |
print("Creating final table with original data and embeddings...") | |
# First determine embedding dimension | |
embedding_dim_query = """ | |
SELECT ARRAY_LENGTH(embedding) | |
FROM temp_embeddings | |
LIMIT 1 | |
""" | |
embedding_dim = con.execute(embedding_dim_query).fetchone()[0] | |
# Join with embeddings based on Hash | |
con.execute(""" | |
CREATE TABLE casts_with_embeddings AS | |
SELECT o.*, e.embedding | |
FROM read_parquet('casts.parquet') o | |
LEFT JOIN temp_embeddings e ON o.Hash = e.Hash | |
""") | |
# Save to parquet with compression | |
print("Saving to parquet file...") | |
con.execute(""" | |
COPY casts_with_embeddings TO 'casts_with_embeddings.parquet' | |
(FORMAT 'parquet', COMPRESSION 'ZSTD') | |
""") | |
# Get count of final records | |
total_count = con.execute("SELECT COUNT(*) FROM casts_with_embeddings").fetchone()[0] | |
with_embedding_count = con.execute("SELECT COUNT(*) FROM casts_with_embeddings WHERE embedding IS NOT NULL").fetchone()[0] | |
print(f"Saved {total_count} casts to parquet file") | |
print(f"Of these, {with_embedding_count} have embeddings attached") | |
print(f"Embedding coverage: {(with_embedding_count/total_count)*100:.2f}%") | |
# Clean up | |
con.execute("DROP TABLE IF EXISTS temp_embeddings") | |
con.execute("DROP TABLE IF EXISTS casts_with_embeddings") | |
con.close() | |
if __name__ == "__main__": | |
start_time = time.time() | |
main() | |
print(f"Total execution time: {time.time() - start_time:.2f} seconds") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment