From d1fe3ff71c1722aaec8d3426e4634757de6ac4d4 Mon Sep 17 00:00:00 2001 From: Jan Heinrich Merker Date: Thu, 29 Aug 2024 23:27:38 +0200 Subject: [PATCH] Add neural re-rankers --- trec_biogen/optimization.py | 135 +++++++++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 8 deletions(-) diff --git a/trec_biogen/optimization.py b/trec_biogen/optimization.py index 0327826..1cdeec5 100644 --- a/trec_biogen/optimization.py +++ b/trec_biogen/optimization.py @@ -1,5 +1,5 @@ -from typing import Callable, Literal, Sequence -from warnings import catch_warnings, simplefilter +from typing import Callable, Literal, Sequence, TypeAlias +from warnings import catch_warnings, simplefilter, filterwarnings from dspy import settings as dspy_settings from optuna import Study, Trial, create_study @@ -9,6 +9,8 @@ from optuna.trial import FrozenTrial from optuna_integration import WeightsAndBiasesCallback from pyterrier.transformer import Transformer +from pyterrier_t5 import MonoT5ReRanker, DuoT5ReRanker +from pyterrier_dr import TasB, TctColBert, Ance from trec_biogen.answering import IndependentAnsweringModule, RecurrentAnsweringModule from trec_biogen.dspy_generation import ( @@ -27,6 +29,7 @@ from trec_biogen.language_models import LanguageModelName, get_dspy_language_model from trec_biogen.model import Answer from trec_biogen.modules import AnsweringModule, GenerationModule, RetrievalModule +from trec_biogen.pyterrier import CutoffRerank from trec_biogen.pyterrier_pubmed import ( PubMedElasticsearchRetrieve, PubMedSentencePassager, @@ -56,6 +59,21 @@ def _suggest_must_should(trial: Trial, name: str) -> Literal["must", "should"] | raise ValueError(f"Illegal value: {must_should}") + +PointwiseRerankerModel: TypeAlias = Literal[ + "castorini/monot5-base-msmarco", + "castorini/monot5-3b-msmarco", + "castorini/monot5-3b-med-msmarco", + "sentence-transformers/msmarco-distilbert-base-tas-b", + "sentence-transformers/msmarco-roberta-base-ance-firstp", + "castorini/tct_colbert-v2-hnp-msmarco", +] +PairwiseRerankerModel: TypeAlias = Literal[ + "castorini/duot5-base-msmarco", + "castorini/duot5-3b-msmarco", + "castorini/duot5-3b-med-msmarco", +] + def build_retrieval_module( trial: Trial, ) -> RetrievalModule: @@ -207,13 +225,114 @@ def build_retrieval_module( ) pipeline = pipeline >> pubmed_sentence_passager - # TODO: Re-ranking. + # Pointwise re-ranking. + pointwise_reranker_model: PointwiseRerankerModel | None = trial.suggest_categorical( + name="pointwise_reranker_model", + choices=[ + "castorini/monot5-base-msmarco", + "castorini/monot5-3b-msmarco", + "castorini/monot5-3b-med-msmarco", + "sentence-transformers/msmarco-distilbert-base-tas-b", + "sentence-transformers/msmarco-roberta-base-ance-firstp", + "castorini/tct_colbert-v2-hnp-msmarco", + None, + ], + ) # type: ignore + pointwise_reranker: Transformer | None + if pointwise_reranker_model in ( + "castorini/monot5-base-msmarco", + "castorini/monot5-3b-msmarco", + "castorini/monot5-3b-med-msmarco", + ): + pointwise_reranker = MonoT5ReRanker( + model=pointwise_reranker_model, + verbose=True, + ) + elif pointwise_reranker_model in ( + "sentence-transformers/msmarco-distilbert-base-tas-b", + ): + with catch_warnings(): + filterwarnings( + action="ignore", + message="TypedStorage is deprecated", + category=UserWarning, + ) + pointwise_reranker = TasB( + model_name=pointwise_reranker_model, + verbose=True, + ) + elif pointwise_reranker_model in ( + "castorini/tct_colbert-v2-hnp-msmarco", + ): + pointwise_reranker = TctColBert( + model_name=pointwise_reranker_model, + verbose=True, + ) + elif pointwise_reranker_model in ( + "sentence-transformers/msmarco-roberta-base-ance-firstp", + ): + pointwise_reranker = Ance( + model_name=pointwise_reranker_model, + verbose=True, + ) + else: + pointwise_reranker = None + if pointwise_reranker is not None: + pointwise_reranker_cutoff=trial.suggest_categorical( + name="pointwise_reranker_cutoff", + choices=[ + 10, + 50, + 100, + ] + ) + pipeline = CutoffRerank( + candidates=pipeline, + reranker=pointwise_reranker, + cutoff=pointwise_reranker_cutoff, + ) - # max_sentences=trial.suggest_int( - # name="pubmed_sentence_passager_max_sentences", - # low=1, - # high=5, - # ), + # Pairwise re-ranking. + pairwise_reranker_model: PairwiseRerankerModel | None = trial.suggest_categorical( + name="pairwise_reranker_model", + choices=[ + "castorini/duot5-base-msmarco", + "castorini/duot5-3b-msmarco", + "castorini/duot5-3b-med-msmarco", + None, + ], + ) # type: ignore + pairwise_reranker: Transformer | None + if pairwise_reranker_model in ( + "castorini/duot5-base-msmarco", + "castorini/duot5-3b-msmarco", + "castorini/duot5-3b-med-msmarco", + ): + with catch_warnings(): + filterwarnings( + action="ignore", + message="TypedStorage is deprecated", + category=UserWarning, + ) + pairwise_reranker = DuoT5ReRanker( + model=pairwise_reranker_model, + verbose=True, + ) + else: + pairwise_reranker = None + if pairwise_reranker is not None: + pairwise_reranker_cutoff=trial.suggest_categorical( + name="pairwise_reranker_cutoff", + choices=[ + 3, + 5, + ] + ) + pipeline = CutoffRerank( + candidates=pipeline, + reranker=pairwise_reranker, + cutoff=pairwise_reranker_cutoff, + ) retrieval_module = PyterrierRetrievalModule(pipeline, progress=True)