Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sessions: add support for ucx and more #12723

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions ompi/communicator/comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* Copyright (c) 2015 Mellanox Technologies. All rights reserved.
* Copyright (c) 2017-2022 IBM Corporation. All rights reserved.
* Copyright (c) 2021 Nanook Consulting. All rights reserved.
* Copyright (c) 2018-2022 Triad National Security, LLC. All rights
* Copyright (c) 2018-2024 Triad National Security, LLC. All rights
* reserved.
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
* $COPYRIGHT$
Expand Down Expand Up @@ -1738,7 +1738,7 @@ int ompi_intercomm_create_from_groups (ompi_group_t *local_group, int local_lead
ompi_communicator_t **newintercomm)
{
ompi_communicator_t *newcomp = NULL, *local_comm, *leader_comm = MPI_COMM_NULL;
ompi_comm_extended_cid_block_t new_block;
ompi_comm_extended_cid_block_t new_block = {0};
bool i_am_leader = local_leader == local_group->grp_my_rank;
ompi_proc_t **rprocs;
uint64_t data[4];
Expand Down Expand Up @@ -1864,14 +1864,7 @@ int ompi_intercomm_create_from_groups (ompi_group_t *local_group, int local_lead
return rc;
}

/* will be using a communicator ID derived from the bridge communicator to save some time */
new_block.block_cid.cid_base = data[1];
new_block.block_cid.cid_sub.u64 = data[2];
new_block.block_nextsub = 0;
new_block.block_nexttag = 0;
new_block.block_level = (int8_t) data[3];

rc = ompi_comm_nextcid (newcomp, NULL, NULL, (void *) tag, &new_block, false, OMPI_COMM_CID_GROUP_NEW);
rc = ompi_comm_nextcid (newcomp, NULL, NULL, (void *) tag, NULL, false, OMPI_COMM_CID_GROUP_NEW);
if ( OMPI_SUCCESS != rc ) {
OBJ_RELEASE(newcomp);
return rc;
Expand Down
213 changes: 176 additions & 37 deletions ompi/communicator/comm_cid.c
Original file line number Diff line number Diff line change
Expand Up @@ -310,21 +310,16 @@ static int ompi_comm_ext_cid_new_block (ompi_communicator_t *newcomm, ompi_commu
const void *arg0, const void *arg1, bool send_first, int mode,
ompi_request_t **req)
{
pmix_info_t pinfo, *results = NULL;
pmix_info_t *pinfo, *results = NULL;
size_t nresults;
opal_process_name_t *name_array = NULL;
char *tag = NULL;
size_t proc_count;
size_t cid_base = 0;
opal_process_name_t opal_proc_name;
bool cid_base_set = false;
char *tag = NULL;
size_t proc_count = 0, rproc_count = 0, tproc_count = 0, cid_base = 0UL, ninfo;
int rc, leader_rank;
int ret = OMPI_SUCCESS;
pmix_proc_t *procs = NULL;

rc = ompi_group_to_proc_name_array (newcomm->c_local_group, &name_array, &proc_count);
if (OPAL_UNLIKELY(OMPI_SUCCESS != rc)) {
return rc;
}
pmix_proc_t *procs;
void *grpinfo = NULL, *list = NULL;
pmix_data_array_t darray;

switch (mode) {
case OMPI_COMM_CID_GROUP_NEW:
Expand All @@ -341,15 +336,75 @@ static int ompi_comm_ext_cid_new_block (ompi_communicator_t *newcomm, ompi_commu
break;
}

PMIX_INFO_LOAD(&pinfo, PMIX_GROUP_ASSIGN_CONTEXT_ID, NULL, PMIX_BOOL);
grpinfo = PMIx_Info_list_start();
if (NULL == grpinfo) {
rc = OMPI_ERR_OUT_OF_RESOURCE;
goto fn_exit;
}

rc = PMIx_Info_list_add(grpinfo, PMIX_GROUP_ASSIGN_CONTEXT_ID, NULL, PMIX_BOOL);
if (PMIX_SUCCESS != rc) {
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Info_list_add failed %s %d", PMIx_Error_string(rc), __LINE__));
rc = OMPI_ERR_OUT_OF_RESOURCE;
goto fn_exit;
}

list = PMIx_Info_list_start();

size_t c_index = (size_t)newcomm->c_index;
rc = PMIx_Info_list_add(list, PMIX_GROUP_LOCAL_CID, &c_index, PMIX_SIZE);
if (PMIX_SUCCESS != rc) {
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Info_list_add failed %s %d", PMIx_Error_string(rc), __LINE__));
rc = OMPI_ERR_OUT_OF_RESOURCE;
goto fn_exit;
}

rc = PMIx_Info_list_convert(list, &darray);
if (PMIX_SUCCESS != rc) {
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Info_list_convert failed %s %d", PMIx_Error_string(rc), __LINE__));
rc = OMPI_ERR_OUT_OF_RESOURCE;
goto fn_exit;
}
rc = PMIx_Info_list_add(grpinfo, PMIX_GROUP_INFO, &darray, PMIX_DATA_ARRAY);
PMIX_DATA_ARRAY_DESTRUCT(&darray);
if (PMIX_SUCCESS != rc) {
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Info_list_add failed %s %d", PMIx_Error_string(rc), __LINE__));
rc = OMPI_ERR_OUT_OF_RESOURCE;
goto fn_exit;
}

rc = PMIx_Info_list_convert(grpinfo, &darray);
if (PMIX_SUCCESS != rc) {
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Info_list_convert failed %s %d", PMIx_Error_string(rc), __LINE__));
rc = OMPI_ERR_OUT_OF_RESOURCE;
goto fn_exit;
}

pinfo = (pmix_info_t*)darray.array;
ninfo = darray.size;

proc_count = newcomm->c_local_group->grp_proc_count;
if ( OMPI_COMM_IS_INTER (newcomm) ){
rproc_count = newcomm->c_remote_group->grp_proc_count;
}

PMIX_PROC_CREATE(procs, proc_count + rproc_count);

PMIX_PROC_CREATE(procs, proc_count);
for (size_t i = 0 ; i < proc_count; ++i) {
OPAL_PMIX_CONVERT_NAME(&procs[i],&name_array[i]);
opal_proc_name = ompi_group_get_proc_name(newcomm->c_local_group, i);
OPAL_PMIX_CONVERT_NAME(&procs[i],&opal_proc_name);
}
for (size_t i = 0; i < rproc_count; ++i) {
opal_proc_name = ompi_group_get_proc_name(newcomm->c_remote_group, i);
OPAL_PMIX_CONVERT_NAME(&procs[proc_count+i],&opal_proc_name);
}

rc = PMIx_Group_construct(tag, procs, proc_count, &pinfo, 1, &results, &nresults);
PMIX_INFO_DESTRUCT(&pinfo);
tproc_count = proc_count + rproc_count;

OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "calling PMIx_Group_construct - tag %s size %ld ninfo %ld cid_base %ld\n",
tag, tproc_count, ninfo, cid_base));
rc = PMIx_Group_construct(tag, procs, tproc_count, pinfo, ninfo, &results, &nresults);
PMIX_DATA_ARRAY_DESTRUCT(&darray);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There must be some optimizations behind this centralized group creation for it to scale. I would love to hear more about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on the RTE, of course, as all PMIx does is construct the request and pass it up. For PRRTE, it is just the tree-based allgather we use for pretty much all collectives. I don't believe anyone has exerted much effort towards optimizing it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly what I'm concerned about as I have seen no performance report on this change.

if(PMIX_SUCCESS != rc) {
char msg_string[1024];
switch (rc) {
Expand All @@ -361,7 +416,7 @@ static int ompi_comm_ext_cid_new_block (ompi_communicator_t *newcomm, ompi_commu
"MPI_Comm_create_from_group/MPI_Intercomm_create_from_groups",
msg_string);

ret = MPI_ERR_UNSUPPORTED_OPERATION;
rc = MPI_ERR_UNSUPPORTED_OPERATION;
break;
case PMIX_ERR_NOT_SUPPORTED:
sprintf(msg_string,"PMIx server does not support PMIx Group operations");
Expand All @@ -370,10 +425,10 @@ static int ompi_comm_ext_cid_new_block (ompi_communicator_t *newcomm, ompi_commu
true,
"MPI_Comm_create_from_group/MPI_Intercomm_create_from_groups",
msg_string);
ret = MPI_ERR_UNSUPPORTED_OPERATION;
rc = MPI_ERR_UNSUPPORTED_OPERATION;
break;
default:
ret = opal_pmix_convert_status(rc);
rc = opal_pmix_convert_status(rc);
break;
}
goto fn_exit;
Expand All @@ -383,23 +438,28 @@ static int ompi_comm_ext_cid_new_block (ompi_communicator_t *newcomm, ompi_commu
if (PMIX_CHECK_KEY(&results[i], PMIX_GROUP_CONTEXT_ID)) {
PMIX_VALUE_GET_NUMBER(rc, &results[i].value, cid_base, size_t);
if(PMIX_SUCCESS != rc) {
ret = opal_pmix_convert_status(rc);
rc = opal_pmix_convert_status(rc);
goto fn_exit;
}
cid_base_set = true;
break;
}
}

OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Group_construct - tag %s size %ld ninfo %ld cid_base %ld\n",
tag, tproc_count, ninfo, cid_base));

/* destruct the group */
rc = PMIx_Group_destruct (tag, NULL, 0);
if(PMIX_SUCCESS != rc) {
ret = opal_pmix_convert_status(rc);
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Group_destruct failed %s", PMIx_Error_string(rc)));
rc = opal_pmix_convert_status(rc);
goto fn_exit;
}

if (!cid_base_set) {
opal_show_help("help-comm.txt", "cid-base-not-set", true);
ret = OMPI_ERROR;
rc = OMPI_ERROR;
goto fn_exit;
}

Expand All @@ -412,16 +472,19 @@ static int ompi_comm_ext_cid_new_block (ompi_communicator_t *newcomm, ompi_commu
}

if(NULL != procs) {
PMIX_PROC_FREE(procs, proc_count);
PMIX_PROC_FREE(procs, tproc_count);
procs = NULL;
}

if(NULL != name_array) {
free (name_array);
name_array = NULL;
if (NULL != grpinfo) {
PMIx_Info_list_release(grpinfo);
}

return ret;
if (NULL != list) {
PMIx_Info_list_release(list);
}

return rc;
}

static int ompi_comm_nextcid_ext_nb (ompi_communicator_t *newcomm, ompi_communicator_t *comm,
Expand All @@ -446,6 +509,15 @@ static int ompi_comm_nextcid_ext_nb (ompi_communicator_t *newcomm, ompi_communic
block = &comm->c_contextidb;
}

for (unsigned int i = ompi_mpi_communicators.lowest_free ; i < mca_pml.pml_max_contextid ; ++i) {
bool flag = opal_pointer_array_test_and_set_item (&ompi_mpi_communicators, i, newcomm);
if (true == flag) {
newcomm->c_index = i;
break;
}
}
assert(newcomm->c_index > 2);

if (NULL == arg1) {
if (OMPI_COMM_CID_GROUP == mode || OMPI_COMM_CID_GROUP_NEW == mode ||
!ompi_comm_extended_cid_block_available (&comm->c_contextidb)) {
Expand All @@ -468,14 +540,6 @@ static int ompi_comm_nextcid_ext_nb (ompi_communicator_t *newcomm, ompi_communic
(void) ompi_comm_extended_cid_block_new (block, &newcomm->c_contextidb, is_new_block);
}

for (unsigned int i = ompi_mpi_communicators.lowest_free ; i < mca_pml.pml_max_contextid ; ++i) {
bool flag = opal_pointer_array_test_and_set_item (&ompi_mpi_communicators, i, newcomm);
if (true == flag) {
newcomm->c_index = i;
break;
}
}

newcomm->c_contextid = newcomm->c_contextidb.block_cid;

opal_hash_table_set_value_ptr (&ompi_comm_hash, &newcomm->c_contextid,
Expand All @@ -502,7 +566,7 @@ int ompi_comm_nextcid_nb (ompi_communicator_t *newcomm, ompi_communicator_t *com
functions but the pml does not support these functions so return not supported */
if (NULL == comm) {
char msg_string[1024];
sprintf(msg_string,"The PML being used - %s - does not support MPI sessions related features",
sprintf(msg_string,"The PML being used - %s - does not support MPI sessions related features",
mca_pml_base_selected_component.pmlm_version.mca_component_name);
opal_show_help("help-comm.txt",
"MPI function not supported",
Expand Down Expand Up @@ -886,6 +950,7 @@ int ompi_comm_activate_nb (ompi_communicator_t **newcomm, ompi_communicator_t *c
ompi_comm_cid_context_t *context;
ompi_comm_request_t *request;
ompi_request_t *subreq;
uint32_t comm_size;
int ret = 0;

/* the caller should not pass NULL for comm (it may be the same as *newcomm) */
Expand All @@ -907,6 +972,25 @@ int ompi_comm_activate_nb (ompi_communicator_t **newcomm, ompi_communicator_t *c

request->context = &context->super;

/* Prep communicator for handling remote cids if needed */

if (!OMPI_COMM_IS_GLOBAL_INDEX(*newcomm)) {
if (OMPI_COMM_IS_INTER(*newcomm)) {
comm_size = ompi_comm_remote_size(*newcomm);
} else {
comm_size = ompi_comm_size(*newcomm);
}

(*newcomm)->c_index_vec = (uint32_t *)calloc(comm_size, sizeof(uint32_t));
if (NULL == (*newcomm)->c_index_vec) {
return OMPI_ERR_OUT_OF_RESOURCE;
}

if (OMPI_COMM_IS_INTRA(*newcomm)) {
(*newcomm)->c_index_vec[(*newcomm)->c_my_rank] = (*newcomm)->c_index;
}
}

if (MPI_UNDEFINED != (*newcomm)->c_local_group->grp_my_rank) {
/* Initialize the PML stuff in the newcomm */
if ( OMPI_SUCCESS != (ret = MCA_PML_CALL(add_comm(*newcomm))) ) {
Expand Down Expand Up @@ -963,6 +1047,61 @@ int ompi_comm_activate (ompi_communicator_t **newcomm, ompi_communicator_t *comm
return rc;
}

int ompi_comm_get_remote_cid_from_pmix (ompi_communicator_t *comm, int dest, uint32_t *remote_cid)
{
ompi_proc_t *ompi_proc;
pmix_proc_t pmix_proc;
pmix_info_t tinfo[2];
pmix_value_t *val = NULL;
ompi_comm_extended_cid_t excid;
int rc = OMPI_SUCCESS;
size_t remote_cid64;

assert(NULL != remote_cid);

ompi_proc = ompi_comm_peer_lookup(comm, dest);
OPAL_PMIX_CONVERT_NAME(&pmix_proc, &ompi_proc->super.proc_name);

PMIx_Info_construct(&tinfo[0]);
PMIX_INFO_LOAD(&tinfo[0], PMIX_TIMEOUT, &ompi_pmix_connect_timeout, PMIX_UINT32);

excid = ompi_comm_get_extended_cid(comm);

PMIX_INFO_CONSTRUCT(&tinfo[1]);
PMIX_INFO_LOAD(&tinfo[1], PMIX_GROUP_CONTEXT_ID, &excid.cid_base, PMIX_SIZE);
PMIX_INFO_SET_QUALIFIER(&tinfo[1]);
if (PMIX_SUCCESS != (rc = PMIx_Get(&pmix_proc, PMIX_GROUP_LOCAL_CID, tinfo, 2, &val))) {
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Get failed for PMIX_GROUP_LOCAL_CID cid_base %ld %s", excid.cid_base, PMIx_Error_string(rc)));
rc = OMPI_ERR_NOT_FOUND;
goto done;
}

if (NULL == val) {
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Get failed for PMIX_GROUP_LOCAL_CID val returned NULL"));
rc = OMPI_ERR_NOT_FOUND;
goto done;
}

if (val->type != PMIX_SIZE) {
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Get failed for PMIX_GROUP_LOCAL_CID type mismatch"));
rc = OMPI_ERR_TYPE_MISMATCH;
goto done;
}

PMIX_VALUE_GET_NUMBER(rc, val, remote_cid64, size_t);
rc = OMPI_SUCCESS;
*remote_cid = (uint32_t)remote_cid64;
comm->c_index_vec[dest] = (uint32_t)remote_cid64;
OPAL_OUTPUT_VERBOSE((10, ompi_comm_output, "PMIx_Get PMIX_GROUP_LOCAL_CID %d for cid_base %ld", *remote_cid, excid.cid_base));

done:
if (NULL != val) {
PMIX_VALUE_RELEASE(val);
}

return rc;
}

static int ompi_comm_activate_nb_complete (ompi_comm_request_t *request)
{
ompi_comm_cid_context_t *context = (ompi_comm_cid_context_t *) request->context;
Expand Down
Loading
Loading