-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Download from HuggingFace with private token #390
base: main
Are you sure you want to change the base?
Changes from all commits
4bda3db
e2aa785
7de66e9
b69db89
ce06288
8cf49c4
918efa1
a18fd37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import shutil | ||
import warnings | ||
from pathlib import Path | ||
from typing import Any, Optional, Union | ||
|
@@ -65,6 +66,7 @@ | |
path: Union[str, Path], | ||
extensions_directory: Optional[Union[str, Path]] = None, | ||
architecture_name: Optional[str] = None, | ||
**kwargs, | ||
) -> Any: | ||
"""Load checkpoints and exported models from an URL or a local file. | ||
|
||
|
@@ -99,15 +101,48 @@ | |
) | ||
|
||
if Path(path).suffix in [".yaml", ".yml"]: | ||
raise ValueError(f"path '{path}' seems to be a YAML option file and no model") | ||
raise ValueError( | ||
f"path '{path}' seems to be a YAML option file and not a model" | ||
) | ||
|
||
if urlparse(str(path)).scheme: | ||
# Download from HuggingFace with a private token | ||
if kwargs.get("huggingface_api_token"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we/can we had a test for this on CI? Maybe creating a cosmo account on HuggingFace & uploading a dummy model+checkpoint there. |
||
try: | ||
from huggingface_hub import hf_hub_download | ||
except ImportError: | ||
raise ImportError( | ||
"To download a model from HuggingFace, please install the " | ||
"`huggingface_hub` package with pip (`pip install " | ||
"huggingface_hub`)." | ||
) | ||
path = str(path) | ||
if not path.startswith("https://huggingface.co/"): | ||
raise ValueError( | ||
f"Invalid URL '{path}'. HuggingFace models should start with " | ||
"'https://huggingface.co/'." | ||
) | ||
# get repo_id and filename | ||
split_path = path.split("/") | ||
repo_id = f"{split_path[3]}/{split_path[4]}" # org/repo | ||
filename = "" | ||
for i in range(5, len(split_path)): | ||
filename += split_path[i] + "/" | ||
filename = filename[:-1] | ||
path = hf_hub_download(repo_id, filename, token=kwargs["huggingface_api_token"]) | ||
# make sure to copy the checkpoint to the current directory | ||
shutil.copy(path, Path.cwd() / str(path).split("/")[-1]) | ||
|
||
elif urlparse(str(path)).scheme: | ||
path, _ = urlretrieve(str(path)) | ||
# make sure to copy the checkpoint to the current directory | ||
shutil.copy(path, Path.cwd() / str(path).split("/")[-1]) | ||
Comment on lines
+137
to
+138
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont understand why this is required? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And if you want the file to be directly downloaded to the current directory you can do path, _ = urlretrieve(
url=str(path),
filename=str(Path.cwd() / str(path).split("/")[-1])) See https://docs.python.org/3/library/urllib.request.html#urllib.request.urlretrieve. |
||
|
||
if is_exported_file(str(path)): | ||
return load_atomistic_model( | ||
str(path), extensions_directory=extensions_directory | ||
) | ||
else: | ||
pass | ||
|
||
path = str(path) | ||
if is_exported_file(path): | ||
return load_atomistic_model(path, extensions_directory=extensions_directory) | ||
else: # model is a checkpoint | ||
if architecture_name is None: | ||
raise ValueError( | ||
|
@@ -117,7 +152,7 @@ | |
architecture = import_architecture(architecture_name) | ||
|
||
try: | ||
return architecture.__model__.load_checkpoint(str(path)) | ||
return architecture.__model__.load_checkpoint(path) | ||
except Exception as err: | ||
raise ValueError( | ||
f"path '{path}' is not a valid model file for the {architecture_name} " | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.