Created
March 30, 2023 17:54
-
-
Save alexlimh/82a1683cd86a40521004252fdbc9b739 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# (c) Facebook, Inc. and its affiliates. Confidential and proprietary. | |
# @manual=//faiss/python:pyfaiss | |
import faiss | |
import hydra | |
import glob | |
import json | |
import os | |
import pickle | |
import torch | |
import numpy as np | |
from dpr_scale.conf.config import MainConfig | |
from dpr_scale.datamodule.dpr import CSVDataset, QueryCSVDataset, QueryTRECDataset | |
from omegaconf import open_dict | |
from pytorch_lightning.trainer import Trainer | |
from typing import Dict, List | |
import time | |
from tqdm import tqdm, trange | |
def merge_results( | |
passages: Dict, | |
questions: List, | |
top_doc_ids: List, | |
scores_list: List, | |
): | |
# join passages text with the result ids, their questions | |
merged_data = [] | |
assert len(top_doc_ids) == len(questions) == len(scores_list) | |
for i, question, doc_ids, scores in zip(range(len(questions)), questions, top_doc_ids, scores_list): | |
ctxs = [ | |
{ | |
"id": passages[id]["id"], | |
"title": passages[id]["title"], | |
"text": passages[id]["text"], | |
"score": float(score), | |
} | |
for id, score in zip(doc_ids, scores) | |
] | |
merged_data.append( | |
{ | |
"question": question["question"], | |
"answers": question["answers"] if "answers" in question else [], | |
"ctxs": ctxs, | |
"id": question.get("id", i), | |
} | |
) | |
return merged_data | |
def build_index(paths): | |
index = None | |
vectors = [] | |
for fname in paths: | |
with open(fname, 'rb') as f: | |
vector = pickle.load(f) | |
if not index: | |
index = faiss.IndexFlatIP(vector.size()[1]) | |
print(f"Adding {vector.size()} vectors from {fname}") | |
index.add(vector.numpy()) | |
vectors.append(vector) | |
vectors = torch.cat(vectors, 0) | |
return index, vectors | |
def sort(scores, topk): | |
top_ids = np.argpartition(scores, -topk, axis=1)[:, -topk:] # linear time partition but shuffled | |
top_scores = np.take_along_axis(scores, top_ids, axis=1) | |
top_subset_ids = np.argsort(-1.*top_scores, axis=1) # sort the top-k list | |
top_scores = np.take_along_axis(top_scores, top_subset_ids, axis=1) | |
top_ids = np.take_along_axis(top_ids, top_subset_ids, axis=1) | |
return top_scores, top_ids | |
@hydra.main(config_path="conf", config_name="config") | |
def main(cfg: MainConfig): | |
# Temp patch for datamodule refactoring | |
cfg.task.datamodule = None | |
cfg.task._target_ = ( | |
"dpr_scale.task.dpr_eval_task.GenerateQueryEmbeddingsTask" # hack | |
) | |
# trainer.fit does some setup, so we need to call it even though no training is done | |
with open_dict(cfg): | |
cfg.trainer.limit_train_batches = 0 | |
if "plugins" in cfg.trainer: | |
cfg.trainer.pop( | |
"plugins" | |
) # remove ddp_sharded, because it breaks during loading | |
print(cfg) | |
task = hydra.utils.instantiate(cfg.task, _recursive_=False) | |
transform = hydra.utils.instantiate(cfg.task.transform) | |
datamodule = hydra.utils.instantiate(cfg.datamodule, transform=transform) | |
trainer = Trainer(**cfg.trainer) | |
trainer.fit(task, datamodule=datamodule) | |
trainer.test(task, datamodule=datamodule) | |
# index all passages | |
input_paths = sorted(glob.glob(os.path.join(cfg.task.ctx_embeddings_dir, "reps_*"))) | |
index, ctx_vectors = build_index(input_paths) | |
# reload question embeddings | |
print("Loading question vectors.") | |
with open( | |
task.query_emb_output_path, "rb" | |
) as f: | |
q_repr = pickle.load(f) | |
if cfg.use_gpu: | |
q_repr, ctx_vectors = q_repr.cuda(), ctx_vectors.cuda() | |
else: | |
q_repr, ctx_vectors = q_repr.numpy().astype(np.float32), ctx_vectors.numpy().astype(np.float32) | |
print("Retrieving results...") | |
retrieval_time = 0 | |
sort_time = 0 | |
all_indexes = [] | |
all_scores = [] | |
# scores, indexes = index.search(q_repr.numpy(), 100) | |
for batch_start in trange(0, len(q_repr), cfg.batch_size): | |
batch_q_repr = q_repr[batch_start: batch_start + cfg.batch_size] | |
tic = time.perf_counter() | |
if cfg.use_gpu: | |
scores = torch.matmul(batch_q_repr, ctx_vectors.T) | |
else: | |
scores = np.matmul(batch_q_repr, ctx_vectors.T) | |
toc = time.perf_counter() | |
retrieval_time += toc - tic | |
tic = time.perf_counter() | |
if cfg.use_gpu: | |
scores, indexes = scores.topk(dim=1, k=cfg.topk) | |
else: | |
scores, indexes = sort(scores, cfg.topk) | |
toc = time.perf_counter() | |
sort_time += toc - tic | |
scores, indexes = scores.tolist(), indexes.tolist() | |
all_scores.extend(scores) | |
all_indexes.extend(indexes) | |
print(f"Retrieval time:{retrieval_time:.2f}s") | |
print(f"Sorting time:{sort_time:.2f}s") | |
# load questions file | |
print(f"Loading questions file {cfg.datamodule.test_path}") | |
if "msmarco" in cfg.datamodule.test_path: | |
questions = QueryTRECDataset(cfg.datamodule.test_path) | |
else: | |
questions = QueryCSVDataset(cfg.datamodule.test_path) | |
# load all passages: | |
print(f"Loading passages from {cfg.task.passages}") | |
ctxs = CSVDataset(cfg.task.passages) | |
# write output file | |
print("Merging results...") | |
if cfg.datamodule.trec_format: | |
trec_data = [] | |
for i, (question, doc_ids, scores) in enumerate(zip(questions, all_indexes, all_scores)): | |
topic_id = question["id"] | |
for rank, (doc_id, score) in enumerate(zip(doc_ids, scores)): | |
trec_data.append(f"{topic_id} Q0 {doc_id} {rank+1} {score:.6f} dpr-scale\n") | |
print(f"Writing output to {cfg.task.output_path}") | |
os.makedirs(cfg.task.output_path, exist_ok=True) | |
with open(os.path.join(cfg.task.output_path, f"retrieval.trec"), "w") as g: | |
g.writelines(trec_data) | |
else: | |
results = merge_results(ctxs, questions, all_indexes, all_scores) | |
print(f"Writing output to {cfg.task.output_path}") | |
os.makedirs(os.path.dirname(cfg.task.output_path), exist_ok=True) | |
with open(cfg.task.output_path, "w") as g: | |
g.write(json.dumps(results, indent=4)) | |
g.write("\n") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment