Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add APIs to offload states of model, optimizer, and engine #6011

Merged
merged 28 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5825104
add apis to offload states of model, optimizer, and engine
tohtana Aug 16, 2024
600c822
update api doc
tohtana Aug 16, 2024
153a482
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 16, 2024
126d9b7
reduce global reference to buffer
tohtana Aug 16, 2024
05df37c
loosen type hint
tohtana Aug 16, 2024
837c06c
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 19, 2024
3f8179d
add option for pin_memory and non blocking copy
tohtana Aug 20, 2024
37ffa02
fix offloading of lp grad
tohtana Aug 20, 2024
93c5a90
add verification in test
tohtana Aug 20, 2024
512e9c9
improve offloading of lp params
tohtana Aug 20, 2024
de2a894
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 21, 2024
c749b05
fix pinning
tohtana Aug 22, 2024
36d6e10
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 22, 2024
af95a37
resolve conflict
tohtana Aug 28, 2024
1ca3a7f
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 3, 2024
2a4733e
fix method name and enum key
tohtana Sep 3, 2024
e9a499e
elimitate duplicated buffer for lp param
tohtana Sep 3, 2024
b8f47c6
simplified offloding of adam states
tohtana Sep 4, 2024
d338079
validate devcies of offload states
tohtana Sep 4, 2024
15ff7b3
add document
tohtana Sep 4, 2024
40427c1
Merge branch 'master' into tohtana/offload_zero_buffers
loadams Sep 4, 2024
e20d827
fix usage example
tohtana Sep 4, 2024
031464d
Merge branch 'tohtana/offload_zero_buffers' of github.com:microsoft/D…
tohtana Sep 4, 2024
60deaf1
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 6, 2024
3f001b6
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 9, 2024
8f81634
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 12, 2024
ff8e1e9
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 21, 2024
02a6a18
Merge branch 'master' into tohtana/offload_zero_buffers
tjruwase Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from typing import Callable, Dict, Union, Iterable
from typing import Callable, Dict, Union, Iterable, Container

import deepspeed

from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim
from .zero.offload_config import OffloadDeviceEnum
from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
Expand Down Expand Up @@ -3681,3 +3681,40 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg
@property
def is_compiled(self) -> bool:
return self._is_compiled

def offload_states(self,
include: Container[OffloadStateTypeEnum] = None,
device: OffloadDeviceEnum = OffloadDeviceEnum.cpu,
pin_memory: bool = True,
non_blocking: bool = False) -> None:
"""Move the ZeRO optimizer buffers to the specified device.

Arguments:
include: Optional. The set of states to offload. If not provided, all states are offloaded.
device: Optional. The device to move the ZeRO optimizer buffers to.
pin_memory: Optional. Whether to pin the memory of the offloaded states.
non_blocking: Optional. Whether to offload the states asynchronously.
"""
assert self.zero_optimization_stage(
) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3."

assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters."

if device == OffloadDeviceEnum.none:
logger.warning("No device specified for offloading states.")
return

if device == OffloadDeviceEnum.nvme:
raise ValueError("NVMe offload is not supported for offloading states.")

self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking)

def reload_states(self, non_blocking: bool = False) -> None:
"""Move the ZeRO optimizer buffers back to the original device.

Arguments:
non_blocking: Optional. Whether to offload the states asynchronously.
"""
assert self.zero_optimization_stage(
) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3."
self.optimizer.reload_states(non_blocking=non_blocking)
35 changes: 35 additions & 0 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,3 +1065,38 @@ def to_tensor(v):
total_norm = -1

return total_norm


def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_key(state, key):
if pin_memory:
pin_mem_key = f"{key}_pin_memory"
tohtana marked this conversation as resolved.
Show resolved Hide resolved
if pin_mem_key not in state:
state[pin_mem_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device=device))
state[pin_mem_key].copy_(state[key], non_blocking=non_blocking)
state[key].data = state[pin_mem_key]
else:
state[key].data = state[key].to(device, non_blocking=non_blocking)

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_key(state, "exp_avg_sq")


def offload_adam_states_back(optimizer, device, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_back_key(state, key):
pin_mem_key = f"{key}_pin_memory"
src_key = pin_mem_key if pin_mem_key in state else key
state[key].data = state[src_key].to(device, non_blocking=non_blocking)

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_back_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_back_key(state, "exp_avg_sq")
9 changes: 9 additions & 0 deletions deepspeed/runtime/zero/offload_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,12 @@ def set_pipeline(self):
pipeline = self.pipeline_read or self.pipeline_write
self.__dict__["pipeline"] = pipeline
return self


class OffloadStateTypeEnum(str, Enum):
""" Enum for internal buffer types """
optim_states = "optim_states"
hp_params = "hp_params"
lp_params = "lp_params"
lp_grads = "lp_grads"
contiguous_grad_buffer = "contiguous_grad_buffer"
179 changes: 166 additions & 13 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@
import sys
import gc
import collections
from typing import Deque, Dict, Tuple
import itertools
from typing import Deque, Dict, Set, Tuple, Container
from contextlib import contextmanager

from deepspeed import comm as dist
from deepspeed.utils import groups
from deepspeed.utils import groups, z3_leaf_parameter

from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, offload_adam_states_back
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.utils import apply_to_tensors_only
from deepspeed.ops.adam import DeepSpeedCPUAdam
Expand All @@ -29,7 +31,6 @@
from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import z3_leaf_parameter

# Toggle this to true to enable correctness test
# with gradient partitioning and without
Expand Down Expand Up @@ -425,6 +426,8 @@ def __init__(

self._link_all_hp_params()

self.offloaded_states: Set(OffloadDeviceEnum) = set()

if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer", force=True)

Expand Down Expand Up @@ -725,15 +728,11 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
for sub_group in self.fp16_groups:
for param in sub_group:
parameter_partitions.append(param.ds_tensor)
device_buffer = __class__.defragment(parameter_partitions)

# setup flat buffers per subgroup, these are each just sections of the
# contiguous flat buffer for all parameters that we created earlier
offset = 0
for sub_group in self.fp16_groups:
sub_group_numel = sum(param.partition_numel() for param in sub_group)
self.fp16_partitioned_groups_flat.append(device_buffer.narrow(0, offset, sub_group_numel))
offset += sub_group_numel
# We need to keep the reference to this buffer to make sure you can free it in `offload_states`
self.lp_param_buffer = __class__.defragment(parameter_partitions)
self._set_fp16_partitioned_groups_flat()

else: # partitioned params offloaded to CPU when not in use
# create a flat CPU memory allocation for each param group
self._create_param_groups_fp16_flat_cpu_memory()
Expand Down Expand Up @@ -1008,6 +1007,15 @@ def _partitioned_params_swap_out(self, i):
swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(dst_fp16_params=swap_fp16_params,
src_fp32_params=swap_fp32_params)

def _set_fp16_partitioned_groups_flat(self):
# setup flat buffers per subgroup, these are each just sections of the
# contiguous flat buffer for all parameters that we created earlier
offset = 0
for sub_group in self.fp16_groups:
sub_group_numel = sum(param.partition_numel() for param in sub_group)
self.fp16_partitioned_groups_flat.append(self.lp_param_buffer.narrow(0, offset, sub_group_numel))
offset += sub_group_numel

def initialize_optimizer_states(self):
num_subgroups = len(self.fp16_groups)

Expand Down Expand Up @@ -2782,6 +2790,151 @@ def checkpoint_event_epilogue(self):
def empty_partition_cache(self):
self.parameter_offload.empty_partition_cache()

def offload_states(self,
include: Container[OffloadStateTypeEnum] = None,
device: OffloadDeviceEnum = OffloadDeviceEnum.cpu,
pin_memory: bool = True,
non_blocking: bool = False):
device = device.value

self.empty_partition_cache()

assert self.optimizer.__class__ == deepspeed.ops.adam.fused_adam.FusedAdam, f"Offloading is supported only for DeepSpeed FusedAdam."

def needs_offload(target):
# return True
return target not in self.offloaded_states and (include == None or target in include)

# HP param
if needs_offload(OffloadStateTypeEnum.hp_params):
if pin_memory:
if not hasattr(self, "hp_params_pin_buffers"):
self.hp_params_pin_buffers = [
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
get_accelerator().pin_memory(torch.empty_like(t, device=device))
for t in self.fp32_partitioned_groups_flat
]

for src_tensor, dest_buf in zip(self.fp32_partitioned_groups_flat, self.hp_params_pin_buffers):
dest_buf.copy_(src_tensor, non_blocking=non_blocking)
src_tensor.data = dest_buf
else:
for buf in self.fp32_partitioned_groups_flat:
buf.data = buf.data.to(device, non_blocking=non_blocking)
self.offloaded_states.add(OffloadStateTypeEnum.hp_params)

# LP param
if needs_offload(OffloadStateTypeEnum.lp_params):
if pin_memory:
if not hasattr(self, "lp_param_contiguous_pin_buffer"):
self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory(
tohtana marked this conversation as resolved.
Show resolved Hide resolved
torch.empty_like(self.lp_param_buffer, device=device))
self.lp_params_pin_buffers = [
get_accelerator().pin_memory(torch.empty_like(p.ds_tensor, device=device))
for p in self.module.parameters()
]
self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking)
self.lp_param_buffer.data = self.lp_param_contiguous_pin_buffer

for p, buf in zip(self.module.parameters(), self.lp_params_pin_buffers):
buf.copy_(p.ds_tensor.data, non_blocking=non_blocking)
p.ds_tensor.data = buf
else:
self.lp_param_buffer.data = self.lp_param_buffer.to(device, non_blocking=non_blocking)
for p in self.module.parameters():
p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking)

self.fp16_partitioned_groups_flat.clear()
self.offloaded_states.add(OffloadStateTypeEnum.lp_params)

# LP grad
if needs_offload(OffloadStateTypeEnum.lp_grads):
if pin_memory:
if not hasattr(self, "lp_grad_partitions_flat_pin_buffers"):
self.lp_grad_partitions_flat_pin_buffers = get_accelerator().pin_memory(
torch.empty_like(self.grad_partitions_flat_buffer, device=device))
self.lp_grad_partitions_flat_pin_buffers.copy_(self.grad_partitions_flat_buffer,
non_blocking=non_blocking)
self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers
else:
self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device)
self.averaged_gradients = {}

self.__param_id_to_grad_partition = {}

self.offloaded_states.add(OffloadStateTypeEnum.lp_grads)

# contiguous bucket
if needs_offload(OffloadStateTypeEnum.contiguous_grad_buffer):
if hasattr(self, "_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer"):
# Record properties like shape, strides, etc. as a meta tensor
self.grad_buffer_meta = self.__ipg_bucket_flat_buffer.to("meta")
self.__ipg_bucket_flat_buffer = None
self.offloaded_states.add(OffloadStateTypeEnum.contiguous_grad_buffer)

# Adam
if needs_offload(OffloadStateTypeEnum.optim_states):
offload_adam_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking)
self.offloaded_states.add(OffloadStateTypeEnum.optim_states)

gc.collect()
get_accelerator().empty_cache()

def reload_states(self, non_blocking: bool = False):

device = get_accelerator().current_device_name()

# HP param
if OffloadStateTypeEnum.hp_params in self.offloaded_states:
if hasattr(self, "hp_params_pin_buffers"):
for src, dest in zip(self.hp_params_pin_buffers, self.fp32_partitioned_groups_flat):
dest.data = src.to(device, non_blocking=non_blocking)
else:
for buf in self.fp32_partitioned_groups_flat:
buf.data = buf.data.to(device, non_blocking=non_blocking)
self.offloaded_states.remove(OffloadStateTypeEnum.hp_params)

# LP Param
if OffloadStateTypeEnum.lp_params in self.offloaded_states:
self.lp_param_buffer.data = self.lp_param_buffer.data.to(device, non_blocking=non_blocking)
tohtana marked this conversation as resolved.
Show resolved Hide resolved
self._set_fp16_partitioned_groups_flat()

for p in self.module.parameters():
p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking)
self.offloaded_states.remove(OffloadStateTypeEnum.lp_params)

# LP grad
if OffloadStateTypeEnum.lp_grads in self.offloaded_states:
if hasattr(self, "lp_grad_partitions_flat_pin_buffers"):
self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers.to(
device, non_blocking=non_blocking)
else:
self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(
device, non_blocking=non_blocking)
self.averaged_gradients = {}

offset = 0
all_params = list(itertools.chain.from_iterable(self.fp16_groups))
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()

self.offloaded_states.remove(OffloadStateTypeEnum.lp_grads)

# contiguous bucket
if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states:
self.__ipg_bucket_flat_buffer = torch.empty_like(self.grad_buffer_meta, device=device)
# self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device)
self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer)

# Adam
if OffloadStateTypeEnum.optim_states in self.offloaded_states:
offload_adam_states_back(self.optimizer, device, non_blocking=non_blocking)
self.offloaded_states.remove(OffloadStateTypeEnum.optim_states)

if non_blocking:
get_accelerator().synchronize()


def _handle_overflow(cpu_sum, x, i):
import math
Expand Down
Loading
Loading