From acb55300a5db28f98c0579d52802c7035f58d533 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 17 Jul 2024 23:36:40 -0700 Subject: [PATCH] Add LoadPlanner and SavePlanner registries (#1358) --- llmfoundry/command_utils/train.py | 27 +++++++++++++++++++++ llmfoundry/registry.py | 39 +++++++++++++++++++++++++++++++ llmfoundry/utils/builders.py | 39 +++++++++++++++++++++++++++++++ tests/test_registry.py | 2 ++ tests/utils/test_builders.py | 35 +++++++++++++++++++++++++++ 5 files changed, 142 insertions(+) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index f49fb28801..feed1e9fb1 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -36,8 +36,10 @@ build_callback, build_composer_model, build_evaluators, + build_load_planner, build_logger, build_optimizer, + build_save_planner, build_scheduler, build_tokenizer, ) @@ -256,6 +258,31 @@ def train(cfg: DictConfig) -> Trainer: # Optional fsdp data, fine-tuning, and eval configs fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config + if fsdp_config is not None: + if 'load_planner' in fsdp_config: + load_planners = fsdp_config['load_planner'].items() + if len(load_planners) > 1: + raise ValueError( + 'Only one load planner can be specified in the config.', + ) + load_planner_name, load_planner_config = load_planners[0] + fsdp_config['load_planner'] = build_load_planner( + load_planner_name, + **load_planner_config, + ) + + if 'save_planner' in fsdp_config: + save_planners = fsdp_config['save_planner'].items() + if len(save_planners) > 1: + raise ValueError( + 'Only one save planner can be specified in the config.', + ) + save_planner_name, save_planner_config = save_planners[0] + fsdp_config['save_planner'] = build_save_planner( + save_planner_name, + **save_planner_config, + ) + eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 50481211ac..e31840d3fb 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -6,6 +6,7 @@ from composer.loggers import LoggerDestination from composer.models import ComposerModel from composer.optim import ComposerScheduler +from torch.distributed.checkpoint import LoadPlanner, SavePlanner from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataloader from torch.utils.data import Dataset @@ -339,6 +340,42 @@ description=_config_transforms_description, ) +_load_planners_description = ( + """The load_planners registry is used to register classes that implement the LoadPlanner interface. + + The LoadPlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to load distributed checkpoints. + + Returns: + LoadPlanner: The load planner. + """ +) + +load_planners = create_registry( + 'llmfoundry', + 'load_planners', + generic_type=Type[LoadPlanner], + entry_points=True, + description=_load_planners_description, +) + +_save_planners_description = ( + """The save_planners registry is used to register classes that implement the SavePlanner interface. + + The savePlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to save distributed checkpoints. + + Returns: + SavePlanner: The save planner. + """ +) + +save_planners = create_registry( + 'llmfoundry', + 'save_planners', + generic_type=Type[SavePlanner], + entry_points=True, + description=_save_planners_description, +) + __all__ = [ 'loggers', 'callbacks', @@ -363,4 +400,6 @@ 'fcs', 'icl_datasets', 'config_transforms', + 'load_planners', + 'save_planners', ] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 012a0b704f..0437736f74 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -27,6 +27,7 @@ from composer.utils import dist from omegaconf import DictConfig from omegaconf import OmegaConf as om +from torch.distributed.checkpoint import LoadPlanner, SavePlanner from torch.optim.optimizer import Optimizer from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -187,6 +188,44 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb +def build_load_planner(name: str, **kwargs: Any) -> LoadPlanner: + """Builds a load planner from the registry. + + Args: + name: Name of the load planner to build. + + Returns: + LoadPlanner: The load planner. + """ + return construct_from_registry( + name=name, + registry=registry.load_planners, + partial_function=True, + pre_validation_function=LoadPlanner, + post_validation_function=None, + kwargs=kwargs, + ) + + +def build_save_planner(name: str, **kwargs: Any) -> SavePlanner: + """Builds a save planner from the registry. + + Args: + name: Name of the save planner to build. + + Returns: + savePlanner: The save planner. + """ + return construct_from_registry( + name=name, + registry=registry.save_planners, + partial_function=True, + pre_validation_function=SavePlanner, + post_validation_function=None, + kwargs=kwargs, + ) + + def build_composer_model( name: str, cfg: Dict[str, Any], diff --git a/tests/test_registry.py b/tests/test_registry.py index 7ee95442c8..aa0c93ee13 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -44,6 +44,8 @@ def test_expected_registries_exist(): 'fcs', 'icl_datasets', 'config_transforms', + 'load_planners', + 'save_planners', } assert existing_registries == expected_registry_names diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index dfcb5b327c..fb6cb0c5df 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -13,17 +13,24 @@ from composer.callbacks import Generate from composer.core import Evaluator from composer.loggers import WandBLogger +from torch.distributed.checkpoint.default_planner import ( + DefaultLoadPlanner, + DefaultSavePlanner, +) from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer +from llmfoundry.registry import load_planners, save_planners from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import ( add_metrics_to_eval_loaders, build_callback, build_eval_loaders, build_evaluators, + build_load_planner, build_logger, build_optimizer, + build_save_planner, build_tokenizer, ) @@ -345,6 +352,34 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): assert eval_loaders2[1].metric_names == [] +def test_build_load_planner(): + # Dummy LoadPlanner for testing + class DummyLoadPlanner(DefaultLoadPlanner): + + def __init__(self, is_test: bool): + self.is_test = is_test + + load_planners.register('dummy', func=DummyLoadPlanner) + load_planner = build_load_planner('dummy', is_test=True) + + assert isinstance(load_planner, DummyLoadPlanner) + assert load_planner.is_test is True + + +def test_build_save_planner(): + # Dummy SavePlanner for testing + class DummySavePlanner(DefaultSavePlanner): + + def __init__(self, is_test: bool): + self.is_test = is_test + + save_planners.register('dummy', func=DummySavePlanner) + save_planner = build_save_planner('dummy', is_test=True) + + assert isinstance(save_planner, DummySavePlanner) + assert save_planner.is_test is True + + def test_add_metrics_to_eval_loaders(): evaluators = [ Evaluator(