Skip to content

Commit

Permalink
Merge pull request #20 from JanKalo/master
Browse files Browse the repository at this point in the history
Added bigscience-LAMA evaluation
  • Loading branch information
StellaAthena authored Apr 29, 2022
2 parents 2e0b659 + 49f117e commit 372ca6f
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 0 deletions.
16 changes: 16 additions & 0 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sacrebleu
import lm_eval.base


from . import superglue
from . import glue
from . import arc
Expand Down Expand Up @@ -54,12 +55,15 @@
from . import storycloze
from . import hans
from . import gem_webnlg
from . import lama
# from . import e2e_nlg_cleaned
from . import gem_xsum
from . import gem_mlsum
from . import wino_bias
from . import e2e_nlg_cleaned
from . import gem_asset_turk
from . import crows_pairs_multilingual
from . import lama

from . import HuffPost
########################################
Expand Down Expand Up @@ -139,6 +143,10 @@
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet
"lama_trex": lama.Trex,
"lama_squad": lama.Squad,
"lama_google_re": lama.google_re,
"lama_concptnet": lama.Conceptnet,
"logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag,
"openbookqa": openbookqa.OpenBookQA,
Expand All @@ -162,6 +170,8 @@
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue,
#"tydiqa_primary" : TyDiQA.Primary, not implemented yet
#"tydiqa_secondary" : TyDiQA.Secondary, not implemented yet
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue
Expand Down Expand Up @@ -314,6 +324,12 @@
"gem_xsum_challenge_test_nopunc": gem_xsum.GEMXSUMChallgeTestNopunc,
"gem_xsum_challenge_test_covid": gem_xsum.GEMXSUMChallgeTestCovid,

#LAMA
"lama-trex": lama.Trex,
"lama-squad": lama.Squad,
"lama-google_re": lama.google_re,
"lama-concptnet": lama.Conceptnet,
"bigscience-lama":lama.BigScienceLAMA,
# WinoBias
"wino_bias_type1_pro": wino_bias.WinoBiasType1Pro,
"wino_bias_type1_anti": wino_bias.WinoBiasType1Anti,
Expand Down
288 changes: 288 additions & 0 deletions lm_eval/tasks/lama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
"""
https://arxiv.org/abs/1909.01066
https://arxiv.org/abs/2005.04611
LAMA is a prob dataset to test the factual and commonsense knowledge in language models. The dataset includes a subset of
Google_RE (https://code.google.com/archive/p/relation-extraction-corpus/), TRex (subset of wikidata triples),
Conceptnet (https://github.com/commonsense/conceptnet5/wiki) and Squad.
Homepage: https://github.com/facebookresearch/LAMA
"""
from lm_eval.base import PromptSourceTask
import numpy as np
from lm_eval.metrics import mean
from typing import Optional

_CITATION = """
@inproceedings{petroni2019language, title={Language Models as Knowledge Bases?},
author={F. Petroni, T. Rockt{"{a}}schel, A. H. Miller, P. Lewis, A. Bakhtin, Y. Wu and S. Riedel},
booktitle={In: Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP), 2019}, year={2019} }
@inproceedings{petroni2020how,
title={How Context Affects Language Models' Factual Predictions},
author={Fabio Petroni and Patrick Lewis and Aleksandra Piktus and Tim Rockt{"a}schel and Yuxiang Wu and Alexander H. Miller and Sebastian Riedel},
booktitle={Automated Knowledge Base Construction}, year={2020}, url={https://openreview.net/forum?id=025X0zPfn} }
"""



class BigScienceLAMA(PromptSourceTask):
VERSION = 0
DATASET_PATH = "janck/bigscience-lama"
DATASET_NAME = None


def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True
def training_docs(self):
if self.has_training_docs():
return self.dataset["train"]


class Trex(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "trex"

def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False

def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False

def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True

def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs

def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]

def test_docs(self):
if self.has_test_docs():
return self.dataset["train"]

def process_results(self, doc, results):
out = {}
#gold = doc
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target


if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example

return out

def higher_is_better(self):
return {"acc": True}

def aggregation(self):
return {"acc": mean}

def doc_to_target(self, doc):
return doc


class google_re(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "google_re"

def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False

def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False

def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True

def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs

def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]

def test_docs(self):
if self.has_test_docs():
return self.dataset["train"]

def process_results(self, doc, results):
out = {}
pred = results[0].strip()

target = self.doc_to_target(doc)['obj_label']
out["acc"] = pred == target


if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example

return out

def higher_is_better(self):
return {"acc": True}

def aggregation(self):
return {"acc": mean}

def doc_to_target(self, doc):
return doc

class Conceptnet(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "conceptnet"

def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False

def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False

def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True


def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs

def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]

def test_docs(self):
if self.has_test_docs():
return self.dataset["train"]

def process_results(self, doc, results):
out = {}
pred = results[0].strip()

target = self.doc_to_target(doc)['obj_label']
out["acc"] = pred == target


if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example

return out

def higher_is_better(self):
return {"acc": True}

def aggregation(self):
return {"acc": mean}

def doc_to_target(self, doc):
return doc


class Squad(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "squad"

def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False

def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False

def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True


def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs

def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]

def test_docs(self):
if self.has_test_docs():

self._test_docs = list(self.dataset["train"])
return self._test_docs

def process_results(self, doc, results):
out = {}
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target



if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example

return out

def higher_is_better(self):
return {"acc": True}

def aggregation(self):
return {"acc": mean}

def doc_to_target(self, doc):
return doc

def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
return 5


0 comments on commit 372ca6f

Please sign in to comment.