Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download from HuggingFace with private token #390

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/src/getting-started/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ a file path.

mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt

(Downloading private models is also supported, only through HuggingFace and using the
``--huggingface_api_token`` flag.)
Comment on lines +49 to +50
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(Downloading private models is also supported, only through HuggingFace and using the
``--huggingface_api_token`` flag.)
Downloading private HuggingFace models is also supported, by specifying the corresponding API token with the ``--huggingface_api_token`` flag.


Keep in mind that a checkpoint (``.ckpt``) is only a temporary file, which can have
several dependencies and may become unusable if the corresponding architecture is
updated. In constrast, exported models (``.pt``) act as standalone files. For long-term
Expand Down
20 changes: 18 additions & 2 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,30 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
default="exported-model.pt",
help="Filename of the exported model (default: %(default)s).",
)
parser.add_argument(
"--huggingface_api_token",
dest="huggingface_api_token",
type=str,
required=False,
default="",
help="API token to download a private model from HuggingFace.",
)


def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
path = args.__dict__.pop("path")
architecture_name = args.__dict__.pop("architecture_name")
args.model = load_model(
path=args.__dict__.pop("path"),
architecture_name=args.__dict__.pop("architecture_name"),
path=path,
architecture_name=architecture_name,
**args.__dict__,
)
keys_to_keep = ["model", "output"] # only these are needed for `export_model``
original_keys = list(args.__dict__.keys())
for key in original_keys:
if key not in keys_to_keep:
args.__dict__.pop(key)


def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None:
Expand Down
23 changes: 15 additions & 8 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,24 @@ def forward(
def load_checkpoint(cls, path: Union[str, Path]) -> "PET":

checkpoint = torch.load(path, weights_only=False, map_location="cpu")
hypers = checkpoint["hypers"]
if "checkpoint" in checkpoint:
# This is the case when the checkpoint was saved with the Trainer
state_dict = checkpoint["checkpoint"]["model_state_dict"]
model_hypers = checkpoint["hypers"]["ARCHITECTURAL_HYPERS"]
self_contributions = checkpoint["self_contributions"]
elif "model_state_dict" in checkpoint:
# This is the case when the checkpoint was saved for the
# HuggingFace API
state_dict = checkpoint["model_state_dict"]
model_hypers = checkpoint["model_hypers"]
self_contributions = state_dict.pop("pet.self_contributions").numpy()
else:
raise ValueError("Invalid checkpoint format")
dataset_info = checkpoint["dataset_info"]
model = cls(
model_hypers=hypers["ARCHITECTURAL_HYPERS"], dataset_info=dataset_info
)

checkpoint = torch.load(path, weights_only=False)
state_dict = checkpoint["checkpoint"]["model_state_dict"]
model = cls(model_hypers=model_hypers, dataset_info=dataset_info)

ARCHITECTURAL_HYPERS = Hypers(model.hypers)
ARCHITECTURAL_HYPERS = Hypers(model_hypers)
raw_pet = RawPET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types))
if ARCHITECTURAL_HYPERS.USE_LORA_PEFT:
lora_rank = ARCHITECTURAL_HYPERS.LORA_RANK
Expand All @@ -160,7 +168,6 @@ def load_checkpoint(cls, path: Union[str, Path]) -> "PET":
dtype = next(iter(new_state_dict.values())).dtype
raw_pet.to(dtype).load_state_dict(new_state_dict)

self_contributions = checkpoint["self_contributions"]
wrapper = SelfContributionsWrapper(raw_pet, self_contributions)

model.to(dtype).set_trained_model(wrapper)
Expand Down
5 changes: 4 additions & 1 deletion src/metatrain/experimental/pet/utils/update_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ def update_state_dict(state_dict: Dict) -> Dict:
"""
new_state_dict = {}
for name, value in state_dict.items():
name = name.split("pet_model.")[1]
if "pet_model" in name:
name = name.split("pet_model.")[1]
else:
name = name.replace("pet.model.", "")
new_state_dict[name] = value
return new_state_dict
49 changes: 42 additions & 7 deletions src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import shutil
import warnings
from pathlib import Path
from typing import Any, Optional, Union
Expand Down Expand Up @@ -65,6 +66,7 @@
path: Union[str, Path],
extensions_directory: Optional[Union[str, Path]] = None,
architecture_name: Optional[str] = None,
**kwargs,
) -> Any:
"""Load checkpoints and exported models from an URL or a local file.

Expand Down Expand Up @@ -99,15 +101,48 @@
)

if Path(path).suffix in [".yaml", ".yml"]:
raise ValueError(f"path '{path}' seems to be a YAML option file and no model")
raise ValueError(
f"path '{path}' seems to be a YAML option file and not a model"
)

if urlparse(str(path)).scheme:
# Download from HuggingFace with a private token
if kwargs.get("huggingface_api_token"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we/can we had a test for this on CI? Maybe creating a cosmo account on HuggingFace & uploading a dummy model+checkpoint there.

try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError(

Check warning on line 113 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L110-L113

Added lines #L110 - L113 were not covered by tests
"To download a model from HuggingFace, please install the "
"`huggingface_hub` package with pip (`pip install "
"huggingface_hub`)."
)
path = str(path)
if not path.startswith("https://huggingface.co/"):
raise ValueError(

Check warning on line 120 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L118-L120

Added lines #L118 - L120 were not covered by tests
f"Invalid URL '{path}'. HuggingFace models should start with "
"'https://huggingface.co/'."
)
# get repo_id and filename
split_path = path.split("/")
repo_id = f"{split_path[3]}/{split_path[4]}" # org/repo
filename = ""
for i in range(5, len(split_path)):
filename += split_path[i] + "/"
filename = filename[:-1]
path = hf_hub_download(repo_id, filename, token=kwargs["huggingface_api_token"])

Check warning on line 131 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L125-L131

Added lines #L125 - L131 were not covered by tests
# make sure to copy the checkpoint to the current directory
shutil.copy(path, Path.cwd() / str(path).split("/")[-1])

Check warning on line 133 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L133

Added line #L133 was not covered by tests

elif urlparse(str(path)).scheme:
path, _ = urlretrieve(str(path))
# make sure to copy the checkpoint to the current directory
shutil.copy(path, Path.cwd() / str(path).split("/")[-1])
Comment on lines +137 to +138
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont understand why this is required?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if you want the file to be directly downloaded to the current directory you can do

path, _ = urlretrieve(
    url=str(path),
    filename=str(Path.cwd() / str(path).split("/")[-1]))

See https://docs.python.org/3/library/urllib.request.html#urllib.request.urlretrieve.


if is_exported_file(str(path)):
return load_atomistic_model(
str(path), extensions_directory=extensions_directory
)
else:
pass

path = str(path)
if is_exported_file(path):
return load_atomistic_model(path, extensions_directory=extensions_directory)
else: # model is a checkpoint
if architecture_name is None:
raise ValueError(
Expand All @@ -117,7 +152,7 @@
architecture = import_architecture(architecture_name)

try:
return architecture.__model__.load_checkpoint(str(path))
return architecture.__model__.load_checkpoint(path)
except Exception as err:
raise ValueError(
f"path '{path}' is not a valid model file for the {architecture_name} "
Expand Down
5 changes: 4 additions & 1 deletion tests/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def test_is_exported_file():
def test_load_model_checkpoint(path):
model = load_model(path, architecture_name="experimental.soap_bpnn")
assert type(model) is SoapBpnn
if str(path).startswith("file:"):
# test that the checkpoint is also copied to the current directory
assert Path("model-32-bit.ckpt").exists()


@pytest.mark.parametrize(
Expand All @@ -64,7 +67,7 @@ def test_load_model_exported(path):

@pytest.mark.parametrize("suffix", [".yml", ".yaml"])
def test_load_model_yaml(suffix):
match = f"path 'foo{suffix}' seems to be a YAML option file and no model"
match = f"path 'foo{suffix}' seems to be a YAML option file and not a model"
with pytest.raises(ValueError, match=match):
load_model(
f"foo{suffix}",
Expand Down