Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use TNT's ManifoldPathHandler for listing checkpoints internally #938

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,18 +1422,10 @@ def test_does_checkpoint_metadata_exist(self) -> None:
dirpath = os.path.join(temp_dir, "checkpoint")
Snapshot.take(dirpath, app_state=app_state)

self.assertTrue(
CheckpointManager.does_checkpoint_metadata_exist(
dirpath, SNAPSHOT_METADATA_FNAME
)
)
self.assertTrue(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME))

os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
self.assertFalse(
CheckpointManager.does_checkpoint_metadata_exist(
dirpath, SNAPSHOT_METADATA_FNAME
)
)
self.assertFalse(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME))

def test_does_checkpoint_exist(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
14 changes: 13 additions & 1 deletion torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from datetime import timedelta
from typing import Any, cast, Iterable, List, Literal, Optional, Union

import fsspec

import torch.distributed as dist
from pyre_extensions import none_throws
from torchtnt.framework.callback import Callback
Expand Down Expand Up @@ -449,6 +451,7 @@ def restore_from_latest(
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Any,
) -> bool:
"""
Expand All @@ -463,12 +466,17 @@ def restore_from_latest(
train_dataloader: An optional train dataloader to restore.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.

Returns:
True if the latest checkpoint directory was found and successfully restored, otherwise False.
"""
path = get_latest_checkpoint_path(
dirpath, metadata_fname=cls.metadata_fnames, process_group=process_group
dirpath,
metadata_fname=cls.metadata_fnames,
process_group=process_group,
file_system=file_system,
)
if path is None:
logger.info(
Expand Down Expand Up @@ -497,6 +505,7 @@ def restore_from_best(
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Any,
) -> bool:
"""
Expand All @@ -512,6 +521,8 @@ def restore_from_best(
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
train_dataloader: An optional train dataloader to restore.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
restore_options: Controls what to filter when restoring the state.

Returns:
Expand All @@ -522,6 +533,7 @@ def restore_from_best(
metric_name=metric_name,
mode=mode,
metadata_fname=cls.metadata_fnames,
file_system=file_system,
process_group=process_group,
)

Expand Down
54 changes: 42 additions & 12 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def __init__(
keep_last_n_checkpoints: Optional[int] = None,
metadata_fnames: Optional[List[str]] = None,
process_group: Optional[dist.ProcessGroup] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
) -> None:
"""
Initialize a checkpoint manager. If a `keep_last_n_checkpoints` value is provided, this will read the
Expand All @@ -389,6 +390,11 @@ def __init__(
self._keep_last_n_checkpoints = keep_last_n_checkpoints
self._pg_wrapper = PGWrapper(process_group)

if file_system is None:
file_system, _ = url_to_fs(self.dirpath)

self._file_system: fsspec.AbstractFileSystem = file_system

if metadata_fnames is None:
self._metadata_fnames: List[str] = []
else:
Expand Down Expand Up @@ -568,17 +574,16 @@ def does_checkpoint_exist(
ckpt.path, self._metadata_fnames, process_group=process_group
)

@staticmethod
def does_checkpoint_metadata_exist(
self,
checkpoint_path: str,
metadata_fname: str,
) -> bool:
"""
Checking whether a checkpoint metadata file exists in the directory.
If the checkpointer has that metadata file, this function will returns True. Returns False otherwise.
"""
fs, _ = url_to_fs(checkpoint_path)
return _metadata_exists(fs, checkpoint_path, metadata_fname)
return _metadata_exists(self._file_system, checkpoint_path, metadata_fname)

@staticmethod
@rank_zero_read_and_broadcast
Expand All @@ -596,9 +601,8 @@ def remove_checkpoint(self) -> None:
"""
worst_ckpt_path = self._ckpt_paths.pop(0)
if self._pg_wrapper.get_rank() == 0:
fs, _ = url_to_fs(self.dirpath)
try:
fs.rm(worst_ckpt_path.path, recursive=True)
self._file_system.rm(worst_ckpt_path.path, recursive=True)
except Exception as exc:
logger.error(
(
Expand All @@ -612,6 +616,7 @@ def remove_checkpoint(self) -> None:
def does_checkpoint_exist(
ckpt_path: str,
metadata_fname: Union[str, List[str]],
file_system: Optional[fsspec.AbstractFileSystem] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> bool:
"""
Expand All @@ -622,6 +627,8 @@ def does_checkpoint_exist(
Args:
ckpt: The checkpoint to check.
metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used.
"""
if not metadata_fname:
Expand All @@ -631,14 +638,18 @@ def does_checkpoint_exist(
[metadata_fname] if isinstance(metadata_fname, str) else metadata_fname
)

fs, _ = url_to_fs(ckpt_path)
fs = file_system
if fs is None:
fs, _ = url_to_fs(ckpt_path)

return any(_metadata_exists(fs, ckpt_path, fname) for fname in metadata_fnames)


@rank_zero_read_and_broadcast
def get_latest_checkpoint_path(
dirpath: str,
metadata_fname: Optional[Union[str, List[str]]] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Optional[str]:
"""
Expand All @@ -648,6 +659,8 @@ def get_latest_checkpoint_path(
dirpath: parent directory where checkpoints are saved.
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
If a list is provided, it will check that at least one of the files is present.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)

Raises:
Expand All @@ -658,14 +671,17 @@ def get_latest_checkpoint_path(
gloo process groups are recommended over nccl.
"""

return _get_latest_checkpoint_path(dirpath, metadata_fname)
return _get_latest_checkpoint_path(dirpath, metadata_fname, file_system)


def _get_latest_checkpoint_path(
dirpath: str,
metadata_fname: Optional[Union[str, List[str]]] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
) -> Optional[str]:
candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname)
candidate_dirpaths = _retrieve_checkpoint_dirpaths(
dirpath, metadata_fname, file_system=file_system
)
if not candidate_dirpaths:
return None

Expand All @@ -683,6 +699,7 @@ def get_best_checkpoint_path(
metric_name: str,
mode: Literal["min", "max"],
metadata_fname: Optional[Union[str, List[str]]] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Optional[str]:
"""
Expand All @@ -697,14 +714,18 @@ def get_best_checkpoint_path(
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
If a list is provided, it will check that at least one of the files is present.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)

Note:
When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks.
gloo process groups are recommended over nccl.
"""

dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
dirpaths = _retrieve_checkpoint_dirpaths(
dirpath, metadata_fname, metric_name, file_system=file_system
)
if not dirpaths:
return None

Expand All @@ -721,6 +742,7 @@ def get_checkpoint_dirpaths(
dirpath: str,
metadata_fname: Optional[Union[str, List[str]]] = None,
metric_name: Optional[str] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> List[CheckpointPath]:
"""
Expand All @@ -736,20 +758,25 @@ def get_checkpoint_dirpaths(
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
If a list is provided, it will check that at least one of the files is present.
metric_name: fetches all the checkpoint directories containing the metric name only.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)

Note:
When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks.
gloo process groups are recommended over nccl.
"""

return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
return _retrieve_checkpoint_dirpaths(
dirpath, metadata_fname, metric_name, file_system=file_system
)


def _retrieve_checkpoint_dirpaths(
dirpath: str,
metadata_fname: Optional[Union[str, List[str]]],
metric_name: Optional[str] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
) -> List[CheckpointPath]:
"""
Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories
Expand All @@ -759,9 +786,12 @@ def _retrieve_checkpoint_dirpaths(
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
If a list is provided, it will check that at least one of the files is present.
metric_name: Name of the metric that must exist in checkpoint name.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
"""

fs, _ = url_to_fs(dirpath)
fs = file_system
if fs is None:
fs, _ = url_to_fs(dirpath)

if not fs.exists(dirpath):
logger.warning(f"Input dirpath doesn't exist: {dirpath}")
Expand Down
Loading