Skip to content

Commit

Permalink
add the evaluate method again
Browse files Browse the repository at this point in the history
  • Loading branch information
JoaquinPolonuer committed Feb 27, 2025
1 parent 9861f92 commit 90b7410
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions packages/lfrqa/src/aviary/envs/lfrqa/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,7 @@ def __init__(
self._query: LFRQAQuestion = query # type: ignore[mutable-override]
self.pairwise_eval_llm = pairwise_eval_llm

async def step(
self, action: ToolRequestMessage
) -> tuple[Messages, float, bool, bool]:
messages, reward, done, truncated = await super().step(action)
if not done:
return messages, reward, done, truncated

async def _evaluate_answer(self) -> dict:
evaluation = await self._query.grade(
proposed_answer=self.state.session.answer,
paper_search_ids=[
Expand All @@ -242,7 +236,18 @@ async def step(
)
)
evaluation["reward"] = reward

return evaluation

async def step(
self, action: ToolRequestMessage
) -> tuple[Messages, float, bool, bool]:
messages, reward, done, truncated = await super(
GradablePaperQAEnvironment, self
).step(action)
if not done:
return messages, reward, done, truncated
evaluation = await self._evaluate_answer()
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)

return messages, reward, done, truncated
return messages, evaluation["reward"], done, truncated

0 comments on commit 90b7410

Please sign in to comment.