Skip to content

Commit

Permalink
Added get_id for LitQA, with a test
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Feb 26, 2025
1 parent 105f95b commit 6e2ddde
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
12 changes: 12 additions & 0 deletions packages/litqa/src/aviary/envs/litqa/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ async def step(
truncated,
)

async def get_id(self) -> str | UUID:
if (
isinstance(self._query, str)
or self._query.question_id
== MultipleChoiceQuestion.model_fields["question_id"].default
):
raise ValueError(
"A multiple choice question with a non-default question ID was not"
" configured."
)
return self._query.question_id

def __deepcopy__(self, memo) -> Self:
copy_state = deepcopy(self.state, memo)
# We don't know the side effects of deep copying a litellm.Router,
Expand Down
2 changes: 2 additions & 0 deletions packages/litqa/tests/test_litqa_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ async def test___len__(
# Now let's check we could use the sources in a validation
for i in range(len(task_dataset)):
env = task_dataset.get_new_env_by_idx(i)
env_id = await env.get_id()
assert await env.get_id() == UUID("dbfbae3d-62f6-4710-8d13-8ce4c8485567")
if i == 0 and split == LitQAv2TaskSplit.TRAIN:
# Yes this assertion is somewhat brittle, but it reliably
# checks the seeding's behavior so we keep it
Expand Down

0 comments on commit 6e2ddde

Please sign in to comment.