Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modal launch hf-download utility #2744

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion modal/cli/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Any, Optional

from typer import Typer
from typer import Argument, Option, Typer

from ..app import App
from ..exception import _CliUserExecutionError
Expand Down Expand Up @@ -93,3 +93,49 @@ def vscode(
"volume": volume,
}
_launch_program("vscode", "vscode.py", detach, args)


@launch_cli.command(name="hf-download")
def hf_download(
repo_id: str = Argument(help="The Hugging Face repository ID"),
volume: str = Argument(help="The name of the Modal volume to use for caching"),
secret: Optional[str] = Option(None, help="The name of a Modal secret with Hugging Face credentials"),
timeout: int = Option(600, help="Maximum time the download is allowed to run (in seconds)"),
type: Optional[str] = Option(None, help="The repository type(e.g. 'model' or 'dataset')"),
revision: Optional[str] = Option(None, help="A specific revision to download"),
ignore: list[str] = Option([], help="Ignore patterns to skip downloading matching files"),
allow: list[str] = Option([], help="Allow patterns to selectively download matching files"),
detatch: bool = Option(False, "--detach", help="Allow the download to continue if the local client disconnects"),
):
"""Download a snapshot from the Hugging Face Hub.

This command uses Hugging Face's `snapshot_download` function to download a snapshot from a repository
on Hugging Face's Hub and cache it in a Modal volume. In your Modal applications, if you mount the
Volume at a location corresponding to the `HF_HUB_CACHE` environment variable, Hugging Face will load
data from the cache instead of downloading it.

\b```
modal launch hf-download microsoft/Phi-3-mini hf-hub-cache --secret "hf-secret" --timeout 600 --detach
```

Then in your Modal App:

\b```
volume = modal.Volume.from_name("hf-hub-cache")
@app.function(volumes={HF_HUB_CACHE: volume})
def f():
model = ModelClass.from_pretrained(model_name, cache_dir=HF_HUB_CACHE)
```

"""
args = {
"volume": volume,
"secret": secret,
"timeout": timeout,
"repo_id": repo_id,
"type": type,
"revision": revision,
"ignore": ignore,
"allow": allow,
}
_launch_program("hf-download", "hf_download.py", detatch, args)
43 changes: 43 additions & 0 deletions modal/cli/programs/hf_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright Modal Labs 2025
import json
import os
import time

import modal

# Passed by `modal launch` locally via CLI, plumbed to remote runner through secrets.
args: dict = json.loads(os.environ.get("MODAL_LAUNCH_ARGS", "{}"))

CACHE_DIR = "/hf-cache"

image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install("huggingface-hub[hf-transfer]==0.27.1")
.env({"HF_HUB_CACHE": CACHE_DIR, "HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
volume = modal.Volume.from_name(str(args.get("volume")))
secrets = [modal.Secret.from_dict({"MODAL_LAUNCH_ARGS": json.dumps(args)})]
if user_secret := args.get("secret"):
secrets.append(modal.Secret.from_name(user_secret))
app = modal.App("hf-download", image=image, secrets=secrets, volumes={CACHE_DIR: volume})


@app.function(cpu=4, memory=1028, timeout=int(args.get("timeout", 600)))
def run():
from huggingface_hub import snapshot_download

t0 = time.monotonic()
snapshot_download(
repo_id=args.get("repo_id"),
repo_type=args.get("type"),
revision=args.get("revision"),
ignore_patterns=args.get("ignore", []),
allow_patterns=args.get("allow", []),
cache_dir=CACHE_DIR,
)
print(f"Completed in {time.monotonic() - t0:.2f} seconds")


@app.local_entrypoint()
def main():
run.remote()
Loading