Created
February 27, 2024 21:18
-
-
Save tsg/1088379515bfae7b293efcd78e0148ae to your computer and use it in GitHub Desktop.
Hybrid Search using Xata
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from xata.client import XataClient | |
from sentence_transformers import SentenceTransformer | |
import sys | |
import time | |
xata = XataClient() | |
# expect the query as the first argument | |
if len(sys.argv) != 2: | |
print("Usage: python hybrid_search.py <query>") | |
exit(1) | |
query = sys.argv[1] | |
def vector_search(query): | |
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
vector = model.encode(sys.argv[1]) | |
results = xata.data().vector_search("docs", { | |
"queryVector": vector.tolist() * 4, | |
"column": "embedding", | |
"size": 5 | |
}) | |
if not results.is_success(): | |
raise Exception(f"Vector search failed: {results.json()}") | |
return results | |
def keyword_search(query): | |
results = xata.data().search_table("docs", { | |
"query": query, | |
"fuzziness": 1, | |
"prefix": "phrase", | |
"page": { | |
"size": 5 | |
} | |
}) | |
if not results.is_success(): | |
raise Exception(f"Keyword search failed: {results.json()}") | |
return results | |
def rerank_with_rrf(keyword_results, vector_results, k=60): | |
"""Computes the reciprocal rank fusion of two search results.""" | |
# Combine and initialize scores | |
unique_results = {result["id"]: result for result in keyword_results + vector_results} | |
scores = {result_id: 0 for result_id in unique_results.keys()} | |
# Helper to update scores based on RRF formula | |
def update_scores(results_list, scores, k): | |
for rank, result in enumerate(results_list, start=1): | |
result_id = result['id'] | |
if result_id in scores: | |
scores[result_id] += 1 / (k + rank) | |
# Update scores for both sets of results | |
update_scores(keyword_results, scores, k) | |
update_scores(vector_results, scores, k) | |
# Sort results by their RRF score in descending order | |
sorted_result_ids = sorted(scores.keys(), key=lambda id: scores[id], reverse=True) | |
# Extract the sorted result objects | |
sorted_results = [unique_results[result_id] for result_id in sorted_result_ids] | |
return sorted_results | |
def main(): | |
vector_results = vector_search(query) | |
# Note: in a real application, it would make sense to run these searches in parallel | |
print(f"Semantic search results:") | |
for result in vector_results["records"]: | |
print(f'{result["id"]}\t{result["sentence"]}\t{result["xata"]["score"]}') | |
keyword_results = keyword_search(query) | |
print(f"\nKeyword search results:") | |
for result in keyword_results["records"]: | |
print(f'{result["id"]}\t{result["sentence"]}\t{result["xata"]["score"]}') | |
results = rerank_with_rrf(keyword_results["records"], vector_results["records"]) | |
print(f"\nReranked results:") | |
for result in results: | |
print(f'{result["id"]}\t{result["sentence"]}') | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment