diff --git a/g4f/Provider/Blackbox2.py b/g4f/Provider/Blackbox2.py index ce949b8b30a..f27a25595f8 100644 --- a/g4f/Provider/Blackbox2.py +++ b/g4f/Provider/Blackbox2.py @@ -6,7 +6,7 @@ import json from pathlib import Path from aiohttp import ClientSession -from typing import AsyncGenerator +from typing import AsyncIterator from ..typing import AsyncResult, Messages from ..image import ImageResponse @@ -21,12 +21,12 @@ class Blackbox2(AsyncGeneratorProvider, ProviderModelMixin): "llama-3.1-70b": "https://www.blackbox.ai/api/improve-prompt", "flux": "https://www.blackbox.ai/api/image-generator" } - + working = True supports_system_message = True supports_message_history = True supports_stream = False - + default_model = 'llama-3.1-70b' chat_models = ['llama-3.1-70b'] image_models = ['flux'] @@ -97,15 +97,14 @@ async def create_async_generator( messages: Messages, prompt: str = None, proxy: str = None, - prompt: str = None, max_retries: int = 3, delay: int = 1, max_tokens: int = None, **kwargs - ) -> AsyncGenerator[str, None]: + ) -> AsyncResult: if not model: model = cls.default_model - + if model in cls.chat_models: async for result in cls._generate_text(model, messages, proxy, max_retries, delay, max_tokens): yield result @@ -125,13 +124,13 @@ async def _generate_text( max_retries: int = 3, delay: int = 1, max_tokens: int = None, - ) -> AsyncGenerator[str, None]: + ) -> AsyncIterator[str]: headers = cls._get_headers() async with ClientSession(headers=headers) as session: license_key = await cls._get_license_key(session) api_endpoint = cls.api_endpoints[model] - + data = { "messages": messages, "max_tokens": max_tokens, @@ -162,7 +161,7 @@ async def _generate_image( model: str, prompt: str, proxy: str = None - ) -> AsyncGenerator[ImageResponse, None]: + ) -> AsyncIterator[ImageResponse]: headers = cls._get_headers() api_endpoint = cls.api_endpoints[model] @@ -170,11 +169,11 @@ async def _generate_image( data = { "query": prompt } - + async with session.post(api_endpoint, headers=headers, json=data, proxy=proxy) as response: response.raise_for_status() response_data = await response.json() - + if 'markdown' in response_data: image_url = response_data['markdown'].split('(')[1].split(')')[0] yield ImageResponse(images=image_url, alt=prompt) diff --git a/g4f/Provider/Cloudflare.py b/g4f/Provider/Cloudflare.py index 7d477d57327..4416f7a315e 100644 --- a/g4f/Provider/Cloudflare.py +++ b/g4f/Provider/Cloudflare.py @@ -2,7 +2,6 @@ import asyncio import json -import uuid from ..typing import AsyncResult, Messages, Cookies from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop @@ -37,18 +36,16 @@ def get_models(cls) -> str: if not cls.models: if cls._args is None: get_running_loop(check_nested=True) - args = get_args_from_nodriver(cls.url, cookies={ - '__cf_bm': uuid.uuid4().hex, - }) + args = get_args_from_nodriver(cls.url) cls._args = asyncio.run(args) with Session(**cls._args) as session: response = session.get(cls.models_url) cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response) try: raise_for_status(response) - except ResponseStatusError as e: + except ResponseStatusError: cls._args = None - raise e + raise json_data = response.json() cls.models = [model.get("name") for model in json_data.get("models")] return cls.models @@ -64,9 +61,9 @@ async def create_async_generator( timeout: int = 300, **kwargs ) -> AsyncResult: - model = cls.get_model(model) if cls._args is None: cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies) + model = cls.get_model(model) data = { "messages": messages, "lora": None, diff --git a/g4f/Provider/PollinationsAI.py b/g4f/Provider/PollinationsAI.py index 9520674a188..31a7e7e436b 100644 --- a/g4f/Provider/PollinationsAI.py +++ b/g4f/Provider/PollinationsAI.py @@ -40,7 +40,7 @@ class PollinationsAI(OpenaiAPI): } @classmethod - def get_models(cls): + def get_models(cls, **kwargs): if not hasattr(cls, 'image_models'): cls.image_models = [] if not cls.image_models: diff --git a/g4f/Provider/needs_auth/Cerebras.py b/g4f/Provider/needs_auth/Cerebras.py index df34db0eec3..86b2dcbda99 100644 --- a/g4f/Provider/needs_auth/Cerebras.py +++ b/g4f/Provider/needs_auth/Cerebras.py @@ -16,6 +16,7 @@ class Cerebras(OpenaiAPI): models = [ "llama3.1-70b", "llama3.1-8b", + "llama-3.3-70b" ] model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"} @@ -29,14 +30,15 @@ async def create_async_generator( cookies: Cookies = None, **kwargs ) -> AsyncResult: - if api_key is None and cookies is None: - cookies = get_cookies(".cerebras.ai") - async with ClientSession(cookies=cookies) as session: - async with session.get("https://inference.cerebras.ai/api/auth/session") as response: - raise_for_status(response) - data = await response.json() - if data: - api_key = data.get("user", {}).get("demoApiKey") + if api_key is None: + if cookies is None: + cookies = get_cookies(".cerebras.ai") + async with ClientSession(cookies=cookies) as session: + async with session.get("https://inference.cerebras.ai/api/auth/session") as response: + await raise_for_status(response) + data = await response.json() + if data: + api_key = data.get("user", {}).get("demoApiKey") async for chunk in super().create_async_generator( model, messages, api_base=api_base, diff --git a/g4f/Provider/needs_auth/DeepInfra.py b/g4f/Provider/needs_auth/DeepInfra.py index 35e7ca7f85e..035effb072c 100644 --- a/g4f/Provider/needs_auth/DeepInfra.py +++ b/g4f/Provider/needs_auth/DeepInfra.py @@ -14,7 +14,7 @@ class DeepInfra(OpenaiAPI): default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct" @classmethod - def get_models(cls): + def get_models(cls, **kwargs): if not cls.models: url = 'https://api.deepinfra.com/models/featured' models = requests.get(url).json() diff --git a/g4f/Provider/needs_auth/GeminiPro.py b/g4f/Provider/needs_auth/GeminiPro.py index 36c906563c6..22c9c015885 100644 --- a/g4f/Provider/needs_auth/GeminiPro.py +++ b/g4f/Provider/needs_auth/GeminiPro.py @@ -2,30 +2,52 @@ import base64 import json +import requests from aiohttp import ClientSession, BaseConnector from ...typing import AsyncResult, Messages, ImagesType -from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ...image import to_bytes, is_accepted_format from ...errors import MissingAuthError +from ...requests.raise_for_status import raise_for_status +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import get_connector +from ... import debug class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): label = "Google Gemini API" url = "https://ai.google.dev" - + api_base = "https://generativelanguage.googleapis.com/v1beta" + working = True supports_message_history = True needs_auth = True - + default_model = "gemini-1.5-pro" default_vision_model = default_model - models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"] + fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"] model_aliases = { "gemini-flash": "gemini-1.5-flash", "gemini-flash": "gemini-1.5-flash-8b", } + @classmethod + def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]: + if not cls.models: + try: + response = requests.get(f"{api_base}/models?key={api_key}") + raise_for_status(response) + data = response.json() + cls.models = [ + model.get("name").split("/").pop() + for model in data.get("models") + if "generateContent" in model.get("supportedGenerationMethods") + ] + cls.models.sort() + except Exception as e: + debug.log(e) + cls.models = cls.fallback_models + return cls.models + @classmethod async def create_async_generator( cls, @@ -34,17 +56,17 @@ async def create_async_generator( stream: bool = False, proxy: str = None, api_key: str = None, - api_base: str = "https://generativelanguage.googleapis.com/v1beta", + api_base: str = api_base, use_auth_header: bool = False, images: ImagesType = None, connector: BaseConnector = None, **kwargs ) -> AsyncResult: - model = cls.get_model(model) - if not api_key: raise MissingAuthError('Add a "api_key"') + model = cls.get_model(model, api_key=api_key, api_base=api_base) + headers = params = None if use_auth_header: headers = {"Authorization": f"Bearer {api_key}"} diff --git a/g4f/Provider/needs_auth/OpenaiAPI.py b/g4f/Provider/needs_auth/OpenaiAPI.py index ebc4d5192d9..a61115eaab9 100644 --- a/g4f/Provider/needs_auth/OpenaiAPI.py +++ b/g4f/Provider/needs_auth/OpenaiAPI.py @@ -23,13 +23,13 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin): fallback_models = [] @classmethod - def get_models(cls, api_key: str = None): + def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]: if not cls.models: try: headers = {} if api_key is not None: headers["authorization"] = f"Bearer {api_key}" - response = requests.get(f"{cls.api_base}/models", headers=headers) + response = requests.get(f"{api_base}/models", headers=headers) raise_for_status(response) data = response.json() cls.models = [model.get("id") for model in data.get("data")] @@ -82,7 +82,7 @@ async def create_async_generator( ) as session: data = filter_none( messages=messages, - model=cls.get_model(model), + model=cls.get_model(model, api_key=api_key, api_base=api_base), temperature=temperature, max_tokens=max_tokens, top_p=top_p, @@ -147,4 +147,4 @@ def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> if api_key is not None else {} ), **({} if headers is None else headers) - } + } \ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 019761985de..0e25a28d8fa 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -504,7 +504,7 @@ async def synthesize(cls, params: dict) -> AsyncIterator[bytes]: await cls.login() async with StreamSession( impersonate="chrome", - timeout=900 + timeout=0 ) as session: async with session.get( f"{cls.url}/backend-api/synthesize", diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html index c53284ac615..5ddc1104f2f 100644 --- a/g4f/gui/client/index.html +++ b/g4f/gui/client/index.html @@ -150,47 +150,47 @@

Settings

-
+ -
+ -
+ -
+ -
+ -
+ -
- - + -
+ -
+ -
+ -
+ diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css index 10f20607ce9..a499779a644 100644 --- a/g4f/gui/client/static/css/style.css +++ b/g4f/gui/client/static/css/style.css @@ -1022,7 +1022,8 @@ ul { } .settings h3 { - padding-left: 50px; + padding-left: 54px; + padding-top: 18px; } .buttons { diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index f8bd894dc8e..222886e9696 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -1293,6 +1293,7 @@ const load_provider_option = (input, provider_name) => { providerSelect.querySelectorAll(`option[data-parent="${provider_name}"]`).forEach( (el) => el.removeAttribute("disabled") ); + settings.querySelector(`.field:has(#${provider_name}-api_key)`)?.classList.remove("hidden"); } else { modelSelect.querySelectorAll(`option[data-providers*="${provider_name}"]`).forEach( (el) => { @@ -1307,6 +1308,7 @@ const load_provider_option = (input, provider_name) => { providerSelect.querySelectorAll(`option[data-parent="${provider_name}"]`).forEach( (el) => el.setAttribute("disabled", "disabled") ); + settings.querySelector(`.field:has(#${provider_name}-api_key)`)?.classList.add("hidden"); } }; diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index 0cdcde90e67..e2c356e338b 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -243,19 +243,20 @@ class ProviderModelMixin: last_model: str = None @classmethod - def get_models(cls) -> list[str]: + def get_models(cls, **kwargs) -> list[str]: if not cls.models and cls.default_model is not None: return [cls.default_model] return cls.models @classmethod - def get_model(cls, model: str) -> str: + def get_model(cls, model: str, **kwargs) -> str: if not model and cls.default_model is not None: model = cls.default_model elif model in cls.model_aliases: model = cls.model_aliases[model] - elif model not in cls.get_models() and cls.models: - raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") + else: + if model not in cls.get_models(**kwargs) and cls.models: + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") cls.last_model = model debug.last_model = model return model \ No newline at end of file