diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index 7aa7b8e47c..0be60a570d 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -5,6 +5,7 @@ import sacrebleu import lm_eval.base + from . import superglue from . import glue from . import arc @@ -54,12 +55,15 @@ from . import storycloze from . import hans from . import gem_webnlg +from . import lama +# from . import e2e_nlg_cleaned from . import gem_xsum from . import gem_mlsum from . import wino_bias from . import e2e_nlg_cleaned from . import gem_asset_turk from . import crows_pairs_multilingual +from . import lama from . import HuffPost ######################################## @@ -139,6 +143,10 @@ "arc_easy": arc.ARCEasy, "arc_challenge": arc.ARCChallenge, # "quac": quac.QuAC, # not implemented yet + "lama_trex": lama.Trex, + "lama_squad": lama.Squad, + "lama_google_re": lama.google_re, + "lama_concptnet": lama.Conceptnet, "logiqa": logiqa.LogiQA, "hellaswag": hellaswag.HellaSwag, "openbookqa": openbookqa.OpenBookQA, @@ -162,6 +170,8 @@ "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_virtue": hendrycks_ethics.EthicsVirtue, + #"tydiqa_primary" : TyDiQA.Primary, not implemented yet + #"tydiqa_secondary" : TyDiQA.Secondary, not implemented yet "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, # dialogue @@ -314,6 +324,12 @@ "gem_xsum_challenge_test_nopunc": gem_xsum.GEMXSUMChallgeTestNopunc, "gem_xsum_challenge_test_covid": gem_xsum.GEMXSUMChallgeTestCovid, + #LAMA + "lama-trex": lama.Trex, + "lama-squad": lama.Squad, + "lama-google_re": lama.google_re, + "lama-concptnet": lama.Conceptnet, + "bigscience-lama":lama.BigScienceLAMA, # WinoBias "wino_bias_type1_pro": wino_bias.WinoBiasType1Pro, "wino_bias_type1_anti": wino_bias.WinoBiasType1Anti, diff --git a/lm_eval/tasks/lama.py b/lm_eval/tasks/lama.py new file mode 100644 index 0000000000..6297599ed4 --- /dev/null +++ b/lm_eval/tasks/lama.py @@ -0,0 +1,288 @@ +""" +https://arxiv.org/abs/1909.01066 +https://arxiv.org/abs/2005.04611 +LAMA is a prob dataset to test the factual and commonsense knowledge in language models. The dataset includes a subset of +Google_RE (https://code.google.com/archive/p/relation-extraction-corpus/), TRex (subset of wikidata triples), +Conceptnet (https://github.com/commonsense/conceptnet5/wiki) and Squad. + +Homepage: https://github.com/facebookresearch/LAMA +""" +from lm_eval.base import PromptSourceTask +import numpy as np +from lm_eval.metrics import mean +from typing import Optional + +_CITATION = """ +@inproceedings{petroni2019language, title={Language Models as Knowledge Bases?}, + author={F. Petroni, T. Rockt{"{a}}schel, A. H. Miller, P. Lewis, A. Bakhtin, Y. Wu and S. Riedel}, + booktitle={In: Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP), 2019}, year={2019} } + +@inproceedings{petroni2020how, + title={How Context Affects Language Models' Factual Predictions}, + author={Fabio Petroni and Patrick Lewis and Aleksandra Piktus and Tim Rockt{"a}schel and Yuxiang Wu and Alexander H. Miller and Sebastian Riedel}, + booktitle={Automated Knowledge Base Construction}, year={2020}, url={https://openreview.net/forum?id=025X0zPfn} } +""" + + + +class BigScienceLAMA(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "janck/bigscience-lama" + DATASET_NAME = None + + + def has_training_docs(self): + # TODO: Fill in the return with `True` if the Task has training data; else `False`. + return False + def has_validation_docs(self): + # TODO: Fill in the return with `True` if the Task has validation data; else `False`. + return False + def has_test_docs(self): + # TODO: Fill in the return with `True` if the Task has test data; else `False`. + return True + def training_docs(self): + if self.has_training_docs(): + return self.dataset["train"] + + +class Trex(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "lama" + DATASET_NAME = "trex" + + def has_training_docs(self): + # TODO: Fill in the return with `True` if the Task has training data; else `False`. + return False + + def has_validation_docs(self): + # TODO: Fill in the return with `True` if the Task has validation data; else `False`. + return False + + def has_test_docs(self): + # TODO: Fill in the return with `True` if the Task has test data; else `False`. + return True + + def training_docs(self): + if self.has_training_docs(): + if self._training_docs is None: + self._training_docs = list(self.dataset["train"]) + return self._training_docs + + def validation_docs(self): + if self.has_validation_docs(): + return self.dataset["validation"] + + def test_docs(self): + if self.has_test_docs(): + return self.dataset["train"] + + def process_results(self, doc, results): + out = {} + #gold = doc + pred = results[0].strip() + target = self.doc_to_target(doc)['obj_label'] + #pred = np.argmax(results) + out["acc"] = pred == target + + + if self.save_examples: + example = { + "pred": pred, + "target": target, + } + return out, example + + return out + + def higher_is_better(self): + return {"acc": True} + + def aggregation(self): + return {"acc": mean} + + def doc_to_target(self, doc): + return doc + + +class google_re(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "lama" + DATASET_NAME = "google_re" + + def has_training_docs(self): + # TODO: Fill in the return with `True` if the Task has training data; else `False`. + return False + + def has_validation_docs(self): + # TODO: Fill in the return with `True` if the Task has validation data; else `False`. + return False + + def has_test_docs(self): + # TODO: Fill in the return with `True` if the Task has test data; else `False`. + return True + + def training_docs(self): + if self.has_training_docs(): + if self._training_docs is None: + self._training_docs = list(self.dataset["train"]) + return self._training_docs + + def validation_docs(self): + if self.has_validation_docs(): + return self.dataset["validation"] + + def test_docs(self): + if self.has_test_docs(): + return self.dataset["train"] + + def process_results(self, doc, results): + out = {} + pred = results[0].strip() + + target = self.doc_to_target(doc)['obj_label'] + out["acc"] = pred == target + + + if self.save_examples: + example = { + "pred": pred, + "target": target, + } + return out, example + + return out + + def higher_is_better(self): + return {"acc": True} + + def aggregation(self): + return {"acc": mean} + + def doc_to_target(self, doc): + return doc + +class Conceptnet(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "lama" + DATASET_NAME = "conceptnet" + + def has_training_docs(self): + # TODO: Fill in the return with `True` if the Task has training data; else `False`. + return False + + def has_validation_docs(self): + # TODO: Fill in the return with `True` if the Task has validation data; else `False`. + return False + + def has_test_docs(self): + # TODO: Fill in the return with `True` if the Task has test data; else `False`. + return True + + + def training_docs(self): + if self.has_training_docs(): + if self._training_docs is None: + self._training_docs = list(self.dataset["train"]) + return self._training_docs + + def validation_docs(self): + if self.has_validation_docs(): + return self.dataset["validation"] + + def test_docs(self): + if self.has_test_docs(): + return self.dataset["train"] + + def process_results(self, doc, results): + out = {} + pred = results[0].strip() + + target = self.doc_to_target(doc)['obj_label'] + out["acc"] = pred == target + + + if self.save_examples: + example = { + "pred": pred, + "target": target, + } + return out, example + + return out + + def higher_is_better(self): + return {"acc": True} + + def aggregation(self): + return {"acc": mean} + + def doc_to_target(self, doc): + return doc + + +class Squad(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "lama" + DATASET_NAME = "squad" + + def has_training_docs(self): + # TODO: Fill in the return with `True` if the Task has training data; else `False`. + return False + + def has_validation_docs(self): + # TODO: Fill in the return with `True` if the Task has validation data; else `False`. + return False + + def has_test_docs(self): + # TODO: Fill in the return with `True` if the Task has test data; else `False`. + return True + + + def training_docs(self): + if self.has_training_docs(): + if self._training_docs is None: + self._training_docs = list(self.dataset["train"]) + return self._training_docs + + def validation_docs(self): + if self.has_validation_docs(): + return self.dataset["validation"] + + def test_docs(self): + if self.has_test_docs(): + + self._test_docs = list(self.dataset["train"]) + return self._test_docs + + def process_results(self, doc, results): + out = {} + pred = results[0].strip() + target = self.doc_to_target(doc)['obj_label'] + #pred = np.argmax(results) + out["acc"] = pred == target + + + + if self.save_examples: + example = { + "pred": pred, + "target": target, + } + return out, example + + return out + + def higher_is_better(self): + return {"acc": True} + + def aggregation(self): + return {"acc": mean} + + def doc_to_target(self, doc): + return doc + + def max_generation_length(self) -> Optional[int]: + """Denote where the max length of the generation if it is obvious from the task.""" + return 5 + +