Created
April 4, 2024 18:45
-
-
Save konradzdunczyk/cdc3041e8cc48df096b599c71f2d3b74 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from qdrant_client import QdrantClient | |
from qdrant_client.models import PointStruct, PointsBatch, Batch, Filter, FieldCondition, MatchValue, VectorParams, Distance | |
from openai import OpenAI | |
from langchain_core.documents import Document | |
from langchain_community.document_loaders import TextLoader | |
from langchain_openai.embeddings import OpenAIEmbeddings | |
from dotenv import load_dotenv | |
import os | |
import uuid | |
MEMORY_PATH = "memory.md" | |
COLLECTION_NAME = "ai_devs" | |
qdrant = QdrantClient(url=os.getenv("QDRANT_URL")) | |
embeddings = OpenAIEmbeddings() | |
query = "Do you know the name of Adam's dog?" | |
result = qdrant.get_collections() | |
indexed = next((element for element in result.collections if element.name == COLLECTION_NAME), None) | |
print(result) | |
if not indexed: | |
qdrant.create_collection(COLLECTION_NAME, vectors_config=VectorParams(size=1536, distance=Distance.COSINE, on_disk=True)) | |
collectionInfo = qdrant.get_collection(COLLECTION_NAME) | |
if collectionInfo.points_count == 0: | |
loader = TextLoader(MEMORY_PATH) | |
memory = loader.load() | |
documents = list() | |
for mem in memory: | |
lines = mem.page_content.split("\n\n") | |
for line in lines: | |
documents.append(Document(line)) | |
for document in documents: | |
document.metadata["source"] = COLLECTION_NAME | |
document.metadata["content"] = document.page_content | |
document.metadata["uuid"] = str(uuid.uuid4()) | |
points = list() | |
for document in documents: | |
embedding = embeddings.embed_documents(list(document.page_content), chunk_size=1) | |
points.append(PointStruct( | |
id=document.metadata["uuid"], | |
payload=document.metadata, | |
vector=embedding[0] | |
)) | |
batchPoints = PointsBatch(batch=Batch( | |
ids=list(map(lambda p: p.id, points)), | |
payloads=list(map(lambda p: p.payload, points)), | |
vectors=list(map(lambda p: p.vector, points)) | |
)) | |
qdrant.upsert(COLLECTION_NAME, points=batchPoints, wait=True) | |
queryEmbedding = embeddings.embed_query(text=query) | |
search = qdrant.search( | |
COLLECTION_NAME, | |
query_vector=queryEmbedding, | |
limit=1, | |
query_filter=Filter(must=[ | |
FieldCondition( | |
key="source", | |
match=MatchValue(value=COLLECTION_NAME) | |
) | |
])) | |
print(search) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment