Skip to content

Commit

Permalink
Add EvalAnswerMode to HotPotQAEnv (#102)
Browse files Browse the repository at this point in the history
Co-authored-by: James Braza <[email protected]>
  • Loading branch information
sidnarayanan and jamesbraza authored Oct 30, 2024
1 parent d09a7ed commit ed46be4
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 10 deletions.
46 changes: 37 additions & 9 deletions packages/hotpotqa/src/aviary/envs/hotpotqa/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,17 @@
from pydantic import BaseModel, ConfigDict, Field
from tenacity import retry, stop_after_attempt, wait_exponential_jitter

from aviary.env import Environment, Frame, TaskDataset
from aviary.message import Message
from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage
from aviary.core import (
Environment,
EvalAnswerMode,
Frame,
Message,
TaskDataset,
Tool,
ToolRequestMessage,
ToolResponseMessage,
eval_answer,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,6 +86,8 @@ class HotPotQAEnvState(BaseModel):
)
page: str | None = Field(default=None, description="The current Wikipedia page.")

evaluation_mode: EvalAnswerMode = EvalAnswerMode.CONTAINS


def create_tool(function: Callable, name: str) -> Tool:
"""Create a Tool object from a function and set its name.
Expand Down Expand Up @@ -176,6 +186,7 @@ def __init__(
correct_reward: float = 1.0,
incorrect_reward: float = 0.0,
tool_failure_reward: float = 0.0,
evaluation_mode: EvalAnswerMode = EvalAnswerMode.CONTAINS,
proxy: str | None = None,
):
super().__init__()
Expand All @@ -186,6 +197,14 @@ def __init__(
self.incorrect_reward = incorrect_reward
self.tool_failure_reward = tool_failure_reward
self.proxy = proxy
self.evaluation_mode = evaluation_mode

if evaluation_mode == EvalAnswerMode.LLM_SCORE:
raise NotImplementedError(
f'{HotPotQAEnv.__name__} does not support "{evaluation_mode}"'
" since the environment was built around binary evaluation of the"
" answer. Further development is needed for this mode."
)

# Title case tool names to match third party demonstration data
self.tools = [
Expand All @@ -198,7 +217,7 @@ def __init__(
def from_task(cls, task: str) -> "HotPotQAEnv":
return cls(question=task, correct_answer=0.0)

def calculate_reward(self, answer: str | None) -> float:
async def calculate_answer_reward(self, answer: str | None) -> float:
"""Calculate the reward based on the agent's answer.
Returns:
Expand All @@ -207,8 +226,17 @@ def calculate_reward(self, answer: str | None) -> float:
"""
if answer is None:
return self.incorrect_reward
pred, gt = normalize_answer(answer), self.normalized_correct_answer
return self.correct_reward if pred == gt else self.incorrect_reward
return (
self.correct_reward
if (
await eval_answer(
normalize_answer(answer),
self.normalized_correct_answer,
self.evaluation_mode,
)
)
else self.incorrect_reward
)

async def reset(self) -> tuple[list[Message], list[Tool]]:
"""Reset the HotPotQA environment to an initial state.
Expand Down Expand Up @@ -331,8 +359,8 @@ def export_frame(self) -> Frame:
}
)

def finish(self, answer: str) -> str:
"""Finish the episode.
async def finish(self, answer: str) -> str:
"""Finish the task by submitting an answer to the question.
Args:
answer: The answer to the question.
Expand All @@ -342,7 +370,7 @@ def finish(self, answer: str) -> str:
return "Finish failed. No answer provided."

self.state.answer = answer
self.state.reward += self.calculate_reward(answer)
self.state.reward += await self.calculate_answer_reward(answer)

self.state.last_action_is_lookup = False
return "Finished."
Expand Down
19 changes: 19 additions & 0 deletions packages/hotpotqa/tests/test_hotpotqa_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from aviary.core import Environment, TaskDataset
from aviary.envs.hotpotqa import HotPotQAEnv
from aviary.tools.utils import EvalAnswerMode


def test_env_construction() -> None:
Expand Down Expand Up @@ -65,3 +66,21 @@ async def test_tool_results() -> None:

# Ensure that the observations are different
assert obs1 != obs2 != obs3 != obs4 != obs5


@pytest.mark.parametrize(
"evaluation_mode",
[EvalAnswerMode.EXACT, EvalAnswerMode.CONTAINS, EvalAnswerMode.LLM],
)
@pytest.mark.asyncio
async def test_answer_evaluation_mode(evaluation_mode: EvalAnswerMode) -> None:
correct_answer = "Golden Gate Bridge"
incorrect_answer = "Bay Bridge"
env = HotPotQAEnv(
question="What is the reddest bridge in San Francisco?",
correct_answer=correct_answer,
evaluation_mode=evaluation_mode,
)

assert (await env.calculate_answer_reward(correct_answer)) == 1
assert (await env.calculate_answer_reward(incorrect_answer)) == 0
2 changes: 2 additions & 0 deletions src/aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from aviary.render import Renderer
from aviary.tools import (
INVALID_TOOL_NAME,
EvalAnswerMode,
FunctionInfo,
Messages,
MessagesAdapter,
Expand Down Expand Up @@ -41,6 +42,7 @@
"EnvStateMessage",
"Environment",
"EnvironmentClient",
"EvalAnswerMode",
"Frame",
"FunctionInfo",
"MalformedMessageError",
Expand Down
3 changes: 2 additions & 1 deletion src/aviary/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
ToolsAdapter,
wraps_doc_only,
)
from .utils import ToolSelector, ToolSelectorLedger, eval_answer
from .utils import EvalAnswerMode, ToolSelector, ToolSelectorLedger, eval_answer

__all__ = [
"INVALID_TOOL_NAME",
"EvalAnswerMode",
"FunctionInfo",
"Messages",
"MessagesAdapter",
Expand Down

0 comments on commit ed46be4

Please sign in to comment.