-
Notifications
You must be signed in to change notification settings - Fork 44.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(blocks): Add AI/ML API support to LLM blocks #9163
base: dev
Are you sure you want to change the base?
Changes from 8 commits
48c28c3
a2720f3
3591306
0d6202c
937b477
10fedd6
c8b8086
ff6fa8d
142671e
996f8d4
8ea7081
d164486
8f911a2
41414a7
847ffe5
9d93171
fda0ad1
ed951b0
78a9677
97b7fac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -35,6 +35,7 @@ | |||||
ProviderName.OLLAMA, | ||||||
ProviderName.OPENAI, | ||||||
ProviderName.OPEN_ROUTER, | ||||||
ProviderName.AIML, | ||||||
] | ||||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]] | ||||||
|
||||||
|
@@ -98,6 +99,12 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta): | |||||
# Anthropic models | ||||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest" | ||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307" | ||||||
# Aiml models | ||||||
aarushik93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
AIML_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo" | ||||||
AIML_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct" | ||||||
AIML_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo" | ||||||
AIML_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" | ||||||
AIML_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo" | ||||||
# Groq models | ||||||
LLAMA3_8B = "llama3-8b-8192" | ||||||
LLAMA3_70B = "llama3-70b-8192" | ||||||
|
@@ -154,6 +161,11 @@ def context_window(self) -> int: | |||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385), | ||||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000), | ||||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000), | ||||||
LlmModel.AIML_QWEN2_5_72B: ModelMetadata("aiml", 32000), | ||||||
LlmModel.AIML_LLAMA3_1_70B: ModelMetadata("aiml", 128000), | ||||||
LlmModel.AIML_LLAMA3_3_70B: ModelMetadata("aiml", 128000), | ||||||
LlmModel.AIML_META_LLAMA_3_1_70B: ModelMetadata("aiml", 131000), | ||||||
LlmModel.AIML_LLAMA_3_2_3B: ModelMetadata("aiml", 128000), | ||||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192), | ||||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192), | ||||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768), | ||||||
|
@@ -433,6 +445,23 @@ def llm_call( | |||||
response.usage.prompt_tokens if response.usage else 0, | ||||||
response.usage.completion_tokens if response.usage else 0, | ||||||
) | ||||||
elif provider == "aiml": | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
client = openai.OpenAI( | ||||||
base_url="https://api.aimlapi.com/v2", | ||||||
api_key=credentials.api_key.get_secret_value(), | ||||||
) | ||||||
|
||||||
completion = client.chat.completions.create( | ||||||
model=llm_model.value, | ||||||
messages=prompt, # type: ignore | ||||||
max_tokens=max_tokens, | ||||||
) | ||||||
|
||||||
return ( | ||||||
completion.choices[0].message.content or "", | ||||||
completion.usage.prompt_tokens if completion.usage else 0, | ||||||
completion.usage.completion_tokens if completion.usage else 0, | ||||||
) | ||||||
else: | ||||||
raise ValueError(f"Unsupported LLM provider: {provider}") | ||||||
|
||||||
|
@@ -509,16 +538,19 @@ def parse_response(resp: str) -> tuple[dict[str, Any], str | None]: | |||||
if input_data.expected_format: | ||||||
parsed_dict, parsed_error = parse_response(response_text) | ||||||
if not parsed_error: | ||||||
yield "response", { | ||||||
k: ( | ||||||
json.loads(v) | ||||||
if isinstance(v, str) | ||||||
and v.startswith("[") | ||||||
and v.endswith("]") | ||||||
else (", ".join(v) if isinstance(v, list) else v) | ||||||
) | ||||||
for k, v in parsed_dict.items() | ||||||
} | ||||||
yield ( | ||||||
"response", | ||||||
{ | ||||||
k: ( | ||||||
json.loads(v) | ||||||
if isinstance(v, str) | ||||||
and v.startswith("[") | ||||||
and v.endswith("]") | ||||||
else (", ".join(v) if isinstance(v, list) else v) | ||||||
) | ||||||
for k, v in parsed_dict.items() | ||||||
}, | ||||||
) | ||||||
return | ||||||
else: | ||||||
yield "response", {"response": response_text} | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -49,6 +49,13 @@ | |||||
title="Use Credits for OpenAI", | ||||||
expires_at=None, | ||||||
) | ||||||
aiml_credentials = APIKeyCredentials( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
id="aad82a89-9794-4ebb-977f-d736aa5260a3", | ||||||
provider="aiml", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
api_key=SecretStr(settings.secrets.aiml_api_key), | ||||||
title="Use Credits for AI/ML", | ||||||
aarushik93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
expires_at=None, | ||||||
) | ||||||
anthropic_credentials = APIKeyCredentials( | ||||||
id="24e5d942-d9e3-4798-8151-90143ee55629", | ||||||
provider="anthropic", | ||||||
|
@@ -98,6 +105,7 @@ | |||||
ideogram_credentials, | ||||||
replicate_credentials, | ||||||
openai_credentials, | ||||||
aiml_credentials, | ||||||
anthropic_credentials, | ||||||
groq_credentials, | ||||||
did_credentials, | ||||||
|
@@ -145,6 +153,8 @@ def get_all_creds(self, user_id: str) -> list[Credentials]: | |||||
all_credentials.append(replicate_credentials) | ||||||
if settings.secrets.openai_api_key: | ||||||
all_credentials.append(openai_credentials) | ||||||
if settings.secrets.aiml_api_key: | ||||||
all_credentials.append(aiml_credentials) | ||||||
if settings.secrets.anthropic_api_key: | ||||||
all_credentials.append(anthropic_credentials) | ||||||
if settings.secrets.did_api_key: | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -266,6 +266,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings): | |||||
) | ||||||
|
||||||
openai_api_key: str = Field(default="", description="OpenAI API key") | ||||||
aiml_api_key: str = Field(default="", description="AI/ML API key") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
anthropic_api_key: str = Field(default="", description="Anthropic API key") | ||||||
groq_api_key: str = Field(default="", description="Groq API key") | ||||||
open_router_api_key: str = Field(default="", description="Open Router API Key") | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -103,6 +103,7 @@ export default function PrivatePage() { | |||||
"6b9fc200-4726-4973-86c9-cd526f5ce5db", // Replicate | ||||||
"53c25cb8-e3ee-465c-a4d1-e75a4c899c2a", // OpenAI | ||||||
"24e5d942-d9e3-4798-8151-90143ee55629", // Anthropic | ||||||
"aad82a89-9794-4ebb-977f-d736aa5260a3", // AI/ML | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
"4ec22295-8f97-4dd1-b42b-2c6957a02545", // Groq | ||||||
"7f7b0654-c36b-4565-8fa7-9a52575dfae2", // D-ID | ||||||
"7f26de70-ba0d-494e-ba76-238e65e7b45f", // Jina | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -52,6 +52,7 @@ export const providerIcons: Record< | |||||
github: FaGithub, | ||||||
google: FaGoogle, | ||||||
groq: fallbackIcon, | ||||||
aiml: fallbackIcon, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
notion: NotionLogoIcon, | ||||||
discord: FaDiscord, | ||||||
d_id: fallbackIcon, | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -16,6 +16,7 @@ const CREDENTIALS_PROVIDER_NAMES = Object.values( | |||||
|
||||||
// --8<-- [start:CredentialsProviderNames] | ||||||
const providerDisplayNames: Record<CredentialsProviderName, string> = { | ||||||
aiml: "AI/ML", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
anthropic: "Anthropic", | ||||||
discord: "Discord", | ||||||
d_id: "D-ID", | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -102,6 +102,7 @@ export type CredentialsType = "api_key" | "oauth2"; | |||||
// --8<-- [start:BlockIOCredentialsSubSchema] | ||||||
export const PROVIDER_NAMES = { | ||||||
ANTHROPIC: "anthropic", | ||||||
AIML: "aiml", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
D_ID: "d_id", | ||||||
DISCORD: "discord", | ||||||
E2B: "e2b", | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from __future__ import annotations | ||
|
||
import enum | ||
import logging | ||
from typing import Any, Optional | ||
|
||
import tiktoken | ||
from pydantic import SecretStr | ||
|
||
from forge.models.config import UserConfigurable | ||
|
||
from ._openai_base import BaseOpenAIChatProvider | ||
from .schema import ( | ||
ChatModelInfo, | ||
ModelProviderBudget, | ||
ModelProviderConfiguration, | ||
ModelProviderCredentials, | ||
ModelProviderName, | ||
ModelProviderSettings, | ||
ModelTokenizer, | ||
) | ||
|
||
|
||
class AimlModelName(str, enum.Enum): | ||
AIML_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo" | ||
AIML_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct" | ||
AIML_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo" | ||
AIML_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo" | ||
AIML_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" | ||
|
||
|
||
AIML_CHAT_MODELS = { | ||
info.name: info | ||
for info in [ | ||
ChatModelInfo( | ||
name=AimlModelName.AIML_QWEN2_5_72B, | ||
provider_name=ModelProviderName.AIML, | ||
prompt_token_cost=1.26 / 1e6, | ||
completion_token_cost=1.26 / 1e6, | ||
max_tokens=32000, | ||
has_function_call_api=False, | ||
), | ||
ChatModelInfo( | ||
name=AimlModelName.AIML_LLAMA3_1_70B, | ||
provider_name=ModelProviderName.AIML, | ||
prompt_token_cost=0.368 / 1e6, | ||
completion_token_cost=0.42 / 1e6, | ||
max_tokens=128000, | ||
has_function_call_api=False, | ||
), | ||
ChatModelInfo( | ||
name=AimlModelName.AIML_LLAMA3_3_70B, | ||
provider_name=ModelProviderName.AIML, | ||
prompt_token_cost=0.924 / 1e6, | ||
completion_token_cost=0.924 / 1e6, | ||
max_tokens=128000, | ||
has_function_call_api=False, | ||
), | ||
ChatModelInfo( | ||
name=AimlModelName.AIML_META_LLAMA_3_1_70B, | ||
provider_name=ModelProviderName.AIML, | ||
prompt_token_cost=0.063 / 1e6, | ||
completion_token_cost=0.063 / 1e6, | ||
max_tokens=131000, | ||
has_function_call_api=False, | ||
), | ||
ChatModelInfo( | ||
name=AimlModelName.AIML_LLAMA_3_2_3B, | ||
provider_name=ModelProviderName.AIML, | ||
prompt_token_cost=0.924 / 1e6, | ||
completion_token_cost=0.924 / 1e6, | ||
max_tokens=128000, | ||
has_function_call_api=False, | ||
), | ||
] | ||
} | ||
|
||
|
||
class AimlCredentials(ModelProviderCredentials): | ||
"""Credentials for Aiml.""" | ||
|
||
api_key: SecretStr = UserConfigurable(from_env="AIML_API_KEY") # type: ignore | ||
api_base: Optional[SecretStr] = UserConfigurable( | ||
default=None, from_env="AIML_API_BASE_URL" | ||
) | ||
|
||
def get_api_access_kwargs(self) -> dict[str, str]: | ||
return { | ||
k: v.get_secret_value() | ||
for k, v in { | ||
"api_key": self.api_key, | ||
"base_url": self.api_base, | ||
}.items() | ||
if v is not None | ||
} | ||
|
||
|
||
class AimlSettings(ModelProviderSettings): | ||
credentials: Optional[AimlCredentials] # type: ignore | ||
budget: ModelProviderBudget # type: ignore | ||
|
||
|
||
class AimlProvider(BaseOpenAIChatProvider[AimlModelName, AimlSettings]): | ||
CHAT_MODELS = AIML_CHAT_MODELS | ||
MODELS = CHAT_MODELS | ||
|
||
default_settings = AimlSettings( | ||
name="aiml_provider", | ||
description="Provides access to AIML's API.", | ||
configuration=ModelProviderConfiguration(), | ||
credentials=None, | ||
budget=ModelProviderBudget(), | ||
) | ||
|
||
_settings: AimlSettings | ||
_configuration: ModelProviderConfiguration | ||
_credentials: AimlCredentials | ||
_budget: ModelProviderBudget | ||
|
||
def __init__( | ||
self, | ||
settings: Optional[AimlSettings] = None, | ||
logger: Optional[logging.Logger] = None, | ||
): | ||
super(AimlProvider, self).__init__(settings=settings, logger=logger) | ||
|
||
from openai import AsyncOpenAI | ||
|
||
self._client = AsyncOpenAI( | ||
**self._credentials.get_api_access_kwargs() # type: ignore | ||
) | ||
|
||
def get_tokenizer(self, model_name: AimlModelName) -> ModelTokenizer[Any]: | ||
# HACK: No official tokenizer is available for AIML | ||
return tiktoken.encoding_for_model("gpt-3.5-turbo") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for all other occurrences