Created
November 8, 2023 09:29
-
-
Save littlewine/bf0723b3710433104afa6a5bf09f11d1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import time | |
import argparse | |
import pandas as pd | |
import sklearn | |
import torch | |
from tqdm.auto import tqdm | |
from simpletransformers.retrieval import RetrievalModel, RetrievalArgs | |
logging.basicConfig(level=logging.INFO) | |
transformers_logger = logging.getLogger("transformers") | |
transformers_logger.setLevel(logging.WARNING) | |
train_data_path = "../data/nq-train.tsv" | |
eval_data_path = "../data_beir/data_scifact" | |
if train_data_path.endswith(".tsv"): | |
train_data = pd.read_csv(train_data_path, sep="\t") | |
else: | |
train_data = train_data_path | |
model_args = RetrievalArgs() | |
model_args.reprocess_input_data = True | |
model_args.overwrite_output_dir = True | |
model_args.use_cached_eval_features = False | |
model_args.include_title = False if "msmarco" in train_data_path else True | |
model_args.max_seq_length = 256 | |
model_args.num_train_epochs = 1 | |
model_args.train_batch_size = 16 | |
model_args.eval_batch_size = 300 | |
model_args.use_hf_datasets = True | |
model_args.learning_rate = 1e-6 | |
model_args.warmup_steps = 5000 | |
model_args.save_steps = -1 | |
model_args.evaluate_during_training = True | |
model_args.evaluate_during_training_steps = 1000 | |
model_args.evaluate_during_training_verbose = True | |
model_args.save_model_every_epoch = False | |
model_args.save_eval_checkpoints = False | |
model_args.save_best_model = True | |
model_args.early_stopping_metric = "recip_rank" | |
model_args.early_stopping_metric_minimize = False | |
model_args.evaluate_each_epoch = False | |
model_args.wandb_project = "IR2 Demo" | |
model_args.hard_negatives_in_eval = False | |
model_args.hard_negatives = False | |
model_args.n_gpu = 1 | |
model_args.evaluate_with_beir = False | |
model_args.data_format = "beir" | |
model_args.wandb_kwargs = {"name": f"repro-dpr-epochs-{model_args.num_train_epochs}-batch_size-{model_args.train_batch_size}"} | |
model_args.output_dir = ( | |
f"../models/dpr-epochs-{model_args.num_train_epochs}-batch_size-{model_args.train_batch_size}" | |
) | |
model_args.best_model_dir = model_args.output_dir + "/best_model" | |
model_type = "custom" | |
model_name = None | |
context_name = "bert-base-cased" | |
question_name = "bert-base-cased" | |
if __name__ == "__main__": | |
from multiprocess import set_start_method | |
set_start_method("spawn") | |
# Create a TransformerModel | |
model = RetrievalModel( | |
model_type, | |
model_name, | |
context_name, | |
question_name, | |
args=model_args, | |
) | |
model.train_model( | |
train_data, | |
# clustered_training=True, | |
eval_data=eval_data_path, | |
eval_set="test", | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment