Skip to content

Commit

Permalink
fix: cache and imports and o1 models (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
Naman-ntc authored Nov 30, 2024
1 parent 65f9431 commit acc48d4
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
2 changes: 2 additions & 0 deletions src/r2e/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from r2e.llms.llm_args import LLMArgs
from r2e.llms.completions import LLMCompletions
4 changes: 2 additions & 2 deletions src/r2e/llms/cache_object.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
import json
from diskcache import Cache
from diskcache import Cache as DiskCache

from r2e.paths import CACHE_PATH, CACHE_DIR


class Cache:
def __init__(self) -> None:
os.makedirs(CACHE_DIR, exist_ok=True)
self.cache_dict = Cache(CACHE_PATH)
self.cache_dict = DiskCache(CACHE_DIR)

@staticmethod
def process_payload(payload):
Expand Down
8 changes: 8 additions & 0 deletions src/r2e/llms/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,12 @@ class LanguageModel:
model_name="gpt-4o",
style=LanguageModelStyle.OpenAI,
),
LanguageModel(
model_name="o1-preview",
style=LanguageModelStyle.OpenAI,
),
LanguageModel(
model_name="o1-mini",
style=LanguageModelStyle.OpenAI,
),
]
26 changes: 16 additions & 10 deletions src/r2e/llms/openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@ class OpenAIRunner(BaseRunner):

def __init__(self, args: LLMArgs, model: LanguageModel):
super().__init__(args, model)
self.client_kwargs: dict[str, Any] = {
"model": args.model_name,
"temperature": args.temperature,
"max_tokens": args.max_tokens,
"top_p": args.top_p,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
"n": args.n,
"timeout": args.openai_timeout,
}
if "o1" in args.model_name:
self.client_kwargs: dict[str, Any] = {
"model": args.model_name,
"max_completion_tokens": args.max_tokens,
}
else:
self.client_kwargs = {
"model": args.model_name,
"temperature": args.temperature,
"max_tokens": args.max_tokens,
"top_p": args.top_p,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
"n": args.n,
"timeout": args.openai_timeout,
}

def config(self):
return self.client_kwargs
Expand Down

0 comments on commit acc48d4

Please sign in to comment.