Skip to content

Commit

Permalink
Prepare DSPy optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
janheinrichmerker committed Aug 29, 2024
1 parent ba07058 commit 91e2c31
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 47 deletions.
2 changes: 1 addition & 1 deletion trec_biogen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def optimize(
print(f"Found {len(answers)} answers.")

best_trials = optimize_answering_module(
ground_truth=answers,
answers=answers,
retrieval_measures=retrieval_measures,
generation_measures=generation_measures,
trials=trials,
Expand Down
108 changes: 62 additions & 46 deletions trec_biogen/optimization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Literal, Sequence
from warnings import catch_warnings, simplefilter

from dspy import settings as dspy_settings

# from dspy.teleprompt import LabeledFewShot
from dspy import settings as dspy_settings, Example
from dspy.teleprompt import LabeledFewShot
from optuna import Study, Trial, create_study
from optuna.study import StudyDirection
from optuna.exceptions import ExperimentalWarning
Expand Down Expand Up @@ -212,7 +211,15 @@ def _suggest_language_model_name(trial: Trial, name: str) -> LanguageModelName:
return language_model_name


def _as_dspy_examples(
answers: Sequence[Answer],
) -> Sequence[Example]:
return [
# TODO
]

def build_generation_module(
answers: Sequence[Answer],
trial: Trial,
) -> GenerationModule:
"""
Expand Down Expand Up @@ -255,31 +262,34 @@ def build_generation_module(
experimental=True,
)

# # Optimization:
# optimize_generation = trial.suggest_categorical(
# name="optimize_generation",
# choices=[False, True],
# )
# if optimize_generation:
# # TODO (later): Add other DSPy optimizers.
# optimizer_type: Literal["labeled-few-shot"]
# optimizer_type = "labeled-few-shot"
# if optimizer_type == "labeled-few-shot":

# # Tune the generation module with DSPy (to select few-shot examples).
# few_shot_k = trial.suggest_int(
# name="few_shot_k",
# low=1,
# high=10,
# )
# optimizer = LabeledFewShot(k=few_shot_k)
# generation_module = optimizer.compile(
# student=generation_module,
# trainset=NotImplemented,
# sample=True,
# )
# else:
# raise ValueError(f"Unkown optimizer type: {optimizer_type}")
# Optimization:
generation_optimizer: Literal[
"labeled-few-shot",
# TODO (later): Add other DSPy optimizers.
] | None = trial.suggest_categorical(
name="generation_optimizer",
choices=[
"labeled-few-shot",
# None,
],
) # type: ignore
if generation_optimizer == "labeled-few-shot":
# Tune the generation module with DSPy (to select few-shot examples).
few_shot_k = trial.suggest_int(
name="labeled_few_shot_optimizer_k",
low=1,
high=3,
)
optimizer = LabeledFewShot(k=few_shot_k)
predict = optimizer.compile(
student=predict,
trainset=_as_dspy_examples(answers),
sample=True,
)
elif generation_optimizer is None:
pass
else:
raise ValueError(f"Unkown optimizer: {generation_optimizer}")

return DspyGenerationModule(
predict=predict,
Expand All @@ -289,15 +299,16 @@ def build_generation_module(


def build_generation_augmented_retrieval_module(
answers: Sequence[Answer],
trial: Trial,
) -> RetrievalModule:
"""
Build a generation-augmented retrieval module based on hyperparameters drawn from the trial.
"""

# Build simple retrieval and generation modules.
retrieval_module = build_retrieval_module(trial=trial)
generation_module = build_generation_module(trial=trial)
retrieval_module = build_retrieval_module(trial)
generation_module = build_generation_module(answers, trial)

# How often should we "cycle" the generation augmented retrieval?
num_augmentations = trial.suggest_int(
Expand Down Expand Up @@ -325,15 +336,16 @@ def build_generation_augmented_retrieval_module(


def build_retrieval_augmented_generation_module(
answers: Sequence[Answer],
trial: Trial,
) -> GenerationModule:
"""
Build a retrieval-augmented generation module based on hyperparameters drawn from the trial.
"""

# Build simple generation and retrieval modules.
generation_module = build_generation_module(trial=trial)
retrieval_module = build_retrieval_module(trial=trial)
generation_module = build_generation_module(answers, trial)
retrieval_module = build_retrieval_module(trial)

# How often should we "cycle" the generation augmented retrieval?
num_augmentations = trial.suggest_int(
Expand Down Expand Up @@ -361,15 +373,16 @@ def build_retrieval_augmented_generation_module(


def build_answering_module_no_augmentation(
answers: Sequence[Answer],
trial: Trial,
) -> AnsweringModule:
"""
Build a answering module that uses generation and retrieval modules independently without any augmentation.
"""

# Build simple generation and retrieval modules.
generation_module = build_generation_module(trial=trial)
retrieval_module = build_retrieval_module(trial=trial)
generation_module = build_generation_module(answers, trial)
retrieval_module = build_retrieval_module(trial)

# Compose answering module.
return IndependentAnsweringModule(
Expand All @@ -379,15 +392,16 @@ def build_answering_module_no_augmentation(


def build_answering_module_independent_augmentation(
answers: Sequence[Answer],
trial: Trial,
) -> AnsweringModule:
"""
Build a answering module that uses generation and retrieval modules independently while augmenting generation and retrieval individually.
"""

# Build augmented generation and retrieval modules.
generation_module = build_retrieval_augmented_generation_module(trial=trial)
retrieval_module = build_generation_augmented_retrieval_module(trial=trial)
generation_module = build_retrieval_augmented_generation_module(answers, trial)
retrieval_module = build_generation_augmented_retrieval_module(answers, trial)

# Compose answering module.
return IndependentAnsweringModule(
Expand All @@ -397,15 +411,16 @@ def build_answering_module_independent_augmentation(


def build_answering_module_cross_augmentation(
answers: Sequence[Answer],
trial: Trial,
) -> AnsweringModule:
"""
Build a answering module that uses generation and retrieval modules recurrently while feeding back the outputs from the generation module to the retrieval module and vice-versa.
"""

# Build simple generation and retrieval modules.
generation_module = build_generation_module(trial=trial)
retrieval_module = build_retrieval_module(trial=trial)
generation_module = build_generation_module(answers, trial)
retrieval_module = build_retrieval_module(trial)

# Compose answering module.
answering_module = IndependentAnsweringModule(
Expand All @@ -426,6 +441,7 @@ def build_answering_module_cross_augmentation(


def build_answering_module(
answers: Sequence[Answer],
trial: Trial,
) -> AnsweringModule:
"""
Expand All @@ -442,41 +458,41 @@ def build_answering_module(
],
)
if augmentation_type == "no augmentation":
return build_answering_module_no_augmentation(trial)
return build_answering_module_no_augmentation(answers, trial)
elif augmentation_type == "independent augmentation":
return build_answering_module_independent_augmentation(trial)
return build_answering_module_independent_augmentation(answers, trial)
elif augmentation_type == "cross augmentation":
return build_answering_module_cross_augmentation(trial)
return build_answering_module_cross_augmentation(answers, trial)
else:
raise ValueError(f"Unknown augmentation type: {augmentation_type}")


def optimize_answering_module(
ground_truth: Sequence[Answer],
answers: Sequence[Answer],
retrieval_measures: Sequence[RetrievalMeasure],
generation_measures: Sequence[GenerationMeasure],
trials: int | None = None,
timeout: float | None = None,
parallelism: int = 1,
progress: bool = False,
) -> Sequence[FrozenTrial]:
questions = [answer.as_question() for answer in ground_truth]
questions = [answer.as_question() for answer in answers]
contexts = [question.as_partial_answer() for question in questions]

def objective(trial: Trial) -> Sequence[float]:
module = build_answering_module(trial)
module = build_answering_module(answers, trial)
predictions = module.answer_many(contexts)
retrieval_metrics = (
evaluate_retrieval(
ground_truth=ground_truth,
ground_truth=answers,
predictions=predictions,
measure=measure,
)
for measure in retrieval_measures
)
generation_metrics = (
evaluate_generation(
ground_truth=ground_truth,
ground_truth=answers,
predictions=predictions,
measure=measure,
language_model_name=_suggest_language_model_name(
Expand Down

0 comments on commit 91e2c31

Please sign in to comment.