Skip to content

Commit

Permalink
Add support for inference providers in gr.load() (#10496)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* format

* changes

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Feb 4, 2025
1 parent 10932a2 commit a9bfbc3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
5 changes: 5 additions & 0 deletions .changeset/lucky-cars-follow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Add support for inference `providers` in `gr.load()`
30 changes: 25 additions & 5 deletions gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from gradio.processing_utils import save_base64_to_cache, to_binary

if TYPE_CHECKING:
from huggingface_hub.inference._providers import PROVIDER_T

from gradio.blocks import Blocks
from gradio.chat_interface import ChatInterface
from gradio.components.chatbot import MessageDict
Expand All @@ -46,6 +48,7 @@ def load(
token: str | None = None,
hf_token: str | None = None,
accept_token: bool = False,
provider: PROVIDER_T | None = None,
**kwargs,
) -> Blocks:
"""
Expand All @@ -56,6 +59,7 @@ def load(
token: optional token that is passed as the second parameter to the `src` function. If not explicitly provided, will use the HF_TOKEN environment variable or fallback to the locally-saved HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find your HF tokens here: https://huggingface.co/settings/tokens.
accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space.
kwargs: additional keyword parameters to pass into the `src` function. If `src` is "models" or "Spaces", these parameters are passed into the `gr.Interface` or `gr.ChatInterface` constructor.
provider: the name of the third-party (non-Hugging Face) providers to use for model inference (e.g. "replicate", "sambanova", "fal-ai", etc). Should be one of the providers supported by `huggingface_hub.InferenceClient`. This parameter is only used when `src` is "models"
Returns:
a Gradio Blocks app for the given model
Example:
Expand Down Expand Up @@ -93,7 +97,7 @@ def load(
if isinstance(src, Callable):
return src(name, token, **kwargs)
return load_blocks_from_huggingface(
name=name, src=src, hf_token=token, **kwargs
name=name, src=src, hf_token=token, provider=provider, **kwargs
)
else:
import gradio as gr
Expand Down Expand Up @@ -150,6 +154,7 @@ def load_blocks_from_huggingface(
src: str,
hf_token: str | Literal[False] | None = None,
alias: str | None = None,
provider: PROVIDER_T | None = None,
**kwargs,
) -> Blocks:
"""Creates and returns a Blocks instance from a Hugging Face model or Space repo."""
Expand All @@ -168,18 +173,24 @@ def load_blocks_from_huggingface(

if src == "spaces" and hf_token is None:
hf_token = False # Since Spaces can read the token, we don't want to pass it in unless the user explicitly provides it
blocks: gradio.Blocks = factory_methods[src](name, hf_token, alias, **kwargs)
blocks: gradio.Blocks = factory_methods[src](
name, hf_token=hf_token, alias=alias, provider=provider, **kwargs
)
return blocks


def from_model(
model_name: str, hf_token: str | Literal[False] | None, alias: str | None, **kwargs
model_name: str,
hf_token: str | Literal[False] | None,
alias: str | None,
provider: PROVIDER_T | None = None,
**kwargs,
) -> Blocks:
headers = {"X-Wait-For-Model": "true"}
if hf_token is False:
hf_token = None
client = huggingface_hub.InferenceClient(
model=model_name, headers=headers, token=hf_token
model=model_name, headers=headers, token=hf_token, provider=provider
)
p, tags = external_utils.get_model_info(model_name, hf_token)

Expand Down Expand Up @@ -458,8 +469,17 @@ def query_huggingface_inference_endpoints(*data):


def from_spaces(
space_name: str, hf_token: str | None | Literal[False], alias: str | None, **kwargs
space_name: str,
hf_token: str | None | Literal[False],
alias: str | None,
provider: PROVIDER_T | None = None,
**kwargs,
) -> Blocks:
if provider is not None:
warnings.warn(
"The `provider` parameter is not supported when loading Spaces. It will be ignored."
)

space_url = f"https://huggingface.co/spaces/{space_name}"

print(f"Fetching Space from: {space_url}")
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ fastapi>=0.115.2,<1.0
ffmpy
gradio_client==1.7.0
httpx>=0.24.1
huggingface_hub>=0.25.1
huggingface_hub>=0.28.1
Jinja2<4.0
markupsafe~=2.0
numpy>=1.0,<3.0
Expand Down

0 comments on commit a9bfbc3

Please sign in to comment.