From 582510444b39b9f0b3d251c6828b4040e1545f78 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 16 Aug 2024 18:17:01 +0000 Subject: [PATCH 01/15] add apis to offload states of model, optimizer, and engine --- deepspeed/runtime/engine.py | 31 ++++- deepspeed/runtime/utils.py | 9 ++ deepspeed/runtime/zero/offload_config.py | 9 ++ deepspeed/runtime/zero/stage3.py | 119 ++++++++++++++++-- .../unit/runtime/zero/test_offload_states.py | 73 +++++++++++ 5 files changed, 232 insertions(+), 9 deletions(-) create mode 100644 tests/unit/runtime/zero/test_offload_states.py diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d2839a8f5d7c..d80eb4cb400f 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -18,13 +18,13 @@ from torch.optim.lr_scheduler import _LRScheduler from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from typing import Callable, Dict, Union, Iterable +from typing import Callable, Dict, Union, Iterable, Set import deepspeed from deepspeed import comm as dist from deepspeed.runtime.utils import see_memory_usage, DummyOptim -from .zero.offload_config import OffloadDeviceEnum +from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException @@ -3672,3 +3672,30 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg @property def is_compiled(self) -> bool: return self._is_compiled + + def offload_states(self, + include: Set[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu) -> None: + """Move the ZeRO optimizer buffers to the specified device. + + Arguments: + device: Required. The device to move the ZeRO optimizer buffers to. + """ + assert self.zero_optimization_stage( + ) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3." + + if device == OffloadDeviceEnum.none: + logger.warning("No device specified for offloading states.") + return + + if device == OffloadDeviceEnum.nvme: + raise ValueError("NVMe offload is not supported for offloading states.") + + self.optimizer.offload_states(include=include, device=device) + + def offload_states_back(self) -> None: + """Move the ZeRO optimizer buffers back to the original device. + """ + assert self.zero_optimization_stage( + ) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3." + self.optimizer.offload_states_back() diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 2c01c3475a70..e7e43279634f 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1065,3 +1065,12 @@ def to_tensor(v): total_norm = -1 return total_norm + + +def adam_states_to(optimizer, device): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + for _, state in optimizer.state.items(): + if "exp_avg" in state: + state["exp_avg"] = state["exp_avg"].to(device) + if "exp_avg_sq" in state: + state["exp_avg_sq"] = state["exp_avg_sq"].to(device) diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index b7adc13a0ea2..a3241a91b7e5 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -95,3 +95,12 @@ def set_pipeline(cls, field_value, values): ratio: float = Field(1.0, ge=0.0, le=1.0) """ Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3.""" + + +class OffloadStateTypeEnum(str, Enum): + """ Enum for internal buffer types """ + opt_states = "opt_states" + hp_params = "hp_params" + lp_params = "lp_params" + lp_grads = "lp_grads" + contiguous_grad_buffer = "contiguous_grad_buffer" diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index b0a3ab778f2a..1b0178827499 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -6,19 +6,20 @@ import sys import gc import collections -from typing import Deque, Dict, Tuple +import itertools +from typing import Deque, Dict, Set, Tuple from deepspeed import comm as dist -from deepspeed.utils import groups +from deepspeed.utils import groups, z3_leaf_parameter from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, adam_states_to from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum -from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.utils import apply_to_tensors_only from deepspeed.ops.adam import DeepSpeedCPUAdam @@ -28,7 +29,6 @@ from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER from deepspeed.accelerator import get_accelerator -from deepspeed.utils import z3_leaf_parameter # Toggle this to true to enable correctness test # with gradient partitioning and without @@ -393,6 +393,8 @@ def __init__( self._link_all_hp_params() + self.offloaded_states: Set(OffloadDeviceEnum) = set() + if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer", force=True) @@ -489,6 +491,9 @@ def _setup_for_real_optimizer(self): self.__ipg_bucket_flat_buffer: Tensor = torch.empty(self.reduce_bucket_size, dtype=self.dtype, device=get_accelerator().current_device_name()) + print( + f"init contiguous gradients{self.contiguous_gradients} __ipg_bucket_flat_buffer?={hasattr(self, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer')}" + ) self.grad_partitions_flat_buffer = None self.__param_id_to_grad_partition: Dict[int, Tensor] = {} @@ -693,14 +698,17 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): for sub_group in self.fp16_groups: for param in sub_group: parameter_partitions.append(param.ds_tensor) - device_buffer = __class__.defragment(parameter_partitions) + + # Store contiguous buffer for low precision parameters. + # Needs to be an instance attribute to be evicted to host memory. + self.lp_param_buffer = __class__.defragment(parameter_partitions) # setup flat buffers per subgroup, these are each just sections of the # contiguous flat buffer for all parameters that we created earlier offset = 0 for sub_group in self.fp16_groups: sub_group_numel = sum(param.partition_numel() for param in sub_group) - self.fp16_partitioned_groups_flat.append(device_buffer.narrow(0, offset, sub_group_numel)) + self.fp16_partitioned_groups_flat.append(self.lp_param_buffer.narrow(0, offset, sub_group_numel)) offset += sub_group_numel else: # partitioned params offloaded to CPU when not in use # create a flat CPU memory allocation for each param group @@ -2749,6 +2757,103 @@ def checkpoint_event_epilogue(self): def empty_partition_cache(self): self.parameter_offload.empty_partition_cache() + def offload_states(self, + include: Set[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu): + device = device.value + + self.empty_partition_cache() + + assert self.optimizer.__class__ == deepspeed.ops.adam.fused_adam.FusedAdam, f"Offloading is supported only for DeepSpeed FusedAdam." + + def needs_offload(target): + # return True + return target not in self.offloaded_states and (include == None or target in include) + + # HP param + if needs_offload(OffloadStateTypeEnum.hp_params): + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device) + self.offloaded_states.add(OffloadStateTypeEnum.hp_params) + + # LP param + if needs_offload(OffloadStateTypeEnum.lp_params): + if hasattr(self, "lp_param_buffer"): + self.lp_param_buffer.data = self.lp_param_buffer.data.to(device) + for p in self.fp16_partitioned_groups_flat: + p.data = p.data.to(device) + for p in self.module.parameters(): + p.ds_tensor.data = p.ds_tensor.data.to(device) + self.offloaded_states.add(OffloadStateTypeEnum.lp_params) + + # LP grad + if needs_offload(OffloadStateTypeEnum.lp_grads): + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) + self.averaged_gradients = {} + + for _, g in self.__param_id_to_grad_partition.items(): + g.data = g.data.to(device) + self.__param_id_to_grad_partition = {} + + self.offloaded_states.add(OffloadStateTypeEnum.lp_grads) + + # contiguous bucket + if needs_offload(OffloadStateTypeEnum.contiguous_grad_buffer): + if hasattr(self, "_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer"): + self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device) + self.offloaded_states.add(OffloadStateTypeEnum.contiguous_grad_buffer) + + # Adam + if needs_offload(OffloadStateTypeEnum.opt_states): + adam_states_to(self.optimizer, device) + self.offloaded_states.add(OffloadStateTypeEnum.opt_states) + + gc.collect() + get_accelerator().empty_cache() + + def offload_states_back(self): + + device = get_accelerator().current_device_name() + + # HP param + if OffloadStateTypeEnum.hp_params in self.offloaded_states: + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device) + self.offloaded_states.remove(OffloadStateTypeEnum.hp_params) + + # LP Param + if OffloadStateTypeEnum.lp_params in self.offloaded_states and hasattr(self, "lp_param_buffer"): + self.lp_param_buffer.data = self.lp_param_buffer.data.to(device) + for p in self.fp16_partitioned_groups_flat: + p.data = p.data.to(device) + for p in self.module.parameters(): + p.ds_tensor.data = p.ds_tensor.data.to(device) + self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) + + # LP grad + if OffloadStateTypeEnum.lp_grads in self.offloaded_states: + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) + self.averaged_gradients = {} + + offset = 0 + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) + for param in all_params: + self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( + 0, offset, param.partition_numel()) + offset += param.partition_numel() + + self.offloaded_states.remove(OffloadStateTypeEnum.lp_grads) + + # contiguous bucket + if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states: + self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device) + self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer) + + # Adam + if OffloadStateTypeEnum.opt_states in self.offloaded_states: + adam_states_to(self.optimizer, device) + self.offloaded_states.remove(OffloadStateTypeEnum.opt_states) + def _handle_overflow(cpu_sum, x, i): import math diff --git a/tests/unit/runtime/zero/test_offload_states.py b/tests/unit/runtime/zero/test_offload_states.py new file mode 100644 index 000000000000..4b603ca96a63 --- /dev/null +++ b/tests/unit/runtime/zero/test_offload_states.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +import torch + +from unit.common import DistributedTest +from unit.simple_model import random_dataloader, SimpleModel + +import deepspeed +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum + + +def run_model(model, config_dict, hidden_dim, dtype, include): + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + alloc_before_offload = get_accelerator().memory_allocated() + model.offload_states(include=include, device=OffloadDeviceEnum.cpu) + alloc_after_offload = get_accelerator().memory_allocated() + assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload" + model.offload_states_back() + assert alloc_after_offload < get_accelerator().memory_allocated( + ), f"Allocated memory should increase after offload back" + + # Needed in ZeRO 3. Not doing so can give memory leak + model.destroy() + + +@pytest.mark.parametrize("included_state", [ + OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.lp_params, OffloadStateTypeEnum.opt_states, + OffloadStateTypeEnum.lp_grads, OffloadStateTypeEnum.contiguous_grad_buffer, None +]) +class TestOffloadStates(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + reuse_dist_env = True + + def test_move_buffer(self, included_state): + hidden_dim = 1024 + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 3, + } + } + config_dict["bf16"] = {"enabled": True} + + model = SimpleModel(hidden_dim) + + include = None if included_state is None else set([included_state]) + run_model(model, config_dict, hidden_dim, torch.bfloat16, include) From 600c822628df996a3bb955e1fd395a84739c36b0 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 16 Aug 2024 18:44:41 +0000 Subject: [PATCH 02/15] update api doc --- deepspeed/runtime/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d80eb4cb400f..bd3dd439e637 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3679,7 +3679,8 @@ def offload_states(self, """Move the ZeRO optimizer buffers to the specified device. Arguments: - device: Required. The device to move the ZeRO optimizer buffers to. + include: Optional. The set of states to offload. If not provided, all states are offloaded. + device: Optional. The device to move the ZeRO optimizer buffers to. """ assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3." From 126d9b7f13641d5a7402dc13b487629d4b4bc683 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 16 Aug 2024 20:03:46 +0000 Subject: [PATCH 03/15] reduce global reference to buffer --- deepspeed/runtime/zero/stage3.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e39bf39170f9..9363c9fb190e 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -489,9 +489,6 @@ def _setup_for_real_optimizer(self): self.__ipg_bucket_flat_buffer: Tensor = torch.empty(self.reduce_bucket_size, dtype=self.dtype, device=get_accelerator().current_device_name()) - print( - f"init contiguous gradients{self.contiguous_gradients} __ipg_bucket_flat_buffer?={hasattr(self, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer')}" - ) self.grad_partitions_flat_buffer = None self.__param_id_to_grad_partition: Dict[int, Tensor] = {} @@ -697,16 +694,14 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): for param in sub_group: parameter_partitions.append(param.ds_tensor) - # Store contiguous buffer for low precision parameters. - # Needs to be an instance attribute to be evicted to host memory. - self.lp_param_buffer = __class__.defragment(parameter_partitions) + device_buffer = __class__.defragment(parameter_partitions) # setup flat buffers per subgroup, these are each just sections of the # contiguous flat buffer for all parameters that we created earlier offset = 0 for sub_group in self.fp16_groups: sub_group_numel = sum(param.partition_numel() for param in sub_group) - self.fp16_partitioned_groups_flat.append(self.lp_param_buffer.narrow(0, offset, sub_group_numel)) + self.fp16_partitioned_groups_flat.append(device_buffer.narrow(0, offset, sub_group_numel)) offset += sub_group_numel else: # partitioned params offloaded to CPU when not in use # create a flat CPU memory allocation for each param group @@ -2777,13 +2772,11 @@ def needs_offload(target): # LP param if needs_offload(OffloadStateTypeEnum.lp_params): - if hasattr(self, "lp_param_buffer"): - self.lp_param_buffer.data = self.lp_param_buffer.data.to(device) - for p in self.fp16_partitioned_groups_flat: - p.data = p.data.to(device) - for p in self.module.parameters(): - p.ds_tensor.data = p.ds_tensor.data.to(device) - self.offloaded_states.add(OffloadStateTypeEnum.lp_params) + for p in self.fp16_partitioned_groups_flat: + p.data = p.data.to(device) + for p in self.module.parameters(): + p.ds_tensor.data = p.ds_tensor.data.to(device) + self.offloaded_states.add(OffloadStateTypeEnum.lp_params) # LP grad if needs_offload(OffloadStateTypeEnum.lp_grads): @@ -2821,8 +2814,8 @@ def offload_states_back(self): self.offloaded_states.remove(OffloadStateTypeEnum.hp_params) # LP Param - if OffloadStateTypeEnum.lp_params in self.offloaded_states and hasattr(self, "lp_param_buffer"): - self.lp_param_buffer.data = self.lp_param_buffer.data.to(device) + if OffloadStateTypeEnum.lp_params in self.offloaded_states: + # self.lp_param_buffer.data = self.lp_param_buffer.data.to(device) for p in self.fp16_partitioned_groups_flat: p.data = p.data.to(device) for p in self.module.parameters(): From 05df37c5a88ab9a3871b4f10ae7d5f85c11d096b Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 16 Aug 2024 23:41:20 +0000 Subject: [PATCH 04/15] loosen type hint --- deepspeed/runtime/engine.py | 4 ++-- deepspeed/runtime/zero/stage3.py | 4 ++-- tests/unit/runtime/zero/test_offload_states.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9cc991d0885a..2c09b83decd2 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -18,7 +18,7 @@ from torch.optim.lr_scheduler import _LRScheduler from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from typing import Callable, Dict, Union, Iterable, Set +from typing import Callable, Dict, Union, Iterable, Container import deepspeed @@ -3674,7 +3674,7 @@ def is_compiled(self) -> bool: return self._is_compiled def offload_states(self, - include: Set[OffloadStateTypeEnum] = None, + include: Container[OffloadStateTypeEnum] = None, device: OffloadDeviceEnum = OffloadDeviceEnum.cpu) -> None: """Move the ZeRO optimizer buffers to the specified device. diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 9363c9fb190e..d79752ad877f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -7,7 +7,7 @@ import gc import collections import itertools -from typing import Deque, Dict, Set, Tuple +from typing import Deque, Dict, Set, Tuple, Container from deepspeed import comm as dist from deepspeed.utils import groups, z3_leaf_parameter @@ -2752,7 +2752,7 @@ def empty_partition_cache(self): self.parameter_offload.empty_partition_cache() def offload_states(self, - include: Set[OffloadStateTypeEnum] = None, + include: Container[OffloadStateTypeEnum] = None, device: OffloadDeviceEnum = OffloadDeviceEnum.cpu): device = device.value diff --git a/tests/unit/runtime/zero/test_offload_states.py b/tests/unit/runtime/zero/test_offload_states.py index 4b603ca96a63..291b915fe595 100644 --- a/tests/unit/runtime/zero/test_offload_states.py +++ b/tests/unit/runtime/zero/test_offload_states.py @@ -69,5 +69,5 @@ def test_move_buffer(self, included_state): model = SimpleModel(hidden_dim) - include = None if included_state is None else set([included_state]) + include = None if included_state is None else [included_state] run_model(model, config_dict, hidden_dim, torch.bfloat16, include) From 3f8179db68ebe2506860c71cf8d5c018b862465b Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 20 Aug 2024 16:19:11 +0000 Subject: [PATCH 05/15] add option for pin_memory and non blocking copy --- deepspeed/runtime/engine.py | 15 ++- deepspeed/runtime/utils.py | 32 ++++++- deepspeed/runtime/zero/stage3.py | 92 ++++++++++++++----- .../unit/runtime/zero/test_offload_states.py | 17 ++-- 4 files changed, 122 insertions(+), 34 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 2c09b83decd2..77b077618aba 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3675,12 +3675,16 @@ def is_compiled(self) -> bool: def offload_states(self, include: Container[OffloadStateTypeEnum] = None, - device: OffloadDeviceEnum = OffloadDeviceEnum.cpu) -> None: + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False) -> None: """Move the ZeRO optimizer buffers to the specified device. Arguments: include: Optional. The set of states to offload. If not provided, all states are offloaded. device: Optional. The device to move the ZeRO optimizer buffers to. + pin_memory: Optional. Whether to pin the memory of the offloaded states. + non_blocking: Optional. Whether to offload the states asynchronously. """ assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3." @@ -3692,11 +3696,14 @@ def offload_states(self, if device == OffloadDeviceEnum.nvme: raise ValueError("NVMe offload is not supported for offloading states.") - self.optimizer.offload_states(include=include, device=device) + self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking) - def offload_states_back(self) -> None: + def offload_states_back(self, non_blocking: bool = False) -> None: """Move the ZeRO optimizer buffers back to the original device. + + Arguments: + non_blocking: Optional. Whether to offload the states asynchronously. """ assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3." - self.optimizer.offload_states_back() + self.optimizer.offload_states_back(non_blocking=non_blocking) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index e7e43279634f..a71731bf47b0 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1067,10 +1067,36 @@ def to_tensor(v): return total_norm -def adam_states_to(optimizer, device): +def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False): """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_key(state, key): + if pin_memory: + pin_mem_key = f"{key}_pin_memory" + if pin_mem_key not in state: + state[pin_mem_key] = torch.empty_like(state[key], device=device).pin_memory() + state[pin_mem_key].copy_(state[key], non_blocking=non_blocking) + state[key].data = state[pin_mem_key] + else: + state[key].data = state[key].to(device, non_blocking=non_blocking) + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + move_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_key(state, "exp_avg_sq") + + +def offload_adam_states_back(optimizer, device, non_blocking: bool = False): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_back_key(state, key): + pin_mem_key = f"{key}_pin_memory" + src_key = pin_mem_key if pin_mem_key in state else key + state[key].data = state[src_key].to(device, non_blocking=non_blocking) + for _, state in optimizer.state.items(): if "exp_avg" in state: - state["exp_avg"] = state["exp_avg"].to(device) + move_back_key(state, "exp_avg") if "exp_avg_sq" in state: - state["exp_avg_sq"] = state["exp_avg_sq"].to(device) + move_back_key(state, "exp_avg_sq") diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ff8efbeec395..1e394b6f4c59 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -18,7 +18,7 @@ from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, adam_states_to +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, offload_adam_states_back from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum @@ -2788,7 +2788,9 @@ def empty_partition_cache(self): def offload_states(self, include: Container[OffloadStateTypeEnum] = None, - device: OffloadDeviceEnum = OffloadDeviceEnum.cpu): + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False): device = device.value self.empty_partition_cache() @@ -2801,25 +2803,58 @@ def needs_offload(target): # HP param if needs_offload(OffloadStateTypeEnum.hp_params): - for buf in self.fp32_partitioned_groups_flat: - buf.data = buf.data.to(device) + if pin_memory: + if not hasattr(self, "hp_params_pin_buffers"): + self.hp_params_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(t, device=device)) + for t in self.fp32_partitioned_groups_flat + ] + + for src_tensor, dest_buf in zip(self.fp32_partitioned_groups_flat, self.hp_params_pin_buffers): + dest_buf.copy_(src_tensor, non_blocking=non_blocking) + src_tensor.data = dest_buf + else: + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) self.offloaded_states.add(OffloadStateTypeEnum.hp_params) # LP param if needs_offload(OffloadStateTypeEnum.lp_params): - for p in self.fp16_partitioned_groups_flat: - p.data = p.data.to(device) - for p in self.module.parameters(): - p.ds_tensor.data = p.ds_tensor.data.to(device) + if pin_memory: + if not hasattr(self, "lp_partitioned_groups_flat_pin_buffers"): + self.lp_partitioned_groups_flat_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(t, device=device)) + for t in self.fp16_partitioned_groups_flat + ] + self.lp_params_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(p.ds_tensor, device=device)) + for p in self.module.parameters() + ] + for p, buf in zip(self.fp16_partitioned_groups_flat, self.lp_partitioned_groups_flat_pin_buffers): + buf.copy_(p, non_blocking=non_blocking) + p.data = buf + for p, buf in zip(self.module.parameters(), self.lp_params_pin_buffers): + buf.copy_(p.ds_tensor.data, non_blocking=non_blocking) + p.ds_tensor.data = buf + else: + for p in self.fp16_partitioned_groups_flat: + p.data = p.data.to(device, non_blocking=non_blocking) + for p in self.module.parameters(): + p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking) self.offloaded_states.add(OffloadStateTypeEnum.lp_params) # LP grad if needs_offload(OffloadStateTypeEnum.lp_grads): - self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) + if pin_memory: + if not hasattr(self, "lp_grad_partitions_flat_pin_buffers"): + self.lp_grad_partitions_flat_pin_buffers = get_accelerator().pin_memory( + torch.empty_like(self.grad_partitions_flat_buffer, device=device)) + self.lp_grad_partitions_flat_pin_buffers.copy_(self.grad_partitions_flat_buffer, + non_blocking=non_blocking) + else: + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) self.averaged_gradients = {} - for _, g in self.__param_id_to_grad_partition.items(): - g.data = g.data.to(device) self.__param_id_to_grad_partition = {} self.offloaded_states.add(OffloadStateTypeEnum.lp_grads) @@ -2827,39 +2862,50 @@ def needs_offload(target): # contiguous bucket if needs_offload(OffloadStateTypeEnum.contiguous_grad_buffer): if hasattr(self, "_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer"): - self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device) + # Record properties like shape, strides, etc. as a meta tensor + self.grad_buffer_meta = self.__ipg_bucket_flat_buffer.to("meta") + self.__ipg_bucket_flat_buffer = None self.offloaded_states.add(OffloadStateTypeEnum.contiguous_grad_buffer) # Adam if needs_offload(OffloadStateTypeEnum.opt_states): - adam_states_to(self.optimizer, device) + offload_adam_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking) self.offloaded_states.add(OffloadStateTypeEnum.opt_states) gc.collect() get_accelerator().empty_cache() - def offload_states_back(self): + def offload_states_back(self, non_blocking: bool = False): device = get_accelerator().current_device_name() # HP param if OffloadStateTypeEnum.hp_params in self.offloaded_states: - for buf in self.fp32_partitioned_groups_flat: - buf.data = buf.data.to(device) + if hasattr(self, "hp_params_pin_buffers"): + for src, dest in zip(self.hp_params_pin_buffers, self.fp32_partitioned_groups_flat): + dest.data = src.to(device, non_blocking=non_blocking) + else: + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) self.offloaded_states.remove(OffloadStateTypeEnum.hp_params) # LP Param if OffloadStateTypeEnum.lp_params in self.offloaded_states: # self.lp_param_buffer.data = self.lp_param_buffer.data.to(device) for p in self.fp16_partitioned_groups_flat: - p.data = p.data.to(device) + p.data = p.data.to(device, non_blocking=non_blocking) for p in self.module.parameters(): - p.ds_tensor.data = p.ds_tensor.data.to(device) + p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking) self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) # LP grad if OffloadStateTypeEnum.lp_grads in self.offloaded_states: - self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) + if hasattr(self, "lp_grad_partitions_flat_pin_buffers"): + self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers.to( + device, non_blocking=non_blocking) + else: + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to( + device, non_blocking=non_blocking) self.averaged_gradients = {} offset = 0 @@ -2873,14 +2919,18 @@ def offload_states_back(self): # contiguous bucket if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states: - self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device) + self.__ipg_bucket_flat_buffer = torch.empty_like(self.grad_buffer_meta, device=device) + # self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device) self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer) # Adam if OffloadStateTypeEnum.opt_states in self.offloaded_states: - adam_states_to(self.optimizer, device) + offload_adam_states_back(self.optimizer, device, non_blocking=non_blocking) self.offloaded_states.remove(OffloadStateTypeEnum.opt_states) + if non_blocking: + get_accelerator().synchronize() + def _handle_overflow(cpu_sum, x, i): import math diff --git a/tests/unit/runtime/zero/test_offload_states.py b/tests/unit/runtime/zero/test_offload_states.py index 291b915fe595..6c319655f366 100644 --- a/tests/unit/runtime/zero/test_offload_states.py +++ b/tests/unit/runtime/zero/test_offload_states.py @@ -16,7 +16,7 @@ from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum -def run_model(model, config_dict, hidden_dim, dtype, include): +def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) data_loader = random_dataloader(model=model, total_samples=10, @@ -30,7 +30,10 @@ def run_model(model, config_dict, hidden_dim, dtype, include): model.step() alloc_before_offload = get_accelerator().memory_allocated() - model.offload_states(include=include, device=OffloadDeviceEnum.cpu) + model.offload_states(include=include, + device=OffloadDeviceEnum.cpu, + pin_memory=pin_memory, + non_blocking=non_blocking) alloc_after_offload = get_accelerator().memory_allocated() assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload" model.offload_states_back() @@ -45,12 +48,13 @@ def run_model(model, config_dict, hidden_dim, dtype, include): OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.lp_params, OffloadStateTypeEnum.opt_states, OffloadStateTypeEnum.lp_grads, OffloadStateTypeEnum.contiguous_grad_buffer, None ]) +@pytest.mark.parametrize("pin_memory", [False, True]) +@pytest.mark.parametrize("non_blocking", [False, True]) class TestOffloadStates(DistributedTest): # Need multiple gpus to test possible hanging world_size = 2 - reuse_dist_env = True - def test_move_buffer(self, included_state): + def test_offload_states(self, included_state, pin_memory, non_blocking): hidden_dim = 1024 config_dict = { @@ -67,7 +71,8 @@ def test_move_buffer(self, included_state): } config_dict["bf16"] = {"enabled": True} - model = SimpleModel(hidden_dim) + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim, nlayers=4) include = None if included_state is None else [included_state] - run_model(model, config_dict, hidden_dim, torch.bfloat16, include) + run_model(model, config_dict, hidden_dim, torch.bfloat16, include, pin_memory, non_blocking) From 37ffa0211cd1db92e1692fafe1f9bd4c51e5fa3d Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 20 Aug 2024 16:48:25 +0000 Subject: [PATCH 06/15] fix offloading of lp grad --- deepspeed/runtime/zero/stage3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 1e394b6f4c59..b7a6d1f4a70a 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2851,6 +2851,7 @@ def needs_offload(target): torch.empty_like(self.grad_partitions_flat_buffer, device=device)) self.lp_grad_partitions_flat_pin_buffers.copy_(self.grad_partitions_flat_buffer, non_blocking=non_blocking) + self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers else: self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) self.averaged_gradients = {} @@ -2919,6 +2920,7 @@ def offload_states_back(self, non_blocking: bool = False): # contiguous bucket if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states: + print(f"loading contiguous_grad_buffer") self.__ipg_bucket_flat_buffer = torch.empty_like(self.grad_buffer_meta, device=device) # self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device) self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer) From 93c5a9065d12f6016ff5f892ceebaabbdf3df8be Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 20 Aug 2024 17:12:18 +0000 Subject: [PATCH 07/15] add verification in test --- deepspeed/runtime/zero/stage3.py | 1 - .../unit/runtime/zero/test_offload_states.py | 30 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index b7a6d1f4a70a..c83f3bcfe9b4 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2920,7 +2920,6 @@ def offload_states_back(self, non_blocking: bool = False): # contiguous bucket if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states: - print(f"loading contiguous_grad_buffer") self.__ipg_bucket_flat_buffer = torch.empty_like(self.grad_buffer_meta, device=device) # self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device) self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer) diff --git a/tests/unit/runtime/zero/test_offload_states.py b/tests/unit/runtime/zero/test_offload_states.py index 6c319655f366..f71dbd0acbfe 100644 --- a/tests/unit/runtime/zero/test_offload_states.py +++ b/tests/unit/runtime/zero/test_offload_states.py @@ -14,6 +14,7 @@ import deepspeed from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum +from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_optimizer_state def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking): @@ -29,6 +30,13 @@ def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_bl model.backward(loss) model.step() + hp_params_expected = [safe_get_local_fp32_param(p).clone() for p in model.parameters()] + lp_params_expected = [p.ds_tensor.clone() for p in model.parameters()] + lp_grads_expected = model.optimizer.grad_partitions_flat_buffer.clone() + adam_exp_avg_expected = [safe_get_local_optimizer_state(p, "exp_avg").clone() for p in model.parameters()] + adam_exp_avg_sq = [safe_get_local_optimizer_state(p, "exp_avg_sq").clone() for p in model.parameters()] + + # Start offloading alloc_before_offload = get_accelerator().memory_allocated() model.offload_states(include=include, device=OffloadDeviceEnum.cpu, @@ -36,10 +44,32 @@ def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_bl non_blocking=non_blocking) alloc_after_offload = get_accelerator().memory_allocated() assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload" + + # Load offloaded states back model.offload_states_back() assert alloc_after_offload < get_accelerator().memory_allocated( ), f"Allocated memory should increase after offload back" + # Verify restored states + hp_param_restored = [safe_get_local_fp32_param(p) for p in model.parameters()] + for hp_param_expected, hp_param_restored in zip(hp_params_expected, hp_param_restored): + assert torch.equal(hp_param_expected, hp_param_restored) + + lp_param_restored = [p.ds_tensor for p in model.parameters()] + + for lp_param_expected, lp_param_restored in zip(lp_params_expected, lp_param_restored): + assert torch.equal(lp_param_expected, lp_param_restored) + + assert torch.equal(lp_grads_expected, model.optimizer.grad_partitions_flat_buffer) + + adam_exp_avg_restored = [safe_get_local_optimizer_state(p, "exp_avg") for p in model.parameters()] + for adam_exp_avg_expected, adam_exp_avg_restored in zip(adam_exp_avg_expected, adam_exp_avg_restored): + assert torch.equal(adam_exp_avg_expected, adam_exp_avg_restored) + + adam_exp_avg_sq_restored = [safe_get_local_optimizer_state(p, "exp_avg_sq") for p in model.parameters()] + for adam_exp_avg_sq_expected, adam_exp_avg_sq_restored in zip(adam_exp_avg_sq, adam_exp_avg_sq_restored): + assert torch.equal(adam_exp_avg_sq_expected, adam_exp_avg_sq_restored) + # Needed in ZeRO 3. Not doing so can give memory leak model.destroy() From 512e9c9fb5018279b0ee0248564357dcd988ccbe Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 20 Aug 2024 21:30:40 +0000 Subject: [PATCH 08/15] improve offloading of lp params --- deepspeed/runtime/engine.py | 2 ++ deepspeed/runtime/zero/stage3.py | 44 +++++++++++++++++--------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 77b077618aba..e12394f2b363 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3689,6 +3689,8 @@ def offload_states(self, assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3." + assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters." + if device == OffloadDeviceEnum.none: logger.warning("No device specified for offloading states.") return diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index c83f3bcfe9b4..f75043248841 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -729,15 +729,10 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): for param in sub_group: parameter_partitions.append(param.ds_tensor) - device_buffer = __class__.defragment(parameter_partitions) + # We need to keep the reference to this buffer to make sure you can free it in `offload_states` + self.lp_param_buffer = __class__.defragment(parameter_partitions) + self._set_fp16_partitioned_groups_flat() - # setup flat buffers per subgroup, these are each just sections of the - # contiguous flat buffer for all parameters that we created earlier - offset = 0 - for sub_group in self.fp16_groups: - sub_group_numel = sum(param.partition_numel() for param in sub_group) - self.fp16_partitioned_groups_flat.append(device_buffer.narrow(0, offset, sub_group_numel)) - offset += sub_group_numel else: # partitioned params offloaded to CPU when not in use # create a flat CPU memory allocation for each param group self._create_param_groups_fp16_flat_cpu_memory() @@ -1012,6 +1007,15 @@ def _partitioned_params_swap_out(self, i): swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(dst_fp16_params=swap_fp16_params, src_fp32_params=swap_fp32_params) + def _set_fp16_partitioned_groups_flat(self): + # setup flat buffers per subgroup, these are each just sections of the + # contiguous flat buffer for all parameters that we created earlier + offset = 0 + for sub_group in self.fp16_groups: + sub_group_numel = sum(param.partition_numel() for param in sub_group) + self.fp16_partitioned_groups_flat.append(self.lp_param_buffer.narrow(0, offset, sub_group_numel)) + offset += sub_group_numel + def initialize_optimizer_states(self): num_subgroups = len(self.fp16_groups) @@ -2821,26 +2825,24 @@ def needs_offload(target): # LP param if needs_offload(OffloadStateTypeEnum.lp_params): if pin_memory: - if not hasattr(self, "lp_partitioned_groups_flat_pin_buffers"): - self.lp_partitioned_groups_flat_pin_buffers = [ - get_accelerator().pin_memory(torch.empty_like(t, device=device)) - for t in self.fp16_partitioned_groups_flat - ] + if not hasattr(self, "lp_param_contiguous_pin_buffer"): + self.lp_param_contiguous_pin_buffer = torch.empty_like(self.lp_param_buffer, device=device) self.lp_params_pin_buffers = [ get_accelerator().pin_memory(torch.empty_like(p.ds_tensor, device=device)) for p in self.module.parameters() ] - for p, buf in zip(self.fp16_partitioned_groups_flat, self.lp_partitioned_groups_flat_pin_buffers): - buf.copy_(p, non_blocking=non_blocking) - p.data = buf + self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking) + self.lp_param_buffer.data = self.lp_param_contiguous_pin_buffer + for p, buf in zip(self.module.parameters(), self.lp_params_pin_buffers): buf.copy_(p.ds_tensor.data, non_blocking=non_blocking) p.ds_tensor.data = buf else: - for p in self.fp16_partitioned_groups_flat: - p.data = p.data.to(device, non_blocking=non_blocking) + self.lp_param_buffer.data = self.lp_param_buffer.to(device, non_blocking=non_blocking) for p in self.module.parameters(): p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking) + + self.fp16_partitioned_groups_flat.clear() self.offloaded_states.add(OffloadStateTypeEnum.lp_params) # LP grad @@ -2892,9 +2894,9 @@ def offload_states_back(self, non_blocking: bool = False): # LP Param if OffloadStateTypeEnum.lp_params in self.offloaded_states: - # self.lp_param_buffer.data = self.lp_param_buffer.data.to(device) - for p in self.fp16_partitioned_groups_flat: - p.data = p.data.to(device, non_blocking=non_blocking) + self.lp_param_buffer.data = self.lp_param_buffer.data.to(device, non_blocking=non_blocking) + self._set_fp16_partitioned_groups_flat() + for p in self.module.parameters(): p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking) self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) From c749b05cb71cbf9a15978b228b6a1da152960bfd Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 22 Aug 2024 07:48:47 +0000 Subject: [PATCH 09/15] fix pinning --- deepspeed/runtime/utils.py | 2 +- deepspeed/runtime/zero/stage3.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index a71731bf47b0..871fca403e4d 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1074,7 +1074,7 @@ def move_key(state, key): if pin_memory: pin_mem_key = f"{key}_pin_memory" if pin_mem_key not in state: - state[pin_mem_key] = torch.empty_like(state[key], device=device).pin_memory() + state[pin_mem_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device=device)) state[pin_mem_key].copy_(state[key], non_blocking=non_blocking) state[key].data = state[pin_mem_key] else: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index f75043248841..96f8ca65e5ea 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2826,7 +2826,8 @@ def needs_offload(target): if needs_offload(OffloadStateTypeEnum.lp_params): if pin_memory: if not hasattr(self, "lp_param_contiguous_pin_buffer"): - self.lp_param_contiguous_pin_buffer = torch.empty_like(self.lp_param_buffer, device=device) + self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory( + torch.empty_like(self.lp_param_buffer, device=device)) self.lp_params_pin_buffers = [ get_accelerator().pin_memory(torch.empty_like(p.ds_tensor, device=device)) for p in self.module.parameters() From 2a4733eedd8986b93a7906035a037b8bd558e8ec Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 3 Sep 2024 21:07:38 +0000 Subject: [PATCH 10/15] fix method name and enum key --- deepspeed/runtime/engine.py | 4 ++-- deepspeed/runtime/zero/offload_config.py | 2 +- deepspeed/runtime/zero/stage3.py | 10 +++++----- tests/unit/runtime/zero/test_offload_states.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 3aa2a829d0b5..e023db4f0bb8 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3709,7 +3709,7 @@ def offload_states(self, self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking) - def offload_states_back(self, non_blocking: bool = False) -> None: + def reload_states(self, non_blocking: bool = False) -> None: """Move the ZeRO optimizer buffers back to the original device. Arguments: @@ -3717,4 +3717,4 @@ def offload_states_back(self, non_blocking: bool = False) -> None: """ assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3." - self.optimizer.offload_states_back(non_blocking=non_blocking) + self.optimizer.reload_states(non_blocking=non_blocking) diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index 0c2dcb651cc6..ca35d7a7d169 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -102,7 +102,7 @@ def set_pipeline(self): class OffloadStateTypeEnum(str, Enum): """ Enum for internal buffer types """ - opt_states = "opt_states" + optim_states = "optim_states" hp_params = "hp_params" lp_params = "lp_params" lp_grads = "lp_grads" diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 96f8ca65e5ea..45e311efa564 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2872,14 +2872,14 @@ def needs_offload(target): self.offloaded_states.add(OffloadStateTypeEnum.contiguous_grad_buffer) # Adam - if needs_offload(OffloadStateTypeEnum.opt_states): + if needs_offload(OffloadStateTypeEnum.optim_states): offload_adam_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking) - self.offloaded_states.add(OffloadStateTypeEnum.opt_states) + self.offloaded_states.add(OffloadStateTypeEnum.optim_states) gc.collect() get_accelerator().empty_cache() - def offload_states_back(self, non_blocking: bool = False): + def reload_states(self, non_blocking: bool = False): device = get_accelerator().current_device_name() @@ -2928,9 +2928,9 @@ def offload_states_back(self, non_blocking: bool = False): self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer) # Adam - if OffloadStateTypeEnum.opt_states in self.offloaded_states: + if OffloadStateTypeEnum.optim_states in self.offloaded_states: offload_adam_states_back(self.optimizer, device, non_blocking=non_blocking) - self.offloaded_states.remove(OffloadStateTypeEnum.opt_states) + self.offloaded_states.remove(OffloadStateTypeEnum.optim_states) if non_blocking: get_accelerator().synchronize() diff --git a/tests/unit/runtime/zero/test_offload_states.py b/tests/unit/runtime/zero/test_offload_states.py index f71dbd0acbfe..c7429447e243 100644 --- a/tests/unit/runtime/zero/test_offload_states.py +++ b/tests/unit/runtime/zero/test_offload_states.py @@ -46,7 +46,7 @@ def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_bl assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload" # Load offloaded states back - model.offload_states_back() + model.reload_states() assert alloc_after_offload < get_accelerator().memory_allocated( ), f"Allocated memory should increase after offload back" @@ -75,7 +75,7 @@ def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_bl @pytest.mark.parametrize("included_state", [ - OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.lp_params, OffloadStateTypeEnum.opt_states, + OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.lp_params, OffloadStateTypeEnum.optim_states, OffloadStateTypeEnum.lp_grads, OffloadStateTypeEnum.contiguous_grad_buffer, None ]) @pytest.mark.parametrize("pin_memory", [False, True]) From e9a499ef7ebd11c9d80f89fa2346daf1d7759999 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 3 Sep 2024 23:09:58 +0000 Subject: [PATCH 11/15] elimitate duplicated buffer for lp param --- deepspeed/runtime/zero/stage3.py | 36 ++++++++++++-------------------- deepspeed/runtime/zero/utils.py | 15 ++++++++++++- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 45e311efa564..d84ac16331f8 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -23,7 +23,7 @@ from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload -from deepspeed.runtime.zero.utils import apply_to_tensors_only +from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper @@ -566,21 +566,15 @@ def defragment(tensors: List[Tensor]) -> Tensor: cpu_buffer = torch.empty(sum(p.numel() for p in tensors), dtype=get_only_unique_item(t.dtype for t in tensors), device="cpu") - tensor_infos: List[Tuple[Tensor, int, int]] = [] + tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors) orig_device = get_only_unique_item(t.device for t in tensors) offset = 0 - for tensor in tensors: - tensor_numel = tensor.numel() + for tensor, offset, tensor_numel in tensor_infos: # move the tensor from device memory to host memory cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - # record some data so we can restore the device tensor later - tensor_infos.append((tensor, offset, tensor_numel)) - - offset += tensor_numel - gc.collect() get_accelerator().empty_cache() @@ -2828,20 +2822,14 @@ def needs_offload(target): if not hasattr(self, "lp_param_contiguous_pin_buffer"): self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory( torch.empty_like(self.lp_param_buffer, device=device)) - self.lp_params_pin_buffers = [ - get_accelerator().pin_memory(torch.empty_like(p.ds_tensor, device=device)) - for p in self.module.parameters() - ] self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking) - self.lp_param_buffer.data = self.lp_param_contiguous_pin_buffer - - for p, buf in zip(self.module.parameters(), self.lp_params_pin_buffers): - buf.copy_(p.ds_tensor.data, non_blocking=non_blocking) - p.ds_tensor.data = buf + cpu_buffer = self.lp_param_contiguous_pin_buffer else: self.lp_param_buffer.data = self.lp_param_buffer.to(device, non_blocking=non_blocking) - for p in self.module.parameters(): - p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking) + cpu_buffer = self.lp_param_buffer + + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(self.module.parameters()): + tensor.data = cpu_buffer.narrow(0, offset, tensor_numel) self.fp16_partitioned_groups_flat.clear() self.offloaded_states.add(OffloadStateTypeEnum.lp_params) @@ -2895,11 +2883,13 @@ def reload_states(self, non_blocking: bool = False): # LP Param if OffloadStateTypeEnum.lp_params in self.offloaded_states: - self.lp_param_buffer.data = self.lp_param_buffer.data.to(device, non_blocking=non_blocking) + cpu_buffer = self.lp_param_contiguous_pin_buffer if hasattr( + self, "lp_param_contiguous_pin_buffer") else self.lp_param_buffer + self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking) self._set_fp16_partitioned_groups_flat() - for p in self.module.parameters(): - p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking) + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(self.module.parameters()): + tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel) self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) # LP grad diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 8f913d065934..2d1cf17962d8 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -4,7 +4,7 @@ # DeepSpeed Team import os -from typing import List +from typing import List, Tuple import torch from deepspeed import comm as dist @@ -160,3 +160,16 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None): logger.warning(warning_msg_fn(value)) warned = True return value + + +def get_mapping_to_flat_buffer(tensors: List[torch.Tensor]) -> List[Tuple[torch.Tensor, int, int]]: + tensor_infos: List[Tuple[torch.Tensor, int, int]] = [] + + offset = 0 + for tensor in tensors: + tensor_numel = tensor.numel() + # record some data so we can restore the device tensor later + tensor_infos.append((tensor, offset, tensor_numel)) + offset += tensor_numel + + return tensor_infos From b8f47c6c68282d1918134d55fbba4185064d6627 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 4 Sep 2024 00:06:11 +0000 Subject: [PATCH 12/15] simplified offloding of adam states --- deepspeed/runtime/utils.py | 25 +++++++++++++------------ deepspeed/runtime/zero/stage3.py | 4 ++-- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 871fca403e4d..adcadd349803 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1067,18 +1067,21 @@ def to_tensor(v): return total_norm +def _make_offload_state_key(key): + return f"{key}_offload_buffer" + + def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False): """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" def move_key(state, key): - if pin_memory: - pin_mem_key = f"{key}_pin_memory" - if pin_mem_key not in state: - state[pin_mem_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device=device)) - state[pin_mem_key].copy_(state[key], non_blocking=non_blocking) - state[key].data = state[pin_mem_key] - else: - state[key].data = state[key].to(device, non_blocking=non_blocking) + offload_buf_key = _make_offload_state_key(key) + if offload_buf_key not in state: + state[offload_buf_key] = torch.empty_like(state[key], device=device) + if pin_memory: + state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key]) + state[offload_buf_key].copy_(state[key], non_blocking=non_blocking) + state[key].data = state[offload_buf_key] for _, state in optimizer.state.items(): if "exp_avg" in state: @@ -1087,13 +1090,11 @@ def move_key(state, key): move_key(state, "exp_avg_sq") -def offload_adam_states_back(optimizer, device, non_blocking: bool = False): +def reload_adam_states(optimizer, device, non_blocking: bool = False): """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" def move_back_key(state, key): - pin_mem_key = f"{key}_pin_memory" - src_key = pin_mem_key if pin_mem_key in state else key - state[key].data = state[src_key].to(device, non_blocking=non_blocking) + state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking) for _, state in optimizer.state.items(): if "exp_avg" in state: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d84ac16331f8..deaa6f4d4fe6 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -18,7 +18,7 @@ from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, offload_adam_states_back +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, reload_adam_states from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum @@ -2919,7 +2919,7 @@ def reload_states(self, non_blocking: bool = False): # Adam if OffloadStateTypeEnum.optim_states in self.offloaded_states: - offload_adam_states_back(self.optimizer, device, non_blocking=non_blocking) + reload_adam_states(self.optimizer, device, non_blocking=non_blocking) self.offloaded_states.remove(OffloadStateTypeEnum.optim_states) if non_blocking: From d33807907ab1eb56b12409fe29a2f4367a80d1ea Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 4 Sep 2024 01:25:09 +0000 Subject: [PATCH 13/15] validate devcies of offload states --- deepspeed/runtime/zero/stage3.py | 10 ++++--- .../unit/runtime/zero/test_offload_states.py | 27 +++++++++++++++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index deaa6f4d4fe6..fb75d2bcebd5 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2825,10 +2825,11 @@ def needs_offload(target): self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking) cpu_buffer = self.lp_param_contiguous_pin_buffer else: - self.lp_param_buffer.data = self.lp_param_buffer.to(device, non_blocking=non_blocking) - cpu_buffer = self.lp_param_buffer + cpu_buffer = self.lp_param_buffer.to(device, non_blocking=non_blocking) - for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(self.module.parameters()): + self.lp_param_buffer.data = cpu_buffer + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer( + [p.ds_tensor for p in self.module.parameters()]): tensor.data = cpu_buffer.narrow(0, offset, tensor_numel) self.fp16_partitioned_groups_flat.clear() @@ -2888,7 +2889,8 @@ def reload_states(self, non_blocking: bool = False): self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking) self._set_fp16_partitioned_groups_flat() - for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(self.module.parameters()): + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer( + [p.ds_tensor for p in self.module.parameters()]): tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel) self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) diff --git a/tests/unit/runtime/zero/test_offload_states.py b/tests/unit/runtime/zero/test_offload_states.py index c7429447e243..cc60908d3c33 100644 --- a/tests/unit/runtime/zero/test_offload_states.py +++ b/tests/unit/runtime/zero/test_offload_states.py @@ -17,7 +17,23 @@ from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_optimizer_state +def validate_device(model, device: torch.device, include) -> None: + # Make sure the model parameters are offloaded + if include is None or OffloadStateTypeEnum.hp_params in include: + assert all(safe_get_local_fp32_param(p).device == device for p in model.parameters()) + if include is None or OffloadStateTypeEnum.lp_params in include: + assert all(p.ds_tensor.device == device for p in model.parameters()) + if include is None or OffloadStateTypeEnum.lp_grads in include: + assert model.optimizer.grad_partitions_flat_buffer.device == device + if include is None or OffloadStateTypeEnum.optim_states in include: + assert all(safe_get_local_optimizer_state(p, "exp_avg").device == device for p in model.parameters()) + assert all(safe_get_local_optimizer_state(p, "exp_avg_sq").device == device for p in model.parameters()) + + def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking): + # Currently we only support OffloadDeviceEnum.cpu + offload_device = OffloadDeviceEnum.cpu + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) data_loader = random_dataloader(model=model, total_samples=10, @@ -38,14 +54,13 @@ def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_bl # Start offloading alloc_before_offload = get_accelerator().memory_allocated() - model.offload_states(include=include, - device=OffloadDeviceEnum.cpu, - pin_memory=pin_memory, - non_blocking=non_blocking) + model.offload_states(include=include, device=offload_device, pin_memory=pin_memory, non_blocking=non_blocking) alloc_after_offload = get_accelerator().memory_allocated() assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload" - # Load offloaded states back + validate_device(model, torch.device(offload_device.value), include) + + # Reload states model.reload_states() assert alloc_after_offload < get_accelerator().memory_allocated( ), f"Allocated memory should increase after offload back" @@ -70,6 +85,8 @@ def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_bl for adam_exp_avg_sq_expected, adam_exp_avg_sq_restored in zip(adam_exp_avg_sq, adam_exp_avg_sq_restored): assert torch.equal(adam_exp_avg_sq_expected, adam_exp_avg_sq_restored) + validate_device(model, torch.device(get_accelerator().current_device_name()), include) + # Needed in ZeRO 3. Not doing so can give memory leak model.destroy() From 15ff7b384e1e9e26b5769d7bcde927f72c1874a8 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 4 Sep 2024 23:01:49 +0000 Subject: [PATCH 14/15] add document --- deepspeed/runtime/engine.py | 6 ++-- docs/code-docs/source/zero3.rst | 53 +++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e023db4f0bb8..b590ea432658 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3687,11 +3687,11 @@ def offload_states(self, device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, pin_memory: bool = True, non_blocking: bool = False) -> None: - """Move the ZeRO optimizer buffers to the specified device. + """Offload the engine's states to the specified device. Arguments: include: Optional. The set of states to offload. If not provided, all states are offloaded. - device: Optional. The device to move the ZeRO optimizer buffers to. + device: Optional. The device to move the ZeRO optimizer buffers to. Currently only `OffloadDeviceEnum.cpu` is supported. pin_memory: Optional. Whether to pin the memory of the offloaded states. non_blocking: Optional. Whether to offload the states asynchronously. """ @@ -3710,7 +3710,7 @@ def offload_states(self, self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking) def reload_states(self, non_blocking: bool = False) -> None: - """Move the ZeRO optimizer buffers back to the original device. + """Reload the engine states to the original device. Arguments: non_blocking: Optional. Whether to offload the states asynchronously. diff --git a/docs/code-docs/source/zero3.rst b/docs/code-docs/source/zero3.rst index 2a6a48ca91db..9e56bec54f43 100644 --- a/docs/code-docs/source/zero3.rst +++ b/docs/code-docs/source/zero3.rst @@ -456,3 +456,56 @@ The following code snippet illustrates this functionality. # Free GPU memory consumed by model parameters ds_engine.empty_partition_cache() + + +Offload States +-------------- + +The DeepSpeed engine maintains a set of states in device memory (e.g., CUDA memory). The following API allows you to offload these states to a different device (currently, only CPU memory is supported), reducing the memory footprint on the device. + +.. code-block:: python + + def offload_states(self, + include: Container[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False) -> None: + """Offload the engine's states to the specified device. + + Arguments: + include: Optional. The set of states to offload. If not provided, all states are offloaded. + device: Optional. The device to move the ZeRO optimizer buffers to. Currently only `OffloadDeviceEnum.cpu` is supported. + pin_memory: Optional. Whether to pin the memory of the offloaded states. + non_blocking: Optional. Whether to offload the states asynchronously. + """ + +You can selectively offload specific states by specifying the ``OffloadStateTypeEnum`` in the include argument. ``OffloadStateTypeEnum`` is an enum that defines the states that can be offloaded. The following states are supported: + +* ``OffloadStateTypeEnum.optim_states``: Optimizer states. Currently, only states of DeepSpeed's FusedAdam optimizer are supported. +* ``OffloadStateTypeEnum.hp_params``: FP32 parameters. +* ``OffloadStateTypeEnum.lp_params``: BF16/FP16 parameters. +* ``OffloadStateTypeEnum.lp_grads``: BF16/FP16 gradients. +* ``OffloadStateTypeEnum.contiguous_grad_buffer``: The contiguous gradient buffer for reduce operations. + +Note that offloading states comes with a trade-off between memory savings and computational overhead. This API allows states to be reloaded back into device memory when needed. + +.. code-block:: python + + def reload_states(self, non_blocking: bool = False) -> None: + """Reload the engine states to the original device. + + Arguments: + non_blocking: Optional. Whether to offload the states asynchronously. + """ + +Below is an example code snippet demonstrating how to offload FP32 parameters and optimizer states to CPU memory: + +.. code-block:: python + + # Offload after forward, backward, and step + ds_engine.offload_states(include=[OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.optim_states]) + + # Do something requiring a lot of device memory + ... + # Load states back to device memory + model.reload_states() From e20d827fa4b29865e269253f679685f5d4cef332 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 4 Sep 2024 23:11:54 +0000 Subject: [PATCH 15/15] fix usage example --- docs/code-docs/source/zero3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/code-docs/source/zero3.rst b/docs/code-docs/source/zero3.rst index 9e56bec54f43..f0974c08c9f3 100644 --- a/docs/code-docs/source/zero3.rst +++ b/docs/code-docs/source/zero3.rst @@ -508,4 +508,4 @@ Below is an example code snippet demonstrating how to offload FP32 parameters an # Do something requiring a lot of device memory ... # Load states back to device memory - model.reload_states() + ds_engine.reload_states()