Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve caching frequency and add config to cache #21

Merged
merged 1 commit into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions r2e/llms/base_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from abc import ABC, abstractmethod

from tqdm import tqdm
Expand All @@ -23,6 +24,10 @@ def save_cache(self):
if self.cache is not None:
self.cache.save_cache()

@abstractmethod
def config(self) -> dict:
pass

@abstractmethod
def _run_single(self, payload) -> list[str]:
return []
Expand All @@ -36,10 +41,10 @@ def run_single(combined_args) -> list[str]:
"""
cache: Cache | None
call_method: callable # type: ignore
payload, cache, args, call_method = combined_args
payload, cache, args, config, call_method = combined_args

if cache is not None:
cache_result = cache.get_from_cache(payload)
cache_result = cache.get_from_cache(json.dumps([payload, config]))
if cache_result is not None:
return cache_result

Expand All @@ -50,11 +55,13 @@ def run_single(combined_args) -> list[str]:

def run_batch(self, payloads: list) -> list[list[str]]:
outputs = []
config = self.config()
arguments = [
(
payload,
self.cache, ## pass the cache as argument for cache check
self.args, ## pass the args as argument for cache check
config,
self._run_single, ## pass the _run_single method as argument because of multiprocessing
)
for payload in payloads
Expand All @@ -79,7 +86,10 @@ def run_batch(self, payloads: list) -> list[list[str]]:

if self.cache is not None:
for payload, output in zip(payloads, outputs):
self.cache.add_to_cache(payload, output) ## save the output to cache
self.cache.add_to_cache(
json.dumps([payload, config]), output
) ## save the output to cache
self.save_cache()

return outputs

Expand All @@ -91,7 +101,6 @@ def run_main(self, payloads: list) -> list[list[str]]:
payload_batch = payloads[i : i + batch_size]
outputs_batch = self.run_batch(payload_batch)
outputs.extend(outputs_batch)
self.save_cache()
else:
outputs = self.run_batch(payloads)
return outputs
8 changes: 8 additions & 0 deletions r2e/llms/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,12 @@ class LanguageModel:
model_name="gpt-4-turbo-2024-04-09",
style=LanguageModelStyle.OpenAI,
),
LanguageModel(
model_name="gpt-4o-mini",
style=LanguageModelStyle.OpenAI,
),
LanguageModel(
model_name="gpt-4o",
style=LanguageModelStyle.OpenAI,
),
]
3 changes: 3 additions & 0 deletions r2e/llms/openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self, args: LLMArgs, model: LanguageModel):
"timeout": args.openai_timeout,
}

def config(self):
return self.client_kwargs

def _run_single(self, payload: list[dict[str, str]]) -> list[str]:
assert isinstance(payload, list)

Expand Down
Loading