Skip to content

Commit

Permalink
Merge branch 'master' into rearrange_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Sep 27, 2024
2 parents 1ea0c81 + 828ddfb commit dc6e04e
Show file tree
Hide file tree
Showing 14 changed files with 710 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nv-accelerate-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ jobs:
# tmp fix: force newer datasets version
#pip install "datasets>=2.0.0"
pip list
pytest $PYTEST_OPTS --color=yes --durations=0 --verbose tests/deepspeed -k "not test_prepare_multiple_models_zero3_inference"
pytest $PYTEST_OPTS --color=yes --durations=0 --verbose tests/deepspeed
59 changes: 59 additions & 0 deletions .github/workflows/xpu-compile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
name: xpu-compile

on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
pull_request:
paths:
- ".github/workflows/xpu-compile.yml"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: read
issues: write

jobs:
compile-tests:
runs-on: [self-hosted, intel, xpu]
container:
image: intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04
ports:
- 80
options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL

steps:
- uses: actions/checkout@v4
- name: Install prerequisite
run: |
apt-get update
apt-get install clinfo libaio-dev python3-pip -y
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torch/
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/intel-extension-for-pytorch/
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/oneccl-bind-pt/
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torchvision/
pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v3.0.0b2/triton_xpu-3.0.0b2-cp310-cp310-linux_x86_64.whl
pip install py-cpuinfo numpy
pip install .[dev,autotuning]
- name: Check container state
run: |
ldd --version
ds_report
python3 -c "import torch; print('torch:', torch.__version__, torch)"
python3 -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())"
python3 -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)"
pip list
- name: Compile Status
shell: bash
run: |
export FI_HMEM=system
ulimit -n 1048575
cd tests/torch_compile
export ZE_AFFINITY_MASK=0,1
deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt
cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY
51 changes: 51 additions & 0 deletions csrc/xpu/aio/deepspeed_cpu_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include "deepspeed_cpu_op.h"

using namespace std;

cpu_op_desc_t::cpu_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const bool validate)
: io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate),
_cpu_buffer(buffer)
{
// XPU don't handle buffer here. See XPU Accelerator pin_memory.
_contiguous_buffer = _cpu_buffer.contiguous();
}

char* cpu_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }

void cpu_op_desc_t::finish()
{
if (_read_op && _buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); }
}

void cpu_op_desc_t::validate()
{
validate_aio_operation(_read_op, _filename.c_str(), data_ptr(), _file_num_bytes);
}

void cpu_op_desc_t::run(const int tid,
std::unique_ptr<aio_context>& aio_ctxt,
deepspeed_aio_config_t* aio_config)
{
assert(tid < _num_threads);
const auto base_offset = _num_bytes_per_thread * tid;

std::unique_ptr<io_xfer_ctxt> xfer_ctxt(
new io_xfer_ctxt(_fd, base_offset, _num_bytes_per_thread, data_ptr()));

if (aio_config->_overlap_events) {
do_aio_operation_overlap(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr);
} else {
do_aio_operation_sequential(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr);
}
}
7 changes: 4 additions & 3 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def __init__(self, model, config):
if hasattr(self.module, "config"):
TransformerPolicy.hf_model_config = self.module.config

if config.dtype == torch.half and not get_accelerator().is_fp16_supported():
raise ValueError("Type fp16 is not supported.")
if config.dtype not in get_accelerator().supported_dtypes():
raise ValueError(
f"Data type {config.dtype} is not supported by {get_accelerator().device_name()} accelerator")

# todo: keep this self.injection_dict because we don't use to change config.injection_policy API
# todo: this will get changed when Molly's PR on auto injection dict is merged
Expand Down Expand Up @@ -315,7 +316,7 @@ def _validate_args(self, mpu, replace_with_kernel_inject):
if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")

supported_dtypes = [None, torch.half, torch.int8, torch.float]
supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16]
if self._config.dtype not in supported_dtypes:
raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")

Expand Down
41 changes: 39 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from typing import Callable, Dict, Union, Iterable
from typing import Callable, Dict, Union, Iterable, Container

import deepspeed

from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim
from .zero.offload_config import OffloadDeviceEnum
from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
Expand Down Expand Up @@ -3681,3 +3681,40 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg
@property
def is_compiled(self) -> bool:
return self._is_compiled

def offload_states(self,
include: Container[OffloadStateTypeEnum] = None,
device: OffloadDeviceEnum = OffloadDeviceEnum.cpu,
pin_memory: bool = True,
non_blocking: bool = False) -> None:
"""Offload the engine's states to the specified device.
Arguments:
include: Optional. The set of states to offload. If not provided, all states are offloaded.
device: Optional. The device to move the ZeRO optimizer buffers to. Currently only `OffloadDeviceEnum.cpu` is supported.
pin_memory: Optional. Whether to pin the memory of the offloaded states.
non_blocking: Optional. Whether to offload the states asynchronously.
"""
assert self.zero_optimization_stage(
) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3."

assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters."

if device == OffloadDeviceEnum.none:
logger.warning("No device specified for offloading states.")
return

if device == OffloadDeviceEnum.nvme:
raise ValueError("NVMe offload is not supported for offloading states.")

self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking)

def reload_states(self, non_blocking: bool = False) -> None:
"""Reload the engine states to the original device.
Arguments:
non_blocking: Optional. Whether to offload the states asynchronously.
"""
assert self.zero_optimization_stage(
) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3."
self.optimizer.reload_states(non_blocking=non_blocking)
36 changes: 36 additions & 0 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,3 +1065,39 @@ def to_tensor(v):
total_norm = -1

return total_norm


def _make_offload_state_key(key):
return f"{key}_offload_buffer"


def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_key(state, key):
offload_buf_key = _make_offload_state_key(key)
if offload_buf_key not in state:
state[offload_buf_key] = torch.empty_like(state[key], device=device)
if pin_memory:
state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key])
state[offload_buf_key].copy_(state[key], non_blocking=non_blocking)
state[key].data = state[offload_buf_key]

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_key(state, "exp_avg_sq")


def reload_adam_states(optimizer, device, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_back_key(state, key):
state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking)

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_back_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_back_key(state, "exp_avg_sq")
9 changes: 9 additions & 0 deletions deepspeed/runtime/zero/offload_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,12 @@ def set_pipeline(self):
pipeline = self.pipeline_read or self.pipeline_write
self.__dict__["pipeline"] = pipeline
return self


class OffloadStateTypeEnum(str, Enum):
""" Enum for internal buffer types """
optim_states = "optim_states"
hp_params = "hp_params"
lp_params = "lp_params"
lp_grads = "lp_grads"
contiguous_grad_buffer = "contiguous_grad_buffer"
Loading

0 comments on commit dc6e04e

Please sign in to comment.