Skip to content

Commit

Permalink
Add LoadPlanner and SavePlanner registries (#1358)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Jul 18, 2024
1 parent 006f251 commit acb5530
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 0 deletions.
27 changes: 27 additions & 0 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
build_callback,
build_composer_model,
build_evaluators,
build_load_planner,
build_logger,
build_optimizer,
build_save_planner,
build_scheduler,
build_tokenizer,
)
Expand Down Expand Up @@ -256,6 +258,31 @@ def train(cfg: DictConfig) -> Trainer:
# Optional fsdp data, fine-tuning, and eval configs
fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config

if fsdp_config is not None:
if 'load_planner' in fsdp_config:
load_planners = fsdp_config['load_planner'].items()
if len(load_planners) > 1:
raise ValueError(
'Only one load planner can be specified in the config.',
)
load_planner_name, load_planner_config = load_planners[0]
fsdp_config['load_planner'] = build_load_planner(
load_planner_name,
**load_planner_config,
)

if 'save_planner' in fsdp_config:
save_planners = fsdp_config['save_planner'].items()
if len(save_planners) > 1:
raise ValueError(
'Only one save planner can be specified in the config.',
)
save_planner_name, save_planner_config = save_planners[0]
fsdp_config['save_planner'] = build_save_planner(
save_planner_name,
**save_planner_config,
)

eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders
icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str
eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str
Expand Down
39 changes: 39 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from composer.loggers import LoggerDestination
from composer.models import ComposerModel
from composer.optim import ComposerScheduler
from torch.distributed.checkpoint import LoadPlanner, SavePlanner
from torch.optim import Optimizer
from torch.utils.data import DataLoader as TorchDataloader
from torch.utils.data import Dataset
Expand Down Expand Up @@ -339,6 +340,42 @@
description=_config_transforms_description,
)

_load_planners_description = (
"""The load_planners registry is used to register classes that implement the LoadPlanner interface.
The LoadPlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to load distributed checkpoints.
Returns:
LoadPlanner: The load planner.
"""
)

load_planners = create_registry(
'llmfoundry',
'load_planners',
generic_type=Type[LoadPlanner],
entry_points=True,
description=_load_planners_description,
)

_save_planners_description = (
"""The save_planners registry is used to register classes that implement the SavePlanner interface.
The savePlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to save distributed checkpoints.
Returns:
SavePlanner: The save planner.
"""
)

save_planners = create_registry(
'llmfoundry',
'save_planners',
generic_type=Type[SavePlanner],
entry_points=True,
description=_save_planners_description,
)

__all__ = [
'loggers',
'callbacks',
Expand All @@ -363,4 +400,6 @@
'fcs',
'icl_datasets',
'config_transforms',
'load_planners',
'save_planners',
]
39 changes: 39 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from composer.utils import dist
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.distributed.checkpoint import LoadPlanner, SavePlanner
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase
Expand Down Expand Up @@ -187,6 +188,44 @@ def build_icl_data_and_gauntlet(
return icl_evaluators, logger_keys, eval_gauntlet_cb


def build_load_planner(name: str, **kwargs: Any) -> LoadPlanner:
"""Builds a load planner from the registry.
Args:
name: Name of the load planner to build.
Returns:
LoadPlanner: The load planner.
"""
return construct_from_registry(
name=name,
registry=registry.load_planners,
partial_function=True,
pre_validation_function=LoadPlanner,
post_validation_function=None,
kwargs=kwargs,
)


def build_save_planner(name: str, **kwargs: Any) -> SavePlanner:
"""Builds a save planner from the registry.
Args:
name: Name of the save planner to build.
Returns:
savePlanner: The save planner.
"""
return construct_from_registry(
name=name,
registry=registry.save_planners,
partial_function=True,
pre_validation_function=SavePlanner,
post_validation_function=None,
kwargs=kwargs,
)


def build_composer_model(
name: str,
cfg: Dict[str, Any],
Expand Down
2 changes: 2 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def test_expected_registries_exist():
'fcs',
'icl_datasets',
'config_transforms',
'load_planners',
'save_planners',
}

assert existing_registries == expected_registry_names
Expand Down
35 changes: 35 additions & 0 deletions tests/utils/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,24 @@
from composer.callbacks import Generate
from composer.core import Evaluator
from composer.loggers import WandBLogger
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from transformers import PreTrainedTokenizerBase

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.registry import load_planners, save_planners
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.builders import (
add_metrics_to_eval_loaders,
build_callback,
build_eval_loaders,
build_evaluators,
build_load_planner,
build_logger,
build_optimizer,
build_save_planner,
build_tokenizer,
)

Expand Down Expand Up @@ -345,6 +352,34 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch):
assert eval_loaders2[1].metric_names == []


def test_build_load_planner():
# Dummy LoadPlanner for testing
class DummyLoadPlanner(DefaultLoadPlanner):

def __init__(self, is_test: bool):
self.is_test = is_test

load_planners.register('dummy', func=DummyLoadPlanner)
load_planner = build_load_planner('dummy', is_test=True)

assert isinstance(load_planner, DummyLoadPlanner)
assert load_planner.is_test is True


def test_build_save_planner():
# Dummy SavePlanner for testing
class DummySavePlanner(DefaultSavePlanner):

def __init__(self, is_test: bool):
self.is_test = is_test

save_planners.register('dummy', func=DummySavePlanner)
save_planner = build_save_planner('dummy', is_test=True)

assert isinstance(save_planner, DummySavePlanner)
assert save_planner.is_test is True


def test_add_metrics_to_eval_loaders():
evaluators = [
Evaluator(
Expand Down

0 comments on commit acb5530

Please sign in to comment.