Skip to content

Commit

Permalink
Not clobbering Docs in LFRQAPairwiseEvalEnv (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Feb 27, 2025
1 parent 8044692 commit 8eb4b25
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
6 changes: 2 additions & 4 deletions packages/lfrqa/src/aviary/envs/lfrqa/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Any

from lmi import CommonLLMNames, LiteLLMModel, LLMModel
from paperqa.docs import Docs
from paperqa.utils import strip_citations
from pydantic import Field, model_validator

Expand Down Expand Up @@ -207,6 +206,8 @@ async def grade( # type: ignore[override]
class LFRQAPairwiseEvalEnv(GradablePaperQAEnvironment[dict]):
"""Environment to evaluate paperqa's vs human's answers on Long Form RAG QA questions."""

_query: LFRQAQuestion # type: ignore[mutable-override]

def __init__(
self,
query: LFRQAQuestion,
Expand All @@ -215,10 +216,7 @@ def __init__(
**kwargs,
):
kwargs["query"] = query
kwargs["docs"] = Docs()
super().__init__(*args, **kwargs)

self._query: LFRQAQuestion = query # type: ignore[mutable-override]
self.pairwise_eval_llm = pairwise_eval_llm

async def _evaluate_answer(self) -> dict:
Expand Down
9 changes: 8 additions & 1 deletion packages/lfrqa/src/aviary/envs/lfrqa/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Awaitable, Callable

from lmi import CommonLLMNames, LLMModel
from paperqa.settings import Settings
from paperqa import Docs, Settings

from aviary.core import TASK_DATASET_REGISTRY, TaskDataset

Expand All @@ -15,6 +15,7 @@ def __init__(
self,
data: list[LFRQAQuestion],
settings: Settings | dict | None = None,
base_docs: Docs | dict | None = None,
pairwise_eval_llm: LLMModel | str = CommonLLMNames.GPT_4O.value,
evaluation_callback: Callable[[dict], Awaitable] | None = None,
):
Expand All @@ -26,6 +27,11 @@ def __init__(
if isinstance(settings, dict):
settings = Settings(**settings)
self._settings = settings
if base_docs is None:
base_docs = Docs()
if isinstance(base_docs, dict):
base_docs = Docs(**base_docs)
self._base_docs = base_docs
self._rewards = {"win": 1, "tie": 0, "lose": -1}
self._evaluation_callback = evaluation_callback

Expand All @@ -35,6 +41,7 @@ def get_new_env_by_idx(self, idx: int) -> LFRQAPairwiseEvalEnv:
query=self.data[idx],
pairwise_eval_llm=self.pairwise_eval_llm,
settings=self._settings,
docs=self._base_docs.model_copy(),
rewards=self._rewards,
evaluation_callback=self._evaluation_callback,
)
Expand Down

0 comments on commit 8eb4b25

Please sign in to comment.