Skip to content

Commit

Permalink
Make sure remote checkpoints are also saved to the current directory
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 14, 2024
1 parent 918efa1 commit a18fd37
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import shutil
import warnings
from pathlib import Path
from typing import Any, Optional, Union
Expand Down Expand Up @@ -128,9 +129,13 @@ def load_model(
filename += split_path[i] + "/"
filename = filename[:-1]
path = hf_hub_download(repo_id, filename, token=kwargs["huggingface_api_token"])

Check warning on line 131 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L125-L131

Added lines #L125 - L131 were not covered by tests
# make sure to copy the checkpoint to the current directory
shutil.copy(path, Path.cwd() / str(path).split("/")[-1])

Check warning on line 133 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L133

Added line #L133 was not covered by tests

elif urlparse(str(path)).scheme:
path, _ = urlretrieve(str(path))
# make sure to copy the checkpoint to the current directory
shutil.copy(path, Path.cwd() / str(path).split("/")[-1])

else:
pass
Expand Down
3 changes: 3 additions & 0 deletions tests/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def test_is_exported_file():
def test_load_model_checkpoint(path):
model = load_model(path, architecture_name="experimental.soap_bpnn")
assert type(model) is SoapBpnn
if str(path).startswith("file:"):
# test that the checkpoint is also copied to the current directory
assert Path("model-32-bit.ckpt").exists()


@pytest.mark.parametrize(
Expand Down

0 comments on commit a18fd37

Please sign in to comment.