Skip to content

Commit

Permalink
fix: shuffle datasets and terminate verification (#34)
Browse files Browse the repository at this point in the history
* fix: shuffle datasets and terminate verification

* fix: use verifier obj timeout

* fix: add terminate function for code verifier, forgot
  • Loading branch information
justusmattern27 authored Feb 3, 2025
1 parent 232d873 commit a2c5e55
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 18 deletions.
3 changes: 1 addition & 2 deletions src/genesys/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def __init__(self, config: DataConfig, tokenizer: AutoTokenizer):
datasets = [load_dataset(path)["train"] for path in self.paths]

if config.shuffle:
for dataset in datasets:
dataset = dataset.shuffle()
datasets = [data.shuffle() for data in datasets]

if config.ratio is not None:
ratio = [float(r) for r in config.ratio.split(",")]
Expand Down
8 changes: 4 additions & 4 deletions src/genesys/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from typing import Dict, Optional


class UnscoredResult(BaseModel):
class Response(BaseModel):
problem_id: str
source: str
task_type: str
in_source: Optional[str]
in_source_id: Optional[str]
prompt: str
gold_standard_solution: Optional[str]
verification_info: Dict
metadata: Dict
llm_response: str
llm_response: str # llm response string
response_id: str
model_name: str
generation_config: Dict
generation_config: Dict # sampling parameters
7 changes: 5 additions & 2 deletions src/genesys/verifiers/base_verifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict
from genesys.schemas import UnscoredResult
from genesys.schemas import Response


class BaseVerifier:
Expand All @@ -10,9 +10,12 @@ class BaseVerifier:
max_parallel: int = 5
timeout: float = None # None means no timeout

def verify(self, result: UnscoredResult) -> Dict:
def verify(self, result: Response) -> Dict:
"""Perform the synchronous verification given a single result.
Subclasses should override this to implement the actual check.
"""
raise NotImplementedError("Subclasses must implement verify().")

def terminate(self):
pass
6 changes: 3 additions & 3 deletions src/genesys/verifiers/code_output_prediction_verifier.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from genesys.utils import extract_json
from genesys.schemas import UnscoredResult
from genesys.schemas import Response
from genesys.verifiers.base_verifier import BaseVerifier


class CodeUnderstandingVerifier(BaseVerifier):
max_parallel = 30
timeout = None # No explicit timeout needed
timeout = 20

def verify(self, result: UnscoredResult):
def verify(self, result: Response):
"""
Verifies whether the output predicted by the LLM matches the ground truth.
Expand Down
5 changes: 4 additions & 1 deletion src/genesys/verifiers/code_test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self):
self.docker_client = docker.from_env()
self.containers = {}
self._init_containers()
self.timeout = 30
self.timeout = 120
self.max_parallel = 5

def __del__(self):
Expand Down Expand Up @@ -288,3 +288,6 @@ def verify(self, result: Dict) -> float:
return self._verify_interpreted_code(container, code, test_cases, language)
else:
raise ValueError("Unsupported language:", language)

def terminate(self):
self._close_containers()
2 changes: 1 addition & 1 deletion src/genesys/verifiers/llm_judge_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class LlmJudgeVerifier(BaseVerifier):
max_parallel = 30 # For concurrency control if needed.
timeout = 100
timeout = 120

def verify(self, result: Dict) -> Tuple[float, str]:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/genesys/verifiers/math_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

class MathVerifier(BaseVerifier):
max_parallel = 10
timeout = 20
timeout = 60

def verify(self, result: Dict) -> int:
"""
Expand Down
16 changes: 12 additions & 4 deletions src/genesys/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from typing import List, Dict, Any, Callable
from tqdm.asyncio import tqdm
from pydantic_config import BaseConfig, parse_argv
from genesys.schemas import UnscoredResult
from genesys.schemas import Response
from genesys.verifiers.registry import VERIFIER_REGISTRY


class VerifyConfig(BaseConfig):
file: str


async def run_sync_with_timeout(sync_func: Callable, result: UnscoredResult, timeout=None):
async def run_sync_with_timeout(sync_func: Callable, result: Response, timeout=None):
"""
Runs a synchronous function in an executor with optional timeout.
"""
Expand All @@ -23,7 +23,7 @@ async def run_sync_with_timeout(sync_func: Callable, result: UnscoredResult, tim
return await coro


async def verify(results: List[UnscoredResult]) -> List[Any]:
async def verify(results: List[Response]) -> List[Any]:
"""
Given a list of result dictionaries, dispatch each to the appropriate verifier
(as determined by the "task_type" field) and run them concurrently using semaphores
Expand All @@ -44,16 +44,24 @@ async def process_result(index: int, result: Dict):
verifier_obj = verifier_instances[ttype]
async with semaphores[ttype]:
try:
verification_result = await run_sync_with_timeout(verifier_obj.verify, result, timeout=200)
verification_result = await run_sync_with_timeout(
verifier_obj.verify, result, timeout=verifier_obj.timeout
)
verification_results[index] = verification_result
except asyncio.TimeoutError:
print(f"Timeout verifying '{ttype}' at index {index}")
verification_results[index] = {"score": None, "verification_result_info": {"failure_reason": "timeout"}}
except Exception as e:
print(f"Error verifying '{ttype}' at index {index}")
verification_results[index] = {"score": None, "verification_result_info": {"failure_reason": e}}

tasks = [asyncio.create_task(process_result(i, r)) for i, r in enumerate(results)]
for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Verifying"):
await task

for ttype in task_types_in_use:
verifier_instances[ttype].terminate()

return verification_results


Expand Down

0 comments on commit a2c5e55

Please sign in to comment.