Skip to content

Commit

Permalink
Merge branch 'master' into zhejiang/fix_runtime_dataloader_shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Sep 27, 2024
2 parents a946e18 + 047bcf6 commit d40ba2e
Show file tree
Hide file tree
Showing 7 changed files with 443 additions and 25 deletions.
41 changes: 39 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from typing import Callable, Dict, Union, Iterable
from typing import Callable, Dict, Union, Iterable, Container

import deepspeed

from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim
from .zero.offload_config import OffloadDeviceEnum
from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
Expand Down Expand Up @@ -3681,3 +3681,40 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg
@property
def is_compiled(self) -> bool:
return self._is_compiled

def offload_states(self,
include: Container[OffloadStateTypeEnum] = None,
device: OffloadDeviceEnum = OffloadDeviceEnum.cpu,
pin_memory: bool = True,
non_blocking: bool = False) -> None:
"""Offload the engine's states to the specified device.
Arguments:
include: Optional. The set of states to offload. If not provided, all states are offloaded.
device: Optional. The device to move the ZeRO optimizer buffers to. Currently only `OffloadDeviceEnum.cpu` is supported.
pin_memory: Optional. Whether to pin the memory of the offloaded states.
non_blocking: Optional. Whether to offload the states asynchronously.
"""
assert self.zero_optimization_stage(
) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3."

assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters."

if device == OffloadDeviceEnum.none:
logger.warning("No device specified for offloading states.")
return

if device == OffloadDeviceEnum.nvme:
raise ValueError("NVMe offload is not supported for offloading states.")

self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking)

def reload_states(self, non_blocking: bool = False) -> None:
"""Reload the engine states to the original device.
Arguments:
non_blocking: Optional. Whether to offload the states asynchronously.
"""
assert self.zero_optimization_stage(
) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3."
self.optimizer.reload_states(non_blocking=non_blocking)
36 changes: 36 additions & 0 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,3 +1065,39 @@ def to_tensor(v):
total_norm = -1

return total_norm


def _make_offload_state_key(key):
return f"{key}_offload_buffer"


def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_key(state, key):
offload_buf_key = _make_offload_state_key(key)
if offload_buf_key not in state:
state[offload_buf_key] = torch.empty_like(state[key], device=device)
if pin_memory:
state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key])
state[offload_buf_key].copy_(state[key], non_blocking=non_blocking)
state[key].data = state[offload_buf_key]

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_key(state, "exp_avg_sq")


def reload_adam_states(optimizer, device, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_back_key(state, key):
state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking)

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_back_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_back_key(state, "exp_avg_sq")
9 changes: 9 additions & 0 deletions deepspeed/runtime/zero/offload_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,12 @@ def set_pipeline(self):
pipeline = self.pipeline_read or self.pipeline_write
self.__dict__["pipeline"] = pipeline
return self


class OffloadStateTypeEnum(str, Enum):
""" Enum for internal buffer types """
optim_states = "optim_states"
hp_params = "hp_params"
lp_params = "lp_params"
lp_grads = "lp_grads"
contiguous_grad_buffer = "contiguous_grad_buffer"
189 changes: 167 additions & 22 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,31 @@
import sys
import gc
import collections
from typing import Deque, Dict, Tuple
import itertools
from typing import Deque, Dict, Set, Tuple, Container
from contextlib import contextmanager

from deepspeed import comm as dist
from deepspeed.utils import groups
from deepspeed.utils import groups, z3_leaf_parameter

from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, reload_adam_states
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.utils import apply_to_tensors_only
from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper
from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper
from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import z3_leaf_parameter

# Toggle this to true to enable correctness test
# with gradient partitioning and without
Expand Down Expand Up @@ -425,6 +426,8 @@ def __init__(

self._link_all_hp_params()

self.offloaded_states: Set(OffloadDeviceEnum) = set()

if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer", force=True)

Expand Down Expand Up @@ -563,21 +566,15 @@ def defragment(tensors: List[Tensor]) -> Tensor:
cpu_buffer = torch.empty(sum(p.numel() for p in tensors),
dtype=get_only_unique_item(t.dtype for t in tensors),
device="cpu")
tensor_infos: List[Tuple[Tensor, int, int]] = []
tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors)
orig_device = get_only_unique_item(t.device for t in tensors)

offset = 0
for tensor in tensors:
tensor_numel = tensor.numel()
for tensor, offset, tensor_numel in tensor_infos:
# move the tensor from device memory to host memory
cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor)
tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device)

# record some data so we can restore the device tensor later
tensor_infos.append((tensor, offset, tensor_numel))

offset += tensor_numel

gc.collect()
get_accelerator().empty_cache()

Expand Down Expand Up @@ -725,15 +722,11 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
for sub_group in self.fp16_groups:
for param in sub_group:
parameter_partitions.append(param.ds_tensor)
device_buffer = __class__.defragment(parameter_partitions)

# setup flat buffers per subgroup, these are each just sections of the
# contiguous flat buffer for all parameters that we created earlier
offset = 0
for sub_group in self.fp16_groups:
sub_group_numel = sum(param.partition_numel() for param in sub_group)
self.fp16_partitioned_groups_flat.append(device_buffer.narrow(0, offset, sub_group_numel))
offset += sub_group_numel
# We need to keep the reference to this buffer to make sure you can free it in `offload_states`
self.lp_param_buffer = __class__.defragment(parameter_partitions)
self._set_fp16_partitioned_groups_flat()

else: # partitioned params offloaded to CPU when not in use
# create a flat CPU memory allocation for each param group
self._create_param_groups_fp16_flat_cpu_memory()
Expand Down Expand Up @@ -1008,6 +1001,15 @@ def _partitioned_params_swap_out(self, i):
swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(dst_fp16_params=swap_fp16_params,
src_fp32_params=swap_fp32_params)

def _set_fp16_partitioned_groups_flat(self):
# setup flat buffers per subgroup, these are each just sections of the
# contiguous flat buffer for all parameters that we created earlier
offset = 0
for sub_group in self.fp16_groups:
sub_group_numel = sum(param.partition_numel() for param in sub_group)
self.fp16_partitioned_groups_flat.append(self.lp_param_buffer.narrow(0, offset, sub_group_numel))
offset += sub_group_numel

def initialize_optimizer_states(self):
num_subgroups = len(self.fp16_groups)

Expand Down Expand Up @@ -2782,6 +2784,149 @@ def checkpoint_event_epilogue(self):
def empty_partition_cache(self):
self.parameter_offload.empty_partition_cache()

def offload_states(self,
include: Container[OffloadStateTypeEnum] = None,
device: OffloadDeviceEnum = OffloadDeviceEnum.cpu,
pin_memory: bool = True,
non_blocking: bool = False):
device = device.value

self.empty_partition_cache()

assert self.optimizer.__class__ == deepspeed.ops.adam.fused_adam.FusedAdam, f"Offloading is supported only for DeepSpeed FusedAdam."

def needs_offload(target):
# return True
return target not in self.offloaded_states and (include == None or target in include)

# HP param
if needs_offload(OffloadStateTypeEnum.hp_params):
if pin_memory:
if not hasattr(self, "hp_params_pin_buffers"):
self.hp_params_pin_buffers = [
get_accelerator().pin_memory(torch.empty_like(t, device=device))
for t in self.fp32_partitioned_groups_flat
]

for src_tensor, dest_buf in zip(self.fp32_partitioned_groups_flat, self.hp_params_pin_buffers):
dest_buf.copy_(src_tensor, non_blocking=non_blocking)
src_tensor.data = dest_buf
else:
for buf in self.fp32_partitioned_groups_flat:
buf.data = buf.data.to(device, non_blocking=non_blocking)
self.offloaded_states.add(OffloadStateTypeEnum.hp_params)

# LP param
if needs_offload(OffloadStateTypeEnum.lp_params):
if pin_memory:
if not hasattr(self, "lp_param_contiguous_pin_buffer"):
self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory(
torch.empty_like(self.lp_param_buffer, device=device))
self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking)
cpu_buffer = self.lp_param_contiguous_pin_buffer
else:
cpu_buffer = self.lp_param_buffer.to(device, non_blocking=non_blocking)

self.lp_param_buffer.data = cpu_buffer
for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(
[p.ds_tensor for p in self.module.parameters()]):
tensor.data = cpu_buffer.narrow(0, offset, tensor_numel)

self.fp16_partitioned_groups_flat.clear()
self.offloaded_states.add(OffloadStateTypeEnum.lp_params)

# LP grad
if needs_offload(OffloadStateTypeEnum.lp_grads):
if pin_memory:
if not hasattr(self, "lp_grad_partitions_flat_pin_buffers"):
self.lp_grad_partitions_flat_pin_buffers = get_accelerator().pin_memory(
torch.empty_like(self.grad_partitions_flat_buffer, device=device))
self.lp_grad_partitions_flat_pin_buffers.copy_(self.grad_partitions_flat_buffer,
non_blocking=non_blocking)
self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers
else:
self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device)
self.averaged_gradients = {}

self.__param_id_to_grad_partition = {}

self.offloaded_states.add(OffloadStateTypeEnum.lp_grads)

# contiguous bucket
if needs_offload(OffloadStateTypeEnum.contiguous_grad_buffer):
if hasattr(self, "_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer"):
# Record properties like shape, strides, etc. as a meta tensor
self.grad_buffer_meta = self.__ipg_bucket_flat_buffer.to("meta")
self.__ipg_bucket_flat_buffer = None
self.offloaded_states.add(OffloadStateTypeEnum.contiguous_grad_buffer)

# Adam
if needs_offload(OffloadStateTypeEnum.optim_states):
offload_adam_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking)
self.offloaded_states.add(OffloadStateTypeEnum.optim_states)

gc.collect()
get_accelerator().empty_cache()

def reload_states(self, non_blocking: bool = False):

device = get_accelerator().current_device_name()

# HP param
if OffloadStateTypeEnum.hp_params in self.offloaded_states:
if hasattr(self, "hp_params_pin_buffers"):
for src, dest in zip(self.hp_params_pin_buffers, self.fp32_partitioned_groups_flat):
dest.data = src.to(device, non_blocking=non_blocking)
else:
for buf in self.fp32_partitioned_groups_flat:
buf.data = buf.data.to(device, non_blocking=non_blocking)
self.offloaded_states.remove(OffloadStateTypeEnum.hp_params)

# LP Param
if OffloadStateTypeEnum.lp_params in self.offloaded_states:
cpu_buffer = self.lp_param_contiguous_pin_buffer if hasattr(
self, "lp_param_contiguous_pin_buffer") else self.lp_param_buffer
self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking)
self._set_fp16_partitioned_groups_flat()

for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(
[p.ds_tensor for p in self.module.parameters()]):
tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel)
self.offloaded_states.remove(OffloadStateTypeEnum.lp_params)

# LP grad
if OffloadStateTypeEnum.lp_grads in self.offloaded_states:
if hasattr(self, "lp_grad_partitions_flat_pin_buffers"):
self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers.to(
device, non_blocking=non_blocking)
else:
self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(
device, non_blocking=non_blocking)
self.averaged_gradients = {}

offset = 0
all_params = list(itertools.chain.from_iterable(self.fp16_groups))
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()

self.offloaded_states.remove(OffloadStateTypeEnum.lp_grads)

# contiguous bucket
if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states:
self.__ipg_bucket_flat_buffer = torch.empty_like(self.grad_buffer_meta, device=device)
# self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device)
self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer)

# Adam
if OffloadStateTypeEnum.optim_states in self.offloaded_states:
reload_adam_states(self.optimizer, device, non_blocking=non_blocking)
self.offloaded_states.remove(OffloadStateTypeEnum.optim_states)

if non_blocking:
get_accelerator().synchronize()


def _handle_overflow(cpu_sum, x, i):
import math
Expand Down
15 changes: 14 additions & 1 deletion deepspeed/runtime/zero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team

import os
from typing import List
from typing import List, Tuple

import torch
from deepspeed import comm as dist
Expand Down Expand Up @@ -160,3 +160,16 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None):
logger.warning(warning_msg_fn(value))
warned = True
return value


def get_mapping_to_flat_buffer(tensors: List[torch.Tensor]) -> List[Tuple[torch.Tensor, int, int]]:
tensor_infos: List[Tuple[torch.Tensor, int, int]] = []

offset = 0
for tensor in tensors:
tensor_numel = tensor.numel()
# record some data so we can restore the device tensor later
tensor_infos.append((tensor, offset, tensor_numel))
offset += tensor_numel

return tensor_infos
Loading

0 comments on commit d40ba2e

Please sign in to comment.