From 21cf6aa3e7f3b2c2305ae02ec2971dad09244c0d Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Mon, 4 Nov 2024 14:17:48 -0800 Subject: [PATCH] Use TNT's ManifoldPathHandler for listing checkpoints internally Summary: We've faced multiple issues in the past where users register incompatible implementations of Manifold path handlers to fsspec, causing errors when listing and even loading the checkpoints. Currently [one user is facing an error because of this](https://fb.workplace.com/groups/277527419809135/permalink/1625534775008386/), and it's tricky to debug because the error is not reproducible outside of that particular project (because of the specific dependencies being used) Most internal customers should be using storage optimizations and then use modelstore components within DCP APIs, but we still use fsspec for listing latest and best checkpoints. We need to make sure that our own implementation is used to list the checkpoints. Note we can't modify directly on checkpoint utils because it's OSS, while filesystem is internal. Differential Revision: D65370757 --- tests/utils/test_checkpoint.py | 12 +---- .../framework/callbacks/base_checkpointer.py | 14 ++++- torchtnt/utils/checkpoint.py | 54 ++++++++++++++----- 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index cc2a7c0c20..cba800a986 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -1422,18 +1422,10 @@ def test_does_checkpoint_metadata_exist(self) -> None: dirpath = os.path.join(temp_dir, "checkpoint") Snapshot.take(dirpath, app_state=app_state) - self.assertTrue( - CheckpointManager.does_checkpoint_metadata_exist( - dirpath, SNAPSHOT_METADATA_FNAME - ) - ) + self.assertTrue(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME)) os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) - self.assertFalse( - CheckpointManager.does_checkpoint_metadata_exist( - dirpath, SNAPSHOT_METADATA_FNAME - ) - ) + self.assertFalse(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME)) def test_does_checkpoint_exist(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 43f3657551..3a77fd11fa 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -12,6 +12,8 @@ from datetime import timedelta from typing import Any, cast, Iterable, List, Literal, Optional, Union +import fsspec + import torch.distributed as dist from pyre_extensions import none_throws from torchtnt.framework.callback import Callback @@ -449,6 +451,7 @@ def restore_from_latest( train_dataloader: Optional[Iterable[TTrainData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, **kwargs: Any, ) -> bool: """ @@ -463,12 +466,17 @@ def restore_from_latest( train_dataloader: An optional train dataloader to restore. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) restore_options: Controls what to filter when restoring the state. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. Returns: True if the latest checkpoint directory was found and successfully restored, otherwise False. """ path = get_latest_checkpoint_path( - dirpath, metadata_fname=cls.metadata_fnames, process_group=process_group + dirpath, + metadata_fname=cls.metadata_fnames, + process_group=process_group, + file_system=file_system, ) if path is None: logger.info( @@ -497,6 +505,7 @@ def restore_from_best( train_dataloader: Optional[Iterable[TTrainData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, **kwargs: Any, ) -> bool: """ @@ -512,6 +521,8 @@ def restore_from_best( mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. train_dataloader: An optional train dataloader to restore. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. restore_options: Controls what to filter when restoring the state. Returns: @@ -522,6 +533,7 @@ def restore_from_best( metric_name=metric_name, mode=mode, metadata_fname=cls.metadata_fnames, + file_system=file_system, process_group=process_group, ) diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index b6e1f962ca..c92c4e00ee 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -366,6 +366,7 @@ def __init__( keep_last_n_checkpoints: Optional[int] = None, metadata_fnames: Optional[List[str]] = None, process_group: Optional[dist.ProcessGroup] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, ) -> None: """ Initialize a checkpoint manager. If a `keep_last_n_checkpoints` value is provided, this will read the @@ -389,6 +390,11 @@ def __init__( self._keep_last_n_checkpoints = keep_last_n_checkpoints self._pg_wrapper = PGWrapper(process_group) + if file_system is None: + file_system, _ = url_to_fs(self.dirpath) + + self._file_system: fsspec.AbstractFileSystem = file_system + if metadata_fnames is None: self._metadata_fnames: List[str] = [] else: @@ -568,8 +574,8 @@ def does_checkpoint_exist( ckpt.path, self._metadata_fnames, process_group=process_group ) - @staticmethod def does_checkpoint_metadata_exist( + self, checkpoint_path: str, metadata_fname: str, ) -> bool: @@ -577,8 +583,7 @@ def does_checkpoint_metadata_exist( Checking whether a checkpoint metadata file exists in the directory. If the checkpointer has that metadata file, this function will returns True. Returns False otherwise. """ - fs, _ = url_to_fs(checkpoint_path) - return _metadata_exists(fs, checkpoint_path, metadata_fname) + return _metadata_exists(self._file_system, checkpoint_path, metadata_fname) @staticmethod @rank_zero_read_and_broadcast @@ -596,9 +601,8 @@ def remove_checkpoint(self) -> None: """ worst_ckpt_path = self._ckpt_paths.pop(0) if self._pg_wrapper.get_rank() == 0: - fs, _ = url_to_fs(self.dirpath) try: - fs.rm(worst_ckpt_path.path, recursive=True) + self._file_system.rm(worst_ckpt_path.path, recursive=True) except Exception as exc: logger.error( ( @@ -612,6 +616,7 @@ def remove_checkpoint(self) -> None: def does_checkpoint_exist( ckpt_path: str, metadata_fname: Union[str, List[str]], + file_system: Optional[fsspec.AbstractFileSystem] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> bool: """ @@ -622,6 +627,8 @@ def does_checkpoint_exist( Args: ckpt: The checkpoint to check. metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used. """ if not metadata_fname: @@ -631,7 +638,10 @@ def does_checkpoint_exist( [metadata_fname] if isinstance(metadata_fname, str) else metadata_fname ) - fs, _ = url_to_fs(ckpt_path) + fs = file_system + if fs is None: + fs, _ = url_to_fs(ckpt_path) + return any(_metadata_exists(fs, ckpt_path, fname) for fname in metadata_fnames) @@ -639,6 +649,7 @@ def does_checkpoint_exist( def get_latest_checkpoint_path( dirpath: str, metadata_fname: Optional[Union[str, List[str]]] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> Optional[str]: """ @@ -648,6 +659,8 @@ def get_latest_checkpoint_path( dirpath: parent directory where checkpoints are saved. metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. If a list is provided, it will check that at least one of the files is present. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) Raises: @@ -658,14 +671,17 @@ def get_latest_checkpoint_path( gloo process groups are recommended over nccl. """ - return _get_latest_checkpoint_path(dirpath, metadata_fname) + return _get_latest_checkpoint_path(dirpath, metadata_fname, file_system) def _get_latest_checkpoint_path( dirpath: str, metadata_fname: Optional[Union[str, List[str]]] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, ) -> Optional[str]: - candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname) + candidate_dirpaths = _retrieve_checkpoint_dirpaths( + dirpath, metadata_fname, file_system=file_system + ) if not candidate_dirpaths: return None @@ -683,6 +699,7 @@ def get_best_checkpoint_path( metric_name: str, mode: Literal["min", "max"], metadata_fname: Optional[Union[str, List[str]]] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> Optional[str]: """ @@ -697,6 +714,8 @@ def get_best_checkpoint_path( mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. If a list is provided, it will check that at least one of the files is present. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) Note: @@ -704,7 +723,9 @@ def get_best_checkpoint_path( gloo process groups are recommended over nccl. """ - dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) + dirpaths = _retrieve_checkpoint_dirpaths( + dirpath, metadata_fname, metric_name, file_system=file_system + ) if not dirpaths: return None @@ -721,6 +742,7 @@ def get_checkpoint_dirpaths( dirpath: str, metadata_fname: Optional[Union[str, List[str]]] = None, metric_name: Optional[str] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> List[CheckpointPath]: """ @@ -736,6 +758,8 @@ def get_checkpoint_dirpaths( metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. If a list is provided, it will check that at least one of the files is present. metric_name: fetches all the checkpoint directories containing the metric name only. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) Note: @@ -743,13 +767,16 @@ def get_checkpoint_dirpaths( gloo process groups are recommended over nccl. """ - return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) + return _retrieve_checkpoint_dirpaths( + dirpath, metadata_fname, metric_name, file_system=file_system + ) def _retrieve_checkpoint_dirpaths( dirpath: str, metadata_fname: Optional[Union[str, List[str]]], metric_name: Optional[str] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, ) -> List[CheckpointPath]: """ Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories @@ -759,9 +786,12 @@ def _retrieve_checkpoint_dirpaths( metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. If a list is provided, it will check that at least one of the files is present. metric_name: Name of the metric that must exist in checkpoint name. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. """ - - fs, _ = url_to_fs(dirpath) + fs = file_system + if fs is None: + fs, _ = url_to_fs(dirpath) if not fs.exists(dirpath): logger.warning(f"Input dirpath doesn't exist: {dirpath}")