diff --git a/r2e/llms/base_runner.py b/r2e/llms/base_runner.py index 75325e6..3d7c8b5 100644 --- a/r2e/llms/base_runner.py +++ b/r2e/llms/base_runner.py @@ -1,3 +1,4 @@ +import json from abc import ABC, abstractmethod from tqdm import tqdm @@ -23,6 +24,10 @@ def save_cache(self): if self.cache is not None: self.cache.save_cache() + @abstractmethod + def config(self) -> dict: + pass + @abstractmethod def _run_single(self, payload) -> list[str]: return [] @@ -36,10 +41,10 @@ def run_single(combined_args) -> list[str]: """ cache: Cache | None call_method: callable # type: ignore - payload, cache, args, call_method = combined_args + payload, cache, args, config, call_method = combined_args if cache is not None: - cache_result = cache.get_from_cache(payload) + cache_result = cache.get_from_cache(json.dumps([payload, config])) if cache_result is not None: return cache_result @@ -50,11 +55,13 @@ def run_single(combined_args) -> list[str]: def run_batch(self, payloads: list) -> list[list[str]]: outputs = [] + config = self.config() arguments = [ ( payload, self.cache, ## pass the cache as argument for cache check self.args, ## pass the args as argument for cache check + config, self._run_single, ## pass the _run_single method as argument because of multiprocessing ) for payload in payloads @@ -79,7 +86,10 @@ def run_batch(self, payloads: list) -> list[list[str]]: if self.cache is not None: for payload, output in zip(payloads, outputs): - self.cache.add_to_cache(payload, output) ## save the output to cache + self.cache.add_to_cache( + json.dumps([payload, config]), output + ) ## save the output to cache + self.save_cache() return outputs @@ -91,7 +101,6 @@ def run_main(self, payloads: list) -> list[list[str]]: payload_batch = payloads[i : i + batch_size] outputs_batch = self.run_batch(payload_batch) outputs.extend(outputs_batch) - self.save_cache() else: outputs = self.run_batch(payloads) return outputs diff --git a/r2e/llms/language_model.py b/r2e/llms/language_model.py index 60b77bd..07d4852 100644 --- a/r2e/llms/language_model.py +++ b/r2e/llms/language_model.py @@ -85,4 +85,12 @@ class LanguageModel: model_name="gpt-4-turbo-2024-04-09", style=LanguageModelStyle.OpenAI, ), + LanguageModel( + model_name="gpt-4o-mini", + style=LanguageModelStyle.OpenAI, + ), + LanguageModel( + model_name="gpt-4o", + style=LanguageModelStyle.OpenAI, + ), ] diff --git a/r2e/llms/openai_runner.py b/r2e/llms/openai_runner.py index d06e7bc..56408df 100644 --- a/r2e/llms/openai_runner.py +++ b/r2e/llms/openai_runner.py @@ -29,6 +29,9 @@ def __init__(self, args: LLMArgs, model: LanguageModel): "timeout": args.openai_timeout, } + def config(self): + return self.client_kwargs + def _run_single(self, payload: list[dict[str, str]]) -> list[str]: assert isinstance(payload, list)