+
+# Introduction
+
+DeepSpeed is a popular open-source deep learning optimization library that makes distributed training and inference easy, efficient, and effective. DeepSpeed has been widely used to train a variety of state-of-the-art models, including Phi-3, Megatron-Turing-530B, BLOOM-176B, and Arctic because of its rich suite of sophisticated optimizations (e.g., ZeRO, 3D parallelism, MoE, etc.). However, the lack of native support for Microsoft Windows, the most popular operating system, means that DeepSpeed innovations are inaccessible to many AI developers and users. To address this problem, we started an effort to make DeepSpeed run natively with full features on Windows, while ensuring the same ease-of-use enjoyed on Linux.
+
+In this blog, we are pleased to announce some early achievements on this journey: DeepSpeed can now be installed in Windows and run natively for single-GPU training, finetuning, and inferencing. Importantly, both the installation and usage experiences are identical to those on Linux. Furthermore, the finetuning and inferencing workloads demonstrate the functioning of three critical DeepSpeed features, HuggingFace Transformers integration, LoRA support, and CPU Offloading. DeepSpeed on Windows is available in DeepSpeed versions 0.14.5 and above. In the rest of this blog, we present examples to demonstrate these achievements.
+
+# Evaluation Environment
+We conducted the experiments on a Surface Laptop Studio 2 running Windows 11 Version 23H2 and Build 22631.3880. The laptop is equipped with a single NVIDIA RTX A2000 GPU with 4GB VRAM. We used Pytorch version 2.3.0 and HuggingFace Transformers version 4.41.2. The example scripts used are from the [DeepSpeedExamples repo](https://github.com/microsoft/DeepSpeedExamples), therefore you need to clone the repo before running any of the following examples.
+
+# Installation
+DeepSpeed can be installed on Windows in one of two ways. The easier way is to use the pip package manager, while the other is to build from source. The prerequisites for in both cases are Python 3.x and Pytorch with CUDA support.
+
+## Installing via pip
+To install DeepSpeed, simply run: `pip install deepspeed`. This will install the latest version of DeepSpeed (0.14.5 at this time). Unlike the Linux counterpart, the Windows version comes with all the operators already prebuilt, so there is no need to have a CUDA SDK or C++ compiler installed.
+
+
+
+
+
+
+ pip installation of DeepSpeed on Windows.
+
+
+
+## Building from Source
+To build DeepSpeed from source, you need to clone the DeepSpeed repository and run the `build_win.bat` compilation script.
+
+
+## Validating Installation
+Regardless of the installation choice, you can check that the installation was successful by running ds_report. The output should look like this:
+
+
+
+
+
+
+
+ ds_report output confirming Windows installation of DeepSpeed.
+
+
+# Pretraining Examples
+We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed.
+
+## Pretraining CIFAR10
+The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py –deepspeed`. The final output should look something like this:
+
+
+
+
+
+ Pretraining CIFAR10 model on Windows using DeepSpeed.
+
+
+## Pretraining BERT
+The scripts and codes for the BERT pretraining example are available in the following path: DeepSpeedExamples\training\HelloDeepSpeed. You can launch the BERT pretraining experiment using the following command: `deepspeed train_bert_ds.py --checkpoint_dir experiment_deepspeed`. The final output should look like this:
+
+
+
+
+
+
+ Pretraining BERT model on Windows using DeepSpeed.
+
+
+# Fine Tuning Example
+We demonstrate fine tuning capability by using the supervised fine tuning (SFT) step of DeepSpeed-Chat application. We conduct SFT of the HuggingFace facebook/opt-125m model while enabling LoRA and CPU offloading memory optimizations. The command line for running this example is as follows:
+deepspeed training\step1_supervised_finetuning\main.py --model_name_or_path facebook/opt-125m --gradient_accumulation_steps 8 --lora_dim 128 --only_optimize_lora --print_loss --zero_stage 2 --deepspeed --dtype bf16 --offload --output_dir output
+The output should look like this:
+
+
+
+
+
+
+ Supervised Finetuning of facebook/opt-125m model on Windows using DeepSpeed.
+
+
+# Inference Example
+We demonstrate inference capability by using ZeRO-Inference for token generation. ZeRO-Inference reduces hardware cost of inferencing by offloading to CPU or NVMe memories. We use the example scripts here to run token generation using Llama-2-7B model from HuggingFace. We offload the model weights to CPU memory since the 4GB VRAM is insufficient to host both the model and the generation working set. We use the following command line to generate 32 tokens from a prompt of 8 tokens:
+deepspeed run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 64 --prompt-len 8 --gen-len 32 --cpu-offload
+The output will look something like this:
+
+
+
+
+
+
+ LLAMA2-7B token generation on Windows using ZeRO-Inference.
+
+
+# Summary
+Enabling DeepSpeed, a popular deep learning framework, to run natively on Windows, the most popular operating system, is a crucial step towards empowering every person and every organization to benefit from the ongoing AI revolution. In this blog, we have shared early results of our work towards this goal. Although Windows support of DeepSpeed is a work-in-progress, we hope that the above updates are encouraging and already useful to users. The next items on our roadmap include running on multiple GPUs, weight quantization, and performance studies.
+
+# Acknowledgements
+This work is a result of significant contributions from current and former DeepSpeed members including Costin Eseanu, Logan Adams, Elton Zheng, Reza Yazdani Aminabadi, Martin Cai, and Olatunji Ruwase. We also acknowledge the valuable contributions of DeepSpeed users who righteously demanded this feature, provided critical workarounds, partial solutions, and constructive feedback, and most importantly, stuck with us.
diff --git a/blogs/windows/08-2024/media/bert_training.png b/blogs/windows/08-2024/media/bert_training.png
new file mode 100644
index 000000000000..c5935e47747e
Binary files /dev/null and b/blogs/windows/08-2024/media/bert_training.png differ
diff --git a/blogs/windows/08-2024/media/cifar10_training.png b/blogs/windows/08-2024/media/cifar10_training.png
new file mode 100644
index 000000000000..99f3fa25bc70
Binary files /dev/null and b/blogs/windows/08-2024/media/cifar10_training.png differ
diff --git a/blogs/windows/08-2024/media/ds_report.png b/blogs/windows/08-2024/media/ds_report.png
new file mode 100644
index 000000000000..43d82d724ed2
Binary files /dev/null and b/blogs/windows/08-2024/media/ds_report.png differ
diff --git a/blogs/windows/08-2024/media/llama2-7b_inference.png b/blogs/windows/08-2024/media/llama2-7b_inference.png
new file mode 100644
index 000000000000..f5874468a854
Binary files /dev/null and b/blogs/windows/08-2024/media/llama2-7b_inference.png differ
diff --git a/blogs/windows/08-2024/media/opt125m_finetuning.png b/blogs/windows/08-2024/media/opt125m_finetuning.png
new file mode 100644
index 000000000000..ed6d1522e3b3
Binary files /dev/null and b/blogs/windows/08-2024/media/opt125m_finetuning.png differ
diff --git a/blogs/windows/08-2024/media/win_pip_install_deepspeed.png b/blogs/windows/08-2024/media/win_pip_install_deepspeed.png
new file mode 100644
index 000000000000..3b87c95ef144
Binary files /dev/null and b/blogs/windows/08-2024/media/win_pip_install_deepspeed.png differ
diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp
new file mode 100644
index 000000000000..dc820be528d0
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp
@@ -0,0 +1,38 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "deepspeed_aio_op_desc.h"
+
+using namespace std;
+
+io_op_desc_t::io_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate)
+ : _read_op(read_op),
+ _buffer(buffer),
+ _fd(fd),
+ _filename(filename),
+ _file_num_bytes(file_num_bytes),
+ _num_threads(num_threads),
+ _num_bytes_per_thread(file_num_bytes / num_threads),
+ _validate(validate)
+{
+}
+
+char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }
+
+void io_op_desc_t::finish() {}
+
+void io_op_desc_t::validate() {}
+
+void io_op_desc_t::run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config)
+{
+}
diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h
new file mode 100644
index 000000000000..7305f6920c91
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h
@@ -0,0 +1,41 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#ifndef _IO_OP_DESC_T_
+#define _IO_OP_DESC_T_
+#include
+#include
+#include "deepspeed_py_aio.h"
+
+struct io_op_desc_t {
+ const bool _read_op;
+ torch::Tensor _buffer;
+ int _fd;
+ const std::string _filename;
+ const long long int _file_num_bytes;
+ const int _num_threads;
+ const int _num_bytes_per_thread;
+ torch::Tensor _contiguous_buffer;
+ const bool _validate;
+
+ io_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate);
+
+ virtual void run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config);
+
+ virtual char* data_ptr() const;
+
+ virtual void validate();
+
+ virtual void finish();
+};
+#endif // _IO_OP_DESC_T_
diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp
index c852711a28c0..30c3b4914397 100644
--- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp
+++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp
@@ -9,50 +9,8 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include "deepspeed_aio_thread.h"
-#if defined(__ENABLE_CANN__)
-#include "torch_npu/csrc/framework/utils/OpAdapter.h"
-#include "torch_npu/csrc/framework/utils/UtilForOpAdapter.h"
-#endif
-
using namespace std;
-io_op_desc_t::io_op_desc_t(const bool read_op,
- const torch::Tensor& buffer,
- const int fd,
- const char* filename,
- const long long int num_bytes,
- const bool validate)
- : _read_op(read_op),
- _buffer(buffer),
- _fd(fd),
- _filename(filename),
- _num_bytes(num_bytes),
- _validate(validate)
-{
- _cpu_buffer = (_buffer.is_cuda() || _buffer.is_xpu()
-#if defined(__ENABLE_CANN__)
- || torch_npu::utils::is_npu(_buffer)
-#endif
- )
- ? _buffer.to(torch::kCPU).pin_memory()
- : _buffer;
- _contiguous_buffer = _cpu_buffer.contiguous();
-}
-
-char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }
-
-void io_op_desc_t::fini()
-{
- if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); }
- if (_read_op && _buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); }
-#if defined(__ENABLE_CANN__)
- if (_read_op && torch_npu::utils::is_npu(_buffer)) {
- auto device = at::Device("npu:0");
- _buffer.copy_(_cpu_buffer.to(device));
- }
-#endif
-}
-
deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config)
: _tid(tid),
_aio_config(aio_config),
@@ -79,18 +37,7 @@ void deepspeed_aio_thread_t::run()
}
if (next_io_op) {
- const auto base_offset = next_io_op->_num_bytes * _tid;
-
- std::unique_ptr xfer_ctxt(new io_xfer_ctxt(
- next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr()));
-
- if (_aio_config._overlap_events) {
- do_aio_operation_overlap(
- next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- } else {
- do_aio_operation_sequential(
- next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- }
+ next_io_op->run(_tid, _aio_ctxt, &_aio_config);
{
std::lock_guard lock(_complete_sync._mutex);
diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.h b/csrc/aio/py_lib/deepspeed_aio_thread.h
index 20799ecbb018..a192804db13d 100644
--- a/csrc/aio/py_lib/deepspeed_aio_thread.h
+++ b/csrc/aio/py_lib/deepspeed_aio_thread.h
@@ -10,28 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include
#include
#include
-#include "deepspeed_py_aio.h"
-
-struct io_op_desc_t {
- const bool _read_op;
- torch::Tensor _buffer;
- int _fd;
- const std::string _filename;
- const long long int _num_bytes;
- torch::Tensor _cpu_buffer;
- torch::Tensor _contiguous_buffer;
- const bool _validate;
-
- io_op_desc_t(const bool read_op,
- const torch::Tensor& buffer,
- const int fd,
- const char* filename,
- const long long int num_bytes,
- const bool validate);
-
- char* data_ptr() const;
- void fini();
-};
+#include "deepspeed_cpu_op.h"
struct thread_sync_t {
std::mutex _mutex;
diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp
new file mode 100644
index 000000000000..41790b99bb88
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp
@@ -0,0 +1,72 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "deepspeed_cpu_op.h"
+
+using namespace std;
+
+cpu_op_desc_t::cpu_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate)
+ : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate),
+ _cpu_buffer(buffer)
+{
+ // Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory.
+ _use_bounce_buffer = !(_buffer.is_cpu() && _buffer.is_pinned());
+ if (_use_bounce_buffer) {
+ if (_read_op) {
+ auto options = torch::TensorOptions()
+ .dtype(_buffer.dtype())
+ .layout(_buffer.layout())
+ .device(torch::kCPU);
+ _cpu_buffer = torch::empty(_buffer.nbytes(), options).pin_memory();
+ } else {
+ _cpu_buffer = _buffer.to(torch::kCPU).pin_memory();
+ }
+ }
+ _contiguous_buffer = _cpu_buffer.contiguous();
+}
+
+char* cpu_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }
+
+void cpu_op_desc_t::finish()
+{
+ if (_read_op) {
+ if (_buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); }
+ if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); }
+#if defined(__ENABLE_CANN__)
+ if (torch_npu::utils::is_npu(_buffer)) {
+ auto device = at::Device("npu:0");
+ _buffer.copy_(_cpu_buffer.to(device));
+ }
+#endif
+ }
+}
+
+void cpu_op_desc_t::validate()
+{
+ validate_aio_operation(_read_op, _filename.c_str(), data_ptr(), _file_num_bytes);
+}
+
+void cpu_op_desc_t::run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config)
+{
+ assert(tid < _num_threads);
+ const auto base_offset = _num_bytes_per_thread * tid;
+
+ std::unique_ptr xfer_ctxt(
+ new io_xfer_ctxt(_fd, base_offset, _num_bytes_per_thread, data_ptr()));
+
+ if (aio_config->_overlap_events) {
+ do_aio_operation_overlap(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr);
+ } else {
+ do_aio_operation_sequential(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr);
+ }
+}
diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.h b/csrc/aio/py_lib/deepspeed_cpu_op.h
new file mode 100644
index 000000000000..da96dd2b1d50
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_cpu_op.h
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include
+#include "deepspeed_aio_op_desc.h"
+
+struct cpu_op_desc_t : io_op_desc_t {
+ torch::Tensor _cpu_buffer;
+ bool _use_bounce_buffer;
+
+ cpu_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate);
+
+ void run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config);
+
+ char* data_ptr() const;
+
+ void validate();
+
+ void finish();
+};
diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp
index 0556f5aa8168..eac268d33433 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio.cpp
+++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp
@@ -4,9 +4,6 @@
// DeepSpeed Team
/*
-Copyright 2020 The Microsoft DeepSpeed Team
-Licensed under the MIT license.
-
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/
diff --git a/csrc/aio/py_lib/deepspeed_py_aio.h b/csrc/aio/py_lib/deepspeed_py_aio.h
index 11d5225de9f1..ba794db5440d 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio.h
+++ b/csrc/aio/py_lib/deepspeed_py_aio.h
@@ -4,10 +4,7 @@
// DeepSpeed Team
/*
-Copyright 2020 The Microsoft DeepSpeed Team
-Licensed under the MIT license.
-
-Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+Functionality for swapping tensors to/from (NVMe) storage devices.
*/
#include
diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
index 23ddabe260d4..c7ca5e82afde 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
+++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
@@ -4,293 +4,21 @@
// DeepSpeed Team
/*
-Copyright 2020 The Microsoft DeepSpeed Team
-Licensed under the MIT license.
-
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/
#include "deepspeed_py_aio_handle.h"
+#include
using namespace std;
-static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); }
-
deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size,
const int queue_depth,
const bool single_submit,
const bool overlap_events,
const int num_threads)
- : _aio_ctxt(new aio_context(block_size, queue_depth)),
- _single_submit(single_submit),
- _overlap_events(overlap_events),
- _num_threads(num_threads),
- _aio_config(block_size, queue_depth, single_submit, overlap_events, false),
- _num_pending_ops(0),
- _pinned_tensor_mgr(new deepspeed_pin_tensor_t())
-{
- for (auto i = 0; i < num_threads; ++i) {
- _thread_contexts.push_back(std::make_shared(i, _aio_config));
- }
-
- for (auto& ctxt : _thread_contexts) {
- _threads.push_back(std::thread(_start_aio_thread, ctxt));
- }
-}
-
-deepspeed_aio_handle_t::~deepspeed_aio_handle_t()
-{
- _stop_threads();
- for (auto& thr : _threads) { thr.join(); }
-}
-
-const int deepspeed_aio_handle_t::get_block_size() const
-{
- return _aio_ctxt ? _aio_ctxt->_block_size : -1;
-}
-
-const int deepspeed_aio_handle_t::get_queue_depth() const
-{
- return _aio_ctxt ? _aio_ctxt->_queue_depth : -1;
-}
-
-const bool deepspeed_aio_handle_t::get_single_submit() const { return _single_submit; }
-
-const bool deepspeed_aio_handle_t::get_overlap_events() const { return _overlap_events; }
-
-const int deepspeed_aio_handle_t::get_thread_count() const { return _num_threads; }
-
-int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate)
-{
- const auto start_time = std::chrono::high_resolution_clock::now();
-
- assert(_aio_ctxt);
-
- long long num_file_bytes;
- if (-1 == get_file_size(filename, num_file_bytes)) {
- const auto error_code = errno;
- report_file_error(filename, " fstat for read", error_code);
- return -1;
- }
- assert(static_cast(buffer.nbytes()) == num_file_bytes);
-
- const auto fd = open_file(filename, true);
- if (fd == -1) { return -1; }
-
- auto read_buffer = (char*)buffer.data_ptr();
- std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer));
-
- if (_aio_config._overlap_events) {
- do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- } else {
- do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- }
-
- close(fd);
- const std::chrono::duration aio_time =
- std::chrono::high_resolution_clock::now() - start_time;
-
- if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); }
- const std::chrono::duration fn_time =
- std::chrono::high_resolution_clock::now() - start_time;
- std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
- << " call = " << fn_time.count() * 1e6 << std::endl;
- return 0;
-}
-
-int deepspeed_aio_handle_t::write(const torch::Tensor& buffer,
- const char* filename,
- const bool validate)
+ : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, num_threads)
{
- assert(_aio_ctxt);
-
- const auto start_time = std::chrono::high_resolution_clock::now();
-
- const auto fd = open_file(filename, false);
- if (fd == -1) { return -1; }
-
- auto write_buffer = (char*)buffer.data_ptr();
- const auto num_write_bytes = static_cast(buffer.nbytes());
- std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer));
-
- if (_aio_config._overlap_events) {
- do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- } else {
- do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- }
- const std::chrono::duration aio_time =
- std::chrono::high_resolution_clock::now() - start_time;
-
- close(fd);
-
- if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); }
-
- const std::chrono::duration fn_time =
- std::chrono::high_resolution_clock::now() - start_time;
- std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
- << " call = " << fn_time.count() * 1e6 << std::endl;
- return 0;
}
-void deepspeed_aio_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op)
-{
- for (auto& ctxt : _thread_contexts) {
- {
- std::lock_guard lock(ctxt->_work_sync._mutex);
- ctxt->_work_queue.push(scheduled_op);
- }
- ctxt->_work_sync._cond_var.notify_one();
- }
- _num_pending_ops++;
-}
-
-std::shared_ptr deepspeed_aio_handle_t::_wait_for_aio_work()
-{
- std::shared_ptr completed_op = nullptr;
- for (auto& ctxt : _thread_contexts) {
- std::unique_lock lock(ctxt->_complete_sync._mutex);
- ctxt->_complete_sync._cond_var.wait(lock,
- [ctxt] { return !ctxt->_complete_queue.empty(); });
- completed_op = ctxt->_complete_queue.front();
- ctxt->_complete_queue.pop();
- }
- return completed_op;
-}
-
-void deepspeed_aio_handle_t::_stop_threads()
-{
- assert(0 == _num_pending_ops);
- for (auto& ctxt : _thread_contexts) {
- {
- std::lock_guard lock(ctxt->_work_sync._mutex);
- ctxt->_time_to_exit = true;
- }
- ctxt->_work_sync._cond_var.notify_one();
- }
-}
-
-int deepspeed_aio_handle_t::wait()
-{
- assert(_num_pending_ops > 0);
- auto num_completed_ops = 0;
-
- while (_num_pending_ops > 0) {
- auto completed_op = _wait_for_aio_work();
-
- completed_op->fini();
-
- close(completed_op->_fd);
-
- if (completed_op->_validate) {
- validate_aio_operation(completed_op->_read_op,
- completed_op->_filename.c_str(),
- completed_op->data_ptr(),
- _num_threads * completed_op->_num_bytes);
- }
- --_num_pending_ops;
- ++num_completed_ops;
- }
-
- return num_completed_ops;
-}
-
-bool deepspeed_aio_handle_t::_is_valid_parallel_aio_op(const bool read_op,
- const long long int num_bytes)
-{
- const auto op_string = read_op ? "Read" : "Write";
- if (num_bytes % get_thread_count()) {
- std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes
- << " not divisible by thread count = " << get_thread_count() << std::endl;
- return false;
- }
-
- return true;
-}
-
-int deepspeed_aio_handle_t::pread(const torch::Tensor& buffer,
- const char* filename,
- const bool validate,
- const bool async)
-{
- long long num_file_bytes;
- if (-1 == get_file_size(filename, num_file_bytes)) {
- const auto error_code = errno;
- report_file_error(filename, " fstat for read", error_code);
- return -1;
- }
- const auto buffer_bytes = static_cast(buffer.nbytes());
- if (buffer_bytes != num_file_bytes) {
- std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes
- << " != " << num_file_bytes << std::endl;
- }
- assert(static_cast(buffer.nbytes()) == num_file_bytes);
- assert((num_file_bytes % _num_threads) == 0);
-
- if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; }
-
- const auto fd = open_file(filename, true);
- if (fd == -1) { return -1; }
-
- auto scheduled_op = std::make_shared(
- true, buffer, fd, filename, (num_file_bytes / _num_threads), validate);
-
- _schedule_aio_work(scheduled_op);
-
- if (async) { return 0; }
-
- return wait();
-}
-
-int deepspeed_aio_handle_t::pwrite(const torch::Tensor& buffer,
- const char* filename,
- const bool validate,
- const bool async)
-{
- const auto num_write_bytes = static_cast(buffer.nbytes());
- assert((num_write_bytes % _num_threads) == 0);
-
- if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; }
-
- const auto fd = open_file(filename, false);
- if (fd == -1) { return -1; }
-
- auto scheduled_op = std::make_shared(
- false, buffer, fd, filename, (num_write_bytes / _num_threads), validate);
-
- _schedule_aio_work(scheduled_op);
-
- if (async) { return 0; }
-
- return wait();
-}
-
-int deepspeed_aio_handle_t::sync_pread(torch::Tensor& buffer, const char* filename)
-{
- return pread(buffer, filename, false, false);
-}
-
-int deepspeed_aio_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename)
-{
- return pwrite(buffer, filename, false, false);
-}
-
-int deepspeed_aio_handle_t::async_pread(torch::Tensor& buffer, const char* filename)
-{
- return pread(buffer, filename, false, true);
-}
-
-int deepspeed_aio_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename)
-{
- return pwrite(buffer, filename, false, true);
-}
-
-at::Tensor deepspeed_aio_handle_t::new_cpu_locked_tensor(const size_t num_elem,
- const torch::Tensor& example_tensor)
-{
- return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type());
-}
-
-bool deepspeed_aio_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor)
-{
- return _pinned_tensor_mgr->free(locked_tensor);
-}
+deepspeed_aio_handle_t::~deepspeed_aio_handle_t() {}
diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.h b/csrc/aio/py_lib/deepspeed_py_aio_handle.h
index 3a254c3814a2..eb6b90ea22f0 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio_handle.h
+++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.h
@@ -9,21 +9,9 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include
#include
-#include "deepspeed_aio_thread.h"
-#include "deepspeed_pin_tensor.h"
-
-struct deepspeed_aio_handle_t {
- std::unique_ptr _aio_ctxt;
- const bool _single_submit;
- const bool _overlap_events;
- const int _num_threads;
- deepspeed_aio_config_t _aio_config;
-
- std::vector> _thread_contexts;
- std::vector _threads;
- int _num_pending_ops;
- std::unique_ptr _pinned_tensor_mgr;
+#include "deepspeed_py_io_handle.h"
+struct deepspeed_aio_handle_t : deepspeed_io_handle_t {
deepspeed_aio_handle_t(const int block_size,
const int queue_depth,
const bool single_submit,
@@ -31,47 +19,4 @@ struct deepspeed_aio_handle_t {
const int num_threads);
~deepspeed_aio_handle_t();
-
- const int get_block_size() const;
- const int get_queue_depth() const;
- const bool get_single_submit() const;
- const bool get_overlap_events() const;
- const int get_thread_count() const;
-
- int read(torch::Tensor& buffer, const char* filename, const bool validate);
-
- int write(const torch::Tensor& buffer, const char* filename, const bool validate);
-
- int pread(const torch::Tensor& buffer,
- const char* filename,
- const bool validate,
- const bool async);
-
- int pwrite(const torch::Tensor& buffer,
- const char* filename,
- const bool validate,
- const bool async);
-
- int sync_pread(torch::Tensor& buffer, const char* filename);
-
- int sync_pwrite(const torch::Tensor& buffer, const char* filename);
-
- int async_pread(torch::Tensor& buffer, const char* filename);
-
- int async_pwrite(const torch::Tensor& buffer, const char* filename);
-
- // TODO: Make API's args to be shape and dtype.
- torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor);
-
- bool free_cpu_locked_tensor(torch::Tensor&);
-
- int wait();
-
- void _stop_threads();
-
- void _schedule_aio_work(std::shared_ptr scheduled_op);
-
- std::shared_ptr _wait_for_aio_work();
-
- bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes);
};
diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp
index c597b91d05c9..f5480e9d9d83 100644
--- a/csrc/aio/py_lib/deepspeed_py_copy.cpp
+++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp
@@ -4,7 +4,7 @@
// DeepSpeed Team
/*
-Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+Functionality for swapping tensors to/from (NVMe) storage devices.
*/
#include "deepspeed_py_copy.h"
diff --git a/csrc/aio/py_lib/deepspeed_py_copy.h b/csrc/aio/py_lib/deepspeed_py_copy.h
index 19ba28317d00..f443571a3e7b 100644
--- a/csrc/aio/py_lib/deepspeed_py_copy.h
+++ b/csrc/aio/py_lib/deepspeed_py_copy.h
@@ -4,9 +4,6 @@
// DeepSpeed Team
/*
-Copyright 2020 The Microsoft DeepSpeed Team
-Licensed under the MIT license.
-
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/
diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp
new file mode 100644
index 000000000000..bdf2a858d797
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp
@@ -0,0 +1,300 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*
+Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+*/
+
+#include "deepspeed_py_io_handle.h"
+#include
+
+using namespace std;
+
+static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); }
+
+deepspeed_io_handle_t::deepspeed_io_handle_t(const int block_size,
+ const int queue_depth,
+ const bool single_submit,
+ const bool overlap_events,
+ const int num_threads)
+ : _aio_ctxt(new aio_context(block_size, queue_depth)),
+ _single_submit(single_submit),
+ _overlap_events(overlap_events),
+ _num_threads(num_threads),
+ _aio_config(block_size, queue_depth, single_submit, overlap_events, false),
+ _num_pending_ops(0),
+ _pinned_tensor_mgr(new deepspeed_pin_tensor_t())
+{
+ for (auto i = 0; i < num_threads; ++i) {
+ _thread_contexts.push_back(std::make_shared(i, _aio_config));
+ }
+
+ for (auto& ctxt : _thread_contexts) {
+ _threads.push_back(std::thread(_start_aio_thread, ctxt));
+ }
+}
+
+deepspeed_io_handle_t::~deepspeed_io_handle_t()
+{
+ _stop_threads();
+ for (auto& thr : _threads) { thr.join(); }
+}
+
+const int deepspeed_io_handle_t::get_block_size() const
+{
+ return _aio_ctxt ? _aio_ctxt->_block_size : -1;
+}
+
+const int deepspeed_io_handle_t::get_queue_depth() const
+{
+ return _aio_ctxt ? _aio_ctxt->_queue_depth : -1;
+}
+
+const bool deepspeed_io_handle_t::get_single_submit() const { return _single_submit; }
+
+const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_events; }
+
+const int deepspeed_io_handle_t::get_thread_count() const { return _num_threads; }
+
+int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate)
+{
+ const auto start_time = std::chrono::high_resolution_clock::now();
+
+ assert(_aio_ctxt);
+
+ long long num_file_bytes;
+ if (-1 == get_file_size(filename, num_file_bytes)) {
+ const auto error_code = errno;
+ report_file_error(filename, " fstat for read", error_code);
+ return -1;
+ }
+ assert(static_cast(buffer.nbytes()) == num_file_bytes);
+
+ const auto fd = open_file(filename, true);
+ if (fd == -1) { return -1; }
+
+ auto read_buffer = (char*)buffer.data_ptr();
+ std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer));
+
+ if (_aio_config._overlap_events) {
+ do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
+ } else {
+ do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
+ }
+
+ close(fd);
+ const std::chrono::duration aio_time =
+ std::chrono::high_resolution_clock::now() - start_time;
+
+ if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); }
+ const std::chrono::duration fn_time =
+ std::chrono::high_resolution_clock::now() - start_time;
+ std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
+ << " call = " << fn_time.count() * 1e6 << std::endl;
+ return 0;
+}
+
+int deepspeed_io_handle_t::write(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate)
+{
+ assert(_aio_ctxt);
+
+ const auto start_time = std::chrono::high_resolution_clock::now();
+
+ const auto fd = open_file(filename, false);
+ if (fd == -1) { return -1; }
+
+ auto write_buffer = (char*)buffer.data_ptr();
+ const auto num_write_bytes = static_cast(buffer.nbytes());
+ std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer));
+
+ if (_aio_config._overlap_events) {
+ do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
+ } else {
+ do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
+ }
+ const std::chrono::duration aio_time =
+ std::chrono::high_resolution_clock::now() - start_time;
+
+ close(fd);
+
+ if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); }
+
+ const std::chrono::duration fn_time =
+ std::chrono::high_resolution_clock::now() - start_time;
+ std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
+ << " call = " << fn_time.count() * 1e6 << std::endl;
+ return 0;
+}
+
+void deepspeed_io_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op)
+{
+ for (auto& ctxt : _thread_contexts) {
+ {
+ std::lock_guard lock(ctxt->_work_sync._mutex);
+ ctxt->_work_queue.push(scheduled_op);
+ }
+ ctxt->_work_sync._cond_var.notify_one();
+ }
+ _num_pending_ops++;
+}
+
+std::shared_ptr deepspeed_io_handle_t::_wait_for_aio_work()
+{
+ std::shared_ptr completed_op = nullptr;
+ for (auto& ctxt : _thread_contexts) {
+ std::unique_lock lock(ctxt->_complete_sync._mutex);
+ ctxt->_complete_sync._cond_var.wait(lock,
+ [ctxt] { return !ctxt->_complete_queue.empty(); });
+ completed_op = ctxt->_complete_queue.front();
+ ctxt->_complete_queue.pop();
+ }
+ return completed_op;
+}
+
+void deepspeed_io_handle_t::_stop_threads()
+{
+ assert(0 == _num_pending_ops);
+ for (auto& ctxt : _thread_contexts) {
+ {
+ std::lock_guard lock(ctxt->_work_sync._mutex);
+ ctxt->_time_to_exit = true;
+ }
+ ctxt->_work_sync._cond_var.notify_one();
+ }
+}
+
+int deepspeed_io_handle_t::wait()
+{
+ assert(_num_pending_ops > 0);
+ auto num_completed_ops = 0;
+
+ while (_num_pending_ops > 0) {
+ auto completed_op = _wait_for_aio_work();
+
+ if (completed_op->_validate) { completed_op->validate(); }
+
+ completed_op->finish();
+
+ close(completed_op->_fd);
+
+ --_num_pending_ops;
+ ++num_completed_ops;
+ }
+
+ return num_completed_ops;
+}
+
+bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op,
+ const long long int num_bytes)
+{
+ const auto op_string = read_op ? "Read" : "Write";
+ if (num_bytes % get_thread_count()) {
+ std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes
+ << " not divisible by thread count = " << get_thread_count() << std::endl;
+ return false;
+ }
+
+ return true;
+}
+
+std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc(
+ const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const bool validate)
+{
+ return std::make_shared(
+ read_op, buffer, fd, filename, file_num_bytes, _num_threads, validate);
+}
+
+int deepspeed_io_handle_t::pread(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate,
+ const bool async)
+{
+ long long num_file_bytes;
+ if (-1 == get_file_size(filename, num_file_bytes)) {
+ const auto error_code = errno;
+ report_file_error(filename, " fstat for read", error_code);
+ return -1;
+ }
+ const auto buffer_bytes = static_cast(buffer.nbytes());
+ if (buffer_bytes != num_file_bytes) {
+ std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes
+ << " != " << num_file_bytes << std::endl;
+ }
+ assert(static_cast(buffer.nbytes()) == num_file_bytes);
+ assert((num_file_bytes % _num_threads) == 0);
+
+ if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; }
+
+ const auto fd = open_file(filename, true);
+ if (fd == -1) { return -1; }
+
+ auto scheduled_op = _create_io_op_desc(true, buffer, fd, filename, num_file_bytes, validate);
+
+ _schedule_aio_work(scheduled_op);
+
+ if (async) { return 0; }
+
+ return wait();
+}
+
+int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate,
+ const bool async)
+{
+ const auto num_write_bytes = static_cast(buffer.nbytes());
+ assert((num_write_bytes % _num_threads) == 0);
+
+ if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; }
+
+ const auto fd = open_file(filename, false);
+ if (fd == -1) { return -1; }
+
+ auto scheduled_op = _create_io_op_desc(false, buffer, fd, filename, num_write_bytes, validate);
+
+ _schedule_aio_work(scheduled_op);
+
+ if (async) { return 0; }
+
+ return wait();
+}
+
+int deepspeed_io_handle_t::sync_pread(torch::Tensor& buffer, const char* filename)
+{
+ return pread(buffer, filename, false, false);
+}
+
+int deepspeed_io_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename)
+{
+ return pwrite(buffer, filename, false, false);
+}
+
+int deepspeed_io_handle_t::async_pread(torch::Tensor& buffer, const char* filename)
+{
+ return pread(buffer, filename, false, true);
+}
+
+int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename)
+{
+ return pwrite(buffer, filename, false, true);
+}
+
+at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const size_t num_elem,
+ const torch::Tensor& example_tensor)
+{
+ return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type());
+}
+
+bool deepspeed_io_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor)
+{
+ return _pinned_tensor_mgr->free(locked_tensor);
+}
diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h
new file mode 100644
index 000000000000..2974ebe87bfc
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h
@@ -0,0 +1,85 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*
+Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+*/
+
+#include
+#include
+#include "deepspeed_aio_thread.h"
+#include "deepspeed_pin_tensor.h"
+
+struct deepspeed_io_handle_t {
+ std::unique_ptr _aio_ctxt;
+ const bool _single_submit;
+ const bool _overlap_events;
+ const int _num_threads;
+ deepspeed_aio_config_t _aio_config;
+
+ std::vector> _thread_contexts;
+ std::vector _threads;
+ int _num_pending_ops;
+ std::unique_ptr _pinned_tensor_mgr;
+
+ deepspeed_io_handle_t(const int block_size,
+ const int queue_depth,
+ const bool single_submit,
+ const bool overlap_events,
+ const int num_threads);
+
+ virtual ~deepspeed_io_handle_t() = 0;
+
+ const int get_block_size() const;
+ const int get_queue_depth() const;
+ const bool get_single_submit() const;
+ const bool get_overlap_events() const;
+ const int get_thread_count() const;
+
+ int read(torch::Tensor& buffer, const char* filename, const bool validate);
+
+ int write(const torch::Tensor& buffer, const char* filename, const bool validate);
+
+ int pread(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate,
+ const bool async);
+
+ int pwrite(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate,
+ const bool async);
+
+ int sync_pread(torch::Tensor& buffer, const char* filename);
+
+ int sync_pwrite(const torch::Tensor& buffer, const char* filename);
+
+ int async_pread(torch::Tensor& buffer, const char* filename);
+
+ int async_pwrite(const torch::Tensor& buffer, const char* filename);
+
+ // TODO: Make API's args to be shape and dtype.
+ torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor);
+
+ bool free_cpu_locked_tensor(torch::Tensor&);
+
+ int wait();
+
+ void _stop_threads();
+
+ void _schedule_aio_work(std::shared_ptr scheduled_op);
+
+ std::shared_ptr _wait_for_aio_work();
+
+ bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes);
+
+ virtual std::shared_ptr _create_io_op_desc(
+ const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const bool validate);
+};
diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp
old mode 100755
new mode 100644
index 9033549bc0d2..3171d0c6bf3c
--- a/csrc/aio/py_lib/py_ds_aio.cpp
+++ b/csrc/aio/py_lib/py_ds_aio.cpp
@@ -10,6 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include
#include "deepspeed_py_aio_handle.h"
#include "deepspeed_py_copy.h"
+using namespace pybind11::literals;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
@@ -20,7 +21,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("deepspeed_memcpy", &deepspeed_py_memcpy, "DeepSpeed Memory Copy");
py::class_(m, "aio_handle")
- .def(py::init())
+ .def(py::init(),
+ "AIO handle constructor",
+ "block_size"_a = 1024 * 1024,
+ "queue_depth"_a = 128,
+ "single_submit"_a = false,
+ "overlap_events"_a = false,
+ "num_threads"_a = 1)
.def("get_block_size", &deepspeed_aio_handle_t::get_block_size)
.def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth)
@@ -28,19 +35,74 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events)
.def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count)
- .def("read", &deepspeed_aio_handle_t::read)
- .def("write", &deepspeed_aio_handle_t::write)
+ .def("read",
+ &deepspeed_aio_handle_t::read,
+ "Synchronous and non-parallel file read. Returns count of completed read ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a)
- .def("pread", &deepspeed_aio_handle_t::pread)
- .def("pwrite", &deepspeed_aio_handle_t::pwrite)
+ .def("write",
+ &deepspeed_aio_handle_t::write,
+ "Synchronous and non-parallel file write. Returns count of completed write ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a)
- .def("sync_pread", &deepspeed_aio_handle_t::sync_pread)
- .def("sync_pwrite", &deepspeed_aio_handle_t::sync_pwrite)
- .def("async_pread", &deepspeed_aio_handle_t::async_pread)
- .def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite)
+ .def("pread",
+ &deepspeed_aio_handle_t::pread,
+ "Parallel file read with option of parallelism. Returns count of completed read ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a,
+ "async"_a)
- .def("new_cpu_locked_tensor", &deepspeed_aio_handle_t::new_cpu_locked_tensor)
- .def("free_cpu_locked_tensor", &deepspeed_aio_handle_t::free_cpu_locked_tensor)
+ .def("pwrite",
+ &deepspeed_aio_handle_t::pwrite,
+ "Parallel file write with option of parallelism. Returns count of completed write ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a,
+ "async"_a)
- .def("wait", &deepspeed_aio_handle_t::wait);
+ .def("sync_pread",
+ &deepspeed_aio_handle_t::sync_pread,
+ "Synchrononous parallel file read. Returns count of completed read ops",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("sync_pwrite",
+ &deepspeed_aio_handle_t::sync_pwrite,
+ "Synchronous parallel file write. Returns count of completed write ops",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("async_pread",
+ &deepspeed_aio_handle_t::async_pread,
+ "Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and "
+ "following wait() returns count of completed ops.",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("async_pwrite",
+ &deepspeed_aio_handle_t::async_pwrite,
+ "Asynchronous parallel file write. Returns 0 on success, and following wait() returns "
+ "count of completed ops.",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("new_cpu_locked_tensor",
+ &deepspeed_aio_handle_t::new_cpu_locked_tensor,
+ "Allocate pinned CPU tensor.",
+ "num_elem"_a,
+ "example_tenosr"_a)
+
+ .def("free_cpu_locked_tensor",
+ &deepspeed_aio_handle_t::free_cpu_locked_tensor,
+ "Free pinned CPU tensor.",
+ "tensor"_a)
+
+ .def("wait",
+ &deepspeed_aio_handle_t::wait,
+ "Wait for (ongoing) asynchronous operations to complete");
}
diff --git a/csrc/aio/py_test/aio_bench_generate_param.py b/csrc/aio/py_test/aio_bench_generate_param.py
index 09d0e03c7ef6..7a0ab59ed73d 100644
--- a/csrc/aio/py_test/aio_bench_generate_param.py
+++ b/csrc/aio/py_test/aio_bench_generate_param.py
@@ -41,9 +41,9 @@ def convert_to_param(key):
return {
"single_submit": "true" if key[0] == "single" else "false",
"overlap_events": "true" if key[1] == "overlap" else "false",
- "thread_count": int(key[3]),
- "queue_depth": int(key[4]),
- "block_size": int(key[5])
+ "thread_count": int(key[5]),
+ "queue_depth": int(key[3]),
+ "block_size": int(key[4])
}
diff --git a/csrc/aio/py_test/aio_bench_perf_sweep.py b/csrc/aio/py_test/aio_bench_perf_sweep.py
index 7d55f7ded65c..ba95150b11e1 100644
--- a/csrc/aio/py_test/aio_bench_perf_sweep.py
+++ b/csrc/aio/py_test/aio_bench_perf_sweep.py
@@ -10,75 +10,47 @@
import argparse
import json
import itertools
-import subprocess
import shutil
-from test_ds_aio_utils import refine_integer_value
+from ds_aio_job import Job, run_job
from perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \
- READ_IO_DIR, WRITE_IO_DIR, READ_LOG_DIR, WRITE_LOG_DIR
+ READ_LOG_DIR, WRITE_LOG_DIR
from deepspeed.ops.op_builder import AsyncIOBuilder
OTHER_OPTIONS = '--handle'
PERF_SCRIPT = 'test_ds_aio.py'
DEFAULT_SWEEP_CONFIG = {
- "block_size": ["128K", "256K"],
- "queue_depth": [4, 16, 32],
- "overlap_events": [True, False],
- "io_parallel": [2, 8],
- "single_submit": [False]
+ "block_size": ["128K", "1M"],
+ "queue_depth": [32, 64, 128],
+ "sequential_requests": [True, False],
+ "single_submit": [False],
+ "io_parallel": [1, 2, 8],
}
-class Job(object):
-
- def __init__(self, cmd_line, output_file=None, work_dir=None):
- self.cmd_line = cmd_line
- self.output_file = output_file
- self.work_dir = work_dir
- self.output_fd = None
-
- def cmd(self):
- return self.cmd_line
-
- def get_stdout(self):
- return self.output_fd
-
- def get_stderr(self):
- return self.output_fd
-
- def get_cwd(self):
- return self.work_dir
-
- def open_output_file(self):
- if self.output_file is not None:
- self.output_fd = open(self.output_file, 'w')
-
- def close_output_file(self):
- if self.output_fd is not None:
- self.output_fd.close()
- self.output_fd = None
-
-
class SweepConfig(object):
def __init__(self, args):
- self.nvme_dir = args.nvme_dir
- self.io_size = args.io_size
+ self.folder_to_device_mapping = get_ftd_map(args.nvme_dir)
self.search_space = get_sweep_config_dict(args.sweep_config)
+ self.search_space.update(self.folder_to_device_mapping)
self.read = not args.no_read
self.write = not args.no_write
self.flush_cache = not args.no_sudo
self.log_dir = args.log_dir
- self.loops = args.loops
- self.other_options = f'{OTHER_OPTIONS} --loops {args.loops}'
+ self.other_options = f'{OTHER_OPTIONS} --loops {args.loops} --io_size {args.io_size}'
+ if args.gpu:
+ self.other_options += ' --gpu'
+ if args.gds:
+ self.other_options += ' --use_gds'
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--nvme_dir',
+ nargs='+',
required=True,
- type=str,
help='Directory in which to perform I/O tests. A writeable directory on a NVMe device.')
parser.add_argument('--sweep_config', type=str, default=None, help='Performance sweep configuration json file.')
@@ -92,6 +64,10 @@ def parse_arguments():
default="400M",
help='Number of I/O bytes to read/write for performance measurements.')
+ parser.add_argument('--gpu', action='store_true', help='Test tensor transfers between GPU device and NVME device.')
+
+ parser.add_argument('--gds', action='store_true', help='Run the sweep over NVIDIA GPUDirectStorage operator')
+
parser.add_argument(
'--no_sudo',
action='store_true',
@@ -118,6 +94,12 @@ def dump_cmd_lines(cmd_lines):
print(f'{i}: {cmd}')
+def get_ftd_map(nvme_dir_list):
+ ftd_list = [f'{dir}:{dev}' for dev, dir in enumerate(nvme_dir_list)]
+ ftd_arg = [' '.join(ftd for ftd in ftd_list)]
+ return {'folder_to_device_mapping': ftd_arg}
+
+
def get_sweep_config_dict(sweep_config_json):
if sweep_config_json is None:
return DEFAULT_SWEEP_CONFIG
@@ -148,16 +130,6 @@ def flatten_options(key, value_list):
return cmd_list
-def run_job(job):
- args = ' '.join(job.cmd())
- print(f'args = {args}')
- job.open_output_file()
- proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
- job.close_output_file()
- assert proc.returncode == 0, \
- f"This command failed: {job.cmd()}"
-
-
def launch_sweep(sweep_jobs, sync_job, flush_cache_job):
for perf_job in sweep_jobs:
if flush_cache_job is not None:
@@ -176,7 +148,12 @@ def create_cmd_tags(cmd_line):
if len(fields) == 1:
tags[fields[0]] = None
elif len(fields) == 2:
- tags[fields[0]] = fields[1]
+ if fields[0] == '--folder_to_device_mapping':
+ tags[fields[0]] = len(fields[1:])
+ else:
+ tags[fields[0]] = fields[1]
+ elif len(fields) > 2:
+ tags[fields[0]] = len(fields[1:])
return tags
@@ -184,16 +161,16 @@ def get_log_file(io_op_desc, cmd_line):
QUEUE_DEPTH = "--queue_depth"
BLOCK_SIZE = "--block_size"
SINGLE_SUBMIT = "--single_submit"
- OVERLAP_EVENTS = "--overlap_events"
- THREAD_COUNT = "--threads"
+ SEQUENTIAL_REQUESTS = "--sequential_requests"
+ FTD_MAP = "--folder_to_device_mapping"
IO_PARALLEL = "--io_parallel"
tag_map = {
QUEUE_DEPTH: "d",
BLOCK_SIZE: "bs",
SINGLE_SUBMIT: "single",
- OVERLAP_EVENTS: "overlap",
- THREAD_COUNT: "t",
+ SEQUENTIAL_REQUESTS: "sequential",
+ FTD_MAP: "ftd",
IO_PARALLEL: "p"
}
@@ -201,14 +178,14 @@ def get_log_file(io_op_desc, cmd_line):
QUEUE_DEPTH: 1,
BLOCK_SIZE: "1M",
SINGLE_SUBMIT: "block",
- OVERLAP_EVENTS: "sequential",
- THREAD_COUNT: 1,
+ SEQUENTIAL_REQUESTS: "overlap",
+ FTD_MAP: 1,
IO_PARALLEL: 1
}
def get_default_value(tag):
value = tag_default[tag]
- if tag in [SINGLE_SUBMIT, OVERLAP_EVENTS]:
+ if tag in [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS]:
return value
return f'{tag_map[tag]}{value}'
@@ -218,7 +195,7 @@ def get_config_value(tag, value):
return tag_key
return f'{tag_key}{value}'
- tag_list = [SINGLE_SUBMIT, OVERLAP_EVENTS, THREAD_COUNT, IO_PARALLEL, QUEUE_DEPTH, BLOCK_SIZE]
+ tag_list = [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS, FTD_MAP, QUEUE_DEPTH, BLOCK_SIZE, IO_PARALLEL]
log_tags = [io_op_desc]
cmd_tags = create_cmd_tags(cmd_line)
for tag in tag_list:
@@ -252,40 +229,14 @@ def async_io_setup():
return AsyncIOBuilder().is_compatible()
-def get_block_size_and_count(io_bytes):
- block_size = 1
- block_count = io_bytes
- bytes_in_KB = 1024
-
- while block_count % bytes_in_KB == 0:
- block_size *= bytes_in_KB
- block_count /= bytes_in_KB
-
- return int(block_size), int(block_count)
-
-
-def create_read_file(sweep_config):
- read_folder = os.path.join(sweep_config.nvme_dir, f'{READ_IO_DIR}')
- os.makedirs(read_folder, exist_ok=True)
- read_file_name = os.path.join(read_folder, f'random_{sweep_config.io_size}B.pt')
- block_size, block_count = get_block_size_and_count(refine_integer_value(sweep_config.io_size))
- dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={read_file_name} bs={block_size} count={block_count}'])
- print(f'[Start] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....')
- run_job(dd_job)
- print(f'[Done] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....')
- return read_folder, read_file_name
-
-
def remove_folder(folder):
assert os.path.isdir(folder), f"Error: cannot remove {folder} - folder not found"
shutil.rmtree(folder)
def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
- read_folder, read_file_name = create_read_file(sweep_config)
- read_option = f'--read_file {read_file_name}'
- read_cmd_lines = [[f'{read_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines]
- #dump_cmd_lines(read_cmd_lines)
+ read_cmd_lines = [[f'--read {sweep_config.other_options}'] + cmd for cmd in cmd_lines]
+ #dump_cmd_lines(cmd_lines)
log_folder = os.path.join(sweep_config.log_dir, f'{READ_LOG_DIR}')
os.makedirs(log_folder, exist_ok=True)
@@ -294,15 +245,9 @@ def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job)
- remove_folder(read_folder)
-
def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
- write_folder = os.path.join(sweep_config.nvme_dir, f'{WRITE_IO_DIR}')
- os.makedirs(write_folder, exist_ok=True)
- write_file_name = os.path.join(write_folder, f'random_{sweep_config.io_size}B.pt')
- write_option = f'--write_size {sweep_config.io_size} --write_file {write_file_name}'
- write_cmd_lines = [[f'{write_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines]
+ write_cmd_lines = [[f'{sweep_config.other_options}'] + cmd for cmd in cmd_lines]
#dump_cmd_lines(write_cmd_lines)
log_folder = os.path.join(sweep_config.log_dir, f'{WRITE_LOG_DIR}')
@@ -312,8 +257,6 @@ def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job)
- remove_folder(write_folder)
-
def main():
print("Running performance sweep of deepspeed nvme library")
diff --git a/csrc/aio/py_test/ds_aio_args.py b/csrc/aio/py_test/ds_aio_args.py
new file mode 100644
index 000000000000..346feabe4810
--- /dev/null
+++ b/csrc/aio/py_test/ds_aio_args.py
@@ -0,0 +1,175 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""
+Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
+"""
+
+import argparse
+import os
+from test_ds_aio_utils import refine_integer_value
+from deepspeed.accelerator import get_accelerator
+
+MAPPING_DELIMITER = ':'
+
+
+def refine_args(args):
+ if args.io_size and type(args.io_size) == str:
+ args.io_size = refine_integer_value(args.io_size)
+
+ if args.block_size and type(args.block_size) == str:
+ args.block_size = refine_integer_value(args.block_size)
+
+ return args
+
+
+def _get_mapping_dict(args):
+ if args.folder is not None:
+ d = {i: args.folder for i in range(args.multi_process)}
+ else:
+ d = {}
+ for m in args.folder_to_device_mapping:
+ fields = m.split(MAPPING_DELIMITER)
+ d[fields[1]] = fields[0]
+
+ return d
+
+
+def _validate_folder_mapping(args):
+ no_error = True
+ error_messages = []
+ invalid_mappings = [m for m in args.folder_to_device_mapping if MAPPING_DELIMITER not in m]
+ if len(invalid_mappings) > 0:
+ error_messages.append(
+ f'Missing delimiter ({MAPPING_DELIMITER}) in folder_to_device_mapping {invalid_mappings}')
+ no_error = False
+
+ folder_list = [m.split(MAPPING_DELIMITER)[0] for m in args.folder_to_device_mapping]
+ invalid_folders = [d for d in folder_list if not os.path.exists(d)]
+ if len(invalid_folders) > 0:
+ error_messages.append(f'Invalid folders in folder_to_device_mapping: {invalid_folders}')
+ no_error = False
+
+ if args.gpu:
+ device_list = [int(m.split(MAPPING_DELIMITER)[1]) for m in args.folder_to_device_mapping]
+ invalid_device_list = [dev_id for dev_id in device_list if not dev_id < get_accelerator().device_count()]
+ if len(invalid_device_list) > 0:
+ error_messages.append(f'Invalid device ids in folder_to_device_mapping: {invalid_device_list}')
+ no_error = False
+
+ return no_error, error_messages
+
+
+def validate_args(args):
+ no_error = True
+ error_messages = []
+
+ if args.folder is not None and len(args.folder_to_device_mapping) > 0:
+ error_messages.append(f'--folder and --folder_to_device_mapping cannot be specified together.')
+ no_error = False
+ elif args.folder is None and len(args.folder_to_device_mapping) == 0:
+ error_messages.append(f'At least one of --folder or --folder_to_device_mapping must be specified.')
+ no_error = False
+
+ # Validate --folder
+ if args.folder is not None and not os.path.exists(args.folder):
+ no_error = False
+ error_messages.append(f'Invalid folder in --folder: {args.folder} ')
+
+ # Validate --folder_mapping_to_device
+ if len(args.folder_to_device_mapping) > 0:
+ no_mapping_error, mapping_error_messages = _validate_folder_mapping(args)
+ no_error = no_error and no_mapping_error
+ error_messages += mapping_error_messages
+
+ # Validate --gpu, --use_gds
+ if args.use_gds and not args.gpu:
+ error_messages.append(f'--gpu must be set to transfer with --use_gds')
+ no_error = False
+
+ if not no_error:
+ print(f'Found {len(error_messages)} validation errors')
+ for i, msg in enumerate(error_messages):
+ print(f'{i+1}: {msg}')
+
+ return no_error
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--folder', default=None, type=str, help='Folder to use for I/O.')
+
+ parser.add_argument('--folder_to_device_mapping',
+ default=[],
+ nargs='+',
+ help='Specification of mapping of folder to (gpu) device id, (ignored for cpu accesses).'
+ 'Can be specified multiple times for multi-process runs,'
+ 'e.g. --folder_to_device_mapping /mnt/nvme0:0 --folder_to_device_mapping /mnt/nvme1:15 --gpu'
+ 'means access /mnt/nvme0 with gpu 0 and /mnt/nvme1 with gpu 15')
+
+ parser.add_argument('--io_size', type=str, default=None, required=True, help='Number of bytes to read or write.')
+
+ parser.add_argument('--read', action='store_true', help='Perform read I/O (default is write)')
+
+ parser.add_argument('--multi_process',
+ type=int,
+ default=1,
+ help='Number of parallel processes doing I/O (default 1).')
+
+ parser.add_argument('--block_size',
+ type=str,
+ default='1M',
+ help='I/O block size. Can use K, M, or G suffix (default 1M for 1 megabytes).')
+
+ parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth (default 32).')
+
+ parser.add_argument('--single_submit',
+ action='store_true',
+ help='Submit I/O requests in singles (default is submit queue_depth amount at once.).')
+
+ parser.add_argument(
+ '--sequential_requests',
+ action='store_true',
+ help=
+ 'Delay I/O request submission until completion of prior requests (default is overlap I/O submission and completion requests.).'
+ )
+
+ parser.add_argument('--validate', action='store_true', help='Perform validation of I/O transfer in library.')
+
+ parser.add_argument('--handle', action='store_true', help='Use AIO handle.')
+
+ parser.add_argument('--loops', type=int, default=3, help='Count of operation repetitions')
+
+ parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism')
+
+ parser.add_argument('--gpu', action='store_true', help='Use GPU memory')
+
+ parser.add_argument('--use_gds', action='store_true', help='Enable GDS AIO')
+
+ parser.add_argument('--slow_bounce_buffer',
+ action='store_true',
+ help='For GPU memory transfers, measure impact of bounce buffer pinning on critical path.')
+
+ args = parser.parse_args()
+ print(f'args = {args}')
+ return args
+
+
+def get_validated_args():
+ args = parse_arguments()
+ args = refine_args(args)
+ if not validate_args(args):
+ quit()
+ print(f'Successful validation of command line arguments')
+
+ peer_tag = 'gpu' if args.gpu else 'process'
+ args.mapping_dict = _get_mapping_dict(args)
+ args.mapping_list = [(device_id, folder) for device_id, folder in args.mapping_dict.items()]
+ assert len(args.mapping_dict) == len(args.mapping_list)
+ print(f'Configuring {len(args.mapping_list)} {peer_tag} to folder mapping')
+ for i, (device_id, folder) in enumerate(args.mapping_list):
+ print(f'[{i}]: {peer_tag} {device_id} <----> {folder}')
+
+ return args
diff --git a/csrc/aio/py_test/ds_aio_basic.py b/csrc/aio/py_test/ds_aio_basic.py
index ad2a4349cd0c..9b3c7cbfc49f 100755
--- a/csrc/aio/py_test/ds_aio_basic.py
+++ b/csrc/aio/py_test/ds_aio_basic.py
@@ -9,10 +9,9 @@
import torch
import os
import time
+from deepspeed.ops.aio import AsyncIOBuilder
from multiprocessing import Pool, Barrier
from test_ds_aio_utils import report_results, task_log, task_barrier
-from deepspeed.accelerator import get_accelerator
-from deepspeed.ops.op_builder import AsyncIOBuilder
def pre_basic(args, tid, read_op):
@@ -21,7 +20,7 @@ def pre_basic(args, tid, read_op):
file = args.read_file if read_op else f'{args.write_file}.{tid}'
task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
- buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu'))
+ buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
ctxt = {}
@@ -56,7 +55,7 @@ def main_basic_read(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_read(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
- args.single_submit, args.overlap_events, args.validate)
+ args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
@@ -67,7 +66,7 @@ def main_basic_write(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_write(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
- args.single_submit, args.overlap_events, args.validate)
+ args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
@@ -90,16 +89,17 @@ def get_schedule(args, read_op):
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
+ num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
@@ -107,14 +107,14 @@ def _aio_handle_tasklet(pool_params):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
@@ -125,9 +125,10 @@ def _init_tasklet(b):
def aio_basic_multiprocessing(args, read_op):
- b = Barrier(args.threads)
- pool_params = [(args, p, read_op) for p in range(args.threads)]
- with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p:
+ num_processes = len(args.mapping_dict)
+ b = Barrier(num_processes)
+ pool_params = [(args, p, read_op) for p in range(num_processes)]
+ with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)
diff --git a/csrc/aio/py_test/ds_aio_handle.py b/csrc/aio/py_test/ds_aio_handle.py
index d35b2713edae..f4a179deb9ec 100755
--- a/csrc/aio/py_test/ds_aio_handle.py
+++ b/csrc/aio/py_test/ds_aio_handle.py
@@ -10,40 +10,56 @@
import os
import time
from multiprocessing import Pool, Barrier
-from test_ds_aio_utils import report_results, task_log, task_barrier
+from deepspeed.ops.aio import AsyncIOBuilder
+from deepspeed.ops.op_builder import GDSBuilder
+from test_ds_aio_utils import report_results, task_log, task_barrier, create_filename, create_file
from deepspeed.accelerator import get_accelerator
-from deepspeed.ops.op_builder import AsyncIOBuilder
+
+BUFFER = 'buffer'
+BOUNCE_BUFFER = 'bounce_buffer'
def pre_handle(args, tid, read_op):
io_string = "Read" if read_op else "Write"
- num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size
- file = args.read_file if read_op else f'{args.write_file}.{tid}'
-
- io_parallel = args.io_parallel if args.io_parallel else 1
- handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
- args.overlap_events, io_parallel)
- task_log(tid, f'Created deepspeed aio handle')
-
+ gds = True if args.use_gds else False
+ device_id, folder = args.mapping_list[tid]
+ filename = create_filename(folder, args.read, args.io_size, tid)
+ if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
+ create_file(filename, args.io_size)
+
+ task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
+ bounce_buffer = None
if args.gpu:
- buffer = torch.empty(num_bytes, dtype=torch.uint8, device=get_accelerator().device_name())
+ device_name = get_accelerator().device_name(device_id)
+ buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=device_name)
+ if not (args.slow_bounce_buffer or gds):
+ bounce_buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8,
+ device='cpu').pin_memory()
else:
- if args.use_accelerator_pin_memory:
- buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu'))
- else:
- buffer = handle.new_cpu_locked_tensor(num_bytes, torch.empty(0, dtype=torch.uint8))
+ buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device='cpu').pin_memory()
+ task_log(tid,
+ f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}',
+ force=True)
- task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
+ io_parallel = args.io_parallel if args.io_parallel else 1
+ if gds:
+ handle = GDSBuilder().load().gds_handle(args.block_size, args.queue_depth, args.single_submit,
+ not args.sequential_requests, io_parallel)
+ handle.pin_device_tensor(buffer)
+ else:
+ handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
+ not args.sequential_requests, io_parallel)
+ task_log(tid, f'created deepspeed aio handle')
ctxt = {}
- ctxt['file'] = file
- ctxt['num_bytes'] = num_bytes
+ ctxt['file'] = filename
+ ctxt['num_bytes'] = args.io_size
ctxt['handle'] = handle
- ctxt['buffer'] = buffer
+ ctxt['gds'] = gds
+ ctxt[BUFFER] = buffer
+ ctxt[BOUNCE_BUFFER] = bounce_buffer
ctxt['elapsed_sec'] = 0
- task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
-
return ctxt
@@ -61,8 +77,12 @@ def pre_handle_write(pool_params):
def post_handle(pool_params):
_, _, ctxt = pool_params
- ctxt["buffer"].detach()
- ctxt["buffer"] = None
+ for buf in [BUFFER, BOUNCE_BUFFER]:
+ if ctxt[buf] is not None:
+ if ctxt['gds']:
+ ctxt['handle'].unpin_device_tensor(ctxt[buf])
+ ctxt[buf].detach()
+ ctxt[buf] = None
return ctxt
@@ -71,20 +91,31 @@ def main_parallel_read(pool_params):
handle = ctxt['handle']
start_time = time.time()
- ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True)
+ dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
+ ret = handle.pread(ctxt[dest_buffer], ctxt['file'], args.validate, True)
assert ret != -1
handle.wait()
+ if dest_buffer == BOUNCE_BUFFER:
+ ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
-
return ctxt
def main_parallel_write(pool_params):
args, tid, ctxt = pool_params
+ # Avoid overwriting existing files as it could be artificially faster
+ if os.path.isfile(ctxt['file']):
+ os.remove(ctxt['file'])
+
handle = ctxt['handle']
start_time = time.time()
- ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True)
+ if ctxt[BOUNCE_BUFFER] is not None:
+ source_buffer = BOUNCE_BUFFER
+ ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
+ else:
+ source_buffer = BUFFER
+ ret = handle.pwrite(ctxt[source_buffer], ctxt['file'], args.validate, True)
assert ret != -1
handle.wait()
end_time = time.time()
@@ -98,8 +129,11 @@ def main_handle_read(pool_parms):
handle = ctxt['handle']
start_time = time.time()
- ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate)
+ dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
+ ret = handle.read(ctxt[dest_buffer], ctxt['file'], args.validate)
assert ret != -1
+ if dest_buffer == BOUNCE_BUFFER:
+ ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
@@ -108,9 +142,18 @@ def main_handle_read(pool_parms):
def main_handle_write(pool_parms):
args, tid, ctxt = pool_parms
+ # Avoid overwriting existing files as it could be artificially faster
+ if os.path.isfile(ctxt['file']):
+ os.remove(ctxt['file'])
+
handle = ctxt['handle']
start_time = time.time()
- ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate)
+ if ctxt[BOUNCE_BUFFER] is not None:
+ source_buffer = BOUNCE_BUFFER
+ ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
+ else:
+ source_buffer = BUFFER
+ ret = handle.write(ctxt[source_buffer], ctxt['file'], args.validate)
assert ret != -1
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
@@ -123,27 +166,28 @@ def get_schedule(args, read_op):
if read_op:
schedule['pre'] = pre_handle_read
schedule['post'] = post_handle
- schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read
+ schedule['main'] = main_parallel_read
else:
schedule['pre'] = pre_handle_write
schedule['post'] = post_handle
- schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write
+ schedule['main'] = main_parallel_write
return schedule
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
+ num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
@@ -151,14 +195,14 @@ def _aio_handle_tasklet(pool_params):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
@@ -169,9 +213,10 @@ def _init_tasklet(b):
def aio_handle_multiprocessing(args, read_op):
- b = Barrier(args.threads)
- pool_params = [(args, p, read_op) for p in range(args.threads)]
- with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p:
+ num_processes = len(args.mapping_dict)
+ b = Barrier(num_processes)
+ pool_params = [(args, p, read_op) for p in range(num_processes)]
+ with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)
diff --git a/csrc/aio/py_test/ds_aio_job.py b/csrc/aio/py_test/ds_aio_job.py
new file mode 100644
index 000000000000..bbddee1bf26d
--- /dev/null
+++ b/csrc/aio/py_test/ds_aio_job.py
@@ -0,0 +1,48 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""
+Functionality of swapping tensors to/from (NVMe) storage devices.
+"""
+import subprocess
+
+
+class Job(object):
+
+ def __init__(self, cmd_line, output_file=None, work_dir=None):
+ self.cmd_line = cmd_line
+ self.output_file = output_file
+ self.work_dir = work_dir
+ self.output_fd = None
+
+ def cmd(self):
+ return self.cmd_line
+
+ def get_stdout(self):
+ return self.output_fd
+
+ def get_stderr(self):
+ return self.output_fd
+
+ def get_cwd(self):
+ return self.work_dir
+
+ def open_output_file(self):
+ if self.output_file is not None:
+ self.output_fd = open(self.output_file, 'w')
+
+ def close_output_file(self):
+ if self.output_fd is not None:
+ self.output_fd.close()
+ self.output_fd = None
+
+
+def run_job(job):
+ args = ' '.join(job.cmd())
+ print(f'args = {args}')
+ job.open_output_file()
+ proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
+ job.close_output_file()
+ assert proc.returncode == 0, \
+ f"This command failed: {job.cmd()}"
diff --git a/csrc/aio/py_test/run_read_sweep.sh b/csrc/aio/py_test/run_read_sweep.sh
index b9d7e050454a..59d82996a0e2 100755
--- a/csrc/aio/py_test/run_read_sweep.sh
+++ b/csrc/aio/py_test/run_read_sweep.sh
@@ -1,13 +1,22 @@
#!/bin/bash
-if [[ $# -ne 2 ]]; then
- echo "Usage: $0