From 047bcf6af6a3721cfac31a13a1ab07c6b5482fb9 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Thu, 26 Sep 2024 22:37:32 -0700 Subject: [PATCH] Add APIs to offload states of model, optimizer, and engine (#6011) This PR adds the following APIs to offload model, optimizer, and engine states. ```pytyon def offload_states(self, include: Container[OffloadStateTypeEnum] = None, device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, pin_memory: bool = True, non_blocking: bool = False) -> None: """Move the ZeRO optimizer buffers to the specified device. Arguments: include: Optional. The set of states to offload. If not provided, all states are offloaded. device: Optional. The device to move the ZeRO optimizer buffers to. pin_memory: Optional. Whether to pin the memory of the offloaded states. non_blocking: Optional. Whether to offload the states asynchronously. ... def offload_states_back(self, non_blocking: bool = False) -> None: ``` Here is the typical usage. ```python # Offload after forward, backward, and step model.offload_states() # Do something requiring a lot of device memory ... # Load states back to device memory model.offload_states_back() ``` You can selectively offload states to balance the offloading overhead and memory saving. ```python model.offload_states(include=set([OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.opt_states], device=OffloadDeviceEnum.cpu) ``` Performance (4.3B parameters / 4x A100) - Environment (4x A100, [benchmark script](https://gist.github.com/tohtana/05d5faba5068cf839abfc7b1e38b85e4)) - Average Device to Host transfer time: 2.45 GB/s, aggregated: 9.79 GB/s - Average Host to Device transfer: 11.05 GB/s, aggregated: 44.19 GB/s - Mem (allocated by PyTorch) - Before offload 18.2GB - After offloading 17.7MB - Time ([benchmark script](https://github.com/microsoft/DeepSpeedExamples/tree/tohtana/offload_states/training/offload_states), offloading time/loading time) python output_table.py | |pin_memory=0 non_blocking=0|pin_memory=0 non_blocking=1|pin_memory=1 non_blocking=0|pin_memory=1 non_blocking=1| |--:|---------------------------|---------------------------|---------------------------|---------------------------| | 1|4.34 / 3.42 |4.99 / 2.37 |6.5 / 2.42 |6.0 / 2.39 | | 2|9.9 / 3.28 |5.1 / 2.34 |6.21 / 2.42 |6.25 / 2.45 | | 3|9.92 / 3.19 |6.71 / 2.35 |6.33 / 2.38 |5.93 / 2.42 | | 4|9.55 / 2.82 |7.11 / 2.39 |6.9 / 2.38 |6.5 / 2.43 | | 5|4.4 / 3.35 |6.04 / 2.41 |6.26 / 2.41 |6.32 / 2.47 | | 6|4.4 / 3.57 |6.58 / 2.42 |6.88 / 2.4 |6.35 / 2.43 | | 7|9.51 / 3.12 |6.9 / 2.39 |6.9 / 2.39 |6.46 / 2.4 | | 8|4.77 / 3.64 |6.69 / 2.39 |7.39 / 2.42 |6.56 / 2.46 | | 9|9.5 / 3.07 |7.18 / 2.42 |6.67 / 2.39 |7.38 / 2.46 | TODO: - Enable offloading to a NVMe storage -> NVMe support is non-trivial. I suggest adding the support in another PR - [DONE] Discard buffer (and recreate it) instead of offloading. We don't need to restore the contiguous buffer for reduce. - [DONE] Check pin_memory improves performance or not --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/engine.py | 41 +++- deepspeed/runtime/utils.py | 36 ++++ deepspeed/runtime/zero/offload_config.py | 9 + deepspeed/runtime/zero/stage3.py | 189 ++++++++++++++++-- deepspeed/runtime/zero/utils.py | 15 +- docs/code-docs/source/zero3.rst | 53 +++++ .../unit/runtime/zero/test_offload_states.py | 125 ++++++++++++ 7 files changed, 443 insertions(+), 25 deletions(-) create mode 100644 tests/unit/runtime/zero/test_offload_states.py diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 61e6da2663cf..b590ea432658 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -18,13 +18,13 @@ from torch.optim.lr_scheduler import _LRScheduler from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from typing import Callable, Dict, Union, Iterable +from typing import Callable, Dict, Union, Iterable, Container import deepspeed from deepspeed import comm as dist from deepspeed.runtime.utils import see_memory_usage, DummyOptim -from .zero.offload_config import OffloadDeviceEnum +from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException @@ -3681,3 +3681,40 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg @property def is_compiled(self) -> bool: return self._is_compiled + + def offload_states(self, + include: Container[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False) -> None: + """Offload the engine's states to the specified device. + + Arguments: + include: Optional. The set of states to offload. If not provided, all states are offloaded. + device: Optional. The device to move the ZeRO optimizer buffers to. Currently only `OffloadDeviceEnum.cpu` is supported. + pin_memory: Optional. Whether to pin the memory of the offloaded states. + non_blocking: Optional. Whether to offload the states asynchronously. + """ + assert self.zero_optimization_stage( + ) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3." + + assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters." + + if device == OffloadDeviceEnum.none: + logger.warning("No device specified for offloading states.") + return + + if device == OffloadDeviceEnum.nvme: + raise ValueError("NVMe offload is not supported for offloading states.") + + self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking) + + def reload_states(self, non_blocking: bool = False) -> None: + """Reload the engine states to the original device. + + Arguments: + non_blocking: Optional. Whether to offload the states asynchronously. + """ + assert self.zero_optimization_stage( + ) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3." + self.optimizer.reload_states(non_blocking=non_blocking) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 2c01c3475a70..adcadd349803 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1065,3 +1065,39 @@ def to_tensor(v): total_norm = -1 return total_norm + + +def _make_offload_state_key(key): + return f"{key}_offload_buffer" + + +def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_key(state, key): + offload_buf_key = _make_offload_state_key(key) + if offload_buf_key not in state: + state[offload_buf_key] = torch.empty_like(state[key], device=device) + if pin_memory: + state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key]) + state[offload_buf_key].copy_(state[key], non_blocking=non_blocking) + state[key].data = state[offload_buf_key] + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + move_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_key(state, "exp_avg_sq") + + +def reload_adam_states(optimizer, device, non_blocking: bool = False): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_back_key(state, key): + state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking) + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + move_back_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_back_key(state, "exp_avg_sq") diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index 74a5673bc1bc..ca35d7a7d169 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -98,3 +98,12 @@ def set_pipeline(self): pipeline = self.pipeline_read or self.pipeline_write self.__dict__["pipeline"] = pipeline return self + + +class OffloadStateTypeEnum(str, Enum): + """ Enum for internal buffer types """ + optim_states = "optim_states" + hp_params = "hp_params" + lp_params = "lp_params" + lp_grads = "lp_grads" + contiguous_grad_buffer = "contiguous_grad_buffer" diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 796957a4c6e5..fb75d2bcebd5 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -6,22 +6,24 @@ import sys import gc import collections -from typing import Deque, Dict, Tuple +import itertools +from typing import Deque, Dict, Set, Tuple, Container from contextlib import contextmanager + from deepspeed import comm as dist -from deepspeed.utils import groups +from deepspeed.utils import groups, z3_leaf_parameter from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, reload_adam_states from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum -from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload -from deepspeed.runtime.zero.utils import apply_to_tensors_only +from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper @@ -29,7 +31,6 @@ from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER from deepspeed.accelerator import get_accelerator -from deepspeed.utils import z3_leaf_parameter # Toggle this to true to enable correctness test # with gradient partitioning and without @@ -425,6 +426,8 @@ def __init__( self._link_all_hp_params() + self.offloaded_states: Set(OffloadDeviceEnum) = set() + if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer", force=True) @@ -563,21 +566,15 @@ def defragment(tensors: List[Tensor]) -> Tensor: cpu_buffer = torch.empty(sum(p.numel() for p in tensors), dtype=get_only_unique_item(t.dtype for t in tensors), device="cpu") - tensor_infos: List[Tuple[Tensor, int, int]] = [] + tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors) orig_device = get_only_unique_item(t.device for t in tensors) offset = 0 - for tensor in tensors: - tensor_numel = tensor.numel() + for tensor, offset, tensor_numel in tensor_infos: # move the tensor from device memory to host memory cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - # record some data so we can restore the device tensor later - tensor_infos.append((tensor, offset, tensor_numel)) - - offset += tensor_numel - gc.collect() get_accelerator().empty_cache() @@ -725,15 +722,11 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): for sub_group in self.fp16_groups: for param in sub_group: parameter_partitions.append(param.ds_tensor) - device_buffer = __class__.defragment(parameter_partitions) - # setup flat buffers per subgroup, these are each just sections of the - # contiguous flat buffer for all parameters that we created earlier - offset = 0 - for sub_group in self.fp16_groups: - sub_group_numel = sum(param.partition_numel() for param in sub_group) - self.fp16_partitioned_groups_flat.append(device_buffer.narrow(0, offset, sub_group_numel)) - offset += sub_group_numel + # We need to keep the reference to this buffer to make sure you can free it in `offload_states` + self.lp_param_buffer = __class__.defragment(parameter_partitions) + self._set_fp16_partitioned_groups_flat() + else: # partitioned params offloaded to CPU when not in use # create a flat CPU memory allocation for each param group self._create_param_groups_fp16_flat_cpu_memory() @@ -1008,6 +1001,15 @@ def _partitioned_params_swap_out(self, i): swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(dst_fp16_params=swap_fp16_params, src_fp32_params=swap_fp32_params) + def _set_fp16_partitioned_groups_flat(self): + # setup flat buffers per subgroup, these are each just sections of the + # contiguous flat buffer for all parameters that we created earlier + offset = 0 + for sub_group in self.fp16_groups: + sub_group_numel = sum(param.partition_numel() for param in sub_group) + self.fp16_partitioned_groups_flat.append(self.lp_param_buffer.narrow(0, offset, sub_group_numel)) + offset += sub_group_numel + def initialize_optimizer_states(self): num_subgroups = len(self.fp16_groups) @@ -2782,6 +2784,149 @@ def checkpoint_event_epilogue(self): def empty_partition_cache(self): self.parameter_offload.empty_partition_cache() + def offload_states(self, + include: Container[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False): + device = device.value + + self.empty_partition_cache() + + assert self.optimizer.__class__ == deepspeed.ops.adam.fused_adam.FusedAdam, f"Offloading is supported only for DeepSpeed FusedAdam." + + def needs_offload(target): + # return True + return target not in self.offloaded_states and (include == None or target in include) + + # HP param + if needs_offload(OffloadStateTypeEnum.hp_params): + if pin_memory: + if not hasattr(self, "hp_params_pin_buffers"): + self.hp_params_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(t, device=device)) + for t in self.fp32_partitioned_groups_flat + ] + + for src_tensor, dest_buf in zip(self.fp32_partitioned_groups_flat, self.hp_params_pin_buffers): + dest_buf.copy_(src_tensor, non_blocking=non_blocking) + src_tensor.data = dest_buf + else: + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) + self.offloaded_states.add(OffloadStateTypeEnum.hp_params) + + # LP param + if needs_offload(OffloadStateTypeEnum.lp_params): + if pin_memory: + if not hasattr(self, "lp_param_contiguous_pin_buffer"): + self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory( + torch.empty_like(self.lp_param_buffer, device=device)) + self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking) + cpu_buffer = self.lp_param_contiguous_pin_buffer + else: + cpu_buffer = self.lp_param_buffer.to(device, non_blocking=non_blocking) + + self.lp_param_buffer.data = cpu_buffer + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer( + [p.ds_tensor for p in self.module.parameters()]): + tensor.data = cpu_buffer.narrow(0, offset, tensor_numel) + + self.fp16_partitioned_groups_flat.clear() + self.offloaded_states.add(OffloadStateTypeEnum.lp_params) + + # LP grad + if needs_offload(OffloadStateTypeEnum.lp_grads): + if pin_memory: + if not hasattr(self, "lp_grad_partitions_flat_pin_buffers"): + self.lp_grad_partitions_flat_pin_buffers = get_accelerator().pin_memory( + torch.empty_like(self.grad_partitions_flat_buffer, device=device)) + self.lp_grad_partitions_flat_pin_buffers.copy_(self.grad_partitions_flat_buffer, + non_blocking=non_blocking) + self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers + else: + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) + self.averaged_gradients = {} + + self.__param_id_to_grad_partition = {} + + self.offloaded_states.add(OffloadStateTypeEnum.lp_grads) + + # contiguous bucket + if needs_offload(OffloadStateTypeEnum.contiguous_grad_buffer): + if hasattr(self, "_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer"): + # Record properties like shape, strides, etc. as a meta tensor + self.grad_buffer_meta = self.__ipg_bucket_flat_buffer.to("meta") + self.__ipg_bucket_flat_buffer = None + self.offloaded_states.add(OffloadStateTypeEnum.contiguous_grad_buffer) + + # Adam + if needs_offload(OffloadStateTypeEnum.optim_states): + offload_adam_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking) + self.offloaded_states.add(OffloadStateTypeEnum.optim_states) + + gc.collect() + get_accelerator().empty_cache() + + def reload_states(self, non_blocking: bool = False): + + device = get_accelerator().current_device_name() + + # HP param + if OffloadStateTypeEnum.hp_params in self.offloaded_states: + if hasattr(self, "hp_params_pin_buffers"): + for src, dest in zip(self.hp_params_pin_buffers, self.fp32_partitioned_groups_flat): + dest.data = src.to(device, non_blocking=non_blocking) + else: + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) + self.offloaded_states.remove(OffloadStateTypeEnum.hp_params) + + # LP Param + if OffloadStateTypeEnum.lp_params in self.offloaded_states: + cpu_buffer = self.lp_param_contiguous_pin_buffer if hasattr( + self, "lp_param_contiguous_pin_buffer") else self.lp_param_buffer + self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking) + self._set_fp16_partitioned_groups_flat() + + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer( + [p.ds_tensor for p in self.module.parameters()]): + tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel) + self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) + + # LP grad + if OffloadStateTypeEnum.lp_grads in self.offloaded_states: + if hasattr(self, "lp_grad_partitions_flat_pin_buffers"): + self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers.to( + device, non_blocking=non_blocking) + else: + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to( + device, non_blocking=non_blocking) + self.averaged_gradients = {} + + offset = 0 + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) + for param in all_params: + self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( + 0, offset, param.partition_numel()) + offset += param.partition_numel() + + self.offloaded_states.remove(OffloadStateTypeEnum.lp_grads) + + # contiguous bucket + if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states: + self.__ipg_bucket_flat_buffer = torch.empty_like(self.grad_buffer_meta, device=device) + # self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device) + self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer) + + # Adam + if OffloadStateTypeEnum.optim_states in self.offloaded_states: + reload_adam_states(self.optimizer, device, non_blocking=non_blocking) + self.offloaded_states.remove(OffloadStateTypeEnum.optim_states) + + if non_blocking: + get_accelerator().synchronize() + def _handle_overflow(cpu_sum, x, i): import math diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 8f913d065934..2d1cf17962d8 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -4,7 +4,7 @@ # DeepSpeed Team import os -from typing import List +from typing import List, Tuple import torch from deepspeed import comm as dist @@ -160,3 +160,16 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None): logger.warning(warning_msg_fn(value)) warned = True return value + + +def get_mapping_to_flat_buffer(tensors: List[torch.Tensor]) -> List[Tuple[torch.Tensor, int, int]]: + tensor_infos: List[Tuple[torch.Tensor, int, int]] = [] + + offset = 0 + for tensor in tensors: + tensor_numel = tensor.numel() + # record some data so we can restore the device tensor later + tensor_infos.append((tensor, offset, tensor_numel)) + offset += tensor_numel + + return tensor_infos diff --git a/docs/code-docs/source/zero3.rst b/docs/code-docs/source/zero3.rst index 2a6a48ca91db..f0974c08c9f3 100644 --- a/docs/code-docs/source/zero3.rst +++ b/docs/code-docs/source/zero3.rst @@ -456,3 +456,56 @@ The following code snippet illustrates this functionality. # Free GPU memory consumed by model parameters ds_engine.empty_partition_cache() + + +Offload States +-------------- + +The DeepSpeed engine maintains a set of states in device memory (e.g., CUDA memory). The following API allows you to offload these states to a different device (currently, only CPU memory is supported), reducing the memory footprint on the device. + +.. code-block:: python + + def offload_states(self, + include: Container[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False) -> None: + """Offload the engine's states to the specified device. + + Arguments: + include: Optional. The set of states to offload. If not provided, all states are offloaded. + device: Optional. The device to move the ZeRO optimizer buffers to. Currently only `OffloadDeviceEnum.cpu` is supported. + pin_memory: Optional. Whether to pin the memory of the offloaded states. + non_blocking: Optional. Whether to offload the states asynchronously. + """ + +You can selectively offload specific states by specifying the ``OffloadStateTypeEnum`` in the include argument. ``OffloadStateTypeEnum`` is an enum that defines the states that can be offloaded. The following states are supported: + +* ``OffloadStateTypeEnum.optim_states``: Optimizer states. Currently, only states of DeepSpeed's FusedAdam optimizer are supported. +* ``OffloadStateTypeEnum.hp_params``: FP32 parameters. +* ``OffloadStateTypeEnum.lp_params``: BF16/FP16 parameters. +* ``OffloadStateTypeEnum.lp_grads``: BF16/FP16 gradients. +* ``OffloadStateTypeEnum.contiguous_grad_buffer``: The contiguous gradient buffer for reduce operations. + +Note that offloading states comes with a trade-off between memory savings and computational overhead. This API allows states to be reloaded back into device memory when needed. + +.. code-block:: python + + def reload_states(self, non_blocking: bool = False) -> None: + """Reload the engine states to the original device. + + Arguments: + non_blocking: Optional. Whether to offload the states asynchronously. + """ + +Below is an example code snippet demonstrating how to offload FP32 parameters and optimizer states to CPU memory: + +.. code-block:: python + + # Offload after forward, backward, and step + ds_engine.offload_states(include=[OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.optim_states]) + + # Do something requiring a lot of device memory + ... + # Load states back to device memory + ds_engine.reload_states() diff --git a/tests/unit/runtime/zero/test_offload_states.py b/tests/unit/runtime/zero/test_offload_states.py new file mode 100644 index 000000000000..cc60908d3c33 --- /dev/null +++ b/tests/unit/runtime/zero/test_offload_states.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +import torch + +from unit.common import DistributedTest +from unit.simple_model import random_dataloader, SimpleModel + +import deepspeed +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum +from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_optimizer_state + + +def validate_device(model, device: torch.device, include) -> None: + # Make sure the model parameters are offloaded + if include is None or OffloadStateTypeEnum.hp_params in include: + assert all(safe_get_local_fp32_param(p).device == device for p in model.parameters()) + if include is None or OffloadStateTypeEnum.lp_params in include: + assert all(p.ds_tensor.device == device for p in model.parameters()) + if include is None or OffloadStateTypeEnum.lp_grads in include: + assert model.optimizer.grad_partitions_flat_buffer.device == device + if include is None or OffloadStateTypeEnum.optim_states in include: + assert all(safe_get_local_optimizer_state(p, "exp_avg").device == device for p in model.parameters()) + assert all(safe_get_local_optimizer_state(p, "exp_avg_sq").device == device for p in model.parameters()) + + +def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking): + # Currently we only support OffloadDeviceEnum.cpu + offload_device = OffloadDeviceEnum.cpu + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + hp_params_expected = [safe_get_local_fp32_param(p).clone() for p in model.parameters()] + lp_params_expected = [p.ds_tensor.clone() for p in model.parameters()] + lp_grads_expected = model.optimizer.grad_partitions_flat_buffer.clone() + adam_exp_avg_expected = [safe_get_local_optimizer_state(p, "exp_avg").clone() for p in model.parameters()] + adam_exp_avg_sq = [safe_get_local_optimizer_state(p, "exp_avg_sq").clone() for p in model.parameters()] + + # Start offloading + alloc_before_offload = get_accelerator().memory_allocated() + model.offload_states(include=include, device=offload_device, pin_memory=pin_memory, non_blocking=non_blocking) + alloc_after_offload = get_accelerator().memory_allocated() + assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload" + + validate_device(model, torch.device(offload_device.value), include) + + # Reload states + model.reload_states() + assert alloc_after_offload < get_accelerator().memory_allocated( + ), f"Allocated memory should increase after offload back" + + # Verify restored states + hp_param_restored = [safe_get_local_fp32_param(p) for p in model.parameters()] + for hp_param_expected, hp_param_restored in zip(hp_params_expected, hp_param_restored): + assert torch.equal(hp_param_expected, hp_param_restored) + + lp_param_restored = [p.ds_tensor for p in model.parameters()] + + for lp_param_expected, lp_param_restored in zip(lp_params_expected, lp_param_restored): + assert torch.equal(lp_param_expected, lp_param_restored) + + assert torch.equal(lp_grads_expected, model.optimizer.grad_partitions_flat_buffer) + + adam_exp_avg_restored = [safe_get_local_optimizer_state(p, "exp_avg") for p in model.parameters()] + for adam_exp_avg_expected, adam_exp_avg_restored in zip(adam_exp_avg_expected, adam_exp_avg_restored): + assert torch.equal(adam_exp_avg_expected, adam_exp_avg_restored) + + adam_exp_avg_sq_restored = [safe_get_local_optimizer_state(p, "exp_avg_sq") for p in model.parameters()] + for adam_exp_avg_sq_expected, adam_exp_avg_sq_restored in zip(adam_exp_avg_sq, adam_exp_avg_sq_restored): + assert torch.equal(adam_exp_avg_sq_expected, adam_exp_avg_sq_restored) + + validate_device(model, torch.device(get_accelerator().current_device_name()), include) + + # Needed in ZeRO 3. Not doing so can give memory leak + model.destroy() + + +@pytest.mark.parametrize("included_state", [ + OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.lp_params, OffloadStateTypeEnum.optim_states, + OffloadStateTypeEnum.lp_grads, OffloadStateTypeEnum.contiguous_grad_buffer, None +]) +@pytest.mark.parametrize("pin_memory", [False, True]) +@pytest.mark.parametrize("non_blocking", [False, True]) +class TestOffloadStates(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + + def test_offload_states(self, included_state, pin_memory, non_blocking): + hidden_dim = 1024 + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 3, + } + } + config_dict["bf16"] = {"enabled": True} + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim, nlayers=4) + + include = None if included_state is None else [included_state] + run_model(model, config_dict, hidden_dim, torch.bfloat16, include, pin_memory, non_blocking)