Skip to content

Commit

Permalink
Merge pull request #4 from janelia-cellmap/load_checkpoints
Browse files Browse the repository at this point in the history
Load checkpoints
  • Loading branch information
rhoadesScholar authored Mar 18, 2024
2 parents 3743cbe + 84be8fe commit 95b791f
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/cellmap_models/pytorch/cellpose/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .add_model import add_model
from .load_model import load_model
from .get_model import get_model
from .download_checkpoint import download_checkpoint

models_dict = {
"jrc_mus-epididymis-1_nuc_cp": "https://github.com/janelia-cellmap/cellmap-models/releases/download/2024.03.08/jrc_mus-epididymis-1_nuc_cp",
Expand Down
30 changes: 30 additions & 0 deletions src/cellmap_models/pytorch/cellpose/download_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path
from cellmap_models import download_url_to_file


def download_checkpoint(checkpoint_name: str, checkpoint_path: Path):
"""
download models checkpoint from GitHub release resources.
Args:
checkpoint_name (str): Name of the checkpoint file.
local_folder (Path): Local path to save the checkpoint.
return:
checkpoint_path (Path): Path to the downloaded checkpoint.
"""
from . import models_dict, models_list # avoid circular import

# Make sure the checkpoint exists
if checkpoint_name not in models_list:
raise ValueError(
f"Checkpoint {checkpoint_name} not found. Available checkpoints: {models_list}"
)

if not checkpoint_path.exists():
url = models_dict[checkpoint_name]
print(f"Downloading {checkpoint_name} from {url}")
download_url_to_file(url, checkpoint_path)
else:
print(f"Checkpoint {checkpoint_name} found at {checkpoint_path}")

return checkpoint_path
Binary file removed src/cellmap_models/pytorch/cosem/.DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions src/cellmap_models/pytorch/cosem/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .load_model import load_model
from .download_checkpoint import download_checkpoint

models_dict = {
"setup04/1820500": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/1820500",
Expand Down
30 changes: 30 additions & 0 deletions src/cellmap_models/pytorch/cosem/download_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path
from cellmap_models import download_url_to_file


def download_checkpoint(checkpoint_name: str, checkpoint_path: Path):
"""
download models checkpoint from s3 bucket.
Args:
checkpoint_name (str): Name of the checkpoint file.
local_folder (Path): Local path to save the checkpoint.
return:
checkpoint_path (Path): Path to the downloaded checkpoint.
"""
from . import models_dict, models_list # avoid circular import

# Make sure the checkpoint exists
if checkpoint_name not in models_list:
raise ValueError(
f"Checkpoint {checkpoint_name} not found. Available checkpoints: {models_list}"
)

if not checkpoint_path.exists():
url = models_dict[checkpoint_name]
print(f"Downloading {checkpoint_name} from {url}")
download_url_to_file(url, checkpoint_path)
else:
print(f"Checkpoint {checkpoint_name} found at {checkpoint_path}")

return checkpoint_path

0 comments on commit 95b791f

Please sign in to comment.