diff --git a/.changeset/lucky-cars-follow.md b/.changeset/lucky-cars-follow.md new file mode 100644 index 0000000000000..b793dec434223 --- /dev/null +++ b/.changeset/lucky-cars-follow.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Add support for inference `providers` in `gr.load()` diff --git a/gradio/external.py b/gradio/external.py index 5fc4ebe78c6c0..463e32aa4a29e 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -31,6 +31,8 @@ from gradio.processing_utils import save_base64_to_cache, to_binary if TYPE_CHECKING: + from huggingface_hub.inference._providers import PROVIDER_T + from gradio.blocks import Blocks from gradio.chat_interface import ChatInterface from gradio.components.chatbot import MessageDict @@ -46,6 +48,7 @@ def load( token: str | None = None, hf_token: str | None = None, accept_token: bool = False, + provider: PROVIDER_T | None = None, **kwargs, ) -> Blocks: """ @@ -56,6 +59,7 @@ def load( token: optional token that is passed as the second parameter to the `src` function. If not explicitly provided, will use the HF_TOKEN environment variable or fallback to the locally-saved HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find your HF tokens here: https://huggingface.co/settings/tokens. accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space. kwargs: additional keyword parameters to pass into the `src` function. If `src` is "models" or "Spaces", these parameters are passed into the `gr.Interface` or `gr.ChatInterface` constructor. + provider: the name of the third-party (non-Hugging Face) providers to use for model inference (e.g. "replicate", "sambanova", "fal-ai", etc). Should be one of the providers supported by `huggingface_hub.InferenceClient`. This parameter is only used when `src` is "models" Returns: a Gradio Blocks app for the given model Example: @@ -93,7 +97,7 @@ def load( if isinstance(src, Callable): return src(name, token, **kwargs) return load_blocks_from_huggingface( - name=name, src=src, hf_token=token, **kwargs + name=name, src=src, hf_token=token, provider=provider, **kwargs ) else: import gradio as gr @@ -150,6 +154,7 @@ def load_blocks_from_huggingface( src: str, hf_token: str | Literal[False] | None = None, alias: str | None = None, + provider: PROVIDER_T | None = None, **kwargs, ) -> Blocks: """Creates and returns a Blocks instance from a Hugging Face model or Space repo.""" @@ -168,18 +173,24 @@ def load_blocks_from_huggingface( if src == "spaces" and hf_token is None: hf_token = False # Since Spaces can read the token, we don't want to pass it in unless the user explicitly provides it - blocks: gradio.Blocks = factory_methods[src](name, hf_token, alias, **kwargs) + blocks: gradio.Blocks = factory_methods[src]( + name, hf_token=hf_token, alias=alias, provider=provider, **kwargs + ) return blocks def from_model( - model_name: str, hf_token: str | Literal[False] | None, alias: str | None, **kwargs + model_name: str, + hf_token: str | Literal[False] | None, + alias: str | None, + provider: PROVIDER_T | None = None, + **kwargs, ) -> Blocks: headers = {"X-Wait-For-Model": "true"} if hf_token is False: hf_token = None client = huggingface_hub.InferenceClient( - model=model_name, headers=headers, token=hf_token + model=model_name, headers=headers, token=hf_token, provider=provider ) p, tags = external_utils.get_model_info(model_name, hf_token) @@ -458,8 +469,17 @@ def query_huggingface_inference_endpoints(*data): def from_spaces( - space_name: str, hf_token: str | None | Literal[False], alias: str | None, **kwargs + space_name: str, + hf_token: str | None | Literal[False], + alias: str | None, + provider: PROVIDER_T | None = None, + **kwargs, ) -> Blocks: + if provider is not None: + warnings.warn( + "The `provider` parameter is not supported when loading Spaces. It will be ignored." + ) + space_url = f"https://huggingface.co/spaces/{space_name}" print(f"Fetching Space from: {space_url}") diff --git a/requirements.txt b/requirements.txt index 7dc0829a5852a..c0e4847a20133 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ fastapi>=0.115.2,<1.0 ffmpy gradio_client==1.7.0 httpx>=0.24.1 -huggingface_hub>=0.25.1 +huggingface_hub>=0.28.1 Jinja2<4.0 markupsafe~=2.0 numpy>=1.0,<3.0