Last active
September 17, 2021 22:42
-
-
Save leogao2/fdc03b2a442f7e6408df18b54d3832ed to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import lm_eval.evaluator | |
import lm_eval.tasks | |
import lm_eval.models | |
import lm_eval.base | |
from lm_eval.base import rf | |
from lm_eval.metrics import mean | |
import json | |
import numpy as np | |
import random | |
random.seed(42) | |
ob = json.load(open("mc_task.json")) | |
class TruthfulTask(lm_eval.base.Task): | |
VERSION = 0 | |
"""A task represents an entire benchmark including its dataset, problems, | |
answers, and evaluation methods. See BoolQ for a simple example implementation | |
A `doc` can be any python object which represents one instance of evaluation. | |
This is usually a dictionary e.g. | |
{"question": ..., "answer": ...} or | |
{"question": ..., question, answer) | |
""" | |
def download(self): | |
"""Downloads the task dataset if necessary""" | |
pass | |
def has_training_docs(self): | |
"""Whether the task has a training set""" | |
return False | |
def has_validation_docs(self): | |
"""Whether the task has a validation set""" | |
return True | |
def has_test_docs(self): | |
"""Whether the task has a test set""" | |
return False | |
def training_docs(self): | |
""" | |
:return: Iterable[obj] | |
A iterable of any object, that doc_to_text can handle | |
""" | |
return [] | |
def validation_docs(self): | |
""" | |
:return: Iterable[obj] | |
A iterable of any object, that doc_to_text can handle | |
""" | |
return ob | |
def test_docs(self): | |
""" | |
:return: Iterable[obj] | |
A iterable of any object, that doc_to_text can handle | |
""" | |
return [] | |
def doc_to_text(self, doc): | |
return "Question: " + doc["question"] + "\nAnswer:" | |
def doc_to_target(self, doc): | |
ret, = [k for k,v in doc["mc1_targets"].items() if v == 1] | |
return " " + ret | |
def construct_requests(self, doc, ctx): | |
return [rf.loglikelihood(ctx, " " + k)[0] for k,v in doc["mc1_targets"].items() if v == 1] \ | |
+ [rf.loglikelihood(ctx, " " + k)[0] for k,v in doc["mc1_targets"].items() if v == 0] | |
def process_results(self, doc, results): | |
"""Take a single document and the LM results and evaluates, returning a | |
dict where keys are the names of submetrics and values are the values of | |
the metric for that one document | |
:param doc: | |
The document as returned from training_docs, validation_docs, or test_docs. | |
:param results: | |
The results of the requests created in construct_requests. | |
""" | |
pred = np.argmax(results) | |
return { | |
"acc": pred == 0 | |
} | |
def aggregation(self): | |
""" | |
:returns: {str: [float] -> float} | |
A dictionary where keys are the names of submetrics and values are | |
functions that aggregate a list of metrics | |
""" | |
return { | |
"acc": mean | |
} | |
def higher_is_better(self): | |
""" | |
:returns: {str: bool} | |
A dictionary where keys are the names of submetrics and values are | |
whether a higher value of the submetric is better | |
""" | |
return { | |
"acc": True | |
} | |
def fewshot_description(self): | |
return "" | |
class UntruthfulTask(TruthfulTask): | |
def doc_to_target(self, doc): | |
ret = random.choice([k for k,v in doc["mc1_targets"].items() if v == 0]) | |
return " " + ret | |
lm_eval.tasks.TASK_REGISTRY["truthful"] = TruthfulTask | |
lm_eval.tasks.TASK_REGISTRY["untruthful"] = UntruthfulTask | |
model = 'gpt2' | |
model_args = 'pretrained=EleutherAI/gpt-neo-1.3B' | |
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { | |
'batch_size': 1, 'device': 'cuda' | |
}) | |
lm = lm_eval.base.CachingLM(lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db') | |
res = lm_eval.evaluator.evaluate(lm, lm_eval.tasks.get_task_dict(["truthful", "untruthful"]), provide_description=False, num_fewshot=3, limit=None, bootstrap_iters=100000) | |
print(res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment