diff --git a/src/genesys/data.py b/src/genesys/data.py index ce7e8a7..b2b8c51 100644 --- a/src/genesys/data.py +++ b/src/genesys/data.py @@ -44,8 +44,7 @@ def __init__(self, config: DataConfig, tokenizer: AutoTokenizer): datasets = [load_dataset(path)["train"] for path in self.paths] if config.shuffle: - for dataset in datasets: - dataset = dataset.shuffle() + datasets = [data.shuffle() for data in datasets] if config.ratio is not None: ratio = [float(r) for r in config.ratio.split(",")] diff --git a/src/genesys/schemas.py b/src/genesys/schemas.py index ac6d003..6f7c805 100644 --- a/src/genesys/schemas.py +++ b/src/genesys/schemas.py @@ -2,16 +2,16 @@ from typing import Dict, Optional -class UnscoredResult(BaseModel): +class Response(BaseModel): problem_id: str source: str task_type: str - in_source: Optional[str] + in_source_id: Optional[str] prompt: str gold_standard_solution: Optional[str] verification_info: Dict metadata: Dict - llm_response: str + llm_response: str # llm response string response_id: str model_name: str - generation_config: Dict + generation_config: Dict # sampling parameters diff --git a/src/genesys/verifiers/base_verifier.py b/src/genesys/verifiers/base_verifier.py index fb6e845..0c639f8 100644 --- a/src/genesys/verifiers/base_verifier.py +++ b/src/genesys/verifiers/base_verifier.py @@ -1,5 +1,5 @@ from typing import Dict -from genesys.schemas import UnscoredResult +from genesys.schemas import Response class BaseVerifier: @@ -10,9 +10,12 @@ class BaseVerifier: max_parallel: int = 5 timeout: float = None # None means no timeout - def verify(self, result: UnscoredResult) -> Dict: + def verify(self, result: Response) -> Dict: """Perform the synchronous verification given a single result. Subclasses should override this to implement the actual check. """ raise NotImplementedError("Subclasses must implement verify().") + + def terminate(self): + pass diff --git a/src/genesys/verifiers/code_output_prediction_verifier.py b/src/genesys/verifiers/code_output_prediction_verifier.py index 3a94df1..88b0be0 100644 --- a/src/genesys/verifiers/code_output_prediction_verifier.py +++ b/src/genesys/verifiers/code_output_prediction_verifier.py @@ -1,13 +1,13 @@ from genesys.utils import extract_json -from genesys.schemas import UnscoredResult +from genesys.schemas import Response from genesys.verifiers.base_verifier import BaseVerifier class CodeUnderstandingVerifier(BaseVerifier): max_parallel = 30 - timeout = None # No explicit timeout needed + timeout = 20 - def verify(self, result: UnscoredResult): + def verify(self, result: Response): """ Verifies whether the output predicted by the LLM matches the ground truth. diff --git a/src/genesys/verifiers/code_test_verifier.py b/src/genesys/verifiers/code_test_verifier.py index 45955cf..f19d2ae 100644 --- a/src/genesys/verifiers/code_test_verifier.py +++ b/src/genesys/verifiers/code_test_verifier.py @@ -34,7 +34,7 @@ def __init__(self): self.docker_client = docker.from_env() self.containers = {} self._init_containers() - self.timeout = 30 + self.timeout = 120 self.max_parallel = 5 def __del__(self): @@ -288,3 +288,6 @@ def verify(self, result: Dict) -> float: return self._verify_interpreted_code(container, code, test_cases, language) else: raise ValueError("Unsupported language:", language) + + def terminate(self): + self._close_containers() diff --git a/src/genesys/verifiers/llm_judge_verifier.py b/src/genesys/verifiers/llm_judge_verifier.py index 3b3d13c..2e25751 100644 --- a/src/genesys/verifiers/llm_judge_verifier.py +++ b/src/genesys/verifiers/llm_judge_verifier.py @@ -31,7 +31,7 @@ class LlmJudgeVerifier(BaseVerifier): max_parallel = 30 # For concurrency control if needed. - timeout = 100 + timeout = 120 def verify(self, result: Dict) -> Tuple[float, str]: """ diff --git a/src/genesys/verifiers/math_verifier.py b/src/genesys/verifiers/math_verifier.py index e086205..34fe525 100644 --- a/src/genesys/verifiers/math_verifier.py +++ b/src/genesys/verifiers/math_verifier.py @@ -65,7 +65,7 @@ class MathVerifier(BaseVerifier): max_parallel = 10 - timeout = 20 + timeout = 60 def verify(self, result: Dict) -> int: """ diff --git a/src/genesys/verify.py b/src/genesys/verify.py index 2acf1ea..bf55df5 100644 --- a/src/genesys/verify.py +++ b/src/genesys/verify.py @@ -4,7 +4,7 @@ from typing import List, Dict, Any, Callable from tqdm.asyncio import tqdm from pydantic_config import BaseConfig, parse_argv -from genesys.schemas import UnscoredResult +from genesys.schemas import Response from genesys.verifiers.registry import VERIFIER_REGISTRY @@ -12,7 +12,7 @@ class VerifyConfig(BaseConfig): file: str -async def run_sync_with_timeout(sync_func: Callable, result: UnscoredResult, timeout=None): +async def run_sync_with_timeout(sync_func: Callable, result: Response, timeout=None): """ Runs a synchronous function in an executor with optional timeout. """ @@ -23,7 +23,7 @@ async def run_sync_with_timeout(sync_func: Callable, result: UnscoredResult, tim return await coro -async def verify(results: List[UnscoredResult]) -> List[Any]: +async def verify(results: List[Response]) -> List[Any]: """ Given a list of result dictionaries, dispatch each to the appropriate verifier (as determined by the "task_type" field) and run them concurrently using semaphores @@ -44,16 +44,24 @@ async def process_result(index: int, result: Dict): verifier_obj = verifier_instances[ttype] async with semaphores[ttype]: try: - verification_result = await run_sync_with_timeout(verifier_obj.verify, result, timeout=200) + verification_result = await run_sync_with_timeout( + verifier_obj.verify, result, timeout=verifier_obj.timeout + ) verification_results[index] = verification_result except asyncio.TimeoutError: print(f"Timeout verifying '{ttype}' at index {index}") verification_results[index] = {"score": None, "verification_result_info": {"failure_reason": "timeout"}} + except Exception as e: + print(f"Error verifying '{ttype}' at index {index}") + verification_results[index] = {"score": None, "verification_result_info": {"failure_reason": e}} tasks = [asyncio.create_task(process_result(i, r)) for i, r in enumerate(results)] for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Verifying"): await task + for ttype in task_types_in_use: + verifier_instances[ttype].terminate() + return verification_results