Created
July 15, 2025 05:03
-
-
Save pszemraj/952f3a26c3ec619ef8b65bf621aca33c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%writefile emoji_search.py | |
#!/usr/bin/env python3 | |
""" | |
Emoji Semantic Search CLI | |
reqs: | |
pip install fire sentence-transformers pandas numpy | |
Usage: | |
python emoji_search.py "that is flames" | |
python emoji_search.py "happy birthday" --top_k=10 | |
python emoji_search.py "love" --format=json | |
""" | |
import json | |
from functools import lru_cache | |
import fire | |
import numpy as np | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import semantic_search | |
@lru_cache(maxsize=1) | |
def load_emoji_data(): | |
"""Load emoji dataset with caching.""" | |
return pd.read_parquet( | |
"hf://datasets/pszemraj/local-emoji-search-gte/data/train-00000-of-00001.parquet" | |
) | |
@lru_cache(maxsize=1) | |
def load_model(model_name="Alibaba-NLP/gte-large-en-v1.5"): | |
"""Load sentence transformer model with caching.""" | |
return SentenceTransformer(model_name, trust_remote_code=True) | |
def search_emojis(query, df, model, top_k=5, num_digits=4): | |
"""Perform semantic search for emojis.""" | |
query_embed = model.encode(query) | |
embeddings_array = np.vstack(df.embed.values, dtype=np.float32) | |
hits = semantic_search(query_embed, embeddings_array, top_k=top_k)[0] | |
return [ | |
{ | |
"emoji": df.loc[hit["corpus_id"], "emoji"], | |
"message": df.loc[hit["corpus_id"], "message"], | |
"score": round(hit["score"], num_digits), | |
} | |
for hit in hits | |
] | |
def main( | |
query, | |
top_k=5, | |
format="plain", | |
model_name="Alibaba-NLP/gte-large-en-v1.5", | |
num_digits=4, | |
): | |
""" | |
Search for emojis semantically similar to the query. | |
Args: | |
query: Search query text | |
top_k: Number of top results to return | |
format: Output format - plain, compact, json, or tsv | |
model_name: Name of the sentence transformer model | |
num_digits: Number of decimal digits for scores | |
""" | |
# Load resources | |
model = load_model(model_name) | |
df = load_emoji_data() | |
# Search | |
results = search_emojis(query, df, model, top_k, num_digits) | |
# Format output | |
if format == "plain": | |
for i, r in enumerate(results, 1): | |
print(f"{i}. {r['emoji']} {r['message']} (score: {r['score']})") | |
elif format == "compact": | |
print(" ".join(r["emoji"] for r in results)) | |
elif format == "json": | |
print(json.dumps(results, indent=2, ensure_ascii=False)) | |
elif format == "tsv": | |
print("emoji\tmessage\tscore") | |
for r in results: | |
print(f"{r['emoji']}\t{r['message']}\t{r['score']}") | |
if __name__ == "__main__": | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment