diff --git a/.github/workflows/nv-nightly.yml b/.github/workflows/nv-nightly.yml index b1e8c042214f..8658ff5d2348 100644 --- a/.github/workflows/nv-nightly.yml +++ b/.github/workflows/nv-nightly.yml @@ -2,6 +2,9 @@ name: nv-nightly on: workflow_dispatch: + pull_request: + paths: + - '.github/workflows/nv-nightly.yml' schedule: - cron: "0 0 * * *" @@ -25,7 +28,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --index-url https://download.pytorch.org/whl/cu117 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -34,7 +37,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 + git checkout v4.42.4 git rev-parse --short HEAD pip install . @@ -55,7 +58,7 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - pytest $PYTEST_OPTS --forked -m 'nightly' unit/ --torch_ver="1.13" --cuda_ver="11.7" + pytest $PYTEST_OPTS --forked -m 'nightly' unit/ --torch_ver="2.4" --cuda_ver="11.8" - name: Open GitHub issue if nightly CI fails if: ${{ failure() && (github.event_name == 'schedule') }} diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index a506bb27fda4..72ba8abbd95d 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -36,7 +36,7 @@ jobs: #python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Compile DeepSpeed Ops run: | - DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . + DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . - name: DS Report run: | ds_report diff --git a/blogs/windows/08-2024/README.md b/blogs/windows/08-2024/README.md new file mode 100644 index 000000000000..34e11bd47792 --- /dev/null +++ b/blogs/windows/08-2024/README.md @@ -0,0 +1,101 @@ +
+ +# DeepSpeed on Windows + +
+ +# Introduction + +DeepSpeed is a popular open-source deep learning optimization library that makes distributed training and inference easy, efficient, and effective. DeepSpeed has been widely used to train a variety of state-of-the-art models, including Phi-3, Megatron-Turing-530B, BLOOM-176B, and Arctic because of its rich suite of sophisticated optimizations (e.g., ZeRO, 3D parallelism, MoE, etc.). However, the lack of native support for Microsoft Windows, the most popular operating system, means that DeepSpeed innovations are inaccessible to many AI developers and users. To address this problem, we started an effort to make DeepSpeed run natively with full features on Windows, while ensuring the same ease-of-use enjoyed on Linux. + +In this blog, we are pleased to announce some early achievements on this journey: DeepSpeed can now be installed in Windows and run natively for single-GPU training, finetuning, and inferencing. Importantly, both the installation and usage experiences are identical to those on Linux. Furthermore, the finetuning and inferencing workloads demonstrate the functioning of three critical DeepSpeed features, HuggingFace Transformers integration, LoRA support, and CPU Offloading. DeepSpeed on Windows is available in DeepSpeed versions 0.14.5 and above. In the rest of this blog, we present examples to demonstrate these achievements. + +# Evaluation Environment +We conducted the experiments on a Surface Laptop Studio 2 running Windows 11 Version 23H2 and Build 22631.3880. The laptop is equipped with a single NVIDIA RTX A2000 GPU with 4GB VRAM. We used Pytorch version 2.3.0 and HuggingFace Transformers version 4.41.2. The example scripts used are from the [DeepSpeedExamples repo](https://github.com/microsoft/DeepSpeedExamples), therefore you need to clone the repo before running any of the following examples. + +# Installation +DeepSpeed can be installed on Windows in one of two ways. The easier way is to use the pip package manager, while the other is to build from source. The prerequisites for in both cases are Python 3.x and Pytorch with CUDA support. + +## Installing via pip +To install DeepSpeed, simply run: `pip install deepspeed`. This will install the latest version of DeepSpeed (0.14.5 at this time). Unlike the Linux counterpart, the Windows version comes with all the operators already prebuilt, so there is no need to have a CUDA SDK or C++ compiler installed. + +
+ +
+ +
+ pip installation of DeepSpeed on Windows. +
+ + +## Building from Source +To build DeepSpeed from source, you need to clone the DeepSpeed repository and run the `build_win.bat` compilation script. + + +## Validating Installation +Regardless of the installation choice, you can check that the installation was successful by running ds_report. The output should look like this: + + +
+ +
+ +
+ ds_report output confirming Windows installation of DeepSpeed. +
+ +# Pretraining Examples +We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed. + +## Pretraining CIFAR10 +The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py –deepspeed`. The final output should look something like this: +
+ +
+ +
+ Pretraining CIFAR10 model on Windows using DeepSpeed. +
+ +## Pretraining BERT +The scripts and codes for the BERT pretraining example are available in the following path: DeepSpeedExamples\training\HelloDeepSpeed. You can launch the BERT pretraining experiment using the following command: `deepspeed train_bert_ds.py --checkpoint_dir experiment_deepspeed`. The final output should look like this: + +
+ +
+ +
+ Pretraining BERT model on Windows using DeepSpeed. +
+ +# Fine Tuning Example +We demonstrate fine tuning capability by using the supervised fine tuning (SFT) step of DeepSpeed-Chat application. We conduct SFT of the HuggingFace facebook/opt-125m model while enabling LoRA and CPU offloading memory optimizations. The command line for running this example is as follows: +deepspeed training\step1_supervised_finetuning\main.py --model_name_or_path facebook/opt-125m --gradient_accumulation_steps 8 --lora_dim 128 --only_optimize_lora --print_loss --zero_stage 2 --deepspeed --dtype bf16 --offload --output_dir output +The output should look like this: + +
+ +
+ +
+ Supervised Finetuning of facebook/opt-125m model on Windows using DeepSpeed. +
+ +# Inference Example +We demonstrate inference capability by using ZeRO-Inference for token generation. ZeRO-Inference reduces hardware cost of inferencing by offloading to CPU or NVMe memories. We use the example scripts here to run token generation using Llama-2-7B model from HuggingFace. We offload the model weights to CPU memory since the 4GB VRAM is insufficient to host both the model and the generation working set. We use the following command line to generate 32 tokens from a prompt of 8 tokens: +deepspeed run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 64 --prompt-len 8 --gen-len 32 --cpu-offload +The output will look something like this: + +
+ +
+ +
+ LLAMA2-7B token generation on Windows using ZeRO-Inference. +
+ +# Summary +Enabling DeepSpeed, a popular deep learning framework, to run natively on Windows, the most popular operating system, is a crucial step towards empowering every person and every organization to benefit from the ongoing AI revolution. In this blog, we have shared early results of our work towards this goal. Although Windows support of DeepSpeed is a work-in-progress, we hope that the above updates are encouraging and already useful to users. The next items on our roadmap include running on multiple GPUs, weight quantization, and performance studies. + +# Acknowledgements +This work is a result of significant contributions from current and former DeepSpeed members including Costin Eseanu, Logan Adams, Elton Zheng, Reza Yazdani Aminabadi, Martin Cai, and Olatunji Ruwase. We also acknowledge the valuable contributions of DeepSpeed users who righteously demanded this feature, provided critical workarounds, partial solutions, and constructive feedback, and most importantly, stuck with us. diff --git a/blogs/windows/08-2024/media/bert_training.png b/blogs/windows/08-2024/media/bert_training.png new file mode 100644 index 000000000000..c5935e47747e Binary files /dev/null and b/blogs/windows/08-2024/media/bert_training.png differ diff --git a/blogs/windows/08-2024/media/cifar10_training.png b/blogs/windows/08-2024/media/cifar10_training.png new file mode 100644 index 000000000000..99f3fa25bc70 Binary files /dev/null and b/blogs/windows/08-2024/media/cifar10_training.png differ diff --git a/blogs/windows/08-2024/media/ds_report.png b/blogs/windows/08-2024/media/ds_report.png new file mode 100644 index 000000000000..43d82d724ed2 Binary files /dev/null and b/blogs/windows/08-2024/media/ds_report.png differ diff --git a/blogs/windows/08-2024/media/llama2-7b_inference.png b/blogs/windows/08-2024/media/llama2-7b_inference.png new file mode 100644 index 000000000000..f5874468a854 Binary files /dev/null and b/blogs/windows/08-2024/media/llama2-7b_inference.png differ diff --git a/blogs/windows/08-2024/media/opt125m_finetuning.png b/blogs/windows/08-2024/media/opt125m_finetuning.png new file mode 100644 index 000000000000..ed6d1522e3b3 Binary files /dev/null and b/blogs/windows/08-2024/media/opt125m_finetuning.png differ diff --git a/blogs/windows/08-2024/media/win_pip_install_deepspeed.png b/blogs/windows/08-2024/media/win_pip_install_deepspeed.png new file mode 100644 index 000000000000..3b87c95ef144 Binary files /dev/null and b/blogs/windows/08-2024/media/win_pip_install_deepspeed.png differ diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp new file mode 100644 index 000000000000..dc820be528d0 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepspeed_aio_op_desc.h" + +using namespace std; + +io_op_desc_t::io_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const int num_threads, + const bool validate) + : _read_op(read_op), + _buffer(buffer), + _fd(fd), + _filename(filename), + _file_num_bytes(file_num_bytes), + _num_threads(num_threads), + _num_bytes_per_thread(file_num_bytes / num_threads), + _validate(validate) +{ +} + +char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } + +void io_op_desc_t::finish() {} + +void io_op_desc_t::validate() {} + +void io_op_desc_t::run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config) +{ +} diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h new file mode 100644 index 000000000000..7305f6920c91 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#ifndef _IO_OP_DESC_T_ +#define _IO_OP_DESC_T_ +#include +#include +#include "deepspeed_py_aio.h" + +struct io_op_desc_t { + const bool _read_op; + torch::Tensor _buffer; + int _fd; + const std::string _filename; + const long long int _file_num_bytes; + const int _num_threads; + const int _num_bytes_per_thread; + torch::Tensor _contiguous_buffer; + const bool _validate; + + io_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const int num_threads, + const bool validate); + + virtual void run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config); + + virtual char* data_ptr() const; + + virtual void validate(); + + virtual void finish(); +}; +#endif // _IO_OP_DESC_T_ diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp index c852711a28c0..30c3b4914397 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp @@ -9,50 +9,8 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include "deepspeed_aio_thread.h" -#if defined(__ENABLE_CANN__) -#include "torch_npu/csrc/framework/utils/OpAdapter.h" -#include "torch_npu/csrc/framework/utils/UtilForOpAdapter.h" -#endif - using namespace std; -io_op_desc_t::io_op_desc_t(const bool read_op, - const torch::Tensor& buffer, - const int fd, - const char* filename, - const long long int num_bytes, - const bool validate) - : _read_op(read_op), - _buffer(buffer), - _fd(fd), - _filename(filename), - _num_bytes(num_bytes), - _validate(validate) -{ - _cpu_buffer = (_buffer.is_cuda() || _buffer.is_xpu() -#if defined(__ENABLE_CANN__) - || torch_npu::utils::is_npu(_buffer) -#endif - ) - ? _buffer.to(torch::kCPU).pin_memory() - : _buffer; - _contiguous_buffer = _cpu_buffer.contiguous(); -} - -char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } - -void io_op_desc_t::fini() -{ - if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } - if (_read_op && _buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); } -#if defined(__ENABLE_CANN__) - if (_read_op && torch_npu::utils::is_npu(_buffer)) { - auto device = at::Device("npu:0"); - _buffer.copy_(_cpu_buffer.to(device)); - } -#endif -} - deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config) : _tid(tid), _aio_config(aio_config), @@ -79,18 +37,7 @@ void deepspeed_aio_thread_t::run() } if (next_io_op) { - const auto base_offset = next_io_op->_num_bytes * _tid; - - std::unique_ptr xfer_ctxt(new io_xfer_ctxt( - next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr())); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap( - next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential( - next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } + next_io_op->run(_tid, _aio_ctxt, &_aio_config); { std::lock_guard lock(_complete_sync._mutex); diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.h b/csrc/aio/py_lib/deepspeed_aio_thread.h index 20799ecbb018..a192804db13d 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.h +++ b/csrc/aio/py_lib/deepspeed_aio_thread.h @@ -10,28 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include #include #include -#include "deepspeed_py_aio.h" - -struct io_op_desc_t { - const bool _read_op; - torch::Tensor _buffer; - int _fd; - const std::string _filename; - const long long int _num_bytes; - torch::Tensor _cpu_buffer; - torch::Tensor _contiguous_buffer; - const bool _validate; - - io_op_desc_t(const bool read_op, - const torch::Tensor& buffer, - const int fd, - const char* filename, - const long long int num_bytes, - const bool validate); - - char* data_ptr() const; - void fini(); -}; +#include "deepspeed_cpu_op.h" struct thread_sync_t { std::mutex _mutex; diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp new file mode 100644 index 000000000000..41790b99bb88 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepspeed_cpu_op.h" + +using namespace std; + +cpu_op_desc_t::cpu_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const int num_threads, + const bool validate) + : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate), + _cpu_buffer(buffer) +{ + // Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory. + _use_bounce_buffer = !(_buffer.is_cpu() && _buffer.is_pinned()); + if (_use_bounce_buffer) { + if (_read_op) { + auto options = torch::TensorOptions() + .dtype(_buffer.dtype()) + .layout(_buffer.layout()) + .device(torch::kCPU); + _cpu_buffer = torch::empty(_buffer.nbytes(), options).pin_memory(); + } else { + _cpu_buffer = _buffer.to(torch::kCPU).pin_memory(); + } + } + _contiguous_buffer = _cpu_buffer.contiguous(); +} + +char* cpu_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } + +void cpu_op_desc_t::finish() +{ + if (_read_op) { + if (_buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } + if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); } +#if defined(__ENABLE_CANN__) + if (torch_npu::utils::is_npu(_buffer)) { + auto device = at::Device("npu:0"); + _buffer.copy_(_cpu_buffer.to(device)); + } +#endif + } +} + +void cpu_op_desc_t::validate() +{ + validate_aio_operation(_read_op, _filename.c_str(), data_ptr(), _file_num_bytes); +} + +void cpu_op_desc_t::run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config) +{ + assert(tid < _num_threads); + const auto base_offset = _num_bytes_per_thread * tid; + + std::unique_ptr xfer_ctxt( + new io_xfer_ctxt(_fd, base_offset, _num_bytes_per_thread, data_ptr())); + + if (aio_config->_overlap_events) { + do_aio_operation_overlap(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr); + } else { + do_aio_operation_sequential(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr); + } +} diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.h b/csrc/aio/py_lib/deepspeed_cpu_op.h new file mode 100644 index 000000000000..da96dd2b1d50 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_cpu_op.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include "deepspeed_aio_op_desc.h" + +struct cpu_op_desc_t : io_op_desc_t { + torch::Tensor _cpu_buffer; + bool _use_bounce_buffer; + + cpu_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const int num_threads, + const bool validate); + + void run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config); + + char* data_ptr() const; + + void validate(); + + void finish(); +}; diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp index 0556f5aa8168..eac268d33433 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp @@ -4,9 +4,6 @@ // DeepSpeed Team /* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ diff --git a/csrc/aio/py_lib/deepspeed_py_aio.h b/csrc/aio/py_lib/deepspeed_py_aio.h index 11d5225de9f1..ba794db5440d 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.h +++ b/csrc/aio/py_lib/deepspeed_py_aio.h @@ -4,10 +4,7 @@ // DeepSpeed Team /* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +Functionality for swapping tensors to/from (NVMe) storage devices. */ #include diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp index 23ddabe260d4..c7ca5e82afde 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp @@ -4,293 +4,21 @@ // DeepSpeed Team /* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ #include "deepspeed_py_aio_handle.h" +#include using namespace std; -static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); } - deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size, const int queue_depth, const bool single_submit, const bool overlap_events, const int num_threads) - : _aio_ctxt(new aio_context(block_size, queue_depth)), - _single_submit(single_submit), - _overlap_events(overlap_events), - _num_threads(num_threads), - _aio_config(block_size, queue_depth, single_submit, overlap_events, false), - _num_pending_ops(0), - _pinned_tensor_mgr(new deepspeed_pin_tensor_t()) -{ - for (auto i = 0; i < num_threads; ++i) { - _thread_contexts.push_back(std::make_shared(i, _aio_config)); - } - - for (auto& ctxt : _thread_contexts) { - _threads.push_back(std::thread(_start_aio_thread, ctxt)); - } -} - -deepspeed_aio_handle_t::~deepspeed_aio_handle_t() -{ - _stop_threads(); - for (auto& thr : _threads) { thr.join(); } -} - -const int deepspeed_aio_handle_t::get_block_size() const -{ - return _aio_ctxt ? _aio_ctxt->_block_size : -1; -} - -const int deepspeed_aio_handle_t::get_queue_depth() const -{ - return _aio_ctxt ? _aio_ctxt->_queue_depth : -1; -} - -const bool deepspeed_aio_handle_t::get_single_submit() const { return _single_submit; } - -const bool deepspeed_aio_handle_t::get_overlap_events() const { return _overlap_events; } - -const int deepspeed_aio_handle_t::get_thread_count() const { return _num_threads; } - -int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - - assert(_aio_ctxt); - - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - assert(static_cast(buffer.nbytes()) == num_file_bytes); - - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto read_buffer = (char*)buffer.data_ptr(); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - - close(fd); - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 - << " call = " << fn_time.count() * 1e6 << std::endl; - return 0; -} - -int deepspeed_aio_handle_t::write(const torch::Tensor& buffer, - const char* filename, - const bool validate) + : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, num_threads) { - assert(_aio_ctxt); - - const auto start_time = std::chrono::high_resolution_clock::now(); - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto write_buffer = (char*)buffer.data_ptr(); - const auto num_write_bytes = static_cast(buffer.nbytes()); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - close(fd); - - if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } - - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 - << " call = " << fn_time.count() * 1e6 << std::endl; - return 0; } -void deepspeed_aio_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op) -{ - for (auto& ctxt : _thread_contexts) { - { - std::lock_guard lock(ctxt->_work_sync._mutex); - ctxt->_work_queue.push(scheduled_op); - } - ctxt->_work_sync._cond_var.notify_one(); - } - _num_pending_ops++; -} - -std::shared_ptr deepspeed_aio_handle_t::_wait_for_aio_work() -{ - std::shared_ptr completed_op = nullptr; - for (auto& ctxt : _thread_contexts) { - std::unique_lock lock(ctxt->_complete_sync._mutex); - ctxt->_complete_sync._cond_var.wait(lock, - [ctxt] { return !ctxt->_complete_queue.empty(); }); - completed_op = ctxt->_complete_queue.front(); - ctxt->_complete_queue.pop(); - } - return completed_op; -} - -void deepspeed_aio_handle_t::_stop_threads() -{ - assert(0 == _num_pending_ops); - for (auto& ctxt : _thread_contexts) { - { - std::lock_guard lock(ctxt->_work_sync._mutex); - ctxt->_time_to_exit = true; - } - ctxt->_work_sync._cond_var.notify_one(); - } -} - -int deepspeed_aio_handle_t::wait() -{ - assert(_num_pending_ops > 0); - auto num_completed_ops = 0; - - while (_num_pending_ops > 0) { - auto completed_op = _wait_for_aio_work(); - - completed_op->fini(); - - close(completed_op->_fd); - - if (completed_op->_validate) { - validate_aio_operation(completed_op->_read_op, - completed_op->_filename.c_str(), - completed_op->data_ptr(), - _num_threads * completed_op->_num_bytes); - } - --_num_pending_ops; - ++num_completed_ops; - } - - return num_completed_ops; -} - -bool deepspeed_aio_handle_t::_is_valid_parallel_aio_op(const bool read_op, - const long long int num_bytes) -{ - const auto op_string = read_op ? "Read" : "Write"; - if (num_bytes % get_thread_count()) { - std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes - << " not divisible by thread count = " << get_thread_count() << std::endl; - return false; - } - - return true; -} - -int deepspeed_aio_handle_t::pread(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async) -{ - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - const auto buffer_bytes = static_cast(buffer.nbytes()); - if (buffer_bytes != num_file_bytes) { - std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes - << " != " << num_file_bytes << std::endl; - } - assert(static_cast(buffer.nbytes()) == num_file_bytes); - assert((num_file_bytes % _num_threads) == 0); - - if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } - - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto scheduled_op = std::make_shared( - true, buffer, fd, filename, (num_file_bytes / _num_threads), validate); - - _schedule_aio_work(scheduled_op); - - if (async) { return 0; } - - return wait(); -} - -int deepspeed_aio_handle_t::pwrite(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async) -{ - const auto num_write_bytes = static_cast(buffer.nbytes()); - assert((num_write_bytes % _num_threads) == 0); - - if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto scheduled_op = std::make_shared( - false, buffer, fd, filename, (num_write_bytes / _num_threads), validate); - - _schedule_aio_work(scheduled_op); - - if (async) { return 0; } - - return wait(); -} - -int deepspeed_aio_handle_t::sync_pread(torch::Tensor& buffer, const char* filename) -{ - return pread(buffer, filename, false, false); -} - -int deepspeed_aio_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename) -{ - return pwrite(buffer, filename, false, false); -} - -int deepspeed_aio_handle_t::async_pread(torch::Tensor& buffer, const char* filename) -{ - return pread(buffer, filename, false, true); -} - -int deepspeed_aio_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename) -{ - return pwrite(buffer, filename, false, true); -} - -at::Tensor deepspeed_aio_handle_t::new_cpu_locked_tensor(const size_t num_elem, - const torch::Tensor& example_tensor) -{ - return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type()); -} - -bool deepspeed_aio_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor) -{ - return _pinned_tensor_mgr->free(locked_tensor); -} +deepspeed_aio_handle_t::~deepspeed_aio_handle_t() {} diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.h b/csrc/aio/py_lib/deepspeed_py_aio_handle.h index 3a254c3814a2..eb6b90ea22f0 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.h @@ -9,21 +9,9 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include #include -#include "deepspeed_aio_thread.h" -#include "deepspeed_pin_tensor.h" - -struct deepspeed_aio_handle_t { - std::unique_ptr _aio_ctxt; - const bool _single_submit; - const bool _overlap_events; - const int _num_threads; - deepspeed_aio_config_t _aio_config; - - std::vector> _thread_contexts; - std::vector _threads; - int _num_pending_ops; - std::unique_ptr _pinned_tensor_mgr; +#include "deepspeed_py_io_handle.h" +struct deepspeed_aio_handle_t : deepspeed_io_handle_t { deepspeed_aio_handle_t(const int block_size, const int queue_depth, const bool single_submit, @@ -31,47 +19,4 @@ struct deepspeed_aio_handle_t { const int num_threads); ~deepspeed_aio_handle_t(); - - const int get_block_size() const; - const int get_queue_depth() const; - const bool get_single_submit() const; - const bool get_overlap_events() const; - const int get_thread_count() const; - - int read(torch::Tensor& buffer, const char* filename, const bool validate); - - int write(const torch::Tensor& buffer, const char* filename, const bool validate); - - int pread(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async); - - int pwrite(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async); - - int sync_pread(torch::Tensor& buffer, const char* filename); - - int sync_pwrite(const torch::Tensor& buffer, const char* filename); - - int async_pread(torch::Tensor& buffer, const char* filename); - - int async_pwrite(const torch::Tensor& buffer, const char* filename); - - // TODO: Make API's args to be shape and dtype. - torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor); - - bool free_cpu_locked_tensor(torch::Tensor&); - - int wait(); - - void _stop_threads(); - - void _schedule_aio_work(std::shared_ptr scheduled_op); - - std::shared_ptr _wait_for_aio_work(); - - bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes); }; diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp index c597b91d05c9..f5480e9d9d83 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.cpp +++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp @@ -4,7 +4,7 @@ // DeepSpeed Team /* -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +Functionality for swapping tensors to/from (NVMe) storage devices. */ #include "deepspeed_py_copy.h" diff --git a/csrc/aio/py_lib/deepspeed_py_copy.h b/csrc/aio/py_lib/deepspeed_py_copy.h index 19ba28317d00..f443571a3e7b 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.h +++ b/csrc/aio/py_lib/deepspeed_py_copy.h @@ -4,9 +4,6 @@ // DeepSpeed Team /* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp new file mode 100644 index 000000000000..bdf2a858d797 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -0,0 +1,300 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_py_io_handle.h" +#include + +using namespace std; + +static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); } + +deepspeed_io_handle_t::deepspeed_io_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int num_threads) + : _aio_ctxt(new aio_context(block_size, queue_depth)), + _single_submit(single_submit), + _overlap_events(overlap_events), + _num_threads(num_threads), + _aio_config(block_size, queue_depth, single_submit, overlap_events, false), + _num_pending_ops(0), + _pinned_tensor_mgr(new deepspeed_pin_tensor_t()) +{ + for (auto i = 0; i < num_threads; ++i) { + _thread_contexts.push_back(std::make_shared(i, _aio_config)); + } + + for (auto& ctxt : _thread_contexts) { + _threads.push_back(std::thread(_start_aio_thread, ctxt)); + } +} + +deepspeed_io_handle_t::~deepspeed_io_handle_t() +{ + _stop_threads(); + for (auto& thr : _threads) { thr.join(); } +} + +const int deepspeed_io_handle_t::get_block_size() const +{ + return _aio_ctxt ? _aio_ctxt->_block_size : -1; +} + +const int deepspeed_io_handle_t::get_queue_depth() const +{ + return _aio_ctxt ? _aio_ctxt->_queue_depth : -1; +} + +const bool deepspeed_io_handle_t::get_single_submit() const { return _single_submit; } + +const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_events; } + +const int deepspeed_io_handle_t::get_thread_count() const { return _num_threads; } + +int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + + assert(_aio_ctxt); + + long long num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + assert(static_cast(buffer.nbytes()) == num_file_bytes); + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto read_buffer = (char*)buffer.data_ptr(); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + + close(fd); + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; + return 0; +} + +int deepspeed_io_handle_t::write(const torch::Tensor& buffer, + const char* filename, + const bool validate) +{ + assert(_aio_ctxt); + + const auto start_time = std::chrono::high_resolution_clock::now(); + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto write_buffer = (char*)buffer.data_ptr(); + const auto num_write_bytes = static_cast(buffer.nbytes()); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; + return 0; +} + +void deepspeed_io_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op) +{ + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_work_queue.push(scheduled_op); + } + ctxt->_work_sync._cond_var.notify_one(); + } + _num_pending_ops++; +} + +std::shared_ptr deepspeed_io_handle_t::_wait_for_aio_work() +{ + std::shared_ptr completed_op = nullptr; + for (auto& ctxt : _thread_contexts) { + std::unique_lock lock(ctxt->_complete_sync._mutex); + ctxt->_complete_sync._cond_var.wait(lock, + [ctxt] { return !ctxt->_complete_queue.empty(); }); + completed_op = ctxt->_complete_queue.front(); + ctxt->_complete_queue.pop(); + } + return completed_op; +} + +void deepspeed_io_handle_t::_stop_threads() +{ + assert(0 == _num_pending_ops); + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_time_to_exit = true; + } + ctxt->_work_sync._cond_var.notify_one(); + } +} + +int deepspeed_io_handle_t::wait() +{ + assert(_num_pending_ops > 0); + auto num_completed_ops = 0; + + while (_num_pending_ops > 0) { + auto completed_op = _wait_for_aio_work(); + + if (completed_op->_validate) { completed_op->validate(); } + + completed_op->finish(); + + close(completed_op->_fd); + + --_num_pending_ops; + ++num_completed_ops; + } + + return num_completed_ops; +} + +bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op, + const long long int num_bytes) +{ + const auto op_string = read_op ? "Read" : "Write"; + if (num_bytes % get_thread_count()) { + std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes + << " not divisible by thread count = " << get_thread_count() << std::endl; + return false; + } + + return true; +} + +std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc( + const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const bool validate) +{ + return std::make_shared( + read_op, buffer, fd, filename, file_num_bytes, _num_threads, validate); +} + +int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async) +{ + long long num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + const auto buffer_bytes = static_cast(buffer.nbytes()); + if (buffer_bytes != num_file_bytes) { + std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes + << " != " << num_file_bytes << std::endl; + } + assert(static_cast(buffer.nbytes()) == num_file_bytes); + assert((num_file_bytes % _num_threads) == 0); + + if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto scheduled_op = _create_io_op_desc(true, buffer, fd, filename, num_file_bytes, validate); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async) +{ + const auto num_write_bytes = static_cast(buffer.nbytes()); + assert((num_write_bytes % _num_threads) == 0); + + if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto scheduled_op = _create_io_op_desc(false, buffer, fd, filename, num_write_bytes, validate); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_io_handle_t::sync_pread(torch::Tensor& buffer, const char* filename) +{ + return pread(buffer, filename, false, false); +} + +int deepspeed_io_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename) +{ + return pwrite(buffer, filename, false, false); +} + +int deepspeed_io_handle_t::async_pread(torch::Tensor& buffer, const char* filename) +{ + return pread(buffer, filename, false, true); +} + +int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename) +{ + return pwrite(buffer, filename, false, true); +} + +at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const size_t num_elem, + const torch::Tensor& example_tensor) +{ + return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type()); +} + +bool deepspeed_io_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor) +{ + return _pinned_tensor_mgr->free(locked_tensor); +} diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h new file mode 100644 index 000000000000..2974ebe87bfc --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include "deepspeed_aio_thread.h" +#include "deepspeed_pin_tensor.h" + +struct deepspeed_io_handle_t { + std::unique_ptr _aio_ctxt; + const bool _single_submit; + const bool _overlap_events; + const int _num_threads; + deepspeed_aio_config_t _aio_config; + + std::vector> _thread_contexts; + std::vector _threads; + int _num_pending_ops; + std::unique_ptr _pinned_tensor_mgr; + + deepspeed_io_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int num_threads); + + virtual ~deepspeed_io_handle_t() = 0; + + const int get_block_size() const; + const int get_queue_depth() const; + const bool get_single_submit() const; + const bool get_overlap_events() const; + const int get_thread_count() const; + + int read(torch::Tensor& buffer, const char* filename, const bool validate); + + int write(const torch::Tensor& buffer, const char* filename, const bool validate); + + int pread(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async); + + int pwrite(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async); + + int sync_pread(torch::Tensor& buffer, const char* filename); + + int sync_pwrite(const torch::Tensor& buffer, const char* filename); + + int async_pread(torch::Tensor& buffer, const char* filename); + + int async_pwrite(const torch::Tensor& buffer, const char* filename); + + // TODO: Make API's args to be shape and dtype. + torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor); + + bool free_cpu_locked_tensor(torch::Tensor&); + + int wait(); + + void _stop_threads(); + + void _schedule_aio_work(std::shared_ptr scheduled_op); + + std::shared_ptr _wait_for_aio_work(); + + bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes); + + virtual std::shared_ptr _create_io_op_desc( + const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const bool validate); +}; diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp old mode 100755 new mode 100644 index 9033549bc0d2..3171d0c6bf3c --- a/csrc/aio/py_lib/py_ds_aio.cpp +++ b/csrc/aio/py_lib/py_ds_aio.cpp @@ -10,6 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include #include "deepspeed_py_aio_handle.h" #include "deepspeed_py_copy.h" +using namespace pybind11::literals; PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -20,7 +21,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("deepspeed_memcpy", &deepspeed_py_memcpy, "DeepSpeed Memory Copy"); py::class_(m, "aio_handle") - .def(py::init()) + .def(py::init(), + "AIO handle constructor", + "block_size"_a = 1024 * 1024, + "queue_depth"_a = 128, + "single_submit"_a = false, + "overlap_events"_a = false, + "num_threads"_a = 1) .def("get_block_size", &deepspeed_aio_handle_t::get_block_size) .def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth) @@ -28,19 +35,74 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) .def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events) .def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count) - .def("read", &deepspeed_aio_handle_t::read) - .def("write", &deepspeed_aio_handle_t::write) + .def("read", + &deepspeed_aio_handle_t::read, + "Synchronous and non-parallel file read. Returns count of completed read ops", + "buffer"_a, + "filename"_a, + "validate"_a) - .def("pread", &deepspeed_aio_handle_t::pread) - .def("pwrite", &deepspeed_aio_handle_t::pwrite) + .def("write", + &deepspeed_aio_handle_t::write, + "Synchronous and non-parallel file write. Returns count of completed write ops", + "buffer"_a, + "filename"_a, + "validate"_a) - .def("sync_pread", &deepspeed_aio_handle_t::sync_pread) - .def("sync_pwrite", &deepspeed_aio_handle_t::sync_pwrite) - .def("async_pread", &deepspeed_aio_handle_t::async_pread) - .def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite) + .def("pread", + &deepspeed_aio_handle_t::pread, + "Parallel file read with option of parallelism. Returns count of completed read ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "async"_a) - .def("new_cpu_locked_tensor", &deepspeed_aio_handle_t::new_cpu_locked_tensor) - .def("free_cpu_locked_tensor", &deepspeed_aio_handle_t::free_cpu_locked_tensor) + .def("pwrite", + &deepspeed_aio_handle_t::pwrite, + "Parallel file write with option of parallelism. Returns count of completed write ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "async"_a) - .def("wait", &deepspeed_aio_handle_t::wait); + .def("sync_pread", + &deepspeed_aio_handle_t::sync_pread, + "Synchrononous parallel file read. Returns count of completed read ops", + "buffer"_a, + "filename"_a) + + .def("sync_pwrite", + &deepspeed_aio_handle_t::sync_pwrite, + "Synchronous parallel file write. Returns count of completed write ops", + "buffer"_a, + "filename"_a) + + .def("async_pread", + &deepspeed_aio_handle_t::async_pread, + "Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and " + "following wait() returns count of completed ops.", + "buffer"_a, + "filename"_a) + + .def("async_pwrite", + &deepspeed_aio_handle_t::async_pwrite, + "Asynchronous parallel file write. Returns 0 on success, and following wait() returns " + "count of completed ops.", + "buffer"_a, + "filename"_a) + + .def("new_cpu_locked_tensor", + &deepspeed_aio_handle_t::new_cpu_locked_tensor, + "Allocate pinned CPU tensor.", + "num_elem"_a, + "example_tenosr"_a) + + .def("free_cpu_locked_tensor", + &deepspeed_aio_handle_t::free_cpu_locked_tensor, + "Free pinned CPU tensor.", + "tensor"_a) + + .def("wait", + &deepspeed_aio_handle_t::wait, + "Wait for (ongoing) asynchronous operations to complete"); } diff --git a/csrc/aio/py_test/aio_bench_generate_param.py b/csrc/aio/py_test/aio_bench_generate_param.py index 09d0e03c7ef6..7a0ab59ed73d 100644 --- a/csrc/aio/py_test/aio_bench_generate_param.py +++ b/csrc/aio/py_test/aio_bench_generate_param.py @@ -41,9 +41,9 @@ def convert_to_param(key): return { "single_submit": "true" if key[0] == "single" else "false", "overlap_events": "true" if key[1] == "overlap" else "false", - "thread_count": int(key[3]), - "queue_depth": int(key[4]), - "block_size": int(key[5]) + "thread_count": int(key[5]), + "queue_depth": int(key[3]), + "block_size": int(key[4]) } diff --git a/csrc/aio/py_test/aio_bench_perf_sweep.py b/csrc/aio/py_test/aio_bench_perf_sweep.py index 7d55f7ded65c..ba95150b11e1 100644 --- a/csrc/aio/py_test/aio_bench_perf_sweep.py +++ b/csrc/aio/py_test/aio_bench_perf_sweep.py @@ -10,75 +10,47 @@ import argparse import json import itertools -import subprocess import shutil -from test_ds_aio_utils import refine_integer_value +from ds_aio_job import Job, run_job from perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \ - READ_IO_DIR, WRITE_IO_DIR, READ_LOG_DIR, WRITE_LOG_DIR + READ_LOG_DIR, WRITE_LOG_DIR from deepspeed.ops.op_builder import AsyncIOBuilder OTHER_OPTIONS = '--handle' PERF_SCRIPT = 'test_ds_aio.py' DEFAULT_SWEEP_CONFIG = { - "block_size": ["128K", "256K"], - "queue_depth": [4, 16, 32], - "overlap_events": [True, False], - "io_parallel": [2, 8], - "single_submit": [False] + "block_size": ["128K", "1M"], + "queue_depth": [32, 64, 128], + "sequential_requests": [True, False], + "single_submit": [False], + "io_parallel": [1, 2, 8], } -class Job(object): - - def __init__(self, cmd_line, output_file=None, work_dir=None): - self.cmd_line = cmd_line - self.output_file = output_file - self.work_dir = work_dir - self.output_fd = None - - def cmd(self): - return self.cmd_line - - def get_stdout(self): - return self.output_fd - - def get_stderr(self): - return self.output_fd - - def get_cwd(self): - return self.work_dir - - def open_output_file(self): - if self.output_file is not None: - self.output_fd = open(self.output_file, 'w') - - def close_output_file(self): - if self.output_fd is not None: - self.output_fd.close() - self.output_fd = None - - class SweepConfig(object): def __init__(self, args): - self.nvme_dir = args.nvme_dir - self.io_size = args.io_size + self.folder_to_device_mapping = get_ftd_map(args.nvme_dir) self.search_space = get_sweep_config_dict(args.sweep_config) + self.search_space.update(self.folder_to_device_mapping) self.read = not args.no_read self.write = not args.no_write self.flush_cache = not args.no_sudo self.log_dir = args.log_dir - self.loops = args.loops - self.other_options = f'{OTHER_OPTIONS} --loops {args.loops}' + self.other_options = f'{OTHER_OPTIONS} --loops {args.loops} --io_size {args.io_size}' + if args.gpu: + self.other_options += ' --gpu' + if args.gds: + self.other_options += ' --use_gds' def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--nvme_dir', + nargs='+', required=True, - type=str, help='Directory in which to perform I/O tests. A writeable directory on a NVMe device.') parser.add_argument('--sweep_config', type=str, default=None, help='Performance sweep configuration json file.') @@ -92,6 +64,10 @@ def parse_arguments(): default="400M", help='Number of I/O bytes to read/write for performance measurements.') + parser.add_argument('--gpu', action='store_true', help='Test tensor transfers between GPU device and NVME device.') + + parser.add_argument('--gds', action='store_true', help='Run the sweep over NVIDIA GPUDirectStorage operator') + parser.add_argument( '--no_sudo', action='store_true', @@ -118,6 +94,12 @@ def dump_cmd_lines(cmd_lines): print(f'{i}: {cmd}') +def get_ftd_map(nvme_dir_list): + ftd_list = [f'{dir}:{dev}' for dev, dir in enumerate(nvme_dir_list)] + ftd_arg = [' '.join(ftd for ftd in ftd_list)] + return {'folder_to_device_mapping': ftd_arg} + + def get_sweep_config_dict(sweep_config_json): if sweep_config_json is None: return DEFAULT_SWEEP_CONFIG @@ -148,16 +130,6 @@ def flatten_options(key, value_list): return cmd_list -def run_job(job): - args = ' '.join(job.cmd()) - print(f'args = {args}') - job.open_output_file() - proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd()) - job.close_output_file() - assert proc.returncode == 0, \ - f"This command failed: {job.cmd()}" - - def launch_sweep(sweep_jobs, sync_job, flush_cache_job): for perf_job in sweep_jobs: if flush_cache_job is not None: @@ -176,7 +148,12 @@ def create_cmd_tags(cmd_line): if len(fields) == 1: tags[fields[0]] = None elif len(fields) == 2: - tags[fields[0]] = fields[1] + if fields[0] == '--folder_to_device_mapping': + tags[fields[0]] = len(fields[1:]) + else: + tags[fields[0]] = fields[1] + elif len(fields) > 2: + tags[fields[0]] = len(fields[1:]) return tags @@ -184,16 +161,16 @@ def get_log_file(io_op_desc, cmd_line): QUEUE_DEPTH = "--queue_depth" BLOCK_SIZE = "--block_size" SINGLE_SUBMIT = "--single_submit" - OVERLAP_EVENTS = "--overlap_events" - THREAD_COUNT = "--threads" + SEQUENTIAL_REQUESTS = "--sequential_requests" + FTD_MAP = "--folder_to_device_mapping" IO_PARALLEL = "--io_parallel" tag_map = { QUEUE_DEPTH: "d", BLOCK_SIZE: "bs", SINGLE_SUBMIT: "single", - OVERLAP_EVENTS: "overlap", - THREAD_COUNT: "t", + SEQUENTIAL_REQUESTS: "sequential", + FTD_MAP: "ftd", IO_PARALLEL: "p" } @@ -201,14 +178,14 @@ def get_log_file(io_op_desc, cmd_line): QUEUE_DEPTH: 1, BLOCK_SIZE: "1M", SINGLE_SUBMIT: "block", - OVERLAP_EVENTS: "sequential", - THREAD_COUNT: 1, + SEQUENTIAL_REQUESTS: "overlap", + FTD_MAP: 1, IO_PARALLEL: 1 } def get_default_value(tag): value = tag_default[tag] - if tag in [SINGLE_SUBMIT, OVERLAP_EVENTS]: + if tag in [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS]: return value return f'{tag_map[tag]}{value}' @@ -218,7 +195,7 @@ def get_config_value(tag, value): return tag_key return f'{tag_key}{value}' - tag_list = [SINGLE_SUBMIT, OVERLAP_EVENTS, THREAD_COUNT, IO_PARALLEL, QUEUE_DEPTH, BLOCK_SIZE] + tag_list = [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS, FTD_MAP, QUEUE_DEPTH, BLOCK_SIZE, IO_PARALLEL] log_tags = [io_op_desc] cmd_tags = create_cmd_tags(cmd_line) for tag in tag_list: @@ -252,40 +229,14 @@ def async_io_setup(): return AsyncIOBuilder().is_compatible() -def get_block_size_and_count(io_bytes): - block_size = 1 - block_count = io_bytes - bytes_in_KB = 1024 - - while block_count % bytes_in_KB == 0: - block_size *= bytes_in_KB - block_count /= bytes_in_KB - - return int(block_size), int(block_count) - - -def create_read_file(sweep_config): - read_folder = os.path.join(sweep_config.nvme_dir, f'{READ_IO_DIR}') - os.makedirs(read_folder, exist_ok=True) - read_file_name = os.path.join(read_folder, f'random_{sweep_config.io_size}B.pt') - block_size, block_count = get_block_size_and_count(refine_integer_value(sweep_config.io_size)) - dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={read_file_name} bs={block_size} count={block_count}']) - print(f'[Start] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....') - run_job(dd_job) - print(f'[Done] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....') - return read_folder, read_file_name - - def remove_folder(folder): assert os.path.isdir(folder), f"Error: cannot remove {folder} - folder not found" shutil.rmtree(folder) def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): - read_folder, read_file_name = create_read_file(sweep_config) - read_option = f'--read_file {read_file_name}' - read_cmd_lines = [[f'{read_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines] - #dump_cmd_lines(read_cmd_lines) + read_cmd_lines = [[f'--read {sweep_config.other_options}'] + cmd for cmd in cmd_lines] + #dump_cmd_lines(cmd_lines) log_folder = os.path.join(sweep_config.log_dir, f'{READ_LOG_DIR}') os.makedirs(log_folder, exist_ok=True) @@ -294,15 +245,9 @@ def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job) - remove_folder(read_folder) - def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): - write_folder = os.path.join(sweep_config.nvme_dir, f'{WRITE_IO_DIR}') - os.makedirs(write_folder, exist_ok=True) - write_file_name = os.path.join(write_folder, f'random_{sweep_config.io_size}B.pt') - write_option = f'--write_size {sweep_config.io_size} --write_file {write_file_name}' - write_cmd_lines = [[f'{write_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines] + write_cmd_lines = [[f'{sweep_config.other_options}'] + cmd for cmd in cmd_lines] #dump_cmd_lines(write_cmd_lines) log_folder = os.path.join(sweep_config.log_dir, f'{WRITE_LOG_DIR}') @@ -312,8 +257,6 @@ def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job) - remove_folder(write_folder) - def main(): print("Running performance sweep of deepspeed nvme library") diff --git a/csrc/aio/py_test/ds_aio_args.py b/csrc/aio/py_test/ds_aio_args.py new file mode 100644 index 000000000000..346feabe4810 --- /dev/null +++ b/csrc/aio/py_test/ds_aio_args.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import argparse +import os +from test_ds_aio_utils import refine_integer_value +from deepspeed.accelerator import get_accelerator + +MAPPING_DELIMITER = ':' + + +def refine_args(args): + if args.io_size and type(args.io_size) == str: + args.io_size = refine_integer_value(args.io_size) + + if args.block_size and type(args.block_size) == str: + args.block_size = refine_integer_value(args.block_size) + + return args + + +def _get_mapping_dict(args): + if args.folder is not None: + d = {i: args.folder for i in range(args.multi_process)} + else: + d = {} + for m in args.folder_to_device_mapping: + fields = m.split(MAPPING_DELIMITER) + d[fields[1]] = fields[0] + + return d + + +def _validate_folder_mapping(args): + no_error = True + error_messages = [] + invalid_mappings = [m for m in args.folder_to_device_mapping if MAPPING_DELIMITER not in m] + if len(invalid_mappings) > 0: + error_messages.append( + f'Missing delimiter ({MAPPING_DELIMITER}) in folder_to_device_mapping {invalid_mappings}') + no_error = False + + folder_list = [m.split(MAPPING_DELIMITER)[0] for m in args.folder_to_device_mapping] + invalid_folders = [d for d in folder_list if not os.path.exists(d)] + if len(invalid_folders) > 0: + error_messages.append(f'Invalid folders in folder_to_device_mapping: {invalid_folders}') + no_error = False + + if args.gpu: + device_list = [int(m.split(MAPPING_DELIMITER)[1]) for m in args.folder_to_device_mapping] + invalid_device_list = [dev_id for dev_id in device_list if not dev_id < get_accelerator().device_count()] + if len(invalid_device_list) > 0: + error_messages.append(f'Invalid device ids in folder_to_device_mapping: {invalid_device_list}') + no_error = False + + return no_error, error_messages + + +def validate_args(args): + no_error = True + error_messages = [] + + if args.folder is not None and len(args.folder_to_device_mapping) > 0: + error_messages.append(f'--folder and --folder_to_device_mapping cannot be specified together.') + no_error = False + elif args.folder is None and len(args.folder_to_device_mapping) == 0: + error_messages.append(f'At least one of --folder or --folder_to_device_mapping must be specified.') + no_error = False + + # Validate --folder + if args.folder is not None and not os.path.exists(args.folder): + no_error = False + error_messages.append(f'Invalid folder in --folder: {args.folder} ') + + # Validate --folder_mapping_to_device + if len(args.folder_to_device_mapping) > 0: + no_mapping_error, mapping_error_messages = _validate_folder_mapping(args) + no_error = no_error and no_mapping_error + error_messages += mapping_error_messages + + # Validate --gpu, --use_gds + if args.use_gds and not args.gpu: + error_messages.append(f'--gpu must be set to transfer with --use_gds') + no_error = False + + if not no_error: + print(f'Found {len(error_messages)} validation errors') + for i, msg in enumerate(error_messages): + print(f'{i+1}: {msg}') + + return no_error + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--folder', default=None, type=str, help='Folder to use for I/O.') + + parser.add_argument('--folder_to_device_mapping', + default=[], + nargs='+', + help='Specification of mapping of folder to (gpu) device id, (ignored for cpu accesses).' + 'Can be specified multiple times for multi-process runs,' + 'e.g. --folder_to_device_mapping /mnt/nvme0:0 --folder_to_device_mapping /mnt/nvme1:15 --gpu' + 'means access /mnt/nvme0 with gpu 0 and /mnt/nvme1 with gpu 15') + + parser.add_argument('--io_size', type=str, default=None, required=True, help='Number of bytes to read or write.') + + parser.add_argument('--read', action='store_true', help='Perform read I/O (default is write)') + + parser.add_argument('--multi_process', + type=int, + default=1, + help='Number of parallel processes doing I/O (default 1).') + + parser.add_argument('--block_size', + type=str, + default='1M', + help='I/O block size. Can use K, M, or G suffix (default 1M for 1 megabytes).') + + parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth (default 32).') + + parser.add_argument('--single_submit', + action='store_true', + help='Submit I/O requests in singles (default is submit queue_depth amount at once.).') + + parser.add_argument( + '--sequential_requests', + action='store_true', + help= + 'Delay I/O request submission until completion of prior requests (default is overlap I/O submission and completion requests.).' + ) + + parser.add_argument('--validate', action='store_true', help='Perform validation of I/O transfer in library.') + + parser.add_argument('--handle', action='store_true', help='Use AIO handle.') + + parser.add_argument('--loops', type=int, default=3, help='Count of operation repetitions') + + parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism') + + parser.add_argument('--gpu', action='store_true', help='Use GPU memory') + + parser.add_argument('--use_gds', action='store_true', help='Enable GDS AIO') + + parser.add_argument('--slow_bounce_buffer', + action='store_true', + help='For GPU memory transfers, measure impact of bounce buffer pinning on critical path.') + + args = parser.parse_args() + print(f'args = {args}') + return args + + +def get_validated_args(): + args = parse_arguments() + args = refine_args(args) + if not validate_args(args): + quit() + print(f'Successful validation of command line arguments') + + peer_tag = 'gpu' if args.gpu else 'process' + args.mapping_dict = _get_mapping_dict(args) + args.mapping_list = [(device_id, folder) for device_id, folder in args.mapping_dict.items()] + assert len(args.mapping_dict) == len(args.mapping_list) + print(f'Configuring {len(args.mapping_list)} {peer_tag} to folder mapping') + for i, (device_id, folder) in enumerate(args.mapping_list): + print(f'[{i}]: {peer_tag} {device_id} <----> {folder}') + + return args diff --git a/csrc/aio/py_test/ds_aio_basic.py b/csrc/aio/py_test/ds_aio_basic.py index ad2a4349cd0c..9b3c7cbfc49f 100755 --- a/csrc/aio/py_test/ds_aio_basic.py +++ b/csrc/aio/py_test/ds_aio_basic.py @@ -9,10 +9,9 @@ import torch import os import time +from deepspeed.ops.aio import AsyncIOBuilder from multiprocessing import Pool, Barrier from test_ds_aio_utils import report_results, task_log, task_barrier -from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import AsyncIOBuilder def pre_basic(args, tid, read_op): @@ -21,7 +20,7 @@ def pre_basic(args, tid, read_op): file = args.read_file if read_op else f'{args.write_file}.{tid}' task_log(tid, f'Allocate tensor of size {num_bytes} bytes') - buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu')) + buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}') ctxt = {} @@ -56,7 +55,7 @@ def main_basic_read(pool_params): args, tid, ctxt = pool_params start_time = time.time() AsyncIOBuilder().load().aio_read(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth, - args.single_submit, args.overlap_events, args.validate) + args.single_submit, not args.sequential_requests, args.validate) end_time = time.time() ctxt['elapsed_sec'] += end_time - start_time @@ -67,7 +66,7 @@ def main_basic_write(pool_params): args, tid, ctxt = pool_params start_time = time.time() AsyncIOBuilder().load().aio_write(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth, - args.single_submit, args.overlap_events, args.validate) + args.single_submit, not args.sequential_requests, args.validate) end_time = time.time() ctxt['elapsed_sec'] += end_time - start_time @@ -90,16 +89,17 @@ def get_schedule(args, read_op): def _aio_handle_tasklet(pool_params): args, tid, read_op = pool_params + num_processes = len(args.mapping_dict) # Create schedule schedule = get_schedule(args, read_op) task_log(tid, f'schedule = {schedule}') - task_barrier(aio_barrier, args.threads) + task_barrier(aio_barrier, num_processes) # Run pre task task_log(tid, f'running pre-task') ctxt = schedule["pre"]((args, tid)) - task_barrier(aio_barrier, args.threads) + task_barrier(aio_barrier, num_processes) # Run main tasks in a loop ctxt["main_task_sec"] = 0 @@ -107,14 +107,14 @@ def _aio_handle_tasklet(pool_params): task_log(tid, f'running main task {i}') start_time = time.time() ctxt = schedule["main"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) + task_barrier(aio_barrier, num_processes) stop_time = time.time() ctxt["main_task_sec"] += stop_time - start_time # Run post task task_log(tid, f'running post-task') ctxt = schedule["post"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) + task_barrier(aio_barrier, num_processes) return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops @@ -125,9 +125,10 @@ def _init_tasklet(b): def aio_basic_multiprocessing(args, read_op): - b = Barrier(args.threads) - pool_params = [(args, p, read_op) for p in range(args.threads)] - with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: + num_processes = len(args.mapping_dict) + b = Barrier(num_processes) + pool_params = [(args, p, read_op) for p in range(num_processes)] + with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p: pool_results = p.map(_aio_handle_tasklet, pool_params) report_results(args, read_op, pool_results) diff --git a/csrc/aio/py_test/ds_aio_handle.py b/csrc/aio/py_test/ds_aio_handle.py index d35b2713edae..f4a179deb9ec 100755 --- a/csrc/aio/py_test/ds_aio_handle.py +++ b/csrc/aio/py_test/ds_aio_handle.py @@ -10,40 +10,56 @@ import os import time from multiprocessing import Pool, Barrier -from test_ds_aio_utils import report_results, task_log, task_barrier +from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.ops.op_builder import GDSBuilder +from test_ds_aio_utils import report_results, task_log, task_barrier, create_filename, create_file from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import AsyncIOBuilder + +BUFFER = 'buffer' +BOUNCE_BUFFER = 'bounce_buffer' def pre_handle(args, tid, read_op): io_string = "Read" if read_op else "Write" - num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size - file = args.read_file if read_op else f'{args.write_file}.{tid}' - - io_parallel = args.io_parallel if args.io_parallel else 1 - handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit, - args.overlap_events, io_parallel) - task_log(tid, f'Created deepspeed aio handle') - + gds = True if args.use_gds else False + device_id, folder = args.mapping_list[tid] + filename = create_filename(folder, args.read, args.io_size, tid) + if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size): + create_file(filename, args.io_size) + + task_log(tid, f'Allocate tensor of size {args.io_size} bytes') + bounce_buffer = None if args.gpu: - buffer = torch.empty(num_bytes, dtype=torch.uint8, device=get_accelerator().device_name()) + device_name = get_accelerator().device_name(device_id) + buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=device_name) + if not (args.slow_bounce_buffer or gds): + bounce_buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, + device='cpu').pin_memory() else: - if args.use_accelerator_pin_memory: - buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu')) - else: - buffer = handle.new_cpu_locked_tensor(num_bytes, torch.empty(0, dtype=torch.uint8)) + buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device='cpu').pin_memory() + task_log(tid, + f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}', + force=True) - task_log(tid, f'Allocate tensor of size {num_bytes} bytes') + io_parallel = args.io_parallel if args.io_parallel else 1 + if gds: + handle = GDSBuilder().load().gds_handle(args.block_size, args.queue_depth, args.single_submit, + not args.sequential_requests, io_parallel) + handle.pin_device_tensor(buffer) + else: + handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit, + not args.sequential_requests, io_parallel) + task_log(tid, f'created deepspeed aio handle') ctxt = {} - ctxt['file'] = file - ctxt['num_bytes'] = num_bytes + ctxt['file'] = filename + ctxt['num_bytes'] = args.io_size ctxt['handle'] = handle - ctxt['buffer'] = buffer + ctxt['gds'] = gds + ctxt[BUFFER] = buffer + ctxt[BOUNCE_BUFFER] = bounce_buffer ctxt['elapsed_sec'] = 0 - task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}') - return ctxt @@ -61,8 +77,12 @@ def pre_handle_write(pool_params): def post_handle(pool_params): _, _, ctxt = pool_params - ctxt["buffer"].detach() - ctxt["buffer"] = None + for buf in [BUFFER, BOUNCE_BUFFER]: + if ctxt[buf] is not None: + if ctxt['gds']: + ctxt['handle'].unpin_device_tensor(ctxt[buf]) + ctxt[buf].detach() + ctxt[buf] = None return ctxt @@ -71,20 +91,31 @@ def main_parallel_read(pool_params): handle = ctxt['handle'] start_time = time.time() - ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True) + dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER + ret = handle.pread(ctxt[dest_buffer], ctxt['file'], args.validate, True) assert ret != -1 handle.wait() + if dest_buffer == BOUNCE_BUFFER: + ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data) end_time = time.time() ctxt['elapsed_sec'] += end_time - start_time - return ctxt def main_parallel_write(pool_params): args, tid, ctxt = pool_params + # Avoid overwriting existing files as it could be artificially faster + if os.path.isfile(ctxt['file']): + os.remove(ctxt['file']) + handle = ctxt['handle'] start_time = time.time() - ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True) + if ctxt[BOUNCE_BUFFER] is not None: + source_buffer = BOUNCE_BUFFER + ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data) + else: + source_buffer = BUFFER + ret = handle.pwrite(ctxt[source_buffer], ctxt['file'], args.validate, True) assert ret != -1 handle.wait() end_time = time.time() @@ -98,8 +129,11 @@ def main_handle_read(pool_parms): handle = ctxt['handle'] start_time = time.time() - ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate) + dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER + ret = handle.read(ctxt[dest_buffer], ctxt['file'], args.validate) assert ret != -1 + if dest_buffer == BOUNCE_BUFFER: + ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data) end_time = time.time() ctxt['elapsed_sec'] += end_time - start_time @@ -108,9 +142,18 @@ def main_handle_read(pool_parms): def main_handle_write(pool_parms): args, tid, ctxt = pool_parms + # Avoid overwriting existing files as it could be artificially faster + if os.path.isfile(ctxt['file']): + os.remove(ctxt['file']) + handle = ctxt['handle'] start_time = time.time() - ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate) + if ctxt[BOUNCE_BUFFER] is not None: + source_buffer = BOUNCE_BUFFER + ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data) + else: + source_buffer = BUFFER + ret = handle.write(ctxt[source_buffer], ctxt['file'], args.validate) assert ret != -1 end_time = time.time() ctxt['elapsed_sec'] += end_time - start_time @@ -123,27 +166,28 @@ def get_schedule(args, read_op): if read_op: schedule['pre'] = pre_handle_read schedule['post'] = post_handle - schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read + schedule['main'] = main_parallel_read else: schedule['pre'] = pre_handle_write schedule['post'] = post_handle - schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write + schedule['main'] = main_parallel_write return schedule def _aio_handle_tasklet(pool_params): args, tid, read_op = pool_params + num_processes = len(args.mapping_dict) # Create schedule schedule = get_schedule(args, read_op) task_log(tid, f'schedule = {schedule}') - task_barrier(aio_barrier, args.threads) + task_barrier(aio_barrier, num_processes) # Run pre task task_log(tid, f'running pre-task') ctxt = schedule["pre"]((args, tid)) - task_barrier(aio_barrier, args.threads) + task_barrier(aio_barrier, num_processes) # Run main tasks in a loop ctxt["main_task_sec"] = 0 @@ -151,14 +195,14 @@ def _aio_handle_tasklet(pool_params): task_log(tid, f'running main task {i}') start_time = time.time() ctxt = schedule["main"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) + task_barrier(aio_barrier, num_processes) stop_time = time.time() ctxt["main_task_sec"] += stop_time - start_time # Run post task task_log(tid, f'running post-task') ctxt = schedule["post"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) + task_barrier(aio_barrier, num_processes) return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops @@ -169,9 +213,10 @@ def _init_tasklet(b): def aio_handle_multiprocessing(args, read_op): - b = Barrier(args.threads) - pool_params = [(args, p, read_op) for p in range(args.threads)] - with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: + num_processes = len(args.mapping_dict) + b = Barrier(num_processes) + pool_params = [(args, p, read_op) for p in range(num_processes)] + with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p: pool_results = p.map(_aio_handle_tasklet, pool_params) report_results(args, read_op, pool_results) diff --git a/csrc/aio/py_test/ds_aio_job.py b/csrc/aio/py_test/ds_aio_job.py new file mode 100644 index 000000000000..bbddee1bf26d --- /dev/null +++ b/csrc/aio/py_test/ds_aio_job.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping tensors to/from (NVMe) storage devices. +""" +import subprocess + + +class Job(object): + + def __init__(self, cmd_line, output_file=None, work_dir=None): + self.cmd_line = cmd_line + self.output_file = output_file + self.work_dir = work_dir + self.output_fd = None + + def cmd(self): + return self.cmd_line + + def get_stdout(self): + return self.output_fd + + def get_stderr(self): + return self.output_fd + + def get_cwd(self): + return self.work_dir + + def open_output_file(self): + if self.output_file is not None: + self.output_fd = open(self.output_file, 'w') + + def close_output_file(self): + if self.output_fd is not None: + self.output_fd.close() + self.output_fd = None + + +def run_job(job): + args = ' '.join(job.cmd()) + print(f'args = {args}') + job.open_output_file() + proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd()) + job.close_output_file() + assert proc.returncode == 0, \ + f"This command failed: {job.cmd()}" diff --git a/csrc/aio/py_test/run_read_sweep.sh b/csrc/aio/py_test/run_read_sweep.sh index b9d7e050454a..59d82996a0e2 100755 --- a/csrc/aio/py_test/run_read_sweep.sh +++ b/csrc/aio/py_test/run_read_sweep.sh @@ -1,13 +1,22 @@ #!/bin/bash -if [[ $# -ne 2 ]]; then - echo "Usage: $0 " +if [[ $# -lt 2 ]]; then + echo "Usage: $0 " exit 1 fi +function prep_folder() +{ + folder=$1 + if [[ -d ${folder} ]]; then + rm -f ${folder}/* + else + mkdir -p ${folder} + fi +} function validate_environment() { - validate_cmd="python ./validate_async_io.py" + validate_cmd="TORCH_EXTENSIONS_DIR=./torch_extentions python3 ./validate_async_io.py" eval ${validate_cmd} res=$? if [[ $res != 0 ]]; then @@ -17,18 +26,27 @@ function validate_environment() fi } +function fileExists() { + local file="$1" + if [[ -f "$file" ]]; then + return 0 + else + return 1 + fi +} validate_environment -INPUT_FILE=$1 -if [[ ! -f ${INPUT_FILE} ]]; then - echo "Input file not found: ${INPUT_FILE}" - exit 1 -fi - -LOG_DIR=$2/aio_perf_sweep +IO_SIZE=$1 +LOG_DIR=./aio_perf_sweep +MAP_DIR=$2/aio +GPU_MEM=$3 +USE_GDS=$4 RUN_SCRIPT=./test_ds_aio.py -READ_OPT="--read_file ${INPUT_FILE}" +READ_OPT="--read" + +prep_folder ${MAP_DIR} +prep_folder ${LOG_DIR} if [[ -d ${LOG_DIR} ]]; then rm -f ${LOG_DIR}/* @@ -36,37 +54,60 @@ else mkdir -p ${LOG_DIR} fi -DISABLE_CACHE="sync; sudo bash -c 'echo 1 > /proc/sys/vm/drop_caches' " -SYNC="sync" +if [[ ${GPU_MEM} == "gpu" ]]; then + gpu_opt="--gpu" +else + gpu_opt="" +fi +if [[ ${USE_GDS} == "gds" ]]; then + gds_opt="--use_gds" +else + gds_opt="" +fi + +DISABLE_CACHE="sudo sync; sudo bash -c 'echo 1 > /proc/sys/vm/drop_caches' " +SYNC="sudo sync" -for sub in single block; do - if [[ $sub == "single" ]]; then - sub_opt="--single_submit" +for xtype in cpu gpu gds; do + if [[ $xtype == "cpu" ]]; then + gpu_opt="" + gds_opt="" + elif [[ $xtype == "gpu" ]]; then + gpu_opt="--gpu" + gds_opt="" else - sub_opt="" + gpu_opt="--gpu" + gds_opt="--use_gds" fi - for ov in overlap sequential; do - if [[ $ov == "overlap" ]]; then - ov_opt="--overlap_events" + for sub in single block; do + if [[ $sub == "single" ]]; then + sub_opt="--single_submit" else - ov_opt="" + sub_opt="" fi - for t in 1 2 4 8; do - for p in 1 ; do - for d in 1 2 4 8 16 32; do - for bs in 128K 256K 512K 1M; do - SCHED_OPTS="${sub_opt} ${ov_opt} --handle --threads ${t}" - OPTS="--io_parallel ${p} --queue_depth ${d} --block_size ${bs}" - LOG="${LOG_DIR}/read_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt" - cmd="python ${RUN_SCRIPT} ${READ_OPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}" - echo ${DISABLE_CACHE} - echo ${cmd} - echo ${SYNC} + for ov in overlap sequential; do + if [[ $ov == "sequential" ]]; then + ov_opt="--sequential_requests" + else + ov_opt="" + fi + for p in 1 2 4 8; do + for t in 1 2 4 8; do + for d in 8 16 32 64 128; do + for bs in 128K 256K 512K 1M 2M 4M 8M 16M; do + SCHED_OPTS="${sub_opt} ${ov_opt} --handle ${gpu_opt} ${gds_opt} --folder_to_device_mapping /mnt/nvme01:0" + OPTS="--queue_depth ${d} --block_size ${bs} --io_size ${IO_SIZE} --io_parallel ${t}" + LOG="${LOG_DIR}/read_${xtype}_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt" + cmd="/usr/bin/time python ${RUN_SCRIPT} ${READ_OPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}" - eval ${DISABLE_CACHE} - eval ${cmd} - eval ${SYNC} - sleep 2 + echo ${DISABLE_CACHE} + echo ${cmd} + echo ${SYNC} + eval ${DISABLE_CACHE} + eval ${cmd} + eval ${SYNC} + sleep 2 + done done done done diff --git a/csrc/aio/py_test/run_write_sweep.sh b/csrc/aio/py_test/run_write_sweep.sh index 99f2113dda6f..a54d1c8d7bed 100755 --- a/csrc/aio/py_test/run_write_sweep.sh +++ b/csrc/aio/py_test/run_write_sweep.sh @@ -25,25 +25,33 @@ function validate_environment() validate_environment -if [[ $# -ne 3 ]]; then - echo "Usage: $0 " - exit 1 -fi - -SIZE="$1M" -WRITE_DIR=$2 -LOG_DIR=$3/aio_perf_sweep +IO_SIZE=$1 +LOG_DIR=$2/aio_perf_sweep +MAP_DIR=$2/aio +GPU_MEM=$3 +USE_GDS=$4 +RUN_SCRIPT=./test_ds_aio.py -OUTPUT_FILE=${WRITE_DIR}/ds_aio_write_${SIZE}B.pt -WRITE_OPT="--write_file ${OUTPUT_FILE} --write_size ${SIZE}" +OUTPUT_FILE=${MAP_DIR}/ds_aio_write_${SIZE}B.pt +WRITE_OPT="" -prep_folder ${WRITE_DIR} +prep_folder ${MAP_DIR} prep_folder ${LOG_DIR} -RUN_SCRIPT=./test_ds_aio.py -DISABLE_CACHE="sync; sudo bash -c 'echo 1 > /proc/sys/vm/drop_caches' " +if [[ ${GPU_MEM} == "gpu" ]]; then + gpu_opt="--gpu" +else + gpu_opt="" +fi +if [[ ${USE_GDS} == "gds" ]]; then + gds_opt="--use_gds" +else + gds_opt="" +fi + +DISABLE_CACHE="sync; bash -c 'echo 1 > /proc/sys/vm/drop_caches' " SYNC="sync" for sub in single block; do @@ -53,19 +61,19 @@ for sub in single block; do sub_opt="" fi for ov in overlap sequential; do - if [[ $ov == "overlap" ]]; then - ov_opt="--overlap_events" + if [[ $ov == "sequential" ]]; then + ov_opt="--sequential_requests" else ov_opt="" fi - for t in 1 2 4 8; do - for p in 1; do - for d in 1 2 4 8 16 32; do - for bs in 128K 256K 512K 1M; do - SCHED_OPTS="${sub_opt} ${ov_opt} --handle --threads ${t}" - OPTS="--io_parallel ${p} --queue_depth ${d} --block_size ${bs}" + for p in 1 2 4 8; do + for t in 1 2 4 8; do + for d in 32 64 128; do + for bs in 256K 512K 1M; do + SCHED_OPTS="${sub_opt} ${ov_opt} --handle ${gpu_opt} ${gds_opt} --folder ${MAP_DIR}" + OPTS="--queue_depth ${d} --block_size ${bs} --io_size ${IO_SIZE} --multi_process ${p} --io_parallel ${t}" LOG="${LOG_DIR}/write_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt" - cmd="python ${RUN_SCRIPT} ${WRITE_OPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}" + cmd="python ${RUN_SCRIPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}" echo ${DISABLE_CACHE} echo ${cmd} echo ${SYNC} diff --git a/csrc/aio/py_test/test_ds_aio.py b/csrc/aio/py_test/test_ds_aio.py index e6242cb35789..6de72755e9e5 100755 --- a/csrc/aio/py_test/test_ds_aio.py +++ b/csrc/aio/py_test/test_ds_aio.py @@ -6,79 +6,19 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices. """ -import os -import argparse import multiprocessing as mp from ds_aio_basic import aio_basic_multiprocessing from ds_aio_handle import aio_handle_multiprocessing -from test_ds_aio_utils import refine_args - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - parser.add_argument('--read_file', type=str, default=None, help='Read file.') - - parser.add_argument('--write_file', type=str, default=None, help='Write file.') - - parser.add_argument('--write_size', type=str, default=None, help='Number of bytes to write.') - - parser.add_argument('--block_size', type=str, default='1M', help='I/O block size.') - - parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth.') - - parser.add_argument('--threads', type=int, default=1, help='Thread parallelism count.') - - parser.add_argument('--single_submit', - action='store_true', - help='Submit I/O requests in singles (default is submit queue_depth amount at once.).') - - parser.add_argument('--overlap_events', - action='store_true', - help='Overlap I/O submission and completion requests.') - - parser.add_argument('--validate', action='store_true', help='Perform validation in library.') - - parser.add_argument('--handle', action='store_true', help='Use AIO handle.') - - parser.add_argument('--loops', type=int, default=1, help='Count of operation repetitions') - - parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism') - - parser.add_argument('--gpu', action='store_true', help='Use GPU memory') - - parser.add_argument('--use_accelerator_pin_memory', - action='store_true', - help='Obtain pinned (CPU page-locked) tensors from accelerator') - - args = parser.parse_args() - print(f'args = {args}') - return args - - -def validate_args(args): - if args.read_file and not os.path.isfile(args.read_file): - print(f'args validation error: {args.read_file} not found') - return False - - return True +from ds_aio_args import get_validated_args def main(): print(f'Testing deepspeed_aio python frontend') - args = parse_arguments() - refine_args(args) - if not validate_args(args): - quit() - + args = get_validated_args() mp.set_start_method('spawn') multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing - if args.read_file: - multiprocess_function(args, True) - - if args.write_file: - multiprocess_function(args, False) + multiprocess_function(args, args.read) if __name__ == "__main__": diff --git a/csrc/aio/py_test/test_ds_aio_utils.py b/csrc/aio/py_test/test_ds_aio_utils.py index 6aad114c0bdc..968ff4a60ef9 100755 --- a/csrc/aio/py_test/test_ds_aio_utils.py +++ b/csrc/aio/py_test/test_ds_aio_utils.py @@ -6,12 +6,17 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices. """ +import os +from ds_aio_job import Job, run_job + BYTES_PER_GB = 1024**3 +BYTES_PER_MB = 1024**2 +BYTES_PER_KB = 1024 LOG_TIDS = [0] -def task_log(tid, msg): - if tid in LOG_TIDS: +def task_log(tid, msg, force=False): + if force or tid in LOG_TIDS: print(f'tid {tid}: {msg}') @@ -31,16 +36,29 @@ def report_results(args, read_op, pool_results): total_bytes = sum([num_bytes for _, _, num_bytes in pool_results]) task_latency_sec = max([sec for _, sec, _ in pool_results]) - task_speed_GB = total_bytes / task_latency_sec / BYTES_PER_GB + task_speed_GB = 0 if task_latency_sec == 0 else total_bytes / task_latency_sec / BYTES_PER_GB print(f'Task {io_string} Latency = {task_latency_sec} sec') print(f'Task {io_string} Speed = {task_speed_GB} GB/sec') e2e_latency_sec = max([sec for sec, _, _ in pool_results]) - e2e_speed_GB = total_bytes / e2e_latency_sec / BYTES_PER_GB + e2e_speed_GB = 0 if e2e_latency_sec == 0 else total_bytes / e2e_latency_sec / BYTES_PER_GB print(f'E2E {io_string} Latency = {e2e_latency_sec} sec') print(f'E2E {io_string} Speed = {e2e_speed_GB} GB/sec') +def get_block_size_and_count(io_bytes): + if io_bytes > BYTES_PER_MB and io_bytes % BYTES_PER_MB == 0: + block_size = BYTES_PER_MB + block_size_string = '1M' + else: + assert io_bytes % BYTES_PER_KB == 0 + block_size = BYTES_PER_KB + block_size_string = '1K' + block_count = io_bytes / block_size + + return block_size_string, int(block_count) + + def refine_integer_value(value): unit_dict = {'K': 1024, 'M': 1024**2, 'G': 1024**3} @@ -50,9 +68,14 @@ def refine_integer_value(value): return int(value) -def refine_args(args): - if args.write_size and type(args.write_size) == str: - args.write_size = refine_integer_value(args.write_size) +def create_filename(folder, read_op, size, tid): + io_string = "read" if read_op else "write" + return os.path.join(folder, f'_aio_{io_string}_{size}.pt.{tid}') + - if args.block_size and type(args.block_size) == str: - args.block_size = refine_integer_value(args.block_size) +def create_file(filename, num_bytes): + block_size, block_count = get_block_size_and_count(num_bytes) + dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={filename} bs={block_size} count={block_count}']) + print(f'[Start] Create {filename} of {num_bytes} bytes by running {dd_job.cmd()} ....') + run_job(dd_job) + print(f'[Done] Create read file of {num_bytes} bytes by running {dd_job.cmd()} ....') diff --git a/csrc/aio/py_test/validate_async_io.py b/csrc/aio/py_test/validate_async_io.py index 019ec05d49d3..10fb638347bc 100644 --- a/csrc/aio/py_test/validate_async_io.py +++ b/csrc/aio/py_test/validate_async_io.py @@ -7,3 +7,4 @@ """ from deepspeed.ops.op_builder import AsyncIOBuilder assert AsyncIOBuilder().is_compatible() +assert AsyncIOBuilder().load() diff --git a/csrc/gds/py_lib/deepspeed_gds_op.cpp b/csrc/gds/py_lib/deepspeed_gds_op.cpp new file mode 100644 index 000000000000..c370a448e5a2 --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_gds_op.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_gds_op.h" + +using namespace std; + +// For when there is more than 1 device +static std::map> base_ptr_registry; + +static void _safe_handle_register(const int fd, CUfileDescr_t& cf_descr, CUfileHandle_t& cf_handle) +{ + memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t)); + cf_descr.handle.fd = fd; + cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr); + if (status.err != CU_FILE_SUCCESS) { + std::cerr << "file register error:" << cuFileGetErrorString(status) << std::endl; + close(fd); + exit(EXIT_FAILURE); + } +} + +static void* _find_base_ptr(const int64_t device, char* buf_ptr) +{ + void* base_ptr = nullptr; + int64_t last = -1; + int64_t ptr_diff; + for (const auto& value : base_ptr_registry[device]) { + ptr_diff = buf_ptr - (char*)value; + if (last == -1 && ptr_diff >= 0) { + last = ptr_diff; + base_ptr = value; + } else if (ptr_diff < last && ptr_diff >= 0) { + last = ptr_diff; + base_ptr = value; + } + } + if (!base_ptr || buf_ptr < base_ptr) { + std::cerr << "BASE PTR ERROR :" << base_ptr << " BUF PTR " << (void*)buf_ptr << std::endl; + for (const auto& value : base_ptr_registry[device]) { + std::cerr << "BASE PTR AVAIL :" << value << std::endl; + } + exit(EXIT_FAILURE); + } + + return base_ptr; +} + +void gds_op_desc_t::add_buffer_to_registry(const torch::Tensor& buffer) +{ + const int64_t device = buffer.get_device(); + void* reg_ptr = buffer.data_ptr(); + + // std::cout << "REG PTR " << reg_ptr << std::endl; + // TODO: add checking to make sure pointer isn't already in set + const auto it = base_ptr_registry.find(device); + if (it == base_ptr_registry.end()) { + std::set new_ptr_set; + new_ptr_set.insert(reg_ptr); + base_ptr_registry.insert(std::pair>(device, new_ptr_set)); + } else { + base_ptr_registry[device].insert(reg_ptr); + } + + check_cudaruntimecall(cudaSetDevice(device)); + CUfileError_t status = cuFileBufRegister(reg_ptr, buffer.nbytes(), 0); + if (status.err != CU_FILE_SUCCESS) { + std::cerr << "buffer register failed:" << cuFileGetErrorString(status) << std::endl; + exit(EXIT_FAILURE); + } +} + +void gds_op_desc_t::remove_buffer_from_registry(const torch::Tensor& buffer) +{ + const int64_t device = buffer.get_device(); + void* reg_ptr = buffer.data_ptr(); + + // std::cout << "DEREG PTR " << reg_ptr << std::endl; + check_cudaruntimecall(cudaSetDevice(device)); + cuFileBufDeregister(reg_ptr); + + // Remove from tracked registry + base_ptr_registry[device].erase(reg_ptr); +} + +gds_op_desc_t::gds_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const int num_threads, + const bool validate) + : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate) +{ + _contiguous_buffer = _buffer.contiguous(); + const int64_t device = _buffer.get_device(); + check_cudaruntimecall(cudaSetDevice(device)); + _base_ptr = _find_base_ptr(device, (char*)_contiguous_buffer.data_ptr()); + + _safe_handle_register(fd, _cf_descr, _cf_handle); +} + +char* gds_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } + +void gds_op_desc_t::finish() { cuFileHandleDeregister(_cf_handle); } + +void gds_op_desc_t::validate() +{ + check_cudaruntimecall(cudaSetDevice(_buffer.get_device())); + const auto cpu_buffer = _buffer.to(torch::kCPU); + validate_aio_operation( + _read_op, _filename.c_str(), (char*)(cpu_buffer.data_ptr()), _file_num_bytes); +} + +void gds_op_desc_t::run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config) +{ + assert(tid < _num_threads); + check_cudaruntimecall(cudaSetDevice(_buffer.get_device())); + int64_t buf_offset = data_ptr() + (_num_bytes_per_thread * tid) - (char*)_base_ptr; + const auto file_offset = _num_bytes_per_thread * tid; + + if (_read_op) { + auto ret = + cuFileRead(_cf_handle, _base_ptr, _num_bytes_per_thread, file_offset, buf_offset); + if (ret < 0) { _report_error(ret, errno, buf_offset); } + } else { + auto ret = + cuFileWrite(_cf_handle, _base_ptr, _num_bytes_per_thread, file_offset, buf_offset); + if (ret < 0) { _report_error(ret, errno, buf_offset); } + } +} + +void gds_op_desc_t::_report_error(const ssize_t return_code, + const int error_num, + const off_t offset) +{ + const auto op_string = _read_op ? "read failed with " : "write failed with "; + const auto error_string = IS_CUFILE_ERR(return_code) ? "cuFile error: " : "posix error: "; + const auto error_code = IS_CUFILE_ERR(return_code) ? cuFileGetErrorString(return_code) + : cuFileGetErrorString(error_num); + std::cerr << op_string << error_string << error_code << " return code = " << return_code + << " filename = " << _filename.c_str() << " num bytes = " << _num_bytes_per_thread + << " offset = " << offset << std::endl; + exit(EXIT_FAILURE); +} diff --git a/csrc/gds/py_lib/deepspeed_gds_op.h b/csrc/gds/py_lib/deepspeed_gds_op.h new file mode 100644 index 000000000000..b7fab64d4054 --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_gds_op.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include +#include +#include + +#include "deepspeed_aio_op_desc.h" +#include "deepspeed_gds_utils.h" + +struct gds_op_desc_t : io_op_desc_t { + CUfileDescr_t _cf_descr; + CUfileHandle_t _cf_handle; + void* _base_ptr; + + gds_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const int num_threads, + const bool validate); + + void run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config); + + char* data_ptr() const; + + void validate(); + + void finish(); + + void _report_error(const ssize_t return_code, const int error_num, const off_t offset); + + static void add_buffer_to_registry(const torch::Tensor& buffer); + + static void remove_buffer_from_registry(const torch::Tensor& buffer); +}; diff --git a/csrc/gds/py_lib/deepspeed_gds_utils.h b/csrc/gds/py_lib/deepspeed_gds_utils.h new file mode 100644 index 000000000000..12b014d90988 --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_gds_utils.h @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +// CUDA/cuFile includes +#include +#include +#include "cufile.h" + +// Macro for checking cuda errors following a cuda launch or api call +#define cudaCheckError() \ + { \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) { \ + printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } + +#define check_cudadrivercall(fn) \ + do { \ + CUresult res = fn; \ + if (res != CUDA_SUCCESS) { \ + const char* str = nullptr; \ + cuGetErrorName(res, &str); \ + std::cerr << "cuda driver api call failed " << #fn << " res : " << res << ", " \ + << __LINE__ << ":" << str << std::endl; \ + std::cerr << "EXITING program!!!" << std::endl; \ + exit(1); \ + } \ + } while (0) + +#define check_cudaruntimecall(fn) \ + do { \ + cudaError_t res = fn; \ + if (res != cudaSuccess) { \ + const char* str = cudaGetErrorName(res); \ + std::cerr << "cuda runtime api call failed " << #fn << __LINE__ << ":" << str \ + << std::endl; \ + std::cerr << "EXITING program!!!" << std::endl; \ + exit(1); \ + } \ + } while (0) + +#define check_cuFileCall(fn, api_msg) \ + do { \ + CUfileError_t status = fn; \ + if (status.err != CU_FILE_SUCCESS) { \ + std::cout << api_msg << " failed with error " << CUFILE_ERRSTR(status.err) \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// +// cuda driver error description +// +static inline const char* GetCuErrorString(CUresult curesult) +{ + const char* descp; + if (cuGetErrorName(curesult, &descp) != CUDA_SUCCESS) descp = "unknown cuda error"; + return descp; +} + +// +// cuFile APIs return both cuFile specific error codes as well as POSIX error codes +// for ease, the below template can be used for getting the error description depending +// on its type. + +// POSIX +template ::value, std::nullptr_t>::type = nullptr> +std::string cuFileGetErrorString(T status) +{ + status = std::abs(status); + return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status)) + : std::string(std::strerror(status)); +} + +// CUfileError_t +template ::value, std::nullptr_t>::type = nullptr> +std::string cuFileGetErrorString(T status) +{ + std::string errStr = cuFileGetErrorString(static_cast(status.err)); + if (IS_CUDA_ERR(status)) errStr.append(".").append(GetCuErrorString(status.cu_err)); + return errStr; +} diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp new file mode 100644 index 000000000000..3a35ad3145a0 --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* + GPUDirect Storage functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_py_gds_handle.h" +#include +#include "deepspeed_gds_op.h" + +using namespace std; + +int deepspeed_gds_handle_t::s_cuFile_init = 0; + +deepspeed_gds_handle_t::deepspeed_gds_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int num_threads) + : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, num_threads) +{ + _init_cuFile(block_size, queue_depth, num_threads); +} + +deepspeed_gds_handle_t::~deepspeed_gds_handle_t() { _close_cuFile(); } + +void deepspeed_gds_handle_t::_init_cuFile(const int block_size, + const int queue_depth, + const int num_threads) +{ + if (deepspeed_gds_handle_t::s_cuFile_init == 0) { + std::string depthStr = std::to_string(queue_depth); + std::string threadsStr = std::to_string(num_threads); + std::string json1 = R"({"execution": {"max_io_queue_depth": )" + depthStr + ", "; + std::string json2 = R"("max_request_parallelism": )" + threadsStr + ", "; + std::string json3 = R"("max_io_threads": )" + threadsStr + ", "; + std::string json4 = R"("parallel_io": true, "min_io_threshold_size_kb": 8192}})"; + std::ofstream outFile("local_cufile.json"); + if (outFile.is_open()) { + outFile << json1 + json2 + json3 + json4; + outFile.close(); + } else { + std::cerr << "Can't open local cufile" << std::endl; + exit(EXIT_FAILURE); + } + // TODO: Address the following issues with this code + // (1) Fix C++14 warning + // (2) Create file in a different location than PWD + // (3) Handle multi-GPU/multi-rank scenarios: should cufile be shared, is per-rank cufile + // safe? + putenv("CUFILE_ENV_PATH_JSON=$PWD/local_cufile.json"); + cuFileDriverOpen(); + cudaCheckError(); + size_t direct_io_size = (size_t)block_size / 1024; + CUfileError_t status = cuFileDriverSetMaxDirectIOSize(direct_io_size); + if (status.err != CU_FILE_SUCCESS) { + std::cerr << "file register error:" << cuFileGetErrorString(status) << std::endl; + exit(EXIT_FAILURE); + } + } + deepspeed_gds_handle_t::s_cuFile_init++; +} + +void deepspeed_gds_handle_t::_close_cuFile() +{ + deepspeed_gds_handle_t::s_cuFile_init--; + if (deepspeed_gds_handle_t::s_cuFile_init == 0) { cuFileDriverClose(); } +} + +torch::Tensor deepspeed_gds_handle_t::new_pinned_device_tensor(const size_t num_elem, + const torch::Tensor& example_tensor) +{ + auto options = torch::TensorOptions().dtype(example_tensor.scalar_type()).device(torch::kCUDA); + auto dev_tensor = torch::empty(num_elem, options); + pin_device_tensor(dev_tensor); + return dev_tensor; +} + +bool deepspeed_gds_handle_t::free_pinned_device_tensor(torch::Tensor& buffer) +{ + unpin_device_tensor(buffer); + return true; +} + +bool deepspeed_gds_handle_t::pin_device_tensor(const torch::Tensor& buffer) +{ + gds_op_desc_t::add_buffer_to_registry(buffer); + return true; +} + +bool deepspeed_gds_handle_t::unpin_device_tensor(const torch::Tensor& buffer) +{ + gds_op_desc_t::remove_buffer_from_registry(buffer); + return true; +} + +std::shared_ptr deepspeed_gds_handle_t::_create_io_op_desc( + const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const bool validate) +{ + if (buffer.is_cuda()) { + return std::make_shared( + read_op, buffer, fd, filename, file_num_bytes, _num_threads, validate); + } + return deepspeed_io_handle_t::_create_io_op_desc( + read_op, buffer, fd, filename, file_num_bytes, validate); +} diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.h b/csrc/gds/py_lib/deepspeed_py_gds_handle.h new file mode 100644 index 000000000000..f324e6b65e80 --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include "deepspeed_py_io_handle.h" + +struct deepspeed_gds_handle_t : deepspeed_io_handle_t { + deepspeed_gds_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int num_threads); + + ~deepspeed_gds_handle_t(); + + torch::Tensor new_pinned_device_tensor(const size_t num_elem, + const torch::Tensor& example_tensor); + + bool free_pinned_device_tensor(torch::Tensor&); + + bool pin_device_tensor(const torch::Tensor& buffer); + + bool unpin_device_tensor(const torch::Tensor& buffer); + + void _init_cuFile(const int block_size, const int queue_length, const int num_threads); + + void _close_cuFile(); + + std::shared_ptr _create_io_op_desc(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int file_num_bytes, + const bool validate); + + static int s_cuFile_init; +}; diff --git a/csrc/gds/py_lib/py_ds_gds.cpp b/csrc/gds/py_lib/py_ds_gds.cpp new file mode 100644 index 000000000000..66eb34d4ea8c --- /dev/null +++ b/csrc/gds/py_lib/py_ds_gds.cpp @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include "deepspeed_py_gds_handle.h" +using namespace pybind11::literals; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + py::class_(m, "gds_handle") + .def(py::init(), + "GDS handle constructor", + "block_size"_a = 1024 * 1024, + "queue_depth"_a = 128, + "single_submit"_a = false, + "overlap_events"_a = false, + "num_threads"_a = 1) + + .def("get_block_size", &deepspeed_gds_handle_t::get_block_size) + .def("get_queue_depth", &deepspeed_gds_handle_t::get_queue_depth) + .def("get_single_submit", &deepspeed_gds_handle_t::get_single_submit) + .def("get_overlap_events", &deepspeed_gds_handle_t::get_overlap_events) + .def("get_thread_count", &deepspeed_gds_handle_t::get_thread_count) + + .def("read", + &deepspeed_gds_handle_t::read, + "Synchronous and non-parallel file read. Returns count of completed read ops", + "buffer"_a, + "filename"_a, + "validate"_a) + + .def("write", + &deepspeed_gds_handle_t::write, + "Synchronous and non-parallel file write. Returns count of completed write ops", + "buffer"_a, + "filename"_a, + "validate"_a) + + .def("pread", + &deepspeed_gds_handle_t::pread, + "Parallel file read with option of parallelism. Returns count of completed read ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "async"_a) + + .def("pwrite", + &deepspeed_gds_handle_t::pwrite, + "Parallel file write with option of parallelism. Returns count of completed write ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "async"_a) + + .def("sync_pread", + &deepspeed_gds_handle_t::sync_pread, + "Synchrononous parallel file read. Returns count of completed read ops", + "buffer"_a, + "filename"_a) + + .def("sync_pwrite", + &deepspeed_gds_handle_t::sync_pwrite, + "Synchronous parallel file write. Returns count of completed write ops", + "buffer"_a, + "filename"_a) + + .def("async_pread", + &deepspeed_gds_handle_t::async_pread, + "Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and " + "following wait() returns count of completed ops.", + "buffer"_a, + "filename"_a) + + .def("async_pwrite", + &deepspeed_gds_handle_t::async_pwrite, + "Asynchronous parallel file write. Returns 0 on success, and following wait() returns " + "count of completed ops.", + "buffer"_a, + "filename"_a) + + .def("new_cpu_locked_tensor", + &deepspeed_gds_handle_t::new_cpu_locked_tensor, + "Allocate pinned CPU tensor.", + "num_elem"_a, + "example_tenosr"_a) + + .def("free_cpu_locked_tensor", + &deepspeed_gds_handle_t::free_cpu_locked_tensor, + "Free pinned CPU tensor.", + "tensor"_a) + + .def("new_pinned_device_tensor", + &deepspeed_gds_handle_t::new_pinned_device_tensor, + "Allocate pinned device tensor.", + "num_elem"_a, + "example_tenosr"_a) + + .def("free_pinned_device_tensor", + &deepspeed_gds_handle_t::free_pinned_device_tensor, + "Free pinned device tensor.", + "tensor"_a) + + .def("pin_device_tensor", + &deepspeed_gds_handle_t::pin_device_tensor, + "Pin device tensor.", + "tensor"_a) + + .def("unpin_device_tensor", + &deepspeed_gds_handle_t::unpin_device_tensor, + "Unpin device tensor.", + "tensor"_a) + + .def("wait", + &deepspeed_gds_handle_t::wait, + "Wait for (ongoing) asynchronous operations to complete"); +} diff --git a/csrc/gds/py_test/validate_gds.py b/csrc/gds/py_test/validate_gds.py new file mode 100644 index 000000000000..b34b1194f582 --- /dev/null +++ b/csrc/gds/py_test/validate_gds.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" +from deepspeed.ops.op_builder import GDSBuilder +assert GDSBuilder().is_compatible(True) +assert GDSBuilder().load(True) diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index 46a84c61f884..77dfa7a23b1e 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -15,13 +15,13 @@ class HuggingFaceCheckpointEngine(CheckpointEngineBase): - def __init__(self, model_name_or_path: str, auth_token: str = None) -> None: + def __init__(self, model_name_or_path: str, auth_token: str = None, **hf_kwargs) -> None: super().__init__() from transformers import AutoConfig, GenerationConfig self.model_name_or_path = model_name_or_path self.auth_token = auth_token - self.model_config = AutoConfig.from_pretrained(self.model_name_or_path) + self.model_config = AutoConfig.from_pretrained(self.model_name_or_path, **hf_kwargs) # Define this property here so we can use it in the model implementation if not hasattr(self.model_config, "max_seq_length"): if hasattr(self.model_config, "max_position_embeddings"): diff --git a/deepspeed/ops/gds/__init__.py b/deepspeed/ops/gds/__init__.py new file mode 100755 index 000000000000..3c0762c81076 --- /dev/null +++ b/deepspeed/ops/gds/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..op_builder import GDSBuilder diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index c627846b743c..26196ff37ac4 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -5,6 +5,8 @@ from types import MethodType from collections import OrderedDict +from functools import reduce +from operator import mul import torch from deepspeed import comm as dist @@ -40,6 +42,9 @@ PIPE_RECV_INPUT_TIMER = 'pipe_recv_input' PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad' +# The buffer size to store the meta data for each tensor. +TENSOR_META_SIZE = 256 + def is_even(number): return number % 2 == 0 @@ -179,6 +184,7 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): } self.pipe_recv_buf = None self.grad_layer = None + self._grad_layer_buf = [] self.meta_buffer = None @@ -250,6 +256,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self.timers(STEP_MICRO_TIMER).start() self.timers(STEP_MICRO_TIMER).stop() + self.dynamic_shape = self.module.dynamic_shape + def set_has_attention_mask(self, value): assert isinstance(value, bool) self.has_attention_mask = value @@ -318,6 +326,7 @@ def reset_activation_shape(self): self.first_output_send = True self.pipe_recv_buf = None self.grad_layer = None + self._grad_layer_buf = [] self.meta_buffer = None self.pipe_partition_input_meta_cache = None @@ -926,51 +935,38 @@ def _send_tensor_meta(self, buffer, recv_stage): * ndims * shape """ - send_bytes = 0 + meta_buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) if isinstance(buffer, torch.Tensor): - type_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.send(type_tensor, recv_stage) - send_shape = torch.LongTensor(data=buffer.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - send_bytes += _tensor_bytes(buffer) - elif isinstance(buffer, list): - assert (False) - type_tensor = torch.LongTensor(data=[1]).to(self.device) - p2p.send(type_tensor, recv_stage) - count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device) - p2p.send(count_tensor, recv_stage) - for tensor in buffer: - assert isinstance(tensor, torch.Tensor) - send_shape = torch.LongTensor(data=tensor.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - send_bytes += _tensor_bytes(tensor) + meta_buf_list = [ + 0, # type of data (0: tensor, 1: list (unused), 2: tuple) + self.DTYPE_TO_ID[buffer.dtype], # dtype + len(buffer.size()) # ndims + ] + meta_buf_list.extend(buffer.size()) + assert len( + meta_buf_list + ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}" + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + elif isinstance(buffer, tuple): - type_tensor = torch.LongTensor(data=[2]).to(self.device) - p2p.send(type_tensor, recv_stage) - count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device) - p2p.send(count_tensor, recv_stage) - for idx, tensor in enumerate(buffer): + meta_buf_list = [ + 2, # type of data (0: tensor, 1: list (unused), 2: tuple) + len(buffer) # num_tensors + ] + + for tensor in buffer: assert isinstance(tensor, torch.Tensor) - send_shape = torch.LongTensor(data=tensor.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device) - send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device) - p2p.send(send_dtype, recv_stage) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - # Useful for performance debugging. - ''' - new_bytes = _tensor_bytes(tensor) - send_bytes += _tensor_bytes(tensor) - # Useful for performance debugging. - if self.grid.data_parallel_id == 0: - print( - f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB' - ) - ''' + meta_buf_list.append(self.DTYPE_TO_ID[tensor.dtype]) + meta_buf_list.append(len(tensor.size())) + meta_buf_list.extend(tensor.size()) + + assert len( + meta_buf_list + ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}" + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + else: raise NotImplementedError(f'Could not send meta type {type(buffer)}') @@ -983,49 +979,35 @@ def _send_tensor_meta(self, buffer, recv_stage): def _recv_tensor_meta(self, send_stage): """Receive metadata about upcoming p2p transfers and return allocated buffers. - Metadata is communicated in this order: - * type (0: tensor, 1: list) - * num_tensors if type=list - foreach tensor in buffer: - * ndims - * shape - Returns: Allocated buffer for receiving from send_stage. """ + buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) + p2p.recv(buffer, send_stage) - type_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(type_tensor, send_stage) - recv_type = type_tensor.item() + recv_type = buffer[0].item() # A single tensor will be sent. if recv_type == 0: - recv_ndims = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_ndims, send_stage) - recv_ndims = recv_ndims.item() - recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device) - p2p.recv(recv_shape, send_stage) - recv_shape = recv_shape.tolist() - return self._allocate_buffer(recv_shape, num_buffers=1)[0] - - # List or tuple of tensors + recv_dtype = self.ID_TO_DTYPE[buffer[1].item()] + recv_ndims = buffer[2].item() + recv_shape = buffer[3:3 + recv_ndims].tolist() + return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype) + + # List or tuple of tensors (recv_type == 1 (list) is currently unused) elif recv_type == 1 or recv_type == 2: - count_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(count_tensor, send_stage) - num_tensors = count_tensor.item() - recv_shapes_and_dtypes = [] + num_tensors = buffer[1].item() + + buffers = [] + offset = 2 for idx in range(num_tensors): - recv_dtype = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_dtype, send_stage) - recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()] - recv_ndims = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_ndims, send_stage) - recv_ndims = recv_ndims.item() - recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device) - p2p.recv(recv_shape, send_stage) - recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype)) - - buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0] + recv_dtype = self.ID_TO_DTYPE[buffer[offset].item()] + recv_ndims = buffer[offset + 1].item() + recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist() + offset += 2 + recv_ndims + + buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype)) + # Convert to tuples if requested. if recv_type == 2: buffers = tuple(buffers) @@ -1048,7 +1030,7 @@ def _exec_send_activations(self, buffer_id): outputs[-1] = outputs[-1].half() outputs = tuple(outputs) - if self.first_output_send: + if self.dynamic_shape or self.first_output_send: self.first_output_send = False self._send_tensor_meta(outputs, self.next_stage) @@ -1133,7 +1115,7 @@ def _exec_recv_activations(self, buffer_id): recvd = None # Allocate the buffer if necessary - if self.pipe_recv_buf is None: + if self.dynamic_shape or self.pipe_recv_buf is None: self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage) if isinstance(self.pipe_recv_buf, torch.Tensor): @@ -1188,10 +1170,9 @@ def _exec_recv_grads(self, buffer_id): self.pipe_buffers['outputs'][buffer_id] = outputs # Allocate gradient if necessary - if self.grad_layer is None: + if self.dynamic_shape or self.grad_layer is None: if isinstance(outputs, torch.Tensor): - s = list(outputs.size()) - self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0] + self.grad_layer = self._allocate_or_extend_buffers(0, list(outputs.size()), outputs.dtype) else: # XXX This is a HACK # When we exchange activations/gradients, the two pipe stages @@ -1213,7 +1194,11 @@ def _exec_recv_grads(self, buffer_id): for t in outputs[2:] if t.is_floating_point()] else: sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()] - self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0] + + self.grad_layer = [ + self._allocate_or_extend_buffers(i, size, dtype) + for i, (size, dtype) in enumerate(sizes_and_dtypes) + ] if isinstance(self.grad_layer, torch.Tensor): p2p.recv(self.grad_layer, self.next_stage) @@ -1294,16 +1279,17 @@ def _allocate_buffer(self, shape, num_buffers=-1, **kwargs): buffers.append(self._allocate_zeros(shape, **kwargs)) return buffers - def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1): - buffers = [] - if num_buffers == -1: - num_buffers = self.num_pipe_buffers - for count in range(num_buffers): - buffer = [] - for shape, dtype in shapes_and_dtypes: - buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad)) - buffers.append(buffer) - return buffers + def _allocate_or_extend_buffers(self, idx, shape, dtype): + numel = reduce(mul, shape) if len(shape) > 0 else 1 + if len(self._grad_layer_buf) <= idx or self._grad_layer_buf[idx].numel() < numel: + new_buf = self._allocate_buffer(shape, dtype=dtype, num_buffers=1)[0] + if len(self._grad_layer_buf) <= idx: + self._grad_layer_buf.append(new_buf) + else: + self._grad_layer_buf[idx] = new_buf + return self._grad_layer_buf[idx] + else: + return self._grad_layer_buf[idx].flatten()[:numel].view(shape) def forward(self, *args, **kwargs): """Disabled for pipeline parallel training. See ``train_batch()``. """ diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 3c25cbee66ec..31fec30be788 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -117,6 +117,7 @@ def forward(self, inputs): activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering. + dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact. """ def __init__(self, @@ -130,7 +131,8 @@ def __init__(self, partition_method='parameters', activation_checkpoint_interval=0, activation_checkpoint_func=checkpointing.checkpoint, - checkpointable_layers=None): + checkpointable_layers=None, + dynamic_shape=False): super().__init__() @@ -213,6 +215,8 @@ def __init__(self, self.tied_comms = self._index_tied_modules() self._synchronize_tied_weights() + self.dynamic_shape = dynamic_shape + def _precompute_checkpointable_values(self): if self.activation_checkpoint_interval > 0 and self.is_checkpointable_results_interval != self.activation_checkpoint_interval: num_layers = len(self.forward_funcs) diff --git a/deepspeed/runtime/swap_tensor/aio_config.py b/deepspeed/runtime/swap_tensor/aio_config.py index df4a38380089..46c3f2a0c954 100644 --- a/deepspeed/runtime/swap_tensor/aio_config.py +++ b/deepspeed/runtime/swap_tensor/aio_config.py @@ -5,25 +5,33 @@ from deepspeed.runtime.config_utils import get_scalar_param from deepspeed.runtime.swap_tensor.constants import * +from deepspeed.accelerator import get_accelerator AIO_DEFAULT_DICT = { AIO_BLOCK_SIZE: AIO_BLOCK_SIZE_DEFAULT, AIO_QUEUE_DEPTH: AIO_QUEUE_DEPTH_DEFAULT, AIO_THREAD_COUNT: AIO_THREAD_COUNT_DEFAULT, AIO_SINGLE_SUBMIT: AIO_SINGLE_SUBMIT_DEFAULT, - AIO_OVERLAP_EVENTS: AIO_OVERLAP_EVENTS_DEFAULT + AIO_OVERLAP_EVENTS: AIO_OVERLAP_EVENTS_DEFAULT, + AIO_USE_GDS: AIO_USE_GDS_DEFAULT } def get_aio_config(param_dict): if AIO in param_dict.keys() and param_dict[AIO] is not None: aio_dict = param_dict[AIO] - return { + aio_config = { AIO_BLOCK_SIZE: get_scalar_param(aio_dict, AIO_BLOCK_SIZE, AIO_BLOCK_SIZE_DEFAULT), AIO_QUEUE_DEPTH: get_scalar_param(aio_dict, AIO_QUEUE_DEPTH, AIO_QUEUE_DEPTH_DEFAULT), AIO_THREAD_COUNT: get_scalar_param(aio_dict, AIO_THREAD_COUNT, AIO_THREAD_COUNT_DEFAULT), AIO_SINGLE_SUBMIT: get_scalar_param(aio_dict, AIO_SINGLE_SUBMIT, AIO_SINGLE_SUBMIT_DEFAULT), - AIO_OVERLAP_EVENTS: get_scalar_param(aio_dict, AIO_OVERLAP_EVENTS, AIO_OVERLAP_EVENTS_DEFAULT) + AIO_OVERLAP_EVENTS: get_scalar_param(aio_dict, AIO_OVERLAP_EVENTS, AIO_OVERLAP_EVENTS_DEFAULT), + AIO_USE_GDS: get_scalar_param(aio_dict, AIO_USE_GDS, AIO_USE_GDS_DEFAULT) } + if aio_config[AIO_USE_GDS]: + assert get_accelerator().device_name() == 'cuda', 'GDS currently only supported for CUDA accelerator' + + return aio_config + return AIO_DEFAULT_DICT diff --git a/deepspeed/runtime/swap_tensor/constants.py b/deepspeed/runtime/swap_tensor/constants.py index 4c9722bc4e4f..cee20ac7b78c 100644 --- a/deepspeed/runtime/swap_tensor/constants.py +++ b/deepspeed/runtime/swap_tensor/constants.py @@ -11,7 +11,8 @@ "queue_depth": 8, "thread_count": 1, "single_submit": false, - "overlap_events": true + "overlap_events": true, + "use_gds": false } ''' AIO = "aio" @@ -25,3 +26,5 @@ AIO_SINGLE_SUBMIT_DEFAULT = False AIO_OVERLAP_EVENTS = "overlap_events" AIO_OVERLAP_EVENTS_DEFAULT = True +AIO_USE_GDS = "use_gds" +AIO_USE_GDS_DEFAULT = False diff --git a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py index fcc6a272883f..120723fae5ab 100644 --- a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py +++ b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py @@ -13,6 +13,7 @@ from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import AsyncIOBuilder +from deepspeed.ops.op_builder import GDSBuilder from .constants import * from .utils import swap_in_tensors, swap_out_tensors, MIN_AIO_BYTES, AIO_ALIGNED_BYTES, print_object, SwapBufferPool @@ -37,8 +38,6 @@ class AsyncPartitionedParameterSwapper(object): def __init__(self, ds_config, model_dtype): - aio_op = AsyncIOBuilder().load(verbose=False) - self.aio_handle = aio_op.aio_handle self.dtype = model_dtype #set swap buffers, create aio handles @@ -93,6 +92,10 @@ def _configure_aio(self, ds_config): self.aio_config = ds_config.aio_config + self.use_gds = self.aio_config[AIO_USE_GDS] + self.aio_handle = GDSBuilder().load(verbose=False).gds_handle if self.use_gds else AsyncIOBuilder().load( + verbose=False).aio_handle + # Read/Write alignment for each thread during Intra-request parallelism self.min_aio_bytes = max(MIN_AIO_BYTES, self.aio_config[AIO_BLOCK_SIZE]) self.aligned_bytes = AIO_ALIGNED_BYTES * self.aio_config[AIO_THREAD_COUNT] @@ -104,11 +107,6 @@ def _configure_aio(self, ds_config): self.available_buffer_ids = [i for i in range(self.param_buffer_count)] self.reserved_buffer_ids = [] - self.buffers = get_accelerator().pin_memory(torch.empty(int(self.aligned_elements_per_buffer * - self.param_buffer_count), - dtype=self.dtype, - requires_grad=False), - align_bytes=0) self.aio_read_handle = self.aio_handle(self.aio_config[AIO_BLOCK_SIZE], self.aio_config[AIO_QUEUE_DEPTH], self.aio_config[AIO_SINGLE_SUBMIT], self.aio_config[AIO_OVERLAP_EVENTS], @@ -118,6 +116,19 @@ def _configure_aio(self, ds_config): self.aio_config[AIO_SINGLE_SUBMIT], self.aio_config[AIO_OVERLAP_EVENTS], self.aio_config[AIO_THREAD_COUNT]) + if self.use_gds: + self.buffers = torch.empty(int(self.aligned_elements_per_buffer * self.param_buffer_count), + dtype=self.dtype, + device=get_accelerator().device_name(), + requires_grad=False) + self.aio_read_handle.new_device_locked_tensor(self.buffers) + else: + self.buffers = get_accelerator().pin_memory(torch.empty(int(self.aligned_elements_per_buffer * + self.param_buffer_count), + dtype=self.dtype, + requires_grad=False), + align_bytes=0) + self.swap_out_params = [] #Check if partitioned param or numel in a tensor is swappable or not diff --git a/deepspeed/runtime/zero/__init__.py b/deepspeed/runtime/zero/__init__.py index 1ccca09a9e69..23fcf9ec13fb 100644 --- a/deepspeed/runtime/zero/__init__.py +++ b/deepspeed/runtime/zero/__init__.py @@ -13,3 +13,5 @@ from .tiling import TiledLinearReturnBias from .mics import MiCS_Init + +from .stage3 import unwrap_model_for_generation diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index e9dd78864cde..8c8db60768eb 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -16,6 +16,7 @@ #when implemented outside of torch.autograd.Function import math +import functools import torch from torch import Tensor @@ -33,8 +34,14 @@ def print_rank_0(message, debug=False, force=False): try: - autocast_custom_fwd = get_accelerator().amp().custom_fwd - autocast_custom_bwd = get_accelerator().amp().custom_bwd + # Fix `torch.[device].amp.custom_fwd/bwd` FutureWarning in torch 2.4 + if hasattr(torch, 'amp') and hasattr(torch.amp, 'custom_fwd') and hasattr(torch.amp, 'custom_bwd'): + autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name()) + autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name()) + else: + # original implementation + autocast_custom_fwd = get_accelerator().amp().custom_fwd + autocast_custom_bwd = get_accelerator().amp().custom_bwd except (ImportError, AttributeError) as exp: autocast_custom_fwd = noop_decorator autocast_custom_bwd = noop_decorator diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index b0a3ab778f2a..796957a4c6e5 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -7,6 +7,7 @@ import gc import collections from typing import Deque, Dict, Tuple +from contextlib import contextmanager from deepspeed import comm as dist from deepspeed.utils import groups @@ -69,6 +70,39 @@ def move_to_cpu(tensor_list): tensor.data = tensor.data.cpu() +@contextmanager +def unwrap_model_for_generation(model): + """ + For ZeRO-3 models, we gather the weights once to speed up generation. + """ + with GatheredParameters(model.parameters()): + # Removes the optimizer hooks from a DeepSpeed ZeRO-3 model. + + # Remove hooks + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + + for hook in optimizer_offload.forward_hooks: + hook.remove() + for hook in optimizer_offload.backward_hooks: + hook.remove() + + optimizer_offload.forward_hooks = [] + optimizer_offload.backward_hooks = [] + + yield model + + # Adds the optimizer hooks from a DeepSpeed ZeRO-3 model. + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + optimizer_offload._register_hooks_recursively(optimizer_offload.module) + return + + INITIAL_MICRO_STEP_ID = -1 @@ -215,14 +249,12 @@ def __init__( self.module = module self.elastic_checkpoint = elastic_checkpoint - self.inf_or_nan_tracker: Tensor = torch.zeros(1, - dtype=torch.bool, - device=get_accelerator().current_device_name(), - requires_grad=False) + self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu + + self.inf_or_nan_tracker: Tensor = torch.zeros(1, dtype=torch.bool, device=self.device, requires_grad=False) self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam) - self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu ### streams used for overlapping computation with communication self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator( ).Stream() if overlap_comm else get_accelerator().default_stream() @@ -2148,7 +2180,8 @@ def has_overflow(self, partition_gradients=True): self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any() self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0 - overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8) + overflow_gpu = self.inf_or_nan_tracker.clone().to(get_accelerator().current_device_name()).to( + torch.uint8) self.inf_or_nan_tracker.zero_() if not get_accelerator().resolves_data_dependency(): diff --git a/docs/index.md b/docs/index.md index 127c7226e6d4..1efdcea132d2 100755 --- a/docs/index.md +++ b/docs/index.md @@ -7,11 +7,12 @@ title: "Latest News" --- DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). +* [2024/08] [DeepSpeed on Windows](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/README.md)[[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/japanese/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/chinese/README.md)] + * [2024/08] [DeepNVMe: Improving DL Applications through I/O Optimizations](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md)[[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/japanese/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/chinese/README.md)] * [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md)[[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)] * [2024/03] [DeepSpeed-FP6: The Power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)] * [2024/01] [DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19) -* [2023/11] [Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/intel-inference) [[Intel version]](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html) @@ -19,12 +20,12 @@ title: "Latest News" More news diff --git a/op_builder/async_io.py b/op_builder/async_io.py index b55c821910b9..e7f16adbf2a3 100644 --- a/op_builder/async_io.py +++ b/op_builder/async_io.py @@ -3,13 +3,14 @@ # DeepSpeed Team +import os import distutils.spawn import subprocess -from .builder import OpBuilder +from .builder import TorchCPUOpBuilder -class AsyncIOBuilder(OpBuilder): +class AsyncIOBuilder(TorchCPUOpBuilder): BUILD_VAR = "DS_BUILD_AIO" NAME = "async_io" @@ -19,44 +20,54 @@ def __init__(self): def absolute_name(self): return f'deepspeed.ops.aio.{self.NAME}_op' - def sources(self): - return [ - 'csrc/aio/py_lib/deepspeed_py_copy.cpp', 'csrc/aio/py_lib/py_ds_aio.cpp', - 'csrc/aio/py_lib/deepspeed_py_aio.cpp', 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp', - 'csrc/aio/py_lib/deepspeed_aio_thread.cpp', 'csrc/aio/common/deepspeed_aio_utils.cpp', - 'csrc/aio/common/deepspeed_aio_common.cpp', 'csrc/aio/common/deepspeed_aio_types.cpp', + def lib_sources(self): + src_list = [ + 'csrc/aio/py_lib/deepspeed_py_io_handle.cpp', 'csrc/aio/py_lib/deepspeed_py_aio.cpp', + 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp', 'csrc/aio/py_lib/deepspeed_aio_thread.cpp', + 'csrc/aio/common/deepspeed_aio_utils.cpp', 'csrc/aio/common/deepspeed_aio_common.cpp', + 'csrc/aio/common/deepspeed_aio_types.cpp', 'csrc/aio/py_lib/deepspeed_cpu_op.cpp', + 'csrc/aio/py_lib/deepspeed_aio_op_desc.cpp', 'csrc/aio/py_lib/deepspeed_py_copy.cpp', 'csrc/aio/py_lib/deepspeed_pin_tensor.cpp' ] + return src_list + + def sources(self): + return self.lib_sources() + ['csrc/aio/py_lib/py_ds_aio.cpp'] def include_paths(self): - return ['csrc/aio/py_lib', 'csrc/aio/common'] + import torch + if self.build_for_cpu: + CUDA_INCLUDE = [] + elif not self.is_rocm_pytorch(): + CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] + else: + CUDA_INCLUDE = [ + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"), + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"), + ] + return ['csrc/aio/py_lib', 'csrc/aio/common'] + CUDA_INCLUDE def cxx_args(self): # -O0 for improved debugging, since performance is bound by I/O - CPU_ARCH = self.cpu_arch() - SIMD_WIDTH = self.simd_width() - import torch # Keep this import here to avoid errors when building DeepSpeed wheel without torch installed + args = super().cxx_args() + import torch TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2]) - if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1: - CPP_STD = '-std=c++17' - else: - CPP_STD = '-std=c++14' - return [ - '-g', - '-Wall', - '-O0', - CPP_STD, - '-shared', - '-fPIC', - '-Wno-reorder', - CPU_ARCH, - '-fopenmp', - SIMD_WIDTH, - '-laio', - ] + if not (TORCH_MAJOR >= 2 and TORCH_MINOR >= 1): + args.remove('-std=c++17') + args.append('-std=c++14') + args += ['-Wall', '-O0', '-shared', '-fPIC', '-Wno-reorder'] + return args def extra_ldflags(self): - return ['-laio'] + if self.build_for_cpu: + return ['-fopenmp'] + + import torch.utils.cpp_extension + CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME + CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64") + ldflags = [f'-L{CUDA_HOME}', f'-L{CUDA_LIB64}', '-laio', '-lcuda', '-lcudart'] + return ldflags def check_for_libaio_pkg(self): libs = dict( @@ -79,13 +90,13 @@ def check_for_libaio_pkg(self): break return found - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): # Check for the existence of libaio by using distutils # to compile and link a test program that calls io_submit, # which is a function provided by libaio that is used in the async_io op. # If needed, one can define -I and -L entries in CFLAGS and LDFLAGS # respectively to specify the directories for libaio.h and libaio.so. - aio_compatible = self.has_function('io_pgetevents', ('aio', )) + aio_compatible = self.has_function('io_submit', ('aio', )) if verbose and not aio_compatible: self.warning(f"{self.NAME} requires the dev libaio .so object and headers but these were not found.") diff --git a/op_builder/builder.py b/op_builder/builder.py index 8998fc0eddb8..ca4b339e2447 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -305,7 +305,7 @@ def is_compatible(self, verbose=True): def extra_ldflags(self): return [] - def has_function(self, funcname, libraries, verbose=False): + def has_function(self, funcname, libraries, library_dirs=None, verbose=False): ''' Test for existence of a function within a tuple of libraries. @@ -361,7 +361,8 @@ def has_function(self, funcname, libraries, verbose=False): compiler.link_executable(objs, os.path.join(tempdir, 'a.out'), extra_preargs=self.strip_empty_entries(ldflags), - libraries=libraries) + libraries=libraries, + library_dirs=library_dirs) # Compile and link succeeded return True diff --git a/op_builder/gds.py b/op_builder/gds.py new file mode 100644 index 000000000000..e024674e01d8 --- /dev/null +++ b/op_builder/gds.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +from .async_io import AsyncIOBuilder + + +class GDSBuilder(AsyncIOBuilder): + BUILD_VAR = "DS_BUILD_GDS" + NAME = "gds" + + def __init__(self): + super().__init__() + + def absolute_name(self): + return f'deepspeed.ops.gds.{self.NAME}_op' + + def lib_sources(self): + src_list = ['csrc/gds/py_lib/deepspeed_py_gds_handle.cpp', 'csrc/gds/py_lib/deepspeed_gds_op.cpp'] + return super().lib_sources() + src_list + + def sources(self): + return self.lib_sources() + ['csrc/gds/py_lib/py_ds_gds.cpp'] + + def cxx_args(self): + return super().cxx_args() + ['-lcufile'] + + def include_paths(self): + import torch + CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] + return ['csrc/aio/py_lib', 'csrc/aio/common'] + CUDA_INCLUDE + + def extra_ldflags(self): + return super().extra_ldflags() + ['-lcufile'] + + def is_compatible(self, verbose=False): + import torch.utils.cpp_extension + CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME + CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64") + gds_compatible = self.has_function(funcname="cuFileDriverOpen", + libraries=("cufile", ), + library_dirs=( + CUDA_HOME, + CUDA_LIB64, + ), + verbose=verbose) + + return gds_compatible and super().is_compatible(verbose) diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index 25256d376eeb..dfab28aa7477 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -14,6 +14,7 @@ from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec +from .util import no_child_process_in_deepspeed_io class AlexNet(nn.Module): @@ -125,22 +126,11 @@ def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, trainset = cifar_trainset(fp16=fp16) config['local_rank'] = dist.get_rank() - # deepspeed_io defaults to creating a dataloader that uses a - # multiprocessing pool. Our tests use pools and we cannot nest pools in - # python. Therefore we're injecting this kwarg to ensure that no pools - # are used in the dataloader. - old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io - - def new_method(*args, **kwargs): - kwargs["num_local_io_workers"] = 0 - return old_method(*args, **kwargs) - - deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method - - engine, _, _, _ = deepspeed.initialize(config=config, - model=model, - model_parameters=[p for p in model.parameters()], - training_data=trainset) + with no_child_process_in_deepspeed_io(): + engine, _, _, _ = deepspeed.initialize(config=config, + model=model, + model_parameters=[p for p in model.parameters()], + training_data=trainset) losses = [] for step in range(num_steps): diff --git a/tests/unit/ops/aio/test_aio.py b/tests/unit/ops/aio/test_aio.py index f6d175ce67bc..e6927efc3824 100644 --- a/tests/unit/ops/aio/test_aio.py +++ b/tests/unit/ops/aio/test_aio.py @@ -78,7 +78,7 @@ def _validate_handle_state(handle, single_submit, overlap_events): assert handle.get_queue_depth() == QUEUE_DEPTH -@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) +@pytest.mark.parametrize("use_cuda_pinned_tensor", [True]) # TODO: aio_handle pinned tensor API is broken @pytest.mark.parametrize("single_submit", [True, False]) @pytest.mark.parametrize("overlap_events", [True, False]) class TestRead(DistributedTest): @@ -144,7 +144,7 @@ def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap h.free_cpu_locked_tensor(aio_buffer) -@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) +@pytest.mark.parametrize("use_cuda_pinned_tensor", [True]) # TODO: aio_handle pinned tensor API is broken @pytest.mark.parametrize("single_submit", [True, False]) @pytest.mark.parametrize("overlap_events", [True, False]) class TestWrite(DistributedTest): @@ -213,7 +213,7 @@ def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overla @pytest.mark.sequential -@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) +@pytest.mark.parametrize("use_cuda_pinned_tensor", [True]) # TODO: aio_handle pinned tensor API is broken @pytest.mark.parametrize("cuda_device", [True, False]) class TestAsyncQueue(DistributedTest): world_size = 1 diff --git a/tests/unit/ops/aio/test_gds.py b/tests/unit/ops/aio/test_gds.py new file mode 100644 index 000000000000..53655994b560 --- /dev/null +++ b/tests/unit/ops/aio/test_gds.py @@ -0,0 +1,270 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import os +import filecmp +import torch +import deepspeed +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import GDSBuilder +from unit.common import DistributedTest + +KILO_BYTE = 1024 * 256 +BLOCK_SIZE = KILO_BYTE +QUEUE_DEPTH = 2 +IO_SIZE = 4 * BLOCK_SIZE +IO_PARALLEL = 2 + +if not deepspeed.ops.__compatible_ops__[GDSBuilder.NAME]: + pytest.skip('Skip tests since gds is not compatible', allow_module_level=True) + + +def _get_local_rank(): + if get_accelerator().is_available(): + return dist.get_rank() + return 0 + + +def _do_ref_write(tmpdir, index=0): + file_suffix = f'{_get_local_rank()}_{index}' + ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt') + ref_buffer = os.urandom(IO_SIZE) + with open(ref_file, 'wb') as f: + f.write(ref_buffer) + + return ref_file, ref_buffer + + +def _get_test_write_file(tmpdir, index): + file_suffix = f'{_get_local_rank()}_{index}' + return os.path.join(tmpdir, f'_gds_write_random_{file_suffix}.pt') + + +def _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, gds_handle, index=0): + test_file = _get_test_write_file(tmpdir, index) + test_buffer = get_accelerator().ByteTensor(list(ref_buffer)) + gds_handle.pin_device_tensor(test_buffer) + return test_file, test_buffer + + +def _validate_handle_state(handle, single_submit, overlap_events): + assert handle.get_single_submit() == single_submit + assert handle.get_overlap_events() == overlap_events + assert handle.get_thread_count() == IO_PARALLEL + assert handle.get_block_size() == BLOCK_SIZE + assert handle.get_queue_depth() == QUEUE_DEPTH + + +@pytest.mark.parametrize("single_submit", [True, False]) +@pytest.mark.parametrize("overlap_events", [True, False]) +class TestRead(DistributedTest): + world_size = 1 + reuse_dist_env = True + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + def test_parallel_read(self, tmpdir, single_submit, overlap_events): + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) + h.pin_device_tensor(gds_buffer) + + _validate_handle_state(h, single_submit, overlap_events) + + ref_file, _ = _do_ref_write(tmpdir) + read_status = h.sync_pread(gds_buffer, ref_file) + assert read_status == 1 + + with open(ref_file, 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == gds_buffer.tolist() + + h.unpin_device_tensor(gds_buffer) + + def test_async_read(self, tmpdir, single_submit, overlap_events): + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) + h.pin_device_tensor(gds_buffer) + + _validate_handle_state(h, single_submit, overlap_events) + + ref_file, _ = _do_ref_write(tmpdir) + read_status = h.async_pread(gds_buffer, ref_file) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == 1 + + with open(ref_file, 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == gds_buffer.tolist() + + h.unpin_device_tensor(gds_buffer) + + +@pytest.mark.parametrize("single_submit", [True, False]) +@pytest.mark.parametrize("overlap_events", [True, False]) +class TestWrite(DistributedTest): + world_size = 1 + reuse_dist_env = True + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + def test_parallel_write(self, tmpdir, single_submit, overlap_events): + + ref_file, ref_buffer = _do_ref_write(tmpdir) + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_file, gds_buffer = _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, h) + + _validate_handle_state(h, single_submit, overlap_events) + + write_status = h.sync_pwrite(gds_buffer, gds_file) + assert write_status == 1 + + h.unpin_device_tensor(gds_buffer) + + assert os.path.isfile(gds_file) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, gds_file, shallow=False) + + def test_async_write(self, tmpdir, single_submit, overlap_events): + ref_file, ref_buffer = _do_ref_write(tmpdir) + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + gds_file, gds_buffer = _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, h) + + _validate_handle_state(h, single_submit, overlap_events) + + write_status = h.async_pwrite(gds_buffer, gds_file) + assert write_status == 0 + + wait_status = h.wait() + assert wait_status == 1 + + h.unpin_device_tensor(gds_buffer) + + assert os.path.isfile(gds_file) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, gds_file, shallow=False) + + +@pytest.mark.sequential +class TestAsyncQueue(DistributedTest): + world_size = 1 + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.parametrize("async_queue", [2, 3]) + def test_read(self, tmpdir, async_queue): + + ref_files = [] + for i in range(async_queue): + f, _ = _do_ref_write(tmpdir, i) + ref_files.append(f) + + single_submit = True + overlap_events = True + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_buffers = [ + torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) for _ in range(async_queue) + ] + for buf in gds_buffers: + h.pin_device_tensor(buf) + + _validate_handle_state(h, single_submit, overlap_events) + + for i in range(async_queue): + read_status = h.async_pread(gds_buffers[i], ref_files[i]) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == async_queue + + for i in range(async_queue): + with open(ref_files[i], 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == gds_buffers[i].tolist() + + for t in gds_buffers: + h.unpin_device_tensor(t) + + @pytest.mark.parametrize("async_queue", [2, 3]) + def test_write(self, tmpdir, async_queue): + ref_files = [] + ref_buffers = [] + for i in range(async_queue): + f, buf = _do_ref_write(tmpdir, i) + ref_files.append(f) + ref_buffers.append(buf) + + single_submit = True + overlap_events = True + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_files = [] + gds_buffers = [] + for i in range(async_queue): + f, buf = _get_test_write_file_and_device_buffer(tmpdir, ref_buffers[i], h, i) + gds_files.append(f) + gds_buffers.append(buf) + + _validate_handle_state(h, single_submit, overlap_events) + + for i in range(async_queue): + read_status = h.async_pwrite(gds_buffers[i], gds_files[i]) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == async_queue + + for t in gds_buffers: + h.unpin_device_tensor(t) + + for i in range(async_queue): + assert os.path.isfile(gds_files[i]) + + filecmp.clear_cache() + assert filecmp.cmp(ref_files[i], gds_files[i], shallow=False) + + +@pytest.mark.parametrize("use_new_api", [True, False]) +class TestLockDeviceTensor(DistributedTest): + world_size = 2 + reuse_dist_env = True + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + def test_pin_device_tensor(self, use_new_api): + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + unpinned_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) + if use_new_api: + pinned_buffer = h.new_pinned_device_tensor(unpinned_buffer.numel(), unpinned_buffer) + else: + pinned_buffer = torch.empty_like(unpinned_buffer) + h.pin_device_tensor(pinned_buffer) + + assert unpinned_buffer.device == pinned_buffer.device + assert unpinned_buffer.dtype == pinned_buffer.dtype + assert unpinned_buffer.numel() == pinned_buffer.numel() + + if use_new_api: + h.free_pinned_device_tensor(pinned_buffer) + else: + h.unpin_device_tensor(pinned_buffer) diff --git a/tests/unit/runtime/pipe/test_pipe.py b/tests/unit/runtime/pipe/test_pipe.py index 88e26290b650..f198762c5fcc 100644 --- a/tests/unit/runtime/pipe/test_pipe.py +++ b/tests/unit/runtime/pipe/test_pipe.py @@ -7,12 +7,15 @@ import torch.nn as nn import pytest +import torch + +import deepspeed import deepspeed.comm as dist from deepspeed.runtime.pipe.topology import PipeDataParallelTopology from deepspeed.runtime.pipe.module import PipelineModule from unit.alexnet_model import AlexNetPipe, train_cifar from unit.common import DistributedTest -from unit.util import skip_on_arch +from unit.util import skip_on_arch, no_child_process_in_deepspeed_io PipeTopo = PipeDataParallelTopology @@ -155,3 +158,95 @@ def test_pipe_use_reentrant(self, topo_config): # the following check could passed on higher version docker: nvcr.io/nvidia/pytorch:23.07-py3(torch2.1.0 cuda12.1) # Check if models have same weights after training # self._check_model_params_equal(base_model, test_model) + + +class DynamicShapeTestLayer(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.fc = nn.Linear(hidden_size, hidden_size) + self.shapes = set() + + def forward(self, x): + self.shapes.add(x.shape) + y = self.fc(x) + return y + + +class DynamicShapeTestModel(nn.Module): + + def __init__(self, n_layers, hidden_size): + super().__init__() + self.layers = nn.ModuleList([DynamicShapeTestLayer(hidden_size) for _ in range(n_layers)]) + + +@pytest.mark.parametrize('topo_config', [ + { + "num_pp": 1, + "num_dp": 4 + }, + { + "num_pp": 2, + "num_dp": 2 + }, + { + "num_pp": 4, + "num_dp": 1 + }, +]) +class TestPipeDynamicShape(DistributedTest): + world_size = 4 + + def test_pipe_base(self, topo_config): + """This test checks if the pipeline engine can handle dynamic shapes correctly. + We pass inputs of different shapes to the pipeline engine. + """ + + n_iter = 10 + n_layers = 4 + n_samples = 1024 + batch_size = 4 + channel_dims = [8, 16, 32, 64] + hidden_size = 16 + + topo = PipeTopo(**topo_config) + + model = DynamicShapeTestModel(n_layers, hidden_size) + model = PipelineModule(layers=model.layers, topology=topo, loss_fn=nn.MSELoss(), dynamic_shape=True) + + # Each batch has different channel dim but we use the same channel dim in the same batch + xs = [ + torch.randn(channel_dims[(i // batch_size) % len(channel_dims)], hidden_size, dtype=torch.float32) + for i in range(n_samples) + ] + ys = [torch.randn_like(x) for x in xs] + + class CustomDataset(torch.utils.data.Dataset): + + def __init__(self, xs, ys): + self.xs = xs + self.ys = ys + + def __len__(self): + return len(self.xs) + + def __getitem__(self, idx): + return self.xs[idx], self.ys[idx] + + dataset = CustomDataset(xs, ys) + + config_dict["train_batch_size"] = batch_size + + with no_child_process_in_deepspeed_io(): + engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=[p for p in model.parameters()], + training_data=dataset) + + for _ in range(n_iter): + _ = engine.train_batch() + + # Check if all layers have seen different shapes + for layer in model.modules(): + if isinstance(layer, DynamicShapeTestLayer): + assert len(layer.shapes) > 1 diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py new file mode 100644 index 000000000000..d75519b67f68 --- /dev/null +++ b/tests/unit/runtime/zero/test_unwrap_model.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +from deepspeed.runtime.zero import unwrap_model_for_generation +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel + +config = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 1, + "offload_param": { + "device": "cpu", + "pin_memory": True + } + } +} + +if get_accelerator().is_fp16_supported(): + config["fp16"] = {"enabled": True, "loss_scale": 138.} +elif get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} + + +class TestUnwrapModel(DistributedTest): + # gather across more than 1 gpu + world_size = 2 + + def test(self): + + def hooks_exist(engine): + if engine.optimizer is not None and hasattr(engine.optimizer, "parameter_offload"): + optimizer_offload = engine.optimizer.parameter_offload + elif engine.optimizer is not None: + optimizer_offload = engine.optimizer + + hooks = 0 + for hook in optimizer_offload.forward_hooks: + hooks += 1 + if hooks > 0: + return True + return False + + model = SimpleModel(hidden_dim=100) + engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config) + + with unwrap_model_for_generation(engine): + # assert no hooks + assert not hooks_exist(engine) + # assert parameters gathered + assert model.linears[0].weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor" + + # assert hooks + assert hooks_exist(engine) diff --git a/tests/unit/util.py b/tests/unit/util.py index feec326ede6c..dba29ed27a4c 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -5,6 +5,8 @@ import pytest import torch + +import deepspeed from deepspeed.accelerator import get_accelerator, is_current_accelerator_supported from deepspeed.git_version_info import torch_info @@ -67,3 +69,22 @@ def required_amp_check(): return False else: return True + + +class no_child_process_in_deepspeed_io: + + def __enter__(self): + # deepspeed_io defaults to creating a dataloader that uses a + # multiprocessing pool. Our tests use pools and we cannot nest pools in + # python. Therefore we're injecting this kwarg to ensure that no pools + # are used in the dataloader. + self.old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io + + def new_method(*args, **kwargs): + kwargs["num_local_io_workers"] = 0 + return self.old_method(*args, **kwargs) + + deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method + + def __exit__(self, *_): + deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method