Skip to content

Instantly share code, notes, and snippets.

@7shi
Last active December 7, 2024 09:11
Show Gist options
  • Save 7shi/c589bba6e739304a5098c8a3f2f55cc8 to your computer and use it in GitHub Desktop.
Save 7shi/c589bba6e739304a5098c8a3f2f55cc8 to your computer and use it in GitHub Desktop.
[py] test Ruri text embeddings
import argparse
parser = argparse.ArgumentParser(description='Process text file and create tensor embeddings')
parser.add_argument('textfile', help='Input text file path')
parser.add_argument('--ollama', action='store_true', help='Use Ollama')
arg = parser.parse_args()
import os, torch, safetensors.torch
from tqdm import tqdm
if arg.ollama:
import ollama
model = "kun432/cl-nagoya-ruri-large"
def embed(s):
return torch.tensor([ollama.embeddings(model=model, prompt=s).embedding])
else:
from sentence_transformers import SentenceTransformer
# Download from the 🤗 Hub
model = SentenceTransformer("cl-nagoya/ruri-base")
def embed(s):
return model.encode([s], convert_to_tensor=True)
tensorfile = os.path.splitext(arg.textfile)[0] + ".safetensors"
# Don't forget to add the prefix "クエリ: " for query-side or "文章: " for passage-side texts.
with open(arg.textfile, "r", encoding="utf-8") as f:
lines = [l for line in f if (l := line.strip())]
test = embed("文章: test")[0]
print("vector size:", len(test))
tensor = torch.zeros(len(lines), len(test), dtype=torch.float32)
for i, line in tqdm(enumerate(lines), total=len(lines)):
# print(f"{i+1} / {len(lines)} {line}")
tensor[i, :] = embed(f"文章: {line}")[0]
safetensors.torch.save_file({"lines": tensor}, tensorfile)
import argparse
parser = argparse.ArgumentParser(description='Process text file and create tensor embeddings')
parser.add_argument('textfile', help='Input text file path')
parser.add_argument('--ollama', action='store_true', help='Use Ollama')
arg = parser.parse_args()
import os, torch, torch.nn.functional as F, safetensors.torch
if arg.ollama:
import ollama
model = "kun432/cl-nagoya-ruri-large"
def embed(s):
return torch.tensor([ollama.embeddings(model=model, prompt=s).embedding])
else:
from sentence_transformers import SentenceTransformer
# Download from the 🤗 Hub
model = SentenceTransformer("cl-nagoya/ruri-base")
def embed(s):
return model.encode([s], convert_to_tensor=True)
tensorfile = os.path.splitext(arg.textfile)[0] + ".safetensors"
with open(arg.textfile, "r", encoding="utf-8") as f:
lines = [l for line in f if (l := line.strip())]
tensor = safetensors.torch.load_file(tensorfile)["lines"]
# Don't forget to add the prefix "クエリ: " for query-side or "文章: " for passage-side texts.
while True:
print()
try:
q = input("> ")
except:
print()
break
embeddings = embed(f"クエリ: {q}")
similarities = F.cosine_similarity(tensor, embeddings, dim=1)
for i, (value, index) in enumerate(zip(*torch.topk(similarities, k=10))):
v, idx = value.item(), index.item()
print(f"{i+1:2d}: {v:.5f} {idx + 1:4d} {lines[idx]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment