Created
September 8, 2024 01:22
-
-
Save ColeMurray/1e6bc35f0c0bb46fd87a995f77d741dd to your computer and use it in GitHub Desktop.
Example using OpenHands to create a vector database server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional, Dict | |
import faiss | |
import numpy as np | |
import os | |
app = FastAPI() | |
class QueryRequest(BaseModel): | |
query: List[float] | |
namespace: str | |
identifier: str | |
num_results: Optional[int] = 10 | |
import pickle | |
class VectorDatabase: | |
def __init__(self, base_path: str = "./vector_dbs"): | |
self.base_path = base_path | |
self.databases: Dict[str, faiss.IndexFlatL2] = {} | |
self.vectors: Dict[str, List[np.ndarray]] = {} | |
self.ids: Dict[str, List[int]] = {} | |
os.makedirs(base_path, exist_ok=True) | |
self.load_all_dbs() | |
def db_path(self, namespace: str, identifier: str) -> str: | |
return os.path.join(self.base_path, f"{namespace}_{identifier}.index") | |
def metadata_path(self, namespace: str, identifier: str) -> str: | |
return os.path.join(self.base_path, f"{namespace}_{identifier}.metadata") | |
def load_all_dbs(self): | |
for filename in os.listdir(self.base_path): | |
if filename.endswith(".index"): | |
namespace, identifier = filename[:-6].split("_", 1) | |
self.load_db(namespace, identifier) | |
def load_db(self, namespace: str, identifier: str): | |
db_key = f"{namespace}_{identifier}" | |
index_path = self.db_path(namespace, identifier) | |
metadata_path = self.metadata_path(namespace, identifier) | |
print(f"Checking index path: {index_path}") | |
print(f"Checking metadata path: {metadata_path}") | |
if os.path.exists(index_path): | |
self.databases[db_key] = faiss.read_index(index_path) | |
if os.path.exists(metadata_path): | |
with open(metadata_path, 'rb') as f: | |
metadata = pickle.load(f) | |
self.vectors[db_key] = metadata['vectors'] | |
self.ids[db_key] = metadata['ids'] | |
print(f"Metadata file contents: {metadata}") | |
else: | |
print("Metadata file does not exist") | |
self.vectors[db_key] = [] | |
self.ids[db_key] = [] | |
print(f"Loaded database: {db_key}") | |
print(f"Number of vectors: {self.databases[db_key].ntotal}") | |
print(f"Number of IDs: {len(self.ids[db_key])}") | |
print(f"IDs: {self.ids[db_key]}") | |
else: | |
print(f"Index file does not exist: {index_path}") | |
def save_db(self, namespace: str, identifier: str): | |
db_key = f"{namespace}_{identifier}" | |
if db_key in self.databases: | |
index_path = self.db_path(namespace, identifier) | |
metadata_path = self.metadata_path(namespace, identifier) | |
faiss.write_index(self.databases[db_key], index_path) | |
with open(metadata_path, 'wb') as f: | |
pickle.dump({'vectors': self.vectors[db_key], 'ids': self.ids[db_key]}, f) | |
print(f"Saved database: {db_key}") | |
def create_or_get_db(self, namespace: str, identifier: str, dim: int): | |
db_key = f"{namespace}_{identifier}" | |
if db_key not in self.databases: | |
self.databases[db_key] = faiss.IndexFlatL2(dim) | |
self.vectors[db_key] = [] | |
self.ids[db_key] = [] | |
return self.databases[db_key] | |
def add_vector(self, namespace: str, identifier: str, vector: List[float], id: int): | |
db_key = f"{namespace}_{identifier}" | |
db = self.create_or_get_db(namespace, identifier, len(vector)) | |
np_vector = np.array([vector], dtype=np.float32) | |
db.add(np_vector) | |
self.vectors[db_key].append(np_vector) | |
self.ids[db_key].append(id) | |
self.save_db(namespace, identifier) | |
def query(self, namespace: str, identifier: str, query_vector: List[float], k: int): | |
db_key = f"{namespace}_{identifier}" | |
if db_key not in self.databases: | |
raise HTTPException(status_code=404, detail="Database not found") | |
db = self.databases[db_key] | |
np_query = np.array([query_vector], dtype=np.float32) | |
distances, indices = db.search(np_query, k) | |
results = [] | |
for idx, score in zip(indices[0], distances[0]): | |
if idx < len(self.ids[db_key]): | |
results.append({"id": self.ids[db_key][idx], "score": float(score)}) | |
else: | |
results.append({"id": None, "score": float(score)}) | |
return results | |
vector_db = VectorDatabase() | |
@app.post("/query") | |
def query_vector_db(request: QueryRequest): | |
results = vector_db.query(request.namespace, request.identifier, request.query, request.num_results) | |
return {"results": results} | |
@app.post("/add_vector") | |
def add_vector(namespace: str, identifier: str, vector: List[float], id: int): | |
vector_db.add_vector(namespace, identifier, vector, id) | |
return {"message": "Vector added successfully"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment