diff --git a/cmake/Find/Findibverbs.cmake b/cmake/Find/Findibverbs.cmake new file mode 100644 index 000000000..7e06cd66b --- /dev/null +++ b/cmake/Find/Findibverbs.cmake @@ -0,0 +1,17 @@ +find_path(IBVERBS_INCLUDE_DIRS + NAMES infiniband/verbs.h + HINTS + ${IBVERBS_INCLUDE_DIR} + ${IBVERBS_ROOT_DIR} + ${IBVERBS_ROOT_DIR}/include) + +find_library(IBVERBS_LIBRARIES + NAMES ibverbs + HINTS + ${IBVERBS_LIB_DIR} + ${IBVERBS_ROOT_DIR} + ${IBVERBS_ROOT_DIR}/lib) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(ibverbs DEFAULT_MSG IBVERBS_INCLUDE_DIRS IBVERBS_LIBRARIES) +mark_as_advanced(IBVERBS_INCLUDE_DIR IBVERBS_LIBRARIES) diff --git a/include/ylt/standalone/rdmapp/cq.h b/include/ylt/standalone/rdmapp/cq.h new file mode 100644 index 000000000..782d4a959 --- /dev/null +++ b/include/ylt/standalone/rdmapp/cq.h @@ -0,0 +1,162 @@ +#pragma once + +#include + +#include +#include +#include + +#include "device.h" +#include "error.h" +#include "fcntl.h" + +namespace rdmapp { + +class qp; +class cq; +typedef cq *cq_ptr; + +class comp_channel { + struct ibv_comp_channel *comp_channel_; + + public: + comp_channel(device_ptr device) { + comp_channel_ = ::ibv_create_comp_channel(device->ctx()); + check_ptr(comp_channel_, "failed to create comp channel"); + } + + void set_non_blocking() { + int flags = ::fcntl(comp_channel_->fd, F_GETFL); + if (flags < 0) { + check_errno(errno, "failed to get flags"); + } + int ret = ::fcntl(comp_channel_->fd, F_SETFL, flags | O_NONBLOCK); + if (ret < 0) { + check_errno(errno, "failed to set flags"); + } + } + + cq_ptr get_event() { + struct ibv_cq *cq; + void *ev_ctx; + check_rc(::ibv_get_cq_event(comp_channel_, &cq, &ev_ctx), + "failed to get event"); + auto cq_obj_ptr = reinterpret_cast(ev_ctx); + return cq_obj_ptr; + } + + int fd() const { return comp_channel_->fd; } + + ~comp_channel() { + if (comp_channel_ == nullptr) [[unlikely]] { + return; + } + if (auto rc = ::ibv_destroy_comp_channel(comp_channel_); rc != 0) { + } + else { + } + } + + struct ibv_comp_channel *channel() const { return comp_channel_; } +}; + +typedef comp_channel *comp_channel_ptr; + +/** + * @brief This class is an abstraction of a Completion Queue. + * + */ +class cq { + device_ptr device_; + struct ibv_cq *cq_; + friend class qp; + + public: + cq(cq const &) = delete; + cq &operator=(cq const &) = delete; + + /** + * @brief Construct a new cq object. + * + * @param device The device to use. + * @param num_cqe The number of completion entries to allocate. + * @param channel If not null, assign this cq to the completion channel + */ + cq(device_ptr device, size_t nr_cqe = 128, comp_channel_ptr channel = nullptr) + : device_(device) { + cq_ = ::ibv_create_cq(device->ctx_, nr_cqe, this, + channel ? channel->channel() : nullptr, 0); + check_ptr(cq_, "failed to create cq"); + } + + void request_notify() { + check_rc(::ibv_req_notify_cq(cq_, 0), "failed to request notify"); + } + + void ack_event(unsigned int nr_events = 1) { + ::ibv_ack_cq_events(cq_, nr_events); + } + + /** + * @brief Poll the completion queue. + * + * @param wc If any, this will be filled with a completion entry. + * @return true If there is a completion entry. + * @return false If there is no completion entry. + * @exception std::runtime_exception Error occured while polling the + * completion queue. + */ + bool poll(struct ibv_wc &wc) { + if (auto rc = ::ibv_poll_cq(cq_, 1, &wc); rc < 0) [[unlikely]] { + check_rc(-rc, "failed to poll cq"); + } + else if (rc == 0) { + return false; + } + else { + return true; + } + return false; + } + + /** + * @brief Poll the completion queue. + * + * @param wc_vec If any, this will be filled with completion entries up to the + * size of the vector. + * @return size_t The number of completion entries. 0 means no completion + * entry. + * @exception std::runtime_exception Error occured while polling the + * completion queue. + */ + size_t poll(std::vector &wc_vec) { + return poll(&wc_vec[0], wc_vec.size()); + } + + template + size_t poll(It wc, int count) { + int rc = ::ibv_poll_cq(cq_, count, wc); + if (rc < 0) { + throw_with("failed to poll cq: %s (rc=%d)", strerror(rc), rc); + } + return rc; + } + + template + size_t poll(std::array &wc_array) { + return poll(&wc_array[0], N); + } + + ~cq() { + if (cq_ == nullptr) [[unlikely]] { + return; + } + + if (auto rc = ::ibv_destroy_cq(cq_); rc != 0) [[unlikely]] { + } + else { + } + } +}; + +} // namespace rdmapp \ No newline at end of file diff --git a/include/ylt/standalone/rdmapp/detail/serdes.h b/include/ylt/standalone/rdmapp/detail/serdes.h new file mode 100644 index 000000000..cd494e10b --- /dev/null +++ b/include/ylt/standalone/rdmapp/detail/serdes.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace rdmapp { +namespace detail { + +static inline uint16_t ntoh(uint16_t const &value) { return ::be16toh(value); } + +static inline uint32_t ntoh(uint32_t const &value) { return ::be32toh(value); } + +static inline uint64_t ntoh(uint64_t const &value) { return ::be64toh(value); } + +static inline uint16_t hton(uint16_t const &value) { return ::htobe16(value); } + +static inline uint32_t hton(uint32_t const &value) { return ::htobe32(value); } + +static inline uint64_t hton(uint64_t const &value) { return ::htobe64(value); } + +template +typename std::enable_if::value>::type +serialize(T const &value, It &it) { + T nvalue = hton(value); + std::copy_n(reinterpret_cast(&nvalue), sizeof(T), it); +} + +template +typename std::enable_if::value>::type +serialize(T const &value, It &it) { + std::copy_n(reinterpret_cast(&value), sizeof(T), it); +} + +template +typename std::enable_if::value>::type +deserialize(It &it, T &value) { + std::copy_n(it, sizeof(T), reinterpret_cast(&value)); + it += sizeof(T); + value = ntoh(value); +} + +template +typename std::enable_if::value>::type +deserialize(It &it, T &value) { + std::copy_n(it, sizeof(T), reinterpret_cast(&value)); + it += sizeof(T); + value = reinterpret_cast(ntoh(reinterpret_cast(value))); +} + +template +typename std::enable_if::value>::type +deserialize(It &it, T &value) { + std::copy_n(it, sizeof(T), reinterpret_cast(&value)); + it += sizeof(T); +} + +} // namespace detail +} // namespace rdmapp \ No newline at end of file diff --git a/include/ylt/standalone/rdmapp/device.h b/include/ylt/standalone/rdmapp/device.h new file mode 100644 index 000000000..17abea1b4 --- /dev/null +++ b/include/ylt/standalone/rdmapp/device.h @@ -0,0 +1,272 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "error.h" + +namespace rdmapp { + +/** + * @brief This class holds a list of devices available on the system. + * + */ +class device_list { + struct ibv_device **devices_; + size_t nr_devices_; + + public: + class iterator { + friend class device_list; + size_t i_; + struct ibv_device **devices_; + iterator(struct ibv_device **devices, size_t i) + : i_(i), devices_(devices) {} + + public: + // iterator traits + using difference_type = long; + using value_type = long; + using pointer = const long *; + using reference = const long &; + using iterator_category = std::forward_iterator_tag; + + struct ibv_device *&operator*() { return devices_[i_]; } + + bool operator==(device_list::iterator const &other) const { + return i_ == other.i_; + } + + bool operator!=(device_list::iterator const &other) const { + return i_ != other.i_; + } + + device_list::iterator &operator++() { + i_++; + return *this; + } + + device_list::iterator &operator++(int) { + i_++; + return *this; + } + }; + + device_list(device_list const &) = delete; + device_list &operator=(device_list const &) = delete; + + device_list() : devices_(nullptr), nr_devices_(0) { + int32_t nr_devices = -1; + devices_ = ::ibv_get_device_list(&nr_devices); + if (nr_devices == 0) { + ::ibv_free_device_list(devices_); + throw std::runtime_error("no Infiniband devices found"); + } + check_ptr(devices_, "failed to get Infiniband devices"); + nr_devices_ = nr_devices; + } + + ~device_list() { + if (devices_ != nullptr) { + ::ibv_free_device_list(devices_); + } + } + + size_t size() { return nr_devices_; } + + struct ibv_device *at(size_t i) { + if (i >= nr_devices_) { + throw std::out_of_range("out of range"); + } + return devices_[i]; + } + + iterator begin() { return iterator(devices_, 0); } + + iterator end() { return iterator(devices_, nr_devices_); } +}; + +/** + * @brief This class is an abstraction of an Infiniband device. + * + */ +class device { + struct ibv_device *device_; + struct ibv_context *ctx_; + struct ibv_port_attr port_attr_; + struct ibv_device_attr_ex device_attr_ex_; + union ibv_gid gid_; + + int gid_index_; + uint16_t port_num_; + friend class pd; + friend class cq; + friend class qp; + friend class srq; + void open_device(struct ibv_device *target, uint16_t port_num) { + device_ = target; + port_num_ = port_num; + ctx_ = ::ibv_open_device(device_); + check_ptr(ctx_, "failed to open device"); + check_rc(::ibv_query_port(ctx_, port_num_, &port_attr_), + "failed to query port"); + struct ibv_query_device_ex_input query = {}; + check_rc(::ibv_query_device_ex(ctx_, &query, &device_attr_ex_), + "failed to query extended attributes"); + + gid_index_ = 0; + check_rc(::ibv_query_gid(ctx_, port_num, gid_index_, &gid_), + "failed to query gid"); + + auto link_layer = [&]() { + switch (port_attr_.link_layer) { + case IBV_LINK_LAYER_ETHERNET: + return "ethernet"; + case IBV_LINK_LAYER_INFINIBAND: + return "infiniband"; + } + return "unspecified"; + }(); + auto const gid_str = gid_hex_string(gid_); + // RDMAPP_LOG_DEBUG("opened Infiniband device gid=%s lid=%d link_layer=%s", + // gid_str.c_str(), port_attr_.lid, link_layer); + } + + public: + device(device const &) = delete; + device &operator=(device const &) = delete; + + /** + * @brief Construct a new device object. + * + * @param target The target device. + * @param port_num The port number of the target device. + */ + device(struct ibv_device *target, uint16_t port_num = 1) { + assert(target != nullptr); + open_device(target, port_num); + } + + /** + * @brief Construct a new device object. + * + * @param device_name The name of the target device. + * @param port_num The port number of the target device. + */ + device(std::string const &device_name, uint16_t port_num = 1) + : device_(nullptr), port_num_(0) { + auto devices = device_list(); + for (auto target : devices) { + if (::ibv_get_device_name(target) == device_name) { + open_device(target, port_num); + return; + } + } + throw_with("no device named %s found", device_name.c_str()); + } + + /** + * @brief Construct a new device object. + * + * @param device_num The index of the target device. + * @param port_num The port number of the target device. + */ + device(uint16_t device_num = 0, uint16_t port_num = 1) + : device_(nullptr), port_num_(0) { + auto devices = device_list(); + if (device_num >= devices.size()) { + char buffer[kErrorStringBufferSize] = {0}; + ::snprintf( + buffer, sizeof(buffer), + "requested device number %d out of range, %lu devices available", + device_num, devices.size()); + throw std::invalid_argument(buffer); + } + open_device(devices.at(device_num), port_num); + } + + /** + * @brief Get the device port number. + * + * @return uint16_t The port number. + */ + uint16_t port_num() const { return port_num_; } + + /** + * @brief Get the lid of the device. + * + * @return uint16_t The lid. + */ + uint16_t lid() const { return port_attr_.lid; } + + /** + * @brief Get the gid of the device. + * + * @return union ibv_gid The gid. + */ + union ibv_gid gid() const { + union ibv_gid gid_copied; + ::memcpy(&gid_copied, &gid_, sizeof(union ibv_gid)); + return gid_copied; + } + + /** + * @brief Checks if the device supports fetch and add. + * + * @return true Supported. + * @return false Not supported. + */ + bool is_fetch_and_add_supported() const { + return device_attr_ex_.orig_attr.atomic_cap != IBV_ATOMIC_NONE; + } + + /** + * @brief Checks if the device supports compare and swap. + * + * @return true Supported. + * @return false Not supported. + */ + bool is_compare_and_swap_supported() const { + return device_attr_ex_.orig_attr.atomic_cap != IBV_ATOMIC_NONE; + } + + int gid_index() const { return gid_index_; } + + struct ibv_context *ctx() const { return ctx_; } + static std::string gid_hex_string(union ibv_gid const &gid) { + std::string gid_str; + char buf[16] = {0}; + const static size_t kGidLength = 16; + for (size_t i = 0; i < kGidLength; ++i) { + ::snprintf(buf, 16, "%02x", gid.raw[i]); + gid_str += i == 0 ? buf : std::string(":") + buf; + } + + return gid_str; + } + + ~device() { + if (ctx_ == nullptr) [[unlikely]] { + return; + } + + auto const gid_str = gid_hex_string(gid_); + + if (auto rc = ::ibv_close_device(ctx_); rc != 0) [[unlikely]] { + // RDMAPP_LOG_ERROR("failed to close device gid=%s lid=%d: %s", + // gid_str.c_str(), port_attr_.lid, ::strerror(rc)); + } + else { + // RDMAPP_LOG_DEBUG("closed device gid=%s lid=%d", gid_str.c_str(), + // port_attr_.lid); + } + } +}; + +typedef device *device_ptr; + +} // namespace rdmapp \ No newline at end of file diff --git a/include/ylt/standalone/rdmapp/error.h b/include/ylt/standalone/rdmapp/error.h new file mode 100644 index 000000000..031e9690d --- /dev/null +++ b/include/ylt/standalone/rdmapp/error.h @@ -0,0 +1,102 @@ +#pragma once +#include +#include +#include +#include +#include + +#include + +namespace rdmapp { + +constexpr size_t kErrorStringBufferSize = 1024; + +static inline void throw_with(const char *message) { + throw std::runtime_error(message); +} + +template +static inline void throw_with(const char *format, Args... args) { + char buffer[kErrorStringBufferSize]; + ::snprintf(buffer, sizeof(buffer), format, args...); + throw std::runtime_error(buffer); +} + +static inline void check_rc(int rc, const char *message) { + if (rc != 0) [[unlikely]] { + throw_with("%s: %s (rc=%d)", message, ::strerror(rc), rc); + } +} + +static inline void check_wc_status(enum ibv_wc_status status, + const char *message) { + if (status != IBV_WC_SUCCESS) [[unlikely]] { + auto errorstr = [status]() { + switch (status) { + case IBV_WC_SUCCESS: + return "IBV_WC_SUCCESS"; + case IBV_WC_LOC_LEN_ERR: + return "IBV_WC_LOC_LEN_ERR"; + case IBV_WC_LOC_QP_OP_ERR: + return "IBV_WC_LOC_QP_OP_ERR"; + case IBV_WC_LOC_EEC_OP_ERR: + return "IBV_WC_LOC_EEC_OP_ERR"; + case IBV_WC_LOC_PROT_ERR: + return "IBV_WC_LOC_PROT_ERR"; + case IBV_WC_WR_FLUSH_ERR: + return "IBV_WC_WR_FLUSH_ERR"; + case IBV_WC_MW_BIND_ERR: + return "IBV_WC_MW_BIND_ERR"; + case IBV_WC_BAD_RESP_ERR: + return "IBV_WC_BAD_RESP_ERR"; + case IBV_WC_LOC_ACCESS_ERR: + return "IBV_WC_LOC_ACCESS_ERR"; + case IBV_WC_REM_INV_REQ_ERR: + return "IBV_WC_REM_INV_REQ_ERR"; + case IBV_WC_REM_ACCESS_ERR: + return "IBV_WC_REM_ACCESS_ERR"; + case IBV_WC_REM_OP_ERR: + return "IBV_WC_REM_OP_ERR"; + case IBV_WC_RETRY_EXC_ERR: + return "IBV_WC_RETRY_EXC_ERR"; + case IBV_WC_RNR_RETRY_EXC_ERR: + return "IBV_WC_RNR_RETRY_EXC_ERR"; + case IBV_WC_LOC_RDD_VIOL_ERR: + return "IBV_WC_LOC_RDD_VIOL_ERR"; + case IBV_WC_REM_INV_RD_REQ_ERR: + return "IBV_WC_REM_INV_RD_REQ_ERR"; + case IBV_WC_REM_ABORT_ERR: + return "IBV_WC_REM_ABORT_ERR"; + case IBV_WC_INV_EECN_ERR: + return "IBV_WC_INV_EECN_ERR"; + case IBV_WC_INV_EEC_STATE_ERR: + return "IBV_WC_INV_EEC_STATE_ERR"; + case IBV_WC_FATAL_ERR: + return "IBV_WC_FATAL_ERR"; + case IBV_WC_RESP_TIMEOUT_ERR: + return "IBV_WC_RESP_TIMEOUT_ERR"; + case IBV_WC_GENERAL_ERR: + return "IBV_WC_GENERAL_ERR"; + case IBV_WC_TM_ERR: + return "IBV_WC_TM_ERR"; + case IBV_WC_TM_RNDV_INCOMPLETE: + return "IBV_WC_TM_RNDV_INCOMPLETE"; + } + return "UNKNOWN_ERROR"; + }(); + throw_with("%s: %s (status=%d)", message, errorstr, status); + } +} +static inline void check_ptr(void *ptr, const char *message) { + if (ptr == nullptr) [[unlikely]] { + throw_with("%s: %s (errno=%d)", message, ::strerror(errno), errno); + } +} + +static inline void check_errno(int rc, const char *message) { + if (rc < 0) [[unlikely]] { + throw_with("%s: %s (errno=%d)", message, ::strerror(errno), errno); + } +} + +} // namespace rdmapp \ No newline at end of file diff --git a/include/ylt/standalone/rdmapp/mr.h b/include/ylt/standalone/rdmapp/mr.h new file mode 100644 index 000000000..ee9611fa3 --- /dev/null +++ b/include/ylt/standalone/rdmapp/mr.h @@ -0,0 +1,211 @@ +#pragma once + +#include + +#include +#include +#include + +#include "detail/serdes.h" + +namespace rdmapp { + +namespace tags { +namespace mr { +struct local {}; +struct remote {}; +} // namespace mr +} // namespace tags + +class pd; + +/** + * @brief A remote or local memory region. + * + * @tparam Tag Either `tags::mr::local` or `tags::mr::remote`. + */ +template +class mr; + +/** + * @brief Represents a local memory region. + * + */ +template <> +class mr { + struct ibv_mr *mr_; + pd *pd_; + + public: + mr(mr const &) = delete; + mr &operator=(mr const &) = delete; + + /** + * @brief Construct a new mr object + * + * @param pd The protection domain to use. + * @param mr The ibverbs memory region handle. + */ + mr(pd *pd, struct ibv_mr *mr) : mr_(mr), pd_(pd) {} + + /** + * @brief Move construct a new mr object + * + * @param other The other mr object to move from. + */ + mr(mr &&other) + : mr_(std::exchange(other.mr_, nullptr)), pd_(std::move(other.pd_)) {} + + /** + * @brief Move assignment operator. + * + * @param other The other mr to move from. + * @return mr& This mr. + */ + mr &operator=(mr &&other) { + mr_ = other.mr_; + pd_ = std::move(other.pd_); + other.mr_ = nullptr; + return *this; + } + + /** + * @brief Destroy the mr object and deregister the memory region. + * + */ + ~mr() { + if (mr_ == nullptr) [[unlikely]] { + // This mr is moved. + return; + } + if (auto rc = ::ibv_dereg_mr(mr_); rc != 0) [[unlikely]] { + } + else { + } + } + + /** + * @brief Serialize the memory region handle to be sent to a remote peer. + * + * @return std::vector The serialized memory region handle. + */ + std::vector serialize() const { + std::vector buffer; + auto it = std::back_inserter(buffer); + detail::serialize(reinterpret_cast(mr_->addr), it); + detail::serialize(mr_->length, it); + detail::serialize(mr_->rkey, it); + return buffer; + } + + /** + * @brief Get the address of the memory region. + * + * @return void* The address of the memory region. + */ + void *addr() const { return mr_->addr; } + + /** + * @brief Get the length of the memory region. + * + * @return size_t The length of the memory region. + */ + size_t length() const { return mr_->length; } + + /** + * @brief Get the remote key of the memory region. + * + * @return uint32_t The remote key of the memory region. + */ + uint32_t rkey() const { return mr_->rkey; } + + /** + * @brief Get the local key of the memory region. + * + * @return uint32_t The local key of the memory region. + */ + uint32_t lkey() const { return mr_->lkey; } +}; + +/** + * @brief Represents a remote memory region. + * + */ +template <> +class mr { + void *addr_; + size_t length_; + uint32_t rkey_; + + public: + /** + * @brief The size of a serialized remote memory region. + * + */ + static constexpr size_t kSerializedSize = + sizeof(addr_) + sizeof(length_) + sizeof(rkey_); + + mr() = default; + + /** + * @brief Construct a new remote mr object + * + * @param addr The address of the remote memory region. + * @param length The length of the remote memory region. + * @param rkey The remote key of the remote memory region. + */ + mr(void *addr, uint32_t length, uint32_t rkey) + : addr_(addr), length_(length), rkey_(rkey) {} + + /** + * @brief Construct a new remote mr object copied from another + * + * @param other The other remote mr object to copy from. + */ + mr(mr const &other) = default; + + /** + * @brief Get the address of the remote memory region. + * + * @return void* The address of the remote memory region. + */ + void *addr() const { return addr_; } + + /** + * @brief Get the length of the remote memory region. + * + * @return uint32_t The length of the remote memory region. + */ + uint32_t length() const { return length_; } + + /** + * @brief Get the remote key of the memory region. + * + * @return uint32_t The remote key of the memory region. + */ + uint32_t rkey() const { return rkey_; } + + /** + * @brief Deserialize a remote memory region handle. + * + * @tparam It The iterator type. + * @param it The iterator to deserialize from. + * @return mr The deserialized remote memory region handle. + */ + template + static mr deserialize(It it) { + mr remote_mr; + detail::deserialize(it, remote_mr.addr_); + detail::deserialize(it, remote_mr.length_); + detail::deserialize(it, remote_mr.rkey_); + return remote_mr; + } +}; + +using local_mr = mr; +using remote_mr = mr; + +typedef local_mr *local_mr_ptr; +typedef remote_mr *remote_mr_ptr; + +} // namespace rdmapp \ No newline at end of file diff --git a/include/ylt/standalone/rdmapp/pd.h b/include/ylt/standalone/rdmapp/pd.h new file mode 100644 index 000000000..121a1884b --- /dev/null +++ b/include/ylt/standalone/rdmapp/pd.h @@ -0,0 +1,77 @@ +#pragma once + +#include + +#include "device.h" +#include "error.h" +#include "mr.h" + +namespace rdmapp { + +class qp; + +/** + * @brief This class is an abstraction of a Protection Domain. + * + */ +class pd { + device_ptr device_; + struct ibv_pd *pd_; + friend class qp; + friend class srq; + + public: + pd(pd const &) = delete; + pd &operator=(pd const &) = delete; + + /** + * @brief Construct a new pd object + * + * @param device The device to use. + */ + pd(device_ptr device) : device_(device) { + pd_ = ::ibv_alloc_pd(device->ctx_); + check_ptr(pd_, "failed to alloc pd"); + } + + /** + * @brief Get the device object pointer. + * + * @return device_ptr The device object pointer. + */ + device_ptr device() const { return device_; } + + /** + * @brief Register a local memory region. + * + * @param addr The address of the memory region. + * @param length The length of the memory region. + * @param flags The access flags to use. + * @return local_mr The local memory region handle. + */ + local_mr reg_mr(void *addr, size_t length, + int flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_REMOTE_ATOMIC) { + auto mr = ::ibv_reg_mr(pd_, addr, length, flags); + check_ptr(mr, "failed to reg mr"); + return rdmapp::local_mr(this, mr); + } + /** + * @brief Destroy the pd object and the associated protection domain. + * + */ + ~pd() { + if (pd_ == nullptr) [[unlikely]] { + return; + } + if (auto rc = ::ibv_dealloc_pd(pd_); rc != 0) [[unlikely]] { + } + else { + } + } +}; + +typedef pd *pd_ptr; + +} // namespace rdmapp \ No newline at end of file diff --git a/include/ylt/standalone/rdmapp/qp.h b/include/ylt/standalone/rdmapp/qp.h new file mode 100644 index 000000000..1692edcfb --- /dev/null +++ b/include/ylt/standalone/rdmapp/qp.h @@ -0,0 +1,763 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "cq.h" +#include "detail/serdes.h" +#include "device.h" +#include "pd.h" +#include "srq.h" + +namespace rdmapp { + +static inline struct ibv_sge fill_local_sge(void *addr, uint32_t lkey, + uint32_t length) { + struct ibv_sge sge = {}; + sge.addr = reinterpret_cast(addr); + sge.length = length; + sge.lkey = lkey; + return sge; +} + +struct deserialized_qp { + struct qp_header { + static constexpr size_t kSerializedSize = + sizeof(uint16_t) + 3 * sizeof(uint32_t) + sizeof(union ibv_gid); + uint16_t lid; + uint32_t qp_num; + uint32_t sq_psn; + uint32_t user_data_size; + union ibv_gid gid; + } header; + template + static deserialized_qp deserialize(It it) { + deserialized_qp des_qp; + detail::deserialize(it, des_qp.header.lid); + detail::deserialize(it, des_qp.header.qp_num); + detail::deserialize(it, des_qp.header.sq_psn); + detail::deserialize(it, des_qp.header.user_data_size); + detail::deserialize(it, des_qp.header.gid); + return des_qp; + } + std::vector user_data; +}; + +class qp; +typedef qp *qp_ptr; + +/** + * @brief This class is an abstraction of an Infiniband Queue Pair. + * + */ +class qp { + static inline std::atomic next_sq_psn{1}; + static uint32_t get_next_sq_psn() { return next_sq_psn.fetch_add(1); } + struct ibv_qp *qp_; + struct ibv_srq *raw_srq_; + uint32_t sq_psn_; + void (qp::*post_recv_fn)(struct ibv_recv_wr const &recv_wr, + struct ibv_recv_wr *&bad_recv_wr) const; + + pd_ptr pd_; + cq_ptr recv_cq_; + cq_ptr send_cq_; + srq_ptr srq_; + std::vector user_data_; + + /** + * @brief Creates a new Queue Pair. The Queue Pair will be in the RESET state. + * + */ + void create() { + struct ibv_qp_init_attr qp_init_attr = {}; + ::bzero(&qp_init_attr, sizeof(qp_init_attr)); + qp_init_attr.qp_type = IBV_QPT_RC; + qp_init_attr.recv_cq = recv_cq_->cq_; + qp_init_attr.send_cq = send_cq_->cq_; + qp_init_attr.cap.max_recv_sge = 1; + qp_init_attr.cap.max_send_sge = 1; + qp_init_attr.cap.max_recv_wr = 128; + qp_init_attr.cap.max_send_wr = 128; + qp_init_attr.sq_sig_all = 0; + qp_init_attr.qp_context = this; + + if (srq_ != nullptr) { + qp_init_attr.srq = srq_->srq_; + raw_srq_ = srq_->srq_; + post_recv_fn = &qp::post_recv_srq; + } + else { + post_recv_fn = &qp::post_recv_rq; + } + + qp_ = ::ibv_create_qp(pd_->pd_, &qp_init_attr); + check_ptr(qp_, "failed to create qp"); + sq_psn_ = get_next_sq_psn(); + // RDMAPP_LOG_TRACE("created qp %p lid=%u qpn=%u psn=%u", + // reinterpret_cast(qp_), pd_->device()->lid(), + // qp_->qp_num, sq_psn_); + } + + /** + * @brief Initializes the Queue Pair. The Queue Pair will be in the INIT + * state. + * + */ + void init() { + struct ibv_qp_attr qp_attr = {}; + ::bzero(&qp_attr, sizeof(qp_attr)); + qp_attr.qp_state = IBV_QPS_INIT; + qp_attr.pkey_index = 0; + qp_attr.port_num = pd_->device()->port_num(); + qp_attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_ATOMIC; + try { + check_rc(::ibv_modify_qp(qp_, &qp_attr, + IBV_QP_STATE | IBV_QP_PORT | + IBV_QP_ACCESS_FLAGS | IBV_QP_PKEY_INDEX), + "failed to transition qp to init state"); + } catch (const std::exception &e) { + // RDMAPP_LOG_ERROR("%s", e.what()); + qp_ = nullptr; + destroy(); + throw; + } + } + + void destroy() { + if (qp_ == nullptr) [[unlikely]] { + return; + } + + if (auto rc = ::ibv_destroy_qp(qp_); rc != 0) [[unlikely]] { + // RDMAPP_LOG_ERROR("failed to destroy qp %p: %s", + // reinterpret_cast(qp_), strerror(errno)); + } + else { + // RDMAPP_LOG_TRACE("destroyed qp %p", reinterpret_cast(qp_)); + } + } + + public: + class send_awaitable { + struct ibv_wc wc_; + std::coroutine_handle<> h_; + qp_ptr qp_; + void *local_addr_; + size_t local_length_; + uint32_t lkey_; + std::exception_ptr exception_; + void *remote_addr_; + size_t remote_length_; + uint32_t rkey_; + uint64_t compare_add_; + uint64_t swap_; + uint32_t imm_; + const enum ibv_wr_opcode opcode_; + + public: + send_awaitable(qp_ptr qp, void *local_addr, size_t local_length, + uint32_t lkey, enum ibv_wr_opcode opcode) + : wc_(), + qp_(qp), + local_addr_(local_addr), + local_length_(local_length), + lkey_(lkey), + opcode_(opcode) {} + + send_awaitable(qp_ptr qp, void *local_addr, size_t local_length, + uint32_t lkey, enum ibv_wr_opcode opcode, + remote_mr const &remote_mr) + : qp_(qp), + local_addr_(local_addr), + local_length_(local_length), + lkey_(lkey), + opcode_(opcode), + remote_addr_(remote_mr.addr()), + remote_length_(remote_mr.length()), + rkey_(remote_mr.rkey()) {} + + send_awaitable(qp_ptr qp, void *local_addr, size_t local_length, + uint32_t lkey, enum ibv_wr_opcode opcode, + remote_mr const &remote_mr, uint32_t imm) + : qp_(qp), + local_addr_(local_addr), + local_length_(local_length), + lkey_(lkey), + opcode_(opcode), + remote_addr_(remote_mr.addr()), + remote_length_(remote_mr.length()), + rkey_(remote_mr.rkey()), + imm_(imm) {} + + send_awaitable(qp_ptr qp, void *local_addr, size_t local_length, + uint32_t lkey, enum ibv_wr_opcode opcode, + remote_mr const &remote_mr, uint64_t add) + : qp_(qp), + local_addr_(local_addr), + local_length_(local_length), + lkey_(lkey), + opcode_(opcode), + remote_addr_(remote_mr.addr()), + remote_length_(remote_mr.length()), + rkey_(remote_mr.rkey()), + compare_add_(add) {} + + send_awaitable(qp_ptr qp, void *local_addr, size_t local_length, + uint32_t lkey, enum ibv_wr_opcode opcode, + remote_mr const &remote_mr, uint64_t compare, uint64_t swap) + : qp_(qp), + local_addr_(local_addr), + local_length_(local_length), + lkey_(lkey), + remote_addr_(remote_mr.addr()), + remote_length_(remote_mr.length()), + rkey_(remote_mr.rkey()), + compare_add_(compare), + swap_(swap), + opcode_(opcode) {} + + send_awaitable(qp_ptr qp, local_mr_ptr local_mr, enum ibv_wr_opcode opcode) + : wc_(), + qp_(qp), + local_addr_(local_mr->addr()), + local_length_(local_mr->length()), + lkey_(local_mr->lkey()), + opcode_(opcode) {} + send_awaitable(qp_ptr qp, local_mr_ptr local_mr, enum ibv_wr_opcode opcode, + remote_mr const &remote_mr) + : qp_(qp), + local_addr_(local_mr->addr()), + local_length_(local_mr->length()), + lkey_(local_mr->lkey()), + remote_addr_(remote_mr.addr()), + remote_length_(remote_mr.length()), + rkey_(remote_mr.rkey()), + opcode_(opcode) {} + + send_awaitable(qp_ptr qp, local_mr_ptr local_mr, enum ibv_wr_opcode opcode, + remote_mr const &remote_mr, uint32_t imm) + : qp_(qp), + local_addr_(local_mr->addr()), + local_length_(local_mr->length()), + lkey_(local_mr->lkey()), + remote_addr_(remote_mr.addr()), + remote_length_(remote_mr.length()), + rkey_(remote_mr.rkey()), + imm_(imm), + opcode_(opcode) {} + send_awaitable(qp_ptr qp, local_mr_ptr local_mr, enum ibv_wr_opcode opcode, + remote_mr const &remote_mr, uint64_t add) + : qp_(qp), + local_addr_(local_mr->addr()), + local_length_(local_mr->length()), + lkey_(local_mr->lkey()), + remote_addr_(remote_mr.addr()), + remote_length_(remote_mr.length()), + rkey_(remote_mr.rkey()), + compare_add_(add), + opcode_(opcode) {} + + send_awaitable(qp_ptr qp, local_mr_ptr local_mr, enum ibv_wr_opcode opcode, + remote_mr const &remote_mr, uint64_t compare, uint64_t swap) + : qp_(qp), + local_addr_(local_mr->addr()), + local_length_(local_mr->length()), + lkey_(local_mr->lkey()), + remote_addr_(remote_mr.addr()), + remote_length_(remote_mr.length()), + rkey_(remote_mr.rkey()), + compare_add_(compare), + swap_(swap), + opcode_(opcode) {} + + void complete(struct ibv_wc const &wc) { + wc_ = wc; + h_.resume(); + } + bool await_ready() const noexcept { return false; } + bool await_suspend(std::coroutine_handle<> h) noexcept { + h_ = h; + auto send_sge = fill_local_sge(local_addr_, lkey_, local_length_); + + struct ibv_send_wr send_wr = {}; + struct ibv_send_wr *bad_send_wr = nullptr; + send_wr.opcode = opcode_; + send_wr.next = nullptr; + send_wr.num_sge = 1; + send_wr.wr_id = reinterpret_cast(this); + send_wr.send_flags = IBV_SEND_SIGNALED; + send_wr.sg_list = &send_sge; + if (is_rdma()) { + assert(remote_addr_ != nullptr); + send_wr.wr.rdma.remote_addr = reinterpret_cast(remote_addr_); + send_wr.wr.rdma.rkey = rkey_; + if (opcode_ == IBV_WR_RDMA_WRITE_WITH_IMM) { + send_wr.imm_data = imm_; + } + } + else if (is_atomic()) { + assert(remote_addr_ != nullptr); + send_wr.wr.atomic.remote_addr = + reinterpret_cast(remote_addr_); + send_wr.wr.atomic.rkey = rkey_; + send_wr.wr.atomic.compare_add = compare_add_; + if (opcode_ == IBV_WR_ATOMIC_CMP_AND_SWP) { + send_wr.wr.atomic.swap = swap_; + } + } + + try { + qp_->post_send(send_wr, bad_send_wr); + } catch (std::runtime_error &e) { + exception_ = std::make_exception_ptr(e); + return false; + } + return true; + } + struct ibv_wc await_resume() const { + if (exception_) [[unlikely]] { + std::rethrow_exception(exception_); + } + check_wc_status(wc_.status, "failed to send"); + return wc_; + } + + constexpr bool is_rdma() const { + return opcode_ == IBV_WR_RDMA_READ || opcode_ == IBV_WR_RDMA_WRITE || + opcode_ == IBV_WR_RDMA_WRITE_WITH_IMM; + } + + constexpr bool is_atomic() const { + return opcode_ == IBV_WR_ATOMIC_CMP_AND_SWP || + opcode_ == IBV_WR_ATOMIC_FETCH_AND_ADD; + } + }; + + class recv_awaitable { + struct ibv_wc wc_; + std::coroutine_handle<> h_; + qp *qp_; + void *local_addr_; + size_t local_length_; + uint32_t lkey_; + std::exception_ptr exception_; + enum ibv_wr_opcode opcode_; + + public: + recv_awaitable(qp_ptr qp, local_mr_ptr local_mr) + : wc_(), + qp_(qp), + local_addr_(local_mr->addr()), + local_length_(local_mr->length()), + lkey_(local_mr->lkey()) {} + + recv_awaitable(qp_ptr qp, void *local_addr, size_t local_length, + uint32_t lkey) + : wc_(), + qp_(qp), + local_addr_(local_addr), + local_length_(local_length), + lkey_(lkey) {} + + void complete(struct ibv_wc const &wc) { + wc_ = wc; + h_.resume(); + } + + bool await_ready() const noexcept { return false; } + bool await_suspend(std::coroutine_handle<> h) noexcept { + h_ = h; + auto recv_sge = fill_local_sge(local_addr_, lkey_, local_length_); + + struct ibv_recv_wr recv_wr = {}; + struct ibv_recv_wr *bad_recv_wr = nullptr; + recv_wr.next = nullptr; + recv_wr.num_sge = 1; + recv_wr.wr_id = reinterpret_cast(this); + recv_wr.sg_list = &recv_sge; + + try { + qp_->post_recv(recv_wr, bad_recv_wr); + } catch (std::runtime_error &e) { + exception_ = std::make_exception_ptr(e); + return false; + } + return true; + } + + std::pair> await_resume() const { + if (exception_) [[unlikely]] { + std::rethrow_exception(exception_); + } + check_wc_status(wc_.status, "failed to recv"); + if (wc_.wc_flags & IBV_WC_WITH_IMM) { + return std::make_pair(wc_.byte_len, wc_.imm_data); + } + return std::make_pair(wc_.byte_len, std::nullopt); + } + }; + + /** + * @brief Construct a new qp object. The Queue Pair will be created with the + * given remote Queue Pair parameters. Once constructed, the Queue Pair will + * be in the RTS state. + * + * @param remote_lid The LID of the remote Queue Pair. + * @param remote_qpn The QPN of the remote Queue Pair. + * @param remote_psn The PSN of the remote Queue Pair. + * @param pd The protection domain of the new Queue Pair. + * @param cq The completion queue of both send and recv work completions. + * @param srq (Optional) If set, all recv work requests will be posted to this + * SRQ. + */ + qp(const uint16_t remote_lid, const uint32_t remote_qpn, + const uint32_t remote_psn, const union ibv_gid remote_gid, pd_ptr pd, + cq_ptr cq, srq_ptr srq = nullptr) + : qp(remote_lid, remote_qpn, remote_psn, remote_gid, pd, cq, cq, srq) {} + + /** + * @brief Construct a new qp object. The Queue Pair will be created with the + * given remote Queue Pair parameters. Once constructed, the Queue Pair will + * be in the RTS state. + * + * @param remote_lid The LID of the remote Queue Pair. + * @param remote_qpn The QPN of the remote Queue Pair. + * @param remote_psn The PSN of the remote Queue Pair. + * @param pd The protection domain of the new Queue Pair. + * @param recv_cq The completion queue of recv work completions. + * @param send_cq The completion queue of send work completions. + * @param srq (Optional) If set, all recv work requests will be posted to this + * SRQ. + */ + qp(const uint16_t remote_lid, const uint32_t remote_qpn, + const uint32_t remote_psn, const union ibv_gid remote_gid, pd_ptr pd, + cq_ptr recv_cq, cq_ptr send_cq, srq_ptr srq = nullptr) + : qp(pd, recv_cq, send_cq, srq) { + rtr(remote_lid, remote_qpn, remote_psn, remote_gid); + rts(); + } + + /** + * @brief Construct a new qp object. The constructed Queue Pair will be in + * INIT state. + * + * @param pd The protection domain of the new Queue Pair. + * @param cq The completion queue of both send and recv work completions. + * @param srq (Optional) If set, all recv work requests will be posted to this + * SRQ. + */ + qp(pd_ptr pd, cq_ptr cq, srq_ptr srq = nullptr) : qp(pd, cq, cq, srq) {} + + /** + * @brief Construct a new qp object. The constructed Queue Pair will be in + * INIT state. + * + * @param pd The protection domain of the new Queue Pair. + * @param recv_cq The completion queue of recv work completions. + * @param send_cq The completion queue of send work completions. + * @param srq (Optional) If set, all recv work requests will be posted to this + * SRQ. + */ + qp(pd_ptr pd, cq_ptr recv_cq, cq_ptr send_cq, srq_ptr srq = nullptr) + : qp_(nullptr), pd_(pd), recv_cq_(recv_cq), send_cq_(send_cq), srq_(srq) { + create(); + init(); + } + + /** + * @brief This function is used to post a send work request to the Queue Pair. + * + * @param recv_wr The work request to post. + * @param bad_recv_wr A pointer to a work request that will be set to the + * first work request that failed to post. + */ + void post_send(struct ibv_send_wr const &send_wr, + struct ibv_send_wr *&bad_send_wr) { + // RDMAPP_LOG_TRACE("post send wr_id=%p addr=%p", + // reinterpret_cast(send_wr.wr_id), + // reinterpret_cast(send_wr.sg_list->addr)); + check_rc(::ibv_post_send(qp_, const_cast(&send_wr), + &bad_send_wr), + "failed to post send"); + } + + /** + * @brief This function is used to post a recv work request to the Queue Pair. + * It will be posted to either RQ or SRQ depending on whether or not SRQ is + * set. + * + * @param recv_wr The work request to post. + * @param bad_recv_wr A pointer to a work request that will be set to the + * first work request that failed to post. + */ + void post_recv(struct ibv_recv_wr const &recv_wr, + struct ibv_recv_wr *&bad_recv_wr) const { + (this->*(post_recv_fn))(recv_wr, bad_recv_wr); + } + + /** + * @brief This function sends a registered local memory region to remote. + * + * @param local_mr Registered local memory region, whose lifetime is + * controlled by a smart pointer. + * @return send_awaitable A coroutine returning length of the data sent. + */ + [[nodiscard]] send_awaitable send(local_mr_ptr local_mr) { + return qp::send_awaitable(this, local_mr, ibv_wr_opcode::IBV_WR_SEND); + } + + /** + * @brief This function writes a registered local memory region to remote. + * + * @param remote_mr Remote memory region handle. + * @param local_mr Registered local memory region, whose lifetime is + * controlled by a smart pointer. + * @return send_awaitable A coroutine returning length of the data written. + */ + [[nodiscard]] send_awaitable write(remote_mr const &remote_mr, + local_mr_ptr local_mr) { + return qp::send_awaitable(this, local_mr, IBV_WR_RDMA_WRITE, remote_mr); + } + + /** + * @brief This function writes a registered local memory region to remote with + * an immediate value. + * + * @param remote_mr Remote memory region handle. + * @param local_mr Registered local memory region, whose lifetime is + * controlled by a smart pointer. + * @param imm The immediate value. + * @return send_awaitable A coroutine returning length of the data sent. + */ + [[nodiscard]] send_awaitable write_with_imm(remote_mr const &remote_mr, + local_mr_ptr local_mr, + uint32_t imm) { + return qp::send_awaitable(this, local_mr, IBV_WR_RDMA_WRITE_WITH_IMM, + remote_mr, imm); + } + + /** + * @brief This function reads to local memory region from remote. + * + * @param remote_mr Remote memory region handle. + * @param local_mr Registered local memory region, whose lifetime is + * controlled by a smart pointer. + * @return send_awaitable A coroutine returning length of the data read. + */ + [[nodiscard]] send_awaitable read(remote_mr const &remote_mr, + local_mr_ptr local_mr) { + return qp::send_awaitable(this, local_mr, IBV_WR_RDMA_READ, remote_mr); + } + + /** + * @brief This function performs an atomic fetch-and-add operation on the + * given remote memory region. + * + * @param remote_mr Remote memory region handle. + * @param local_mr Registered local memory region, whose lifetime is + * controlled by a smart pointer. + * @param add The delta. + * @return send_awaitable A coroutine returning length of the data sent. + */ + [[nodiscard]] send_awaitable fetch_and_add(remote_mr const &remote_mr, + local_mr_ptr local_mr, + uint64_t add) { + assert(pd_->device()->is_fetch_and_add_supported()); + return qp::send_awaitable(this, local_mr, IBV_WR_ATOMIC_FETCH_AND_ADD, + remote_mr, add); + } + + /** + * @brief This function performs an atomic compare-and-swap operation on the + * given remote memory region. + * + * @param remote_mr Remote memory region handle. + * @param local_mr Registered local memory region, whose lifetime is + * controlled by a smart pointer. + * @param compare The expected old value. + * @param swap The desired new value. + * @return send_awaitable A coroutine returning length of the data sent. + */ + [[nodiscard]] send_awaitable compare_and_swap(remote_mr const &remote_mr, + local_mr_ptr local_mr, + uint64_t compare, + uint64_t swap) { + assert(pd_->device()->is_compare_and_swap_supported()); + return qp::send_awaitable(this, local_mr, IBV_WR_ATOMIC_CMP_AND_SWP, + remote_mr, compare, swap); + } + + /** + * @brief This function posts a recv request on the queue pair. The buffer + * will be filled with data received. + * + * @param local_mr Registered local memory region, whose lifetime is + * controlled by a smart pointer. + * @return recv_awaitable A coroutine returning std::pair>, with first indicating the length of received + * data, and second indicating the immediate value if any. + */ + [[nodiscard]] recv_awaitable recv(local_mr_ptr local_mr) { + return qp::recv_awaitable(this, local_mr); + } + + /** + * @brief This function serializes a Queue Pair prepared to be sent to a + * buffer. + * + * @return std::vector The serialized QP. + */ + std::vector serialize() const { + std::vector buffer; + auto it = std::back_inserter(buffer); + detail::serialize(pd_->device()->lid(), it); + detail::serialize(qp_->qp_num, it); + detail::serialize(sq_psn_, it); + detail::serialize(static_cast(user_data_.size()), it); + detail::serialize(pd_->device()->gid(), it); + std::copy(user_data_.cbegin(), user_data_.cend(), it); + return buffer; + } + + /** + * @brief This function provides access to the extra user data of the Queue + * Pair. + * + * @return std::vector& The extra user data. + */ + std::vector &user_data() { return user_data_; } + + /** + * @brief This function provides access to the Protection Domain of the Queue + * Pair. + * + * @return pd_ptr Pointer to the PD. + */ + pd_ptr pd() const { return pd_; } + + ~qp() { destroy(); } + + /** + * @brief This function transitions the Queue Pair to the RTR state. + * + * @param remote_lid The remote LID. + * @param remote_qpn The remote QPN. + * @param remote_psn The remote PSN. + * @param remote_gid The remote GID. + */ + void rtr(uint16_t remote_lid, uint32_t remote_qpn, uint32_t remote_psn, + union ibv_gid remote_gid) { + struct ibv_qp_attr qp_attr = {}; + ::bzero(&qp_attr, sizeof(qp_attr)); + qp_attr.qp_state = IBV_QPS_RTR; + qp_attr.path_mtu = IBV_MTU_4096; + qp_attr.dest_qp_num = remote_qpn; + qp_attr.rq_psn = remote_psn; + qp_attr.max_dest_rd_atomic = 16; + qp_attr.min_rnr_timer = 12; + qp_attr.ah_attr.is_global = 1; + qp_attr.ah_attr.grh.dgid = remote_gid; + qp_attr.ah_attr.grh.sgid_index = pd_->device_->gid_index_; + qp_attr.ah_attr.grh.hop_limit = 16; + qp_attr.ah_attr.dlid = remote_lid; + qp_attr.ah_attr.sl = 0; + qp_attr.ah_attr.src_path_bits = 0; + qp_attr.ah_attr.port_num = pd_->device()->port_num(); + + try { + check_rc( + ::ibv_modify_qp(qp_, &qp_attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MIN_RNR_TIMER | IBV_QP_MAX_DEST_RD_ATOMIC), + "failed to transition qp to rtr state"); + } catch (const std::exception &e) { + // RDMAPP_LOG_ERROR("%s", e.what()); + qp_ = nullptr; + destroy(); + throw; + } + } + + /** + * @brief This function transitions the Queue Pair to the RTS state. + * + */ + void rts() { + struct ibv_qp_attr qp_attr = {}; + ::bzero(&qp_attr, sizeof(qp_attr)); + qp_attr.qp_state = IBV_QPS_RTS; + qp_attr.timeout = 14; + qp_attr.retry_cnt = 7; + qp_attr.rnr_retry = 7; + qp_attr.max_rd_atomic = 16; + qp_attr.sq_psn = sq_psn_; + + try { + check_rc(::ibv_modify_qp(qp_, &qp_attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | + IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC), + "failed to transition qp to rts state"); + } catch (std::exception const &e) { + // RDMAPP_LOG_ERROR("%s", e.what()); + qp_ = nullptr; + destroy(); + throw; + } + } + + private: + /** + * @brief This function posts a recv request on the Queue Pair's own RQ. + * + * @param recv_wr + * @param bad_recv_wr + */ + void post_recv_rq(struct ibv_recv_wr const &recv_wr, + struct ibv_recv_wr *&bad_recv_wr) const { + // RDMAPP_LOG_TRACE("post recv wr_id=%p addr=%p", + // reinterpret_cast(recv_wr.wr_id), + // reinterpret_cast(recv_wr.sg_list->addr)); + check_rc(::ibv_post_recv(qp_, const_cast(&recv_wr), + &bad_recv_wr), + "failed to post recv"); + } + + /** + * @brief This function posts a send request on the Queue Pair's SRQ. + * + * @param recv_wr + * @param bad_recv_wr + */ + void post_recv_srq(struct ibv_recv_wr const &recv_wr, + struct ibv_recv_wr *&bad_recv_wr) const { + check_rc( + ::ibv_post_srq_recv( + raw_srq_, const_cast(&recv_wr), &bad_recv_wr), + "failed to post srq recv"); + } +}; + +static inline void process_wc(struct ibv_wc const &wc) { + if (wc.opcode & IBV_WC_RECV) { + auto &recv_awaitable = + *reinterpret_cast(wc.wr_id); + recv_awaitable.complete(wc); + } + else { + auto &send_awaitable = + *reinterpret_cast(wc.wr_id); + send_awaitable.complete(wc); + } +} + +} // namespace rdmapp diff --git a/include/ylt/standalone/rdmapp/srq.h b/include/ylt/standalone/rdmapp/srq.h new file mode 100644 index 000000000..d810d5529 --- /dev/null +++ b/include/ylt/standalone/rdmapp/srq.h @@ -0,0 +1,54 @@ +#pragma once + +#include + +#include "pd.h" + +namespace rdmapp { + +/** + * @brief This class represents a Shared Receive Queue. + * + */ +class srq { + struct ibv_srq *srq_; + pd_ptr pd_; + friend class qp; + + public: + /** + * @brief Construct a new srq object + * + * @param pd The protection domain to use. + * @param max_wr The maximum number of outstanding work requests. + */ + srq(pd_ptr pd, size_t max_wr = 1024) : srq_(nullptr), pd_(pd) { + struct ibv_srq_init_attr srq_init_attr; + srq_init_attr.srq_context = this; + srq_init_attr.attr.max_sge = 1; + srq_init_attr.attr.max_wr = max_wr; + srq_init_attr.attr.srq_limit = max_wr; + + srq_ = ::ibv_create_srq(pd_->pd_, &srq_init_attr); + check_ptr(srq_, "failed to create srq"); + } + + /** + * @brief Destroy the srq object and the associated shared receive queue. + * + */ + ~srq() { + if (srq_ == nullptr) [[unlikely]] { + return; + } + + if (auto rc = ::ibv_destroy_srq(srq_); rc != 0) [[unlikely]] { + } + else { + } + } +}; + +typedef srq *srq_ptr; + +} // namespace rdmapp \ No newline at end of file diff --git a/src/coro_rdma/examples/CMakeLists.txt b/src/coro_rdma/examples/CMakeLists.txt new file mode 100644 index 000000000..c81bfa831 --- /dev/null +++ b/src/coro_rdma/examples/CMakeLists.txt @@ -0,0 +1,31 @@ + +if("${yaLanTingLibs_SOURCE_DIR}" STREQUAL "${CMAKE_SOURCE_DIR}") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/output/examples/coro_rdma) +else() + # else find installed yalantinglibs + cmake_minimum_required(VERSION 3.15) + project(file_transfer) + if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") + endif() + set(CMAKE_CXX_STANDARD 20) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + set(CMAKE_INCLUDE_CURRENT_DIR ON) + find_package(Threads REQUIRED) + link_libraries(Threads::Threads) + # if you have install ylt + find_package(yalantinglibs REQUIRED) + link_libraries(yalantinglibs::yalantinglibs) + # else + # include_directories(include) + # include_directories(include/ylt/thirdparty) + + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcoroutines") + #-ftree-slp-vectorize with coroutine cause link error. disable it util gcc fix. + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -fno-tree-slp-vectorize") + endif() +endif() + +add_executable(coro_rdma_example example.cpp) +target_link_libraries(coro_rdma_example ibverbs) diff --git a/src/coro_rdma/examples/example.cpp b/src/coro_rdma/examples/example.cpp new file mode 100644 index 000000000..6e35e19af --- /dev/null +++ b/src/coro_rdma/examples/example.cpp @@ -0,0 +1,365 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "asio/posix/stream_descriptor.hpp" + +rdmapp::device_ptr g_device; +rdmapp::pd_ptr g_pd; +rdmapp::cq_ptr g_cq; +rdmapp::comp_channel_ptr g_channel; + +async_simple::coro::Lazy recv_qp( + asio::ip::tcp::socket &socket) { + std::array + header; + { + auto [ec, size] = + co_await coro_io::async_read(socket, asio::buffer(header)); + if (ec) { + std::cerr << ec.message() << std::endl; + throw std::runtime_error("read qp header failed"); + } + } + auto remote_qp = rdmapp::deserialized_qp::deserialize(header.data()); + auto const remote_gid_str = + rdmapp::device::gid_hex_string(remote_qp.header.gid); + fprintf(stderr, + "received header gid=%s lid=%u qpn=%u psn=%u user_data_size=%u\n", + remote_gid_str.c_str(), remote_qp.header.lid, remote_qp.header.qp_num, + remote_qp.header.sq_psn, remote_qp.header.user_data_size); + if (remote_qp.header.user_data_size > 0) { + remote_qp.user_data.resize(remote_qp.header.user_data_size); + auto [ec, size] = + co_await coro_io::async_read(socket, asio::buffer(remote_qp.user_data)); + if (ec) { + std::cerr << ec.message() << std::endl; + throw std::runtime_error("read qp user data failed"); + } + } + co_return remote_qp; +} + +async_simple::coro::Lazy send_qp(asio::ip::tcp::socket &socket, + rdmapp::qp const &qp) { + auto serialized_qp = qp.serialize(); + auto [ec, size] = + co_await coro_io::async_write(socket, asio::buffer(serialized_qp)); + co_return ec; +} + +class rdma_qp_client { + asio::io_context *ctx_; + coro_io::ExecutorWrapper<> executor_; + uint16_t port_; + std::string address_; + std::error_code errc_ = {}; + asio::ip::tcp::socket socket_; + + public: + rdma_qp_client(asio::io_context &ctx, unsigned short port, + std::string address = "0.0.0.0") + : ctx_(&ctx), + executor_(ctx.get_executor()), + port_(port), + address_(address), + socket_(ctx) {} + + async_simple::coro::Lazy handle_qp(rdmapp::qp_ptr qp) { + char buffer[6]; + auto buffer_mr = std::make_unique( + qp->pd()->reg_mr(&buffer[0], sizeof(buffer))); + + /* Send/Recv */ + auto [n, _] = co_await qp->recv(buffer_mr.get()); + std::cout << "Received " << n << " bytes from server: " << buffer + << std::endl; + std::copy_n("world", sizeof(buffer), buffer); + co_await qp->send(buffer_mr.get()); + std::cout << "Sent to server: " << buffer << std::endl; + + /* Read/Write */ + char remote_mr_serialized[rdmapp::remote_mr::kSerializedSize]; + auto remote_mr_header_buffer = + std::make_unique(qp->pd()->reg_mr( + &remote_mr_serialized[0], sizeof(remote_mr_serialized))); + co_await qp->recv(remote_mr_header_buffer.get()); + auto remote_mr = rdmapp::remote_mr::deserialize(remote_mr_serialized); + std::cout << "Received mr addr=" << remote_mr.addr() + << " length=" << remote_mr.length() + << " rkey=" << remote_mr.rkey() << " from server" << std::endl; + auto wc = co_await qp->read(remote_mr, buffer_mr.get()); + std::cout << "Read " << wc.byte_len << " bytes from server: " << buffer + << std::endl; + std::copy_n("world", sizeof(buffer), buffer); + co_await qp->write_with_imm(remote_mr, buffer_mr.get(), 1); + + /* Atomic Fetch-and-Add (FA)/Compare-and-Swap (CS) */ + char counter_mr_serialized[rdmapp::remote_mr::kSerializedSize]; + auto counter_mr_header_mr = + std::make_unique(qp->pd()->reg_mr( + &counter_mr_serialized[0], sizeof(counter_mr_serialized))); + co_await qp->recv(counter_mr_header_mr.get()); + auto remote_counter_mr = + rdmapp::remote_mr::deserialize(counter_mr_serialized); + std::cout << "Received mr addr=" << remote_counter_mr.addr() + << " length=" << remote_counter_mr.length() + << " rkey=" << remote_counter_mr.rkey() << " from server" + << std::endl; + uint64_t counter = 0; + auto local_counter_mr = std::make_unique( + qp->pd()->reg_mr(&counter, sizeof(counter))); + co_await qp->fetch_and_add(remote_counter_mr, local_counter_mr.get(), 1); + std::cout << "Fetched and added from server: " << counter << std::endl; + co_await qp->write_with_imm(remote_mr, buffer_mr.get(), 1); + co_await qp->compare_and_swap(remote_counter_mr, local_counter_mr.get(), 43, + 4422); + std::cout << "Compared and swapped from server: " << counter << std::endl; + co_await qp->write_with_imm(remote_mr, buffer_mr.get(), 1); + + co_return; + } + + async_simple::coro::Lazy handle_server_connection( + asio::ip::tcp::socket &socket) { + auto qp = std::make_unique(g_pd, g_cq, g_cq); + co_await send_qp(socket, *qp); + std::cerr << "send qp success" << std::endl; + auto remote_qp = co_await recv_qp(socket); + std::cerr << "recv qp success" << std::endl; + qp->rtr(remote_qp.header.lid, remote_qp.header.qp_num, + remote_qp.header.sq_psn, remote_qp.header.gid); + qp->user_data() = std::move(remote_qp.user_data); + qp->rts(); + co_await handle_qp(qp.get()); + } + + async_simple::coro::Lazy run() { + auto ec = co_await coro_io::async_connect(&executor_, socket_, address_, + std::to_string(port_)); + if (ec) { + throw std::system_error(ec, "connect failed"); + } + co_await handle_server_connection(socket_); + } +}; + +class rdma_qp_server { + asio::io_context *ctx_; + coro_io::ExecutorWrapper<> executor_; + uint16_t port_; + std::string address_; + std::error_code errc_ = {}; + asio::ip::tcp::acceptor acceptor_; + std::promise acceptor_close_waiter_; + + public: + rdma_qp_server(asio::io_context &ctx, unsigned short port, + std::string address = "0.0.0.0") + : ctx_(&ctx), + executor_(ctx.get_executor()), + port_(port), + address_(address), + acceptor_(ctx) {} + + std::error_code listen() { + asio::error_code ec; + + asio::ip::tcp::resolver::query query(address_, std::to_string(port_)); + asio::ip::tcp::resolver resolver(acceptor_.get_executor()); + asio::ip::tcp::resolver::iterator it = resolver.resolve(query, ec); + + asio::ip::tcp::resolver::iterator it_end; + if (ec || it == it_end) { + if (ec) { + return ec; + } + return std::make_error_code(std::errc::address_not_available); + } + + auto endpoint = it->endpoint(); + ec = acceptor_.open(endpoint.protocol(), ec); + + if (ec) { + return ec; + } +#ifdef __GNUC__ + ec = acceptor_.set_option(asio::ip::tcp::acceptor::reuse_address(true), ec); +#endif + ec = acceptor_.bind(endpoint, ec); + if (ec) { + std::error_code ignore_ec; + ignore_ec = acceptor_.cancel(ignore_ec); + ignore_ec = acceptor_.close(ignore_ec); + return ec; + } + ec = acceptor_.listen(asio::socket_base::max_listen_connections, ec); + if (ec) { + return ec; + } + auto local_ep = acceptor_.local_endpoint(ec); + if (ec) { + return ec; + } + port_ = local_ep.port(); + std::cerr << "listen success, port: " << port_ << std::endl; + return {}; + } + + async_simple::coro::Lazy handle_qp(rdmapp::qp_ptr qp) { + /* Send/Recv */ + char buffer[6] = "hello"; + auto local_mr = std::make_unique( + qp->pd()->reg_mr(&buffer[0], sizeof(buffer))); + co_await qp->send(local_mr.get()); + std::cout << "Sent to client: " << buffer << std::endl; + co_await qp->recv(local_mr.get()); + std::cout << "Received from client: " << buffer << std::endl; + + /* Read/Write */ + std::copy_n("hello", sizeof(buffer), buffer); + auto local_mr_serialized = local_mr->serialize(); + auto local_mr_header_mr = std::make_unique( + qp->pd()->reg_mr(&local_mr_serialized[0], local_mr_serialized.size())); + co_await qp->send(local_mr_header_mr.get()); + std::cout << "Sent mr addr=" << local_mr->addr() + << " length=" << local_mr->length() + << " rkey=" << local_mr->rkey() << " to client" << std::endl; + auto [_, imm] = co_await qp->recv(local_mr.get()); + assert(imm.has_value()); + std::cout << "Written by client (imm=" << imm.value() << "): " << buffer + << std::endl; + + /* Atomic */ + uint64_t counter = 42; + auto counter_mr = std::make_unique( + qp->pd()->reg_mr(&counter, sizeof(counter))); + auto counter_mr_serialized = counter_mr->serialize(); + auto counter_mr_header_mr = + std::make_unique(qp->pd()->reg_mr( + &counter_mr_serialized[0], counter_mr_serialized.size())); + co_await qp->send(&*counter_mr_header_mr); + std::cout << "Sent mr addr=" << counter_mr->addr() + << " length=" << counter_mr->length() + << " rkey=" << counter_mr->rkey() << " to client" << std::endl; + imm = (co_await qp->recv(local_mr.get())).second; + assert(imm.has_value()); + std::cout << "Fetched and added by client: " << counter << std::endl; + imm = (co_await qp->recv(local_mr.get())).second; + assert(imm.has_value()); + std::cout << "Compared and swapped by client: " << counter << std::endl; + } + + async_simple::coro::Lazy handle_client_connection( + asio::ip::tcp::socket socket) { + auto remote_qp = co_await recv_qp(socket); + auto qp = std::make_unique( + remote_qp.header.lid, remote_qp.header.qp_num, remote_qp.header.sq_psn, + remote_qp.header.gid, g_pd, g_cq, g_cq); + std::cerr << "recv qp success" << std::endl; + co_await send_qp(socket, *qp); + std::cerr << "send qp success" << std::endl; + co_await handle_qp(&*qp); + } + + async_simple::coro::Lazy<> loop() { + for (;;) { + asio::ip::tcp::socket socket(executor_.get_asio_executor()); + auto ec = co_await coro_io::async_accept(acceptor_, socket); + if (ec) { + std::cerr << "accept failed: " << ec.message() << std::endl; + if (ec == asio::error::operation_aborted || + ec == asio::error::bad_descriptor) { + acceptor_close_waiter_.set_value(); + co_return; + } + continue; + } + std::cout << "accpeted connection " << socket.remote_endpoint() + << std::endl; + handle_client_connection(std::move(socket)).via(&executor_).detach(); + } + } +}; + +async_simple::coro::Lazy process_rdma_cq( + asio::io_context &ctx, rdmapp::comp_channel_ptr channel) { + asio::posix::stream_descriptor cq_fd(ctx.get_executor(), channel->fd()); + while (!ctx.stopped()) { + coro_io::callback_awaitor awaitor; + auto ec = co_await awaitor.await_resume([&cq_fd](auto handler) { + cq_fd.async_wait(asio::posix::stream_descriptor::wait_read, + [handler](const auto &ec) mutable { + handler.set_value_then_resume(ec); + }); + }); + if (ec) { + std::cerr << "failed to wait: " << ec.message() << std::endl; + co_return; + } + + auto cq = channel->get_event(); + cq->ack_event(); + cq->request_notify(); + struct ibv_wc wc; + while (cq->poll(wc)) { + rdmapp::process_wc(wc); + } + } +} + +int main(int argc, char *argv[]) { + // For server run ./example [port] + // For client run ./example [ip] [port] + auto socket_ctx = std::make_unique(); + auto executor_wrapper = + std::make_unique>(socket_ctx->get_executor()); + g_device = new rdmapp::device(); + g_pd = new rdmapp::pd(g_device); + + g_channel = new rdmapp::comp_channel(g_device); + g_channel->set_non_blocking(); + g_cq = new rdmapp::cq(g_device, 128, g_channel); + g_cq->request_notify(); + + process_rdma_cq(*socket_ctx, g_channel).via(&*executor_wrapper).detach(); + std::jthread worker_thread([&socket_ctx] { + socket_ctx->run(); + }); + + if (argc == 2) { + std::cout << "server mode, press Control+C to exit" << std::endl; + std::string port_str = argv[1]; + rdma_qp_server server(*socket_ctx, std::stoi(port_str)); + auto ec = server.listen(); + if (ec) { + std::cerr << "listen failed: " << ec.message() << std::endl; + return -1; + } + async_simple::coro::syncAwait(server.loop()); + } + else if (argc == 3) { + std::cout << "client mode, connecting to server" << std::endl; + std::string ip = argv[1]; + std::string port_str = argv[2]; + rdma_qp_client client(*socket_ctx, std::stoi(port_str), ip); + // TODO: fix memory leak + async_simple::coro::syncAwait(client.run()); + socket_ctx->stop(); + } + else { + std::cerr << "server usage: " << argv[0] << " [port]" << std::endl; + std::cerr << "client usage: " << argv[0] << " [ip] [port]" << std::endl; + } + return 0; +} \ No newline at end of file