Skip to content

Commit

Permalink
Merge branch 'main' into nancy/wrap-uc-upload
Browse files Browse the repository at this point in the history
  • Loading branch information
nancyhung committed Nov 12, 2024
2 parents b39ccfb + 18da725 commit 7502fc0
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 52 deletions.
46 changes: 37 additions & 9 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler

from composer.utils import dist, ensure_tuple
from composer.utils import VersionedDeprecationWarning, dist, ensure_tuple

if TYPE_CHECKING:
from composer.core.types import Batch
Expand Down Expand Up @@ -126,16 +126,16 @@ def _default_split_batch(batch: Any, microbatch_size: Union[int, float]) -> Sequ
class DataSpec:
"""Specifications for operating and training on data.
An example of constructing a :class:`DataSpec` object with a ``device_transforms``
An example of constructing a :class:`DataSpec` object with a ``batch_transforms``
callable and then using it with :class:`~.Trainer`:
.. doctest::
>>> # Construct DataSpec and subtract mean from the batch
>>> device_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys)
>>> train_dspec = DataSpec(train_dataloader, device_transforms=device_transform_fn)
>>> batch_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys)
>>> train_dspec = DataSpec(train_dataloader, batch_transforms=batch_transform_fn)
>>> # The same function can be used for eval dataloader as well
>>> eval_dspec = DataSpec(eval_dataloader, device_transforms=device_transform_fn)
>>> eval_dspec = DataSpec(eval_dataloader, batch_transforms=batch_transform_fn)
>>> # Use this DataSpec object to construct trainer
>>> trainer = Trainer(
... model=model,
Expand All @@ -155,11 +155,20 @@ class DataSpec:
num_tokens (int, optional): The total number of tokens in an epoch. This field is used by the
:class:`.Timestamp` (training progress tracker).
device_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
batch once it has been moved onto the device. For example, this function can be used for GPU-based
device_transforms ((Batch) -> Batch, optional): Deprecated argument. Please use ``batch_transforms`` for batch
level transformations on CPU and ``microbatch_transforms`` for microbatch level transformations on target
device.
batch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
batch before it is moved onto the device. For example, this function can be used for CPU-based
normalization. It can modify the batch in-place, and it should return the modified batch. If not specified,
the batch is not modified.
microbatch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
microbatch before it is moved onto the device. For example, this function can be used for GPU-based
normalization. It can modify the microbatch in-place, and it should return the modified microbatch. If not
specified, the microbatch is not modified.
split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to
split a batch (the first parameter) into microbatches of a given size (the second parameter). If
the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, tuple, or list, then
Expand All @@ -186,13 +195,32 @@ def __init__(
num_samples: Optional[int] = None,
num_tokens: Optional[int] = None,
device_transforms: Optional[Callable[[Batch], Batch]] = None,
batch_transforms: Optional[Callable[[Batch], Batch]] = None,
microbatch_transforms: Optional[Callable[[Batch], Batch]] = None,
split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None,
get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], Union[int, dict[str, int]]]] = None,
) -> None:
self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader
self.num_tokens = num_tokens
self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms
if device_transforms is not None:
if batch_transforms is not None:
raise ValueError(
'Cannot specify both `device_transforms` and `batch_transforms`. Please use `batch_transforms` for '
'batch level transformations on CPU and `microbatch_transforms` for microbatch level transformations '
'on target device.',
)
warnings.warn(
VersionedDeprecationWarning(
'The `device_transforms` argument is deprecated. Please use `batch_transforms` for batch level '
'transformations on CPU and `microbatch_transforms` for microbatch level transformations on target '
'device.',
'v0.29.0',
),
)
self.batch_transforms = device_transforms
self.batch_transforms = self._default_transforms if batch_transforms is None else batch_transforms
self.microbatch_transforms = self._default_transforms if microbatch_transforms is None else microbatch_transforms
self.split_batch = default_split_batch if split_batch is None else split_batch
self.get_num_samples_in_batch = self._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch
self._get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch
Expand Down Expand Up @@ -242,7 +270,7 @@ def __init__(
'For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.',
)

def _default_device_transforms(self, batch: Batch):
def _default_transforms(self, batch: Batch):
return batch

def _default_get_num_samples_in_batch(self, batch: Batch) -> int:
Expand Down
2 changes: 1 addition & 1 deletion composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ def unshard_with_sync(self):

if version.parse(torch.__version__) >= version.parse('2.5.0') and version.parse(
torch.__version__,
) < version.parse('2.5.1'):
) < version.parse('2.5.2'):

# Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks
from torch.distributed.fsdp._flat_param import FlatParamHandle
Expand Down
12 changes: 7 additions & 5 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2622,7 +2622,7 @@ def _train_loop(self) -> None:
self._rng_state = None
continue

self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
self.state.batch = self._train_data_spec.batch_transforms(self.state.batch)
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)

Expand Down Expand Up @@ -3034,6 +3034,7 @@ def _train_microbatches(

for microbatch_idx, self.state.batch in enumerate(microbatches):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = self._train_data_spec.microbatch_transforms(self.state.batch)
is_final_microbatch = microbatch_idx + 1 == len(microbatches)
microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)

Expand Down Expand Up @@ -3306,11 +3307,11 @@ def predict_batch_end(self, state: State, logger: Logger) -> None:
self.engine.run_event(Event.PREDICT_START)

for self.state.batch in self._iter_dataloader(TrainerMode.PREDICT):

# Move the batch onto the device
self.state.batch = data_spec.batch_transforms(self.state.batch)
self.state.batch = self.state.device.batch_to_device(self.state.batch)

# Perform any device transforms
self.state.batch = data_spec.device_transforms(self.state.batch)
self.state.batch = data_spec.microbatch_transforms(self.state.batch)

# Count the batch size and num tokens before any events run
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
Expand Down Expand Up @@ -3586,7 +3587,7 @@ def _eval_loop(
)

for self.state.batch in self._iter_dataloader(TrainerMode.EVAL):
self.state.batch = data_spec.device_transforms(self.state.batch)
self.state.batch = data_spec.batch_transforms(self.state.batch)

# Count the batch size and num tokens before any events run
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
Expand Down Expand Up @@ -3616,6 +3617,7 @@ def _eval_loop(
microbatches = data_spec.split_batch(device_batch, evaluator.device_eval_microbatch_size)
for i, self.state.batch in enumerate(microbatches):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = data_spec.microbatch_transforms(self.state.batch)
last_microbatch = i == len(microbatches) - 1
skip_metric_update = False
# Distributed samplers pad batches to be the same size. If using a
Expand Down
5 changes: 4 additions & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,10 @@ def dist_cp_load(
storage_reader: StorageReader,
load_planner: Optional[LoadPlanner] = None,
):
if version.parse(torch.__version__) >= version.parse('2.4.0'):
if (
version.parse(torch.__version__) >= version.parse('2.4.0') and
version.parse(torch.__version__) < version.parse('2.5.0')
):
from torch.distributed.checkpoint.utils import CheckpointException
try:
dist_cp.load(
Expand Down
6 changes: 3 additions & 3 deletions docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ To install composer, once inside the image, run `pip install mosaicml`.
<!-- BEGIN_PYTORCH_BUILD_MATRIX -->
| Linux Distro | Flavor | PyTorch Version | CUDA Version | Python Version | Docker Tags |
|----------------|----------|-------------------|---------------------|------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Ubuntu 20.04 | Base | 2.5.0 | 12.4.1 (Infiniband) | 3.11 | `mosaicml/pytorch:latest`, `mosaicml/pytorch:2.5.0_cu124-python3.11-ubuntu20.04` |
| Ubuntu 20.04 | Base | 2.5.0 | 12.4.1 (EFA) | 3.11 | `mosaicml/pytorch:latest-aws`, `mosaicml/pytorch:2.5.0_cu124-python3.11-ubuntu20.04-aws` |
| Ubuntu 20.04 | Base | 2.5.0 | cpu | 3.11 | `mosaicml/pytorch:latest_cpu`, `mosaicml/pytorch:2.5.0_cpu-python3.11-ubuntu20.04` |
| Ubuntu 20.04 | Base | 2.5.1 | 12.4.1 (Infiniband) | 3.11 | `mosaicml/pytorch:latest`, `mosaicml/pytorch:2.5.1_cu124-python3.11-ubuntu20.04` |
| Ubuntu 20.04 | Base | 2.5.1 | 12.4.1 (EFA) | 3.11 | `mosaicml/pytorch:latest-aws`, `mosaicml/pytorch:2.5.1_cu124-python3.11-ubuntu20.04-aws` |
| Ubuntu 20.04 | Base | 2.5.1 | cpu | 3.11 | `mosaicml/pytorch:latest_cpu`, `mosaicml/pytorch:2.5.1_cpu-python3.11-ubuntu20.04` |
| Ubuntu 20.04 | Base | 2.4.1 | 12.4.1 (Infiniband) | 3.11 | `mosaicml/pytorch:2.4.1_cu124-python3.11-ubuntu20.04` |
| Ubuntu 20.04 | Base | 2.4.1 | 12.4.1 (EFA) | 3.11 | `mosaicml/pytorch:2.4.1_cu124-python3.11-ubuntu20.04-aws` |
| Ubuntu 20.04 | Base | 2.4.1 | cpu | 3.11 | `mosaicml/pytorch:2.4.1_cpu-python3.11-ubuntu20.04` |
Expand Down
38 changes: 19 additions & 19 deletions docker/build_matrix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,54 @@
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: nvidia/cuda:12.4.1-cudnn-devel-ubuntu20.04
CUDA_VERSION: 12.4.1
IMAGE_NAME: torch-2-5-0-cu124
IMAGE_NAME: torch-2-5-1-cu124
MOFED_VERSION: latest-23.10
NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
PYTHON_VERSION: '3.11'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
PYTORCH_VERSION: 2.5.0
PYTORCH_VERSION: 2.5.1
TAGS:
- mosaicml/pytorch:2.5.0_cu124-python3.11-ubuntu20.04
- ghcr.io/databricks-mosaic/pytorch:2.5.0_cu124-python3.11-ubuntu20.04
- mosaicml/pytorch:2.5.1_cu124-python3.11-ubuntu20.04
- ghcr.io/databricks-mosaic/pytorch:2.5.1_cu124-python3.11-ubuntu20.04
- mosaicml/pytorch:latest
- ghcr.io/databricks-mosaic/pytorch:latest
TARGET: pytorch_stage
TORCHVISION_VERSION: 0.20.0
TORCHVISION_VERSION: 0.20.1
- AWS_OFI_NCCL_VERSION: v1.11.0-aws
BASE_IMAGE: nvidia/cuda:12.4.1-cudnn-devel-ubuntu20.04
CUDA_VERSION: 12.4.1
IMAGE_NAME: torch-2-5-0-cu124-aws
IMAGE_NAME: torch-2-5-1-cu124-aws
MOFED_VERSION: ''
NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
PYTHON_VERSION: '3.11'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
PYTORCH_VERSION: 2.5.0
PYTORCH_VERSION: 2.5.1
TAGS:
- mosaicml/pytorch:2.5.0_cu124-python3.11-ubuntu20.04-aws
- ghcr.io/databricks-mosaic/pytorch:2.5.0_cu124-python3.11-ubuntu20.04-aws
- mosaicml/pytorch:2.5.1_cu124-python3.11-ubuntu20.04-aws
- ghcr.io/databricks-mosaic/pytorch:2.5.1_cu124-python3.11-ubuntu20.04-aws
- mosaicml/pytorch:latest-aws
- ghcr.io/databricks-mosaic/pytorch:latest-aws
TARGET: pytorch_stage
TORCHVISION_VERSION: 0.20.0
TORCHVISION_VERSION: 0.20.1
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: ubuntu:20.04
CUDA_VERSION: ''
IMAGE_NAME: torch-2-5-0-cpu
IMAGE_NAME: torch-2-5-1-cpu
MOFED_VERSION: ''
NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
PYTHON_VERSION: '3.11'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
PYTORCH_VERSION: 2.5.0
PYTORCH_VERSION: 2.5.1
TAGS:
- mosaicml/pytorch:2.5.0_cpu-python3.11-ubuntu20.04
- ghcr.io/databricks-mosaic/pytorch:2.5.0_cpu-python3.11-ubuntu20.04
- mosaicml/pytorch:2.5.1_cpu-python3.11-ubuntu20.04
- ghcr.io/databricks-mosaic/pytorch:2.5.1_cpu-python3.11-ubuntu20.04
- mosaicml/pytorch:latest_cpu
- ghcr.io/databricks-mosaic/pytorch:latest_cpu
TARGET: pytorch_stage
TORCHVISION_VERSION: 0.20.0
TORCHVISION_VERSION: 0.20.1
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: nvidia/cuda:12.4.1-cudnn-devel-ubuntu20.04
CUDA_VERSION: 12.4.1
Expand Down Expand Up @@ -176,14 +176,14 @@
PYTHON_VERSION: '3.11'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
PYTORCH_VERSION: 2.5.0
PYTORCH_VERSION: 2.5.1
TAGS:
- mosaicml/composer:0.26.0
- ghcr.io/databricks-mosaic/composer:0.26.0
- mosaicml/composer:latest
- ghcr.io/databricks-mosaic/composer:latest
TARGET: composer_stage
TORCHVISION_VERSION: 0.20.0
TORCHVISION_VERSION: 0.20.1
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: ubuntu:20.04
COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.26.0
Expand All @@ -194,11 +194,11 @@
PYTHON_VERSION: '3.11'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
PYTORCH_VERSION: 2.5.0
PYTORCH_VERSION: 2.5.1
TAGS:
- mosaicml/composer:0.26.0_cpu
- ghcr.io/databricks-mosaic/composer:0.26.0_cpu
- mosaicml/composer:latest_cpu
- ghcr.io/databricks-mosaic/composer:latest_cpu
TARGET: composer_stage
TORCHVISION_VERSION: 0.20.0
TORCHVISION_VERSION: 0.20.1
10 changes: 5 additions & 5 deletions docker/generate_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import yaml

PRODUCTION_PYTHON_VERSION = '3.11'
PRODUCTION_PYTORCH_VERSION = '2.5.0'
PRODUCTION_PYTORCH_VERSION = '2.5.1'


def _get_torchvision_version(pytorch_version: str):
if pytorch_version == '2.5.0':
return '0.20.0'
if pytorch_version == '2.5.1':
return '0.20.1'
if pytorch_version == '2.4.1':
return '0.19.1'
if pytorch_version == '2.3.1':
Expand All @@ -45,7 +45,7 @@ def _get_cuda_version(pytorch_version: str, use_cuda: bool):
# From https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/
if not use_cuda:
return ''
if pytorch_version == '2.5.0':
if pytorch_version == '2.5.1':
return '12.4.1'
if pytorch_version == '2.4.1':
return '12.4.1'
Expand Down Expand Up @@ -180,7 +180,7 @@ def _write_table(table_tag: str, table_contents: str):


def _main():
python_pytorch_versions = [('3.11', '2.5.0'), ('3.11', '2.4.1'), ('3.11', '2.3.1')]
python_pytorch_versions = [('3.11', '2.5.1'), ('3.11', '2.4.1'), ('3.11', '2.3.1')]
cuda_options = [True, False]
stages = ['pytorch_stage']
interconnects = ['mellanox', 'EFA'] # mellanox is default, EFA needed for AWS
Expand Down
2 changes: 1 addition & 1 deletion docs/source/composer_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ A full example of a validation implementation would be:
def get_metrics(self, is_train=False):
# defines which metrics to use in each phase of training
return {'MulticlassAccuracy': self.train_accuracy} if train else {'MulticlassAccuracy': self.val_accuracy}
return {'MulticlassAccuracy': self.train_accuracy} if is_train else {'MulticlassAccuracy': self.val_accuracy}
.. note::

Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ def package_files(prefix: str, directory: str, extension: str):
install_requires = [
'pyyaml>=6.0,<7',
'tqdm>=4.62.3,<5',
'torchmetrics>=1.0,<1.4.1',
'torchmetrics>=1.0,<1.5.3',
'torch_optimizer>=0.3.0,<0.4',
'torchvision>=0.18.0,<0.20.1',
'torch>=2.3.0,<2.5.1',
'torchvision>=0.18.0,<0.20.2',
'torch>=2.3.0,<2.5.2',
'requests>=2.26.0,<3',
'numpy>=1.21.5,<2.2.0',
'psutil>=5.8.0,<7',
'coolname>=1.1.0,<3',
'tabulate==0.9.0', # for auto-generating tables
'py-cpuinfo>=8.0.0,<10',
'packaging>=21.3.0,<24.2',
'packaging>=21.3.0,<24.3',
'importlib-metadata>=5.0.0,<9',
'mosaicml-cli>=0.5.25,<0.7',
'pillow>=10.3.0,<12',
Expand All @@ -103,7 +103,7 @@ def package_files(prefix: str, directory: str, extension: str):
# Should manually update dependency versions occassionally.
'custom_inherit==2.4.1',
'junitparser==3.1.2',
'coverage[toml]==7.6.3',
'coverage[toml]==7.6.4',
'fasteners==0.18', # object store tests require fasteners
'pytest==7.4.4',
'ipython==8.11.0',
Expand Down
5 changes: 3 additions & 2 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def test_fsdp_mixed_with_sync(
'0.23.0',
'0.24.0',
'0.25.0',
'0.26.0',
],
)
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
Expand All @@ -534,8 +535,8 @@ def test_fsdp_load_old_checkpoint(
pytest.skip('TODO: This checkpoint is missing')

if (composer_version in ['0.22.0', '0.23.0'] and version.parse(torch.__version__) < version.parse('2.3.0')) or (
composer_version == '0.24.0' and version.parse(torch.__version__) < version.parse('2.4.0')
) or (composer_version == '0.25.0' and version.parse(torch.__version__) < version.parse('2.5.0')):
composer_version in ['0.24.0', '0.25.0'] and version.parse(torch.__version__) < version.parse('2.4.0')
) or (composer_version in '0.26.0' and version.parse(torch.__version__) < version.parse('2.5.0')):
pytest.skip('Current torch version is older than torch version that checkpoint was written with.')

if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']:
Expand Down
Loading

0 comments on commit 7502fc0

Please sign in to comment.