Skip to content

Commit

Permalink
Merge pull request #219 from valory-xyz/chore/fix-llm-provider
Browse files Browse the repository at this point in the history
auto-infer llm provider
  • Loading branch information
0xArdi authored Apr 25, 2024
2 parents 4f39a8a + 3371077 commit 82036d0
Show file tree
Hide file tree
Showing 13 changed files with 720 additions and 585 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeibt7f7crtwvmkg7spy3jhscmlqltvyblzp32g6gj44v7tlo5lycuq
prediction_request_rag.py: bafybeibtvuddvbhjlyd4sbk7rwz4mcsr4hiigfgrpdhzwa6vn6bhb6fboy
prediction_request_rag.py: bafybeicllugnruskdj7ipmrj2vrtlxmjpqtwlk4c3cfjttfzuvkeldp3m4
fingerprint_ignore_patterns: []
entry_point: prediction_request_rag.py
callable: run
params:
default_model: claude-3-sonnet-20240229
default_model: gpt-4-0125-preview
dependencies:
google-api-python-client:
version: ==2.95.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(
self, api_keys: List, llm_provider: str = None, embedding_provider: str = None
self, api_keys: List, model: str = None, embedding_provider: str = None
):
self.api_keys = api_keys
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
clients = []
Expand Down Expand Up @@ -220,6 +225,9 @@ def embeddings(self, model, input):
}
ALLOWED_TOOLS = [
"prediction-request-rag",

# LEGACY
"prediction-request-rag-claude",
]
ALLOWED_MODELS = list(LLM_SETTINGS.keys())
DEFAULT_NUM_URLS = defaultdict(lambda: 3)
Expand Down Expand Up @@ -300,8 +308,8 @@ def multi_queries(
model: str,
num_queries: int,
counter_callback: Optional[Callable[[int, int, str], None]] = None,
temperature: Optional[float] = LLM_SETTINGS["gpt-4-0125-preview"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["gpt-4-0125-preview"]["default_max_tokens"],
temperature: Optional[float] = LLM_SETTINGS["claude-3-sonnet-20240229"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["claude-3-sonnet-20240229"]["default_max_tokens"],
) -> List[str]:
"""Generate multiple queries for fetching information from the web."""
url_query_prompt = URL_QUERY_PROMPT.format(
Expand Down Expand Up @@ -544,8 +552,8 @@ def fetch_additional_information(
source_links: Optional[List[str]] = None,
num_urls: Optional[int] = DEFAULT_NUM_URLS,
num_queries: Optional[int] = DEFAULT_NUM_QUERIES,
temperature: Optional[float] = LLM_SETTINGS["gpt-4-0125-preview"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["gpt-4-0125-preview"]["default_max_tokens"],
temperature: Optional[float] = LLM_SETTINGS["claude-3-sonnet-20240229"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["claude-3-sonnet-20240229"]["default_max_tokens"],
) -> Tuple[str, Callable[[int, int, str], None]]:
"""Fetch additional information to help answer the user prompt."""

Expand Down Expand Up @@ -668,13 +676,15 @@ def parser_prediction_response(response: str) -> str:

def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
"""Run the task"""
tool = kwargs["tool"]
model = kwargs.get("model")
if "claude" in tool: # maintain backwards compatibility
model = "claude-3-sonnet-20240229"
print(f"MODEL: {model}")
with LLMClientManager(
kwargs["api_keys"], kwargs["llm_provider"], embedding_provider="openai"
kwargs["api_keys"], model, embedding_provider="openai"
):
tool = kwargs["tool"]
prompt = extract_question(kwargs["prompt"])
model = kwargs.get("model")
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeiekjzoy2haayvkiwhb2u2epflpqxticud34mma3gdhfzgu36lxwiq
prediction_request_rag_cohere.py: bafybeig4oq3tdjuz2la2pz232u5m7347q7gplu5pw4vebbxteuiqw6hh3u
prediction_request_rag_cohere.py: bafybeib4jviue2jqktqbxca4gtzxrvvxi5oihhsbvarymiqyp3xkee7soi
fingerprint_ignore_patterns: []
entry_point: prediction_request_rag_cohere.py
callable: run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(
self, api_keys: List, llm_provider: str = None, embedding_provider: str = None
self, api_keys: List, model: str = None, embedding_provider: str = None
):
self.api_keys = api_keys
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
clients = []
Expand Down Expand Up @@ -642,13 +647,13 @@ def parser_prediction_response(response: str) -> str:

def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
"""Run the task"""
model = kwargs.get("model")
print(f"MODEL: {model}")
with LLMClientManager(
kwargs["api_keys"], kwargs["llm_provider"], embedding_provider="openai"
kwargs["api_keys"], model, embedding_provider="openai"
):
tool = kwargs["tool"]
model = kwargs.get("model")
prompt = extract_question(kwargs["prompt"])
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeib36ew6vbztldut5xayk5553rylrq7yv4cpqyhwc5ktvd4cx67vwu
prediction_request_reasoning.py: bafybeidiabgnlc453spgrdn7rhhl2xc3aa6zqeukkw2bthndbugtjf6bya
prediction_request_reasoning.py: bafybeidb43nygtvbhimnsd223ddpoii46dwirb5znmp2g473u4jii36jqa
fingerprint_ignore_patterns: []
entry_point: prediction_request_reasoning.py
callable: run
params:
default_model: claude-3-sonnet-20240229
default_model: gpt-4-0125-preview
dependencies:
google-api-python-client:
version: ==2.95.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,16 @@ class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(
self, api_keys: List, llm_provider: str = None, embedding_provider: str = None
self, api_keys: List, model: str = None, embedding_provider: str = None
):
self.api_keys = api_keys
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
clients = []
Expand Down Expand Up @@ -223,6 +228,9 @@ def embeddings(self, model, input):
}
ALLOWED_TOOLS = [
"prediction-request-reasoning",

# LEGACY
"prediction-request-reasoning-claude",
]
ALLOWED_MODELS = list(LLM_SETTINGS.keys())
DEFAULT_NUM_URLS = defaultdict(lambda: 3)
Expand Down Expand Up @@ -834,13 +842,15 @@ def extract_question(prompt: str) -> str:

def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
"""Run the task"""
tool = kwargs["tool"]
model = kwargs.get("model")
if "claude" in tool: # maintain backwards compatibility
model = "claude-3-sonnet-20240229"
print(f"MODEL: {model}")
with LLMClientManager(
kwargs["api_keys"], kwargs["llm_provider"], embedding_provider="openai"
kwargs["api_keys"], model, embedding_provider="openai"
):
tool = kwargs["tool"]
model = kwargs.get("model")
prompt = extract_question(kwargs["prompt"])
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeiflni5dkn5fqe7fnu4lgbqxzfrgochhqfbgzwz3vlf5grijp3nkpm
prediction_url_cot.py: bafybeifrkisrrphzyhqnvjqtwxynue7xlsmlqhggm5vcuceod2sl4td7ei
prediction_url_cot.py: bafybeihebxfv4xj22nq4mkch6xuddcnu7jv473zec2n5p65oxy63asjudy
fingerprint_ignore_patterns: []
entry_point: prediction_url_cot.py
callable: run
Expand Down
22 changes: 16 additions & 6 deletions packages/napthaai/customs/prediction_url_cot/prediction_url_cot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(
self, api_keys: List, llm_provider: str = None, embedding_provider: str = None
self, api_keys: List, model: str = None, embedding_provider: str = None
):
self.api_keys = api_keys
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
clients = []
Expand Down Expand Up @@ -180,6 +185,9 @@ def embeddings(self, model, input):
}
ALLOWED_TOOLS = [
"prediction-url-cot",

# LEGACY
"prediction-url-cot-claude",
]
ALLOWED_MODELS = list(LLM_SETTINGS.keys())
NUM_QUERIES = 5
Expand Down Expand Up @@ -583,11 +591,13 @@ def parser_prediction_response(response: str) -> str:

def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
"""Run the task"""
with LLMClientManager(kwargs["api_keys"], kwargs["llm_provider"]):
tool = kwargs["tool"]
model = kwargs.get("model")
tool = kwargs["tool"]
model = kwargs.get("model")
if "claude" in tool: # maintain backwards compatibility
model = "claude-3-sonnet-20240229"
print(f"MODEL: {model}")
with LLMClientManager(kwargs["api_keys"], model):
prompt = extract_question(kwargs["prompt"])
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
Expand Down
10 changes: 5 additions & 5 deletions packages/packages.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
"custom/valory/openai_request/0.1.0": "bafybeihjtddwwkvwzaltk6yhtkk3xxnwnkurdtyuy6ki5tpf7h5htvuxnq",
"custom/valory/prediction_request_embedding/0.1.0": "bafybeifnz5fzxvzyj3mmjpfsre3nzbdieuyjvnxqxuplopp5taz4qw7ys4",
"custom/valory/resolve_market/0.1.0": "bafybeiaag2e7rsdr3bwg6mlmfyom4vctsdapohco7z45pxhzjymepz3rya",
"custom/valory/prediction_request/0.1.0": "bafybeieq5lu3dtgz7svxr6eelbopyvravg3iiomvvtdv33ej5w7hgbjhja",
"custom/valory/prediction_request/0.1.0": "bafybeibnshbgciu6inzdjzxeysrwvsin4iitkgd4fkj7a2omjzbdrga2ue",
"custom/valory/stability_ai_request/0.1.0": "bafybeicyyteycvzj4lk33p4t7mspfarc5d5ktbysu7oqkv6woo4aouxira",
"custom/polywrap/prediction_with_research_report/0.1.0": "bafybeiewbcbfyjnyqyp4oou6ianxseakblwjyck22bd2doqojjk37uyxwy",
"custom/jhehemann/prediction_sum_url_content/0.1.0": "bafybeiby55g53cvc4vpbgww5awrlf6x67h7q7pg5xlhwber75ejdkh4twa",
"custom/psouranis/optimization_by_prompting/0.1.0": "bafybeihb3pyk5qcbj5ib7377p65tznzdsnwilyyhlkcvaj2scmfcpsh6ru",
"custom/nickcom007/sme_generation_request/0.1.0": "bafybeibqv4ru4lpufy2hvcb3swqhzuq2kejjxmlyepofx6l6mxce6lhiqq",
"custom/nickcom007/prediction_request_sme/0.1.0": "bafybeigsszaat6k5m5a3ljyem7xdhjflpcm24imtcscgst3tghpwhamglu",
"custom/napthaai/resolve_market_reasoning/0.1.0": "bafybeiewdqtfkee3od5kuktrhyzexy7466ea3w3to7vv6qnli6qutfrqaa",
"custom/napthaai/prediction_request_rag/0.1.0": "bafybeigb7hfqcuvkvsc54526hxwhl6utfj44dnbwiyabcdbghlr5ctkwuu",
"custom/napthaai/prediction_request_reasoning/0.1.0": "bafybeiati546f5fyhtwv6yo7zaq3xwtb635p3jp3h3f546stknpbkkyhou",
"custom/napthaai/prediction_request_rag/0.1.0": "bafybeif7ufhrlhpuegm6kpiw6jzye6jmp4fjvxgn3hwcv4vkolrrrmidmy",
"custom/napthaai/prediction_request_reasoning/0.1.0": "bafybeifzkvc6j5wbbremt2jqig4ozaackzpz3o5okkoihmm3wdpptpviz4",
"custom/valory/prepare_tx/0.1.0": "bafybeibjqckeb73df724lr4xkrmeh3woqwas4mswa7au65xnwag2edad2e",
"custom/valory/short_maker/0.1.0": "bafybeif63rt4lkopu3rc3l7sg6tebrrwg2lxqufjx6dx4hoda5yzax43fa",
"custom/napthaai/prediction_url_cot/0.1.0": "bafybeidk6s4nqtow6dxmslhjtxzbbnhzpeogyy33e2zpjmqdjijtqb6rz4",
"custom/napthaai/prediction_url_cot/0.1.0": "bafybeic3ch7wfhxqvwgoud7xotuu3khs4xch3ej35kox2gulya2hv65wbu",
"custom/napthaai/prediction_url_cot_claude/0.1.0": "bafybeicbjywni5hx5ssoiv6tnnjbqzsck6cmtsdpr6m562z6afogz5eh44",
"custom/napthaai/prediction_request_reasoning_claude/0.1.0": "bafybeihtx2cejxoy42jwk2i5m4evfzz537aic5njuawxnzdzwlo63kdduq",
"custom/napthaai/prediction_request_rag_claude/0.1.0": "bafybeickr32t7nmapuoymjyo3cf5rr2v2zapksxcivuqsgjr2gn6zo6y7y",
"custom/napthaai/prediction_request_rag_cohere/0.1.0": "bafybeig3xsmmb4bgbjong6uzvnurf4mwdisqwp3eidmeuo7hj42wkcbymm",
"custom/napthaai/prediction_request_rag_cohere/0.1.0": "bafybeie2qi27usujclcje524qu4w6iv3viouq3pxhs2yft3aw26nnerla4",
"protocol/valory/acn_data_share/0.1.0": "bafybeih5ydonnvrwvy2ygfqgfabkr47s4yw3uqxztmwyfprulwfsoe7ipq",
"protocol/valory/websocket_client/0.1.0": "bafybeifjk254sy65rna2k32kynzenutujwqndap2r222afvr3zezi27mx4",
"contract/valory/agent_mech/0.1.0": "bafybeiah6b5epo2hlvzg5rr2cydgpp2waausoyrpnoarf7oa7bw33rex34",
Expand Down
2 changes: 1 addition & 1 deletion packages/valory/customs/prediction_request/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeibbn67pnrrm4qm3n3kbelvbs3v7fjlrjniywmw2vbizarippidtvi
prediction_request.py: bafybeif3s6wd3gotqpg6qdcs7zjszhffkyffeb554r2j5xvtmrbsxy7oca
prediction_request.py: bafybeigf5k62mxbmcrvjvsnixpbn3hvxlp2l62sk7jtx5vs7fdg5cgtfxe
fingerprint_ignore_patterns: []
entry_point: prediction_request.py
callable: run
Expand Down
23 changes: 17 additions & 6 deletions packages/valory/customs/prediction_request/prediction_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@
class LLMClientManager:
"""Client context manager for LLMs."""

def __init__(self, api_keys: List, llm_provider: str = None):
def __init__(self, api_keys: List, model: str = None):
self.api_keys = api_keys
self.llm_provider = llm_provider
if "gpt" in model:
self.llm_provider = "openai"
elif "claude" in model:
self.llm_provider = "anthropic"
else:
self.llm_provider = "openrouter"

def __enter__(self):
global client
Expand Down Expand Up @@ -219,6 +224,10 @@ def count_tokens(text: str, model: str) -> int:
"prediction-offline",
"prediction-online",
# "prediction-online-summarized-info",

# LEGACY
"claude-prediction-offline",
"claude-prediction-online",
]
ALLOWED_MODELS = list(LLM_SETTINGS.keys())
# the default number of URLs to fetch online information for
Expand Down Expand Up @@ -653,11 +662,13 @@ def adjust_additional_information(

def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
"""Run the task"""
with LLMClientManager(kwargs["api_keys"], kwargs["llm_provider"]):
tool = kwargs["tool"]
tool = kwargs["tool"]
engine = kwargs.get("model")
if "claude" in tool: # maintain backwards compatibility
engine = "claude-3-sonnet-20240229"
print(f"ENGINE: {engine}")
with LLMClientManager(kwargs["api_keys"], engine):
prompt = kwargs["prompt"]
engine = kwargs.get("model")
print(f"ENGINE: {engine}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[engine]["default_max_tokens"]
)
Expand Down
Loading

0 comments on commit 82036d0

Please sign in to comment.