Skip to content

Commit

Permalink
CUDA: protocol changes for eager support for domain memory
Browse files Browse the repository at this point in the history
  • Loading branch information
bureddy committed Oct 12, 2017
1 parent 71fa1ed commit ac2cc56
Show file tree
Hide file tree
Showing 13 changed files with 389 additions and 68 deletions.
19 changes: 19 additions & 0 deletions src/ucp/core/ucp_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,7 @@ void ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config)
{
ucp_context_h context = worker->context;
ucp_ep_rma_config_t *rma_config;
ucp_ep_addr_domain_config_t *domain_config;
uct_iface_attr_t *iface_attr;
uct_md_attr_t *md_attr;
ucp_rsc_index_t rsc_index;
Expand All @@ -917,6 +918,7 @@ void ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config)
config->tag.eager.zcopy_auto_thresh = 0;
config->am.zcopy_auto_thresh = 0;
config->p2p_lanes = 0;
config->domain_lanes = 0;
config->bcopy_thresh = context->config.ext.bcopy_thresh;
config->tag.lane = UCP_NULL_LANE;
config->tag.proto = &ucp_tag_eager_proto;
Expand Down Expand Up @@ -1004,6 +1006,23 @@ void ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config)
}
}

/* Configuration for memory domains */
for (lane = 0; lane < config->key.num_lanes; ++lane) {
if (config->key.domain_lanes[lane] == UCP_NULL_LANE) {
continue;
}
config->domain_lanes |= UCS_BIT(lane);

domain_config = &config->domain[lane];
rsc_index = config->key.lanes[lane].rsc_index;
iface_attr = &worker->ifaces[rsc_index].attr;

domain_config->tag.eager.max_short = iface_attr->cap.am.max_short;
//TODO: zcopy threshold should be based on the ep AM lane capability with domain addr(i.e can UCT do zcopy from domain)
memset(domain_config->tag.eager.zcopy_thresh, 0, UCP_MAX_IOV * sizeof(size_t));

}

/* Configuration for remote memory access */
for (lane = 0; lane < config->key.num_lanes; ++lane) {
if (ucp_ep_config_get_rma_prio(config->key.rma_lanes, lane) == -1) {
Expand Down
25 changes: 24 additions & 1 deletion src/ucp/core/ucp_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ typedef struct ucp_ep_config_key {
/* Lanes for atomic operations, sorted by priority, highest first */
ucp_lane_index_t amo_lanes[UCP_MAX_LANES];

/* Lanes for domain operations, sorted by priority, highest first */
ucp_lane_index_t domain_lanes[UCP_MAX_LANES];

/* Bitmap of remote mds which are reachable from this endpoint (with any set
* of transports which could be selected in the future).
*/
Expand All @@ -106,6 +109,15 @@ typedef struct ucp_ep_rma_config {
} ucp_ep_rma_config_t;


typedef struct ucp_ep_addr_domain_config {
struct {
struct {
ssize_t max_short;
size_t zcopy_thresh[UCP_MAX_IOV];
} eager;
} tag;
} ucp_ep_addr_domain_config_t;

/*
* Configuration for AM and tag offload protocols
*/
Expand Down Expand Up @@ -136,6 +148,10 @@ typedef struct ucp_ep_config {
*/
ucp_lane_map_t p2p_lanes;

/* Bitmap of which lanes are domain lanes
*/
ucp_lane_map_t domain_lanes;

/* Configuration for each lane that provides RMA */
ucp_ep_rma_config_t rma[UCP_MAX_LANES];
/* Threshold for switching from put_short to put_bcopy */
Expand Down Expand Up @@ -179,8 +195,11 @@ typedef struct ucp_ep_config {
* (currently it's only AM based). */
const ucp_proto_t *proto;
} stream;
} ucp_ep_config_t;

/* Configuration of all domains */
ucp_ep_addr_domain_config_t domain[UCP_MAX_LANES];

} ucp_ep_config_t;

/**
* Remote protocol layer endpoint
Expand Down Expand Up @@ -245,4 +264,8 @@ size_t ucp_ep_config_get_zcopy_auto_thresh(size_t iovcnt,
const ucp_context_h context,
double bandwidth);

ucp_lane_index_t ucp_config_find_domain_lane(const ucp_ep_config_t *config,
const ucp_lane_index_t *lanes,
ucp_md_map_t dn_md_map);
ucs_status_t ucp_ep_set_domain_lanes(ucp_ep_h ep, ucp_mem_type_h mem_type_h);
#endif
1 change: 1 addition & 0 deletions src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ typedef struct ucp_mem_desc {
typedef struct ucp_mem_type {
ucp_md_map_t md_map; /* Which MDs have own ths addr Domain */
uct_memory_type_t id; /* memory type */
ucp_lane_index_t eager_lane;
} ucp_mem_type_t;

void ucp_rkey_resolve_inner(ucp_rkey_h rkey, ucp_ep_h ep);
Expand Down
121 changes: 121 additions & 0 deletions src/ucp/dt/dt.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

#include "dt.h"
#include <ucp/core/ucp_request.inl>


size_t ucp_dt_pack(ucp_datatype_t datatype, void *dest, const void *src,
Expand Down Expand Up @@ -44,3 +45,123 @@ size_t ucp_dt_pack(ucp_datatype_t datatype, void *dest, const void *src,
state->offset += result_len;
return result_len;
}

static UCS_F_ALWAYS_INLINE ucs_status_t ucp_dn_dt_unpack(ucp_request_t *req, void *buffer, size_t buffer_size,
const void *recv_data, size_t recv_length)
{
ucs_status_t status;
ucp_worker_h worker = req->recv.worker;
ucp_context_h context = worker->context;
ucp_ep_h ep = ucp_worker_ep_find(worker, worker->uuid);
ucp_ep_config_t *config = ucp_ep_config(ep);
ucp_md_map_t dn_md_map = req->mem_type.md_map;
ucp_lane_index_t dn_lane;
ucp_rsc_index_t rsc_index;
uct_iface_attr_t *iface_attr;
unsigned md_index;
uct_mem_h memh;
uct_iov_t iov;

if (recv_length == 0) {
return UCS_OK;
}

while (1) {
dn_lane = ucp_config_find_domain_lane(config, config->key.domain_lanes, dn_md_map);
if (dn_lane == UCP_NULL_LANE) {
ucs_error("Not find address domain lane.");
return UCS_ERR_IO_ERROR;
}
rsc_index = ucp_ep_get_rsc_index(ep, dn_lane);
iface_attr = &worker->ifaces[rsc_index].attr;
md_index = config->key.lanes[dn_lane].dst_md_index;
if (!(iface_attr->cap.flags & UCT_IFACE_FLAG_PUT_ZCOPY)) {
dn_md_map |= ~UCS_BIT(md_index);
continue;
}
break;
}


status = uct_md_mem_reg(context->tl_mds[md_index].md, buffer, buffer_size,
UCT_MD_MEM_ACCESS_REMOTE_PUT, &memh);
if (status != UCS_OK) {
ucs_error("Failed to reg address %p with md %s", buffer,
context->tl_mds[md_index].rsc.md_name);
return status;
}

ucs_assert(buffer_size >= recv_length);
iov.buffer = (void *)recv_data;
iov.length = recv_length;
iov.count = 1;
iov.memh = UCT_MEM_HANDLE_NULL;


status = uct_ep_put_zcopy(ep->uct_eps[dn_lane], &iov, 1, (uint64_t)buffer,
(uct_rkey_t )memh, NULL);
if (status != UCS_OK) {
uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
ucs_error("Failed to perform uct_ep_put_zcopy to address %p", recv_data);
return status;
}

status = uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
if (status != UCS_OK) {
ucs_error("Failed to dereg address %p with md %s", buffer,
context->tl_mds[md_index].rsc.md_name);
return status;
}

return UCS_OK;
}


ucs_status_t ucp_dt_unpack(ucp_request_t *req, ucp_datatype_t datatype, void *buffer, size_t buffer_size,
ucp_dt_state_t *state, const void *recv_data, size_t recv_length, int last)
{
ucp_dt_generic_t *dt_gen;
size_t offset = state->offset;
ucs_status_t status;

if (ucs_unlikely((recv_length + offset) > buffer_size)) {
ucs_trace_req("message truncated: recv_length %zu offset %zu buffer_size %zu",
recv_length, offset, buffer_size);
if (UCP_DT_IS_GENERIC(datatype) && last) {
ucp_dt_generic(datatype)->ops.finish(state->dt.generic.state);
}
return UCS_ERR_MESSAGE_TRUNCATED;
}

switch (datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_CONTIG:
if (ucs_likely(UCP_IS_DEFAULT_MEMORY_TYPE(req->mem_type.id))) {
UCS_PROFILE_NAMED_CALL("memcpy_recv", memcpy, buffer + offset,
recv_data, recv_length);
return UCS_OK;
} else {
return ucp_dn_dt_unpack(req, buffer, buffer_size, recv_data, recv_length);
}

case UCP_DATATYPE_IOV:
UCS_PROFILE_CALL(ucp_dt_iov_scatter, buffer, state->dt.iov.iovcnt,
recv_data, recv_length, &state->dt.iov.iov_offset,
&state->dt.iov.iovcnt_offset);
return UCS_OK;

case UCP_DATATYPE_GENERIC:
dt_gen = ucp_dt_generic(datatype);
status = UCS_PROFILE_NAMED_CALL("dt_unpack", dt_gen->ops.unpack,
state->dt.generic.state, offset,
recv_data, recv_length);
if (last) {
UCS_PROFILE_NAMED_CALL_VOID("dt_finish", dt_gen->ops.finish,
state->dt.generic.state);
}
return status;

default:
ucs_error("unexpected datatype=%lx", datatype);
return UCS_ERR_INVALID_PARAM;
}
}
50 changes: 4 additions & 46 deletions src/ucp/dt/dt.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <uct/api/uct.h>
#include <ucs/debug/profile.h>
#include <string.h>
#include <ucp/core/ucp_types.h>


/**
Expand Down Expand Up @@ -72,51 +73,8 @@ size_t ucp_dt_length(ucp_datatype_t datatype, size_t count,
size_t ucp_dt_pack(ucp_datatype_t datatype, void *dest, const void *src,
ucp_dt_state_t *state, size_t length);

static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_dt_unpack(ucp_datatype_t datatype, void *buffer, size_t buffer_size,
ucp_dt_state_t *state, const void *recv_data,
size_t recv_length, int last)
{
ucp_dt_generic_t *dt_gen;
size_t offset = state->offset;
ucs_status_t status;

if (ucs_unlikely((recv_length + offset) > buffer_size)) {
ucs_trace_req("message truncated: recv_length %zu offset %zu buffer_size %zu",
recv_length, offset, buffer_size);
if (UCP_DT_IS_GENERIC(datatype) && last) {
ucp_dt_generic(datatype)->ops.finish(state->dt.generic.state);
}
return UCS_ERR_MESSAGE_TRUNCATED;
}

switch (datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_CONTIG:
UCS_PROFILE_NAMED_CALL("memcpy_recv", memcpy, buffer + offset,
recv_data, recv_length);
return UCS_OK;

case UCP_DATATYPE_IOV:
UCS_PROFILE_CALL(ucp_dt_iov_scatter, buffer, state->dt.iov.iovcnt,
recv_data, recv_length, &state->dt.iov.iov_offset,
&state->dt.iov.iovcnt_offset);
return UCS_OK;

case UCP_DATATYPE_GENERIC:
dt_gen = ucp_dt_generic(datatype);
status = UCS_PROFILE_NAMED_CALL("dt_unpack", dt_gen->ops.unpack,
state->dt.generic.state, offset,
recv_data, recv_length);
if (last) {
UCS_PROFILE_NAMED_CALL_VOID("dt_finish", dt_gen->ops.finish,
state->dt.generic.state);
}
return status;

default:
ucs_error("unexpected datatype=%lx", datatype);
return UCS_ERR_INVALID_PARAM;
}
}
ucs_status_t ucp_dt_unpack(ucp_request_t *req, ucp_datatype_t datatype,
void *buffer, size_t buffer_size, ucp_dt_state_t *state,
const void *recv_data, size_t recv_length, int last);

#endif
4 changes: 2 additions & 2 deletions src/ucp/tag/eager.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_eager_unexp_match(ucp_worker_h worker, ucp_recv_desc_t *rdesc, ucp_tag_t tag,
unsigned flags, void *buffer, size_t count,
ucp_datatype_t datatype, ucp_dt_state_t *state,
ucp_tag_recv_info_t *info)
ucp_request_t *req, ucp_tag_recv_info_t *info)
{
size_t recv_len, hdr_len;
ucs_status_t status;
Expand All @@ -110,7 +110,7 @@ ucp_eager_unexp_match(ucp_worker_h worker, ucp_recv_desc_t *rdesc, ucp_tag_t tag
UCP_WORKER_STAT_EAGER_CHUNK(worker, UNEXP);
hdr_len = rdesc->hdr_len;
recv_len = rdesc->length - hdr_len;
status = ucp_dt_unpack(datatype, buffer, count, state, data + hdr_len,
status = ucp_dt_unpack(req, datatype, buffer, count, state, data + hdr_len,
recv_len, flags & UCP_RECV_DESC_FLAG_LAST);
state->offset += recv_len;

Expand Down
2 changes: 1 addition & 1 deletion src/ucp/tag/eager_rcv.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ ucp_eager_handler(void *arg, void *data, size_t length, unsigned am_flags,
if (req != NULL) {
UCS_PROFILE_REQUEST_EVENT(req, "eager_recv", recv_len);

status = ucp_dt_unpack(req->recv.datatype, req->recv.buffer,
status = ucp_dt_unpack(req, req->recv.datatype, req->recv.buffer,
req->recv.length, &req->recv.state,
data + hdr_len, recv_len,
flags & UCP_RECV_DESC_FLAG_LAST);
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/tag/offload.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void ucp_tag_offload_completed(uct_tag_context_t *self, uct_tag_t stag,
}

if (req->recv.rdesc != NULL) {
status = ucp_dt_unpack(req->recv.datatype, req->recv.buffer, req->recv.length,
status = ucp_dt_unpack(req, req->recv.datatype, req->recv.buffer, req->recv.length,
&req->recv.state, req->recv.rdesc + 1, length, 1);
ucs_mpool_put_inline(req->recv.rdesc);
} else {
Expand Down
8 changes: 4 additions & 4 deletions src/ucp/tag/rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_data_handler,
}

UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_data_recv", recv_len);
status = ucp_dt_unpack(rreq->recv.datatype, rreq->recv.buffer,
status = ucp_dt_unpack(rreq, rreq->recv.datatype, rreq->recv.buffer,
rreq->recv.length, &rreq->recv.state,
data + hdr_len, recv_len, 0);
if ((status == UCS_OK) || (status == UCS_INPROGRESS)) {
Expand Down Expand Up @@ -764,9 +764,9 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_data_last_handler,
/* Check that total received length matches RTS->length */
ucs_assert(rreq->recv.info.length == rreq->recv.state.offset + recv_len);
UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_data_last_recv", recv_len);
status = ucp_dt_unpack(rreq->recv.datatype, rreq->recv.buffer,
rreq->recv.length, &rreq->recv.state,
data + hdr_len, recv_len, 1);
status = ucp_dt_unpack(rreq, rreq->recv.datatype, rreq->recv.buffer,
rreq->recv.length, &rreq->recv.state,
data + hdr_len, recv_len, 1);
} else {
ucs_trace_data("drop last segment for rreq %p, length %zu, status %s",
rreq, recv_len, ucs_status_string(rreq->status));
Expand Down
4 changes: 3 additions & 1 deletion src/ucp/tag/tag_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size,
UCS_PROFILE_REQUEST_EVENT(req, "eager_match", 0);
status = ucp_eager_unexp_match(worker, rdesc, recv_tag, flags,
buffer, buffer_size, datatype,
&req->recv.state, info);
&req->recv.state, req, info);
ucs_trace_req("release receive descriptor %p", rdesc);
if (status != UCS_INPROGRESS) {
goto out_release_desc;
Expand Down Expand Up @@ -128,6 +128,8 @@ ucp_tag_recv_request_init(ucp_request_t *req, ucp_worker_h worker, void* buffer,
req->recv.state.offset = 0;
req->recv.worker = worker;

ucp_addr_domain_detect_mds(worker->context, buffer, &req->mem_type);

switch (datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_IOV:
req->recv.state.dt.iov.iov_offset = 0;
Expand Down
Loading

0 comments on commit ac2cc56

Please sign in to comment.