Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to get devices of offload states #6586

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5825104
add apis to offload states of model, optimizer, and engine
tohtana Aug 16, 2024
600c822
update api doc
tohtana Aug 16, 2024
153a482
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 16, 2024
126d9b7
reduce global reference to buffer
tohtana Aug 16, 2024
05df37c
loosen type hint
tohtana Aug 16, 2024
837c06c
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 19, 2024
3f8179d
add option for pin_memory and non blocking copy
tohtana Aug 20, 2024
37ffa02
fix offloading of lp grad
tohtana Aug 20, 2024
93c5a90
add verification in test
tohtana Aug 20, 2024
512e9c9
improve offloading of lp params
tohtana Aug 20, 2024
de2a894
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 21, 2024
c749b05
fix pinning
tohtana Aug 22, 2024
36d6e10
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 22, 2024
af95a37
resolve conflict
tohtana Aug 28, 2024
1ca3a7f
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 3, 2024
2a4733e
fix method name and enum key
tohtana Sep 3, 2024
e9a499e
elimitate duplicated buffer for lp param
tohtana Sep 3, 2024
b8f47c6
simplified offloding of adam states
tohtana Sep 4, 2024
d338079
validate devcies of offload states
tohtana Sep 4, 2024
15ff7b3
add document
tohtana Sep 4, 2024
40427c1
Merge branch 'master' into tohtana/offload_zero_buffers
loadams Sep 4, 2024
e20d827
fix usage example
tohtana Sep 4, 2024
031464d
Merge branch 'tohtana/offload_zero_buffers' of github.com:microsoft/D…
tohtana Sep 4, 2024
60deaf1
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 6, 2024
3f001b6
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 9, 2024
8f81634
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 12, 2024
ff8e1e9
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 21, 2024
02a6a18
Merge branch 'master' into tohtana/offload_zero_buffers
tjruwase Sep 26, 2024
703333c
refactor to extract API to get devices of offload states
tohtana Sep 27, 2024
5ecc2c9
Merge branch 'master'
tohtana Sep 27, 2024
338e4e0
fix import and add license
tohtana Sep 27, 2024
7102193
fix document
tohtana Sep 27, 2024
fdcc0a4
Merge branch 'master' into tohtana/get_offload_state_api
tjruwase Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,28 @@
"""

from collections.abc import Iterable
from deepspeed.moe.utils import is_moe_param
import os
import psutil
import gc
from math import sqrt

from numpy import prod

import torch
from deepspeed import comm as dist
from torch.nn import functional as F
try:
from torch._six import inf
except ModuleNotFoundError:
from torch import inf

from deepspeed import comm as dist
from deepspeed.moe.utils import is_moe_param
from deepspeed.utils import groups, logger
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size,
bwc_pipeline_parallel_group)
from deepspeed.runtime.constants import PIPE_REPLICATED
from numpy import prod
from deepspeed.accelerator import get_accelerator

from deepspeed.module_inject.policy import transpose
from torch.nn import functional as F

torch_memory_reserved = get_accelerator().memory_reserved
torch_max_memory_reserved = get_accelerator().max_memory_reserved
Expand Down
74 changes: 74 additions & 0 deletions deepspeed/runtime/zero/offload_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Set
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.offload_config import OffloadStateTypeEnum

from deepspeed.utils.tensor_fragment import safe_get_local_fp32_param, safe_get_local_optimizer_state


def _make_offload_state_key(key):
return f"{key}_offload_buffer"


def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_key(state, key):
offload_buf_key = _make_offload_state_key(key)
if offload_buf_key not in state:
state[offload_buf_key] = torch.empty_like(state[key], device=device)
if pin_memory:
state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key])
state[offload_buf_key].copy_(state[key], non_blocking=non_blocking)
state[key].data = state[offload_buf_key]

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_key(state, "exp_avg_sq")


def reload_adam_states(optimizer, device, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_back_key(state, key):
state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking)

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_back_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_back_key(state, "exp_avg_sq")


def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]:
"""Retrieve the devices of the specified state of the model.

Args:
model (DeepSpeedEngine): The model whose device allocations are to be checked.
state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved.

Returns:
Set[torch.device]: A set of devices of the specified state.

"""
if state == OffloadStateTypeEnum.hp_params:
return set(safe_get_local_fp32_param(p).device for p in model.parameters())
elif state == OffloadStateTypeEnum.lp_params:
return set(p.ds_tensor.device for p in model.parameters())
elif state == OffloadStateTypeEnum.lp_grads:
return {model.optimizer.grad_partitions_flat_buffer.device}
elif state == OffloadStateTypeEnum.optim_states:
return set(safe_get_local_optimizer_state(p, "exp_avg").device for p in model.parameters()) | \
set(safe_get_local_optimizer_state(p, "exp_avg_sq").device for p in model.parameters())
elif state == OffloadStateTypeEnum.contiguous_grad_buffer:
if model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer == None:
return {}
return {model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer.device}
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, reload_adam_states
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper
Expand Down
16 changes: 16 additions & 0 deletions docs/code-docs/source/zero3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,19 @@ Below is an example code snippet demonstrating how to offload FP32 parameters an
...
# Load states back to device memory
ds_engine.reload_states()

``deepspeed.runtime.zero.offload_states.get_state_devices`` returns devices of the specified state.

.. code-block:: python

def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]:
"""Retrieve the devices of the specified state of the model.

Args:
model (DeepSpeedEngine): The model whose device allocations are to be checked.
state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved.

Returns:
Set[torch.device]: A set of devices of the specified state.

"""
23 changes: 13 additions & 10 deletions tests/unit/runtime/zero/test_offload_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@
import deepspeed
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_optimizer_state
from deepspeed.runtime.zero.offload_states import get_state_devices


def validate_device(model, device: torch.device, include) -> None:
# Make sure the model parameters are offloaded
if include is None or OffloadStateTypeEnum.hp_params in include:
assert all(safe_get_local_fp32_param(p).device == device for p in model.parameters())
if include is None or OffloadStateTypeEnum.lp_params in include:
assert all(p.ds_tensor.device == device for p in model.parameters())
if include is None or OffloadStateTypeEnum.lp_grads in include:
assert model.optimizer.grad_partitions_flat_buffer.device == device
if include is None or OffloadStateTypeEnum.optim_states in include:
assert all(safe_get_local_optimizer_state(p, "exp_avg").device == device for p in model.parameters())
assert all(safe_get_local_optimizer_state(p, "exp_avg_sq").device == device for p in model.parameters())

def compare_device(state) -> bool:
devices = get_state_devices(model, state)
return len(devices) == 1 and device in devices

for state in OffloadStateTypeEnum:
if include is None or state in include:
if state == OffloadStateTypeEnum.contiguous_grad_buffer and device == torch.device("cpu"):
assert len(get_state_devices(model,
state)) == 0, f"State {state} must be removed after offload_states()"
else:
assert compare_device(state), f"State {state} is not on device {device}"


def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking):
Expand Down
Loading