Skip to content

Commit

Permalink
auto-infer llm provider
Browse files Browse the repository at this point in the history
  • Loading branch information
richardblythman committed Apr 24, 2024
1 parent 80051bd commit 13bcfcb
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(
self, api_keys: List, llm_provider: str = None, embedding_provider: str = None
self, api_keys: List, model: str = None, embedding_provider: str = None
):
self.api_keys = api_keys
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
clients = []
Expand Down Expand Up @@ -300,8 +305,8 @@ def multi_queries(
model: str,
num_queries: int,
counter_callback: Optional[Callable[[int, int, str], None]] = None,
temperature: Optional[float] = LLM_SETTINGS["gpt-4-0125-preview"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["gpt-4-0125-preview"]["default_max_tokens"],
temperature: Optional[float] = LLM_SETTINGS["claude-3-sonnet-20240229"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["claude-3-sonnet-20240229"]["default_max_tokens"],
) -> List[str]:
"""Generate multiple queries for fetching information from the web."""
url_query_prompt = URL_QUERY_PROMPT.format(
Expand Down Expand Up @@ -544,8 +549,8 @@ def fetch_additional_information(
source_links: Optional[List[str]] = None,
num_urls: Optional[int] = DEFAULT_NUM_URLS,
num_queries: Optional[int] = DEFAULT_NUM_QUERIES,
temperature: Optional[float] = LLM_SETTINGS["gpt-4-0125-preview"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["gpt-4-0125-preview"]["default_max_tokens"],
temperature: Optional[float] = LLM_SETTINGS["claude-3-sonnet-20240229"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["claude-3-sonnet-20240229"]["default_max_tokens"],
) -> Tuple[str, Callable[[int, int, str], None]]:
"""Fetch additional information to help answer the user prompt."""

Expand Down Expand Up @@ -668,13 +673,13 @@ def parser_prediction_response(response: str) -> str:

def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
"""Run the task"""
model = kwargs.get("model")
print(f"MODEL: {model}")
with LLMClientManager(
kwargs["api_keys"], kwargs["llm_provider"], embedding_provider="openai"
kwargs["api_keys"], model, embedding_provider="openai"
):
tool = kwargs["tool"]
prompt = extract_question(kwargs["prompt"])
model = kwargs.get("model")
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(
self, api_keys: List, llm_provider: str = None, embedding_provider: str = None
self, api_keys: List, model: str = None, embedding_provider: str = None
):
self.api_keys = api_keys
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
clients = []
Expand Down Expand Up @@ -642,13 +647,13 @@ def parser_prediction_response(response: str) -> str:

def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
"""Run the task"""
model = kwargs.get("model")
print(f"MODEL: {model}")
with LLMClientManager(
kwargs["api_keys"], kwargs["llm_provider"], embedding_provider="openai"
kwargs["api_keys"], model, embedding_provider="openai"
):
tool = kwargs["tool"]
model = kwargs.get("model")
prompt = extract_question(kwargs["prompt"])
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,16 @@ class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(
self, api_keys: List, llm_provider: str = None, embedding_provider: str = None
self, api_keys: List, model: str = None, embedding_provider: str = None
):
self.api_keys = api_keys
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
clients = []
Expand Down Expand Up @@ -834,13 +839,13 @@ def extract_question(prompt: str) -> str:

def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
"""Run the task"""
model = kwargs.get("model")
print(f"MODEL: {model}")
with LLMClientManager(
kwargs["api_keys"], kwargs["llm_provider"], embedding_provider="openai"
kwargs["api_keys"], model, embedding_provider="openai"
):
tool = kwargs["tool"]
model = kwargs.get("model")
prompt = extract_question(kwargs["prompt"])
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
Expand Down
15 changes: 10 additions & 5 deletions packages/napthaai/customs/prediction_url_cot/prediction_url_cot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(
self, api_keys: List, llm_provider: str = None, embedding_provider: str = None
self, api_keys: List, model: str = None, embedding_provider: str = None
):
self.api_keys = api_keys
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
clients = []
Expand Down Expand Up @@ -583,11 +588,11 @@ def parser_prediction_response(response: str) -> str:

def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
"""Run the task"""
with LLMClientManager(kwargs["api_keys"], kwargs["llm_provider"]):
model = kwargs.get("model")
print(f"MODEL: {model}")
with LLMClientManager(kwargs["api_keys"], model):
tool = kwargs["tool"]
model = kwargs.get("model")
prompt = extract_question(kwargs["prompt"])
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
Expand Down
15 changes: 10 additions & 5 deletions packages/valory/customs/prediction_request/prediction_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@
class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(self, api_keys: List, llm_provider: str = None):
def __init__(self, api_keys: List, model: str = None):
self.api_keys = api_keys
self.llm_provider = llm_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
global client
Expand Down Expand Up @@ -653,11 +658,11 @@ def adjust_additional_information(

def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
"""Run the task"""
with LLMClientManager(kwargs["api_keys"], kwargs["llm_provider"]):
engine = kwargs.get("model")
print(f"ENGINE: {engine}")
with LLMClientManager(kwargs["api_keys"], engine):
tool = kwargs["tool"]
prompt = kwargs["prompt"]
engine = kwargs.get("model")
print(f"ENGINE: {engine}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[engine]["default_max_tokens"]
)
Expand Down
7 changes: 0 additions & 7 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,12 @@ def test_run(self) -> None:
for model in self.models:
for tool in self.tools:
for prompt in self.prompts:
if "gpt" in model:
llm_provider = "openai"
elif "claude" in model:
llm_provider = "anthropic"
else:
llm_provider = "openrouter"
kwargs = dict(
prompt=prompt,
tool=tool,
api_keys=self.keys,
counter_callback=TokenCounterCallback(),
model=model,
llm_provider=llm_provider,
)
func = getattr(self.tool_module, self.tool_callable)
response = func(**kwargs)
Expand Down

0 comments on commit 13bcfcb

Please sign in to comment.