Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: enable cuda support v1 - EAGER with GDR COPY #20

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions src/ucp/api/ucp_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ typedef struct ucp_rkey *ucp_rkey_h;
*/
typedef struct ucp_mem *ucp_mem_h;

/*
* @ingroup UCP_ADDR_DN
* @brief UCP Address Domain
*
* Address Domain handle is an opaque object representing a memory adreess domain
*/
typedef struct ucp_addr_dn *ucp_addr_dn_h;

/**
* @ingroup UCP_WORKER
Expand Down
40 changes: 40 additions & 0 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ static ucp_mem_t ucp_mem_dummy_handle = {
.md_map = 0
};

ucp_addr_dn_t ucp_addr_dn_dummy_handle = {
.md_map = 0,
.id = UCT_MD_ADDR_DOMAIN_LAST
};


/**
* Unregister memory from all memory domains.
* Save in *alloc_md_memh_p the memory handle of the allocating MD, if such exists.
Expand Down Expand Up @@ -106,6 +112,40 @@ static ucs_status_t ucp_memh_reg_mds(ucp_context_h context, ucp_mem_h memh,
return UCS_OK;
}

ucs_status_t ucp_addr_domain_detect_mds(ucp_context_h context, void *addr, ucp_addr_dn_h *addr_dn_h)
{
ucs_status_t status;
unsigned md_index;
uct_addr_domain_t domain_id = UCT_MD_ADDR_DOMAIN_DEFAULT;

*addr_dn_h = &ucp_addr_dn_dummy_handle;

/*TODO: return if no MDs with address domain detect */

for (md_index = 0; md_index < context->num_mds; ++md_index) {
if (context->tl_mds[md_index].attr.cap.flags & UCT_MD_FLAG_ADDR_DN) {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove cap flag and use UCT_MD_ADDR_DOMAIN_DEFAULT

if (domain_id == UCT_MD_ADDR_DOMAIN_DEFAULT) {
status = uct_md_mem_detect(context->tl_mds[md_index].md, addr);
if (status == UCS_OK) {
domain_id = context->tl_mds[md_index].attr.cap.addr_dn;

*addr_dn_h = ucs_malloc(sizeof(ucp_addr_dn_t), "ucp_addr_dn_h");
if (*addr_dn_h == NULL) {
return UCS_ERR_NO_MEMORY;
}

(*addr_dn_h)->id = domain_id;
(*addr_dn_h)->md_map = UCS_BIT(md_index);
}
} else {
if (domain_id == context->tl_mds[md_index].attr.cap.addr_dn) {
(*addr_dn_h)->md_map |= UCS_BIT(md_index);
}
}
}
}
return UCS_OK;
}
/**
* @return Whether MD number 'md_index' is selected by the configuration as part
* of allocation method number 'config_method_index'.
Expand Down
20 changes: 20 additions & 0 deletions src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ typedef struct ucp_mem_desc {
} ucp_mem_desc_t;


/**
* Memory Address Domain descriptor.
* Contains domain information of the memory address it belongs to.
*/
typedef struct ucp_addr_dn {
ucp_md_map_t md_map; /* Which MDs have own ths addr Domain */
uct_addr_domain_t id; /* Address domain index */
ucp_lane_index_t eager_lane;
} ucp_addr_dn_t;

void ucp_rkey_resolve_inner(ucp_rkey_h rkey, ucp_ep_h ep);

ucs_status_t ucp_mpool_malloc(ucs_mpool_t *mp, size_t *size_p, void **chunk_p);
Expand All @@ -72,6 +82,16 @@ void ucp_mpool_free(ucs_mpool_t *mp, void *chunk);

void ucp_mpool_obj_init(ucs_mpool_t *mp, void *obj, void *chunk);

/**
* Detects the address domain on all MDs. skips on detect on sub-sequence MDs
* if it sucessfully detected by MD.
**/
ucs_status_t ucp_addr_domain_detect_mds(ucp_context_h context, void *addr,
ucp_addr_dn_h *addr_dn_h);


extern ucp_addr_dn_t ucp_addr_dn_dummy_handle;

static UCS_F_ALWAYS_INLINE uct_mem_h
ucp_memh2uct(ucp_mem_h memh, ucp_md_index_t md_idx)
{
Expand Down
1 change: 1 addition & 0 deletions src/ucp/core/ucp_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ typedef void (*ucp_request_callback_t)(ucp_request_t *req);
struct ucp_request {
ucs_status_t status; /* Operation status */
uint16_t flags; /* Request flags */
ucp_addr_dn_h addr_dn_h; /* Memory domain handle */
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embeded fields into req struct


union {
struct {
Expand Down
2 changes: 2 additions & 0 deletions src/ucp/tag/tag_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ ucp_tag_recv_request_init(ucp_request_t *req, ucp_worker_h worker, void* buffer,
req->recv.state.offset = 0;
req->recv.worker = worker;

ucp_addr_domain_detect_mds(worker->context, buffer, &(req->addr_dn_h));

switch (datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_IOV:
req->recv.state.dt.iov.iov_offset = 0;
Expand Down