Skip to content

Commit

Permalink
Update coll framework count/disp arrays for bigcount
Browse files Browse the repository at this point in the history
This updates the coll framework functions using count/displacement
arrays to support bigcount. Instead of directly using pointers to
bigcount/non-bigcount type arrays, this adds special descriptor types,
ompi_count_array and ompi_disp_array, that internally hold a union of
either type. Collective components can now access count and displacement
values through inline get functions on the descriptors, allowing use of
both bigcount/non-bigcount arrays, depending on how the descriptors were
initialized.

Co-authored-by: Howard Pritchard <[email protected]>
Signed-off-by: Jake Tronge <[email protected]>
  • Loading branch information
jtronge and hppritcha committed Jun 14, 2024
1 parent 55c0bda commit 9aa2745
Show file tree
Hide file tree
Showing 116 changed files with 1,816 additions and 1,337 deletions.
6 changes: 5 additions & 1 deletion ompi/communicator/comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -2402,6 +2402,8 @@ int ompi_comm_determine_first ( ompi_communicator_t *intercomm, int high )
int rank, rsize;
int *rcounts;
int *rdisps;
ompi_count_array rcounts_desc;
ompi_disp_array rdisps_desc;
int scount=0;
int rc;

Expand Down Expand Up @@ -2429,8 +2431,10 @@ int ompi_comm_determine_first ( ompi_communicator_t *intercomm, int high )
scount = 1;
}

OMPI_COUNT_ARRAY_INIT(&rcounts_desc, rcounts);
OMPI_DISP_ARRAY_INIT(&rdisps_desc, rdisps);
rc = intercomm->c_coll->coll_allgatherv(&high, scount, MPI_INT,
&rhigh, rcounts, rdisps,
&rhigh, &rcounts_desc, &rdisps_desc,
MPI_INT, intercomm,
intercomm->c_coll->coll_allgatherv_module);
if ( NULL != rdisps ) {
Expand Down
153 changes: 94 additions & 59 deletions ompi/mca/coll/base/coll_base_allgatherv.c

Large diffs are not rendered by default.

65 changes: 34 additions & 31 deletions ompi/mca/coll/base/coll_base_alltoallv.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
* and count) to send the data to the other.
*/
int
mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts, const int *rdisps,
mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, ompi_count_array *rcounts, ompi_disp_array *rdisps,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
Expand All @@ -72,7 +72,7 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
if (i == rank) {
continue;
}
packed_size = rcounts[i] * type_size;
packed_size = ompi_count_array_get(rcounts, i) * type_size;
max_size = opal_max(packed_size, max_size);
}

Expand Down Expand Up @@ -111,11 +111,11 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
right = (rank + i) % size;
left = (rank + size - i) % size;

if( 0 != rcounts[right] ) { /* nothing to exchange with the peer on the right */
if( 0 != ompi_count_array_get(rcounts, right) ) { /* nothing to exchange with the peer on the right */
ompi_proc_t *right_proc = ompi_comm_peer_lookup(comm, right);
opal_convertor_clone(right_proc->super.proc_convertor, &convertor, 0);
opal_convertor_prepare_for_send(&convertor, &rdtype->super, rcounts[right],
(char *) rbuf + rdisps[right] * extent);
opal_convertor_prepare_for_send(&convertor, &rdtype->super, ompi_count_array_get(rcounts, right),
(char *) rbuf + ompi_disp_array_get(rdisps, right) * extent);
packed_size = max_size;
err = opal_convertor_pack(&convertor, &iov, &iov_count, &packed_size);
if (1 != err) {
Expand All @@ -124,17 +124,19 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
}

/* Receive data from the right */
err = MCA_PML_CALL(irecv ((char *) rbuf + rdisps[right] * extent, rcounts[right], rdtype,
err = MCA_PML_CALL(irecv ((char *) rbuf + ompi_disp_array_get(rdisps, right) * extent,
ompi_count_array_get(rcounts, right), rdtype,
right, MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
if (MPI_SUCCESS != err) {
line = __LINE__;
goto error_hndl;
}
}

if( (left != right) && (0 != rcounts[left]) ) {
if( (left != right) && (0 != ompi_count_array_get(rcounts, left)) ) {
/* Send data to the left */
err = MCA_PML_CALL(send ((char *) rbuf + rdisps[left] * extent, rcounts[left], rdtype,
err = MCA_PML_CALL(send ((char *) rbuf + ompi_disp_array_get(rdisps, left) * extent,
ompi_count_array_get(rcounts, left), rdtype,
left, MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD,
comm));
if (MPI_SUCCESS != err) {
Expand All @@ -149,15 +151,16 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
}

/* Receive data from the left */
err = MCA_PML_CALL(irecv ((char *) rbuf + rdisps[left] * extent, rcounts[left], rdtype,
err = MCA_PML_CALL(irecv ((char *) rbuf + ompi_disp_array_get(rdisps, left) * extent,
ompi_count_array_get(rcounts, left), rdtype,
left, MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
if (MPI_SUCCESS != err) {
line = __LINE__;
goto error_hndl;
}
}

if( 0 != rcounts[right] ) { /* nothing to exchange with the peer on the right */
if( 0 != ompi_count_array_get(rcounts, right) ) { /* nothing to exchange with the peer on the right */
/* Send data to the right */
err = MCA_PML_CALL(send ((char *) tmp_buffer, packed_size, MPI_PACKED,
right, MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD,
Expand Down Expand Up @@ -191,9 +194,9 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
}

int
ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, const int *sdisps,
ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, ompi_count_array *scounts, ompi_disp_array *sdisps,
struct ompi_datatype_t *sdtype,
void* rbuf, const int *rcounts, const int *rdisps,
void* rbuf, ompi_count_array *rcounts, ompi_disp_array *rdisps,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
Expand Down Expand Up @@ -230,21 +233,21 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
recvfrom = (rank + size - step) % size;

/* Determine sending and receiving locations */
psnd = (char*)sbuf + (ptrdiff_t)sdisps[sendto] * sext;
prcv = (char*)rbuf + (ptrdiff_t)rdisps[recvfrom] * rext;
psnd = (char*)sbuf + ompi_disp_array_get(sdisps, sendto) * sext;
prcv = (char*)rbuf + ompi_disp_array_get(rdisps, recvfrom) * rext;

/* send and receive */
if (0 < rcounts[recvfrom] && 0 < rdtype_size) {
err = MCA_PML_CALL(irecv(prcv, rcounts[recvfrom], rdtype, recvfrom,
if (0 < ompi_count_array_get(rcounts, recvfrom) && 0 < rdtype_size) {
err = MCA_PML_CALL(irecv(prcv, ompi_count_array_get(rcounts, recvfrom), rdtype, recvfrom,
MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
if (MPI_SUCCESS != err) {
line = __LINE__;
goto err_hndl;
}
}

if (0 < scounts[sendto] && 0 < sdtype_size) {
err = MCA_PML_CALL(send(psnd, scounts[sendto], sdtype, sendto,
if (0 < ompi_count_array_get(scounts, sendto) && 0 < sdtype_size) {
err = MCA_PML_CALL(send(psnd, ompi_count_array_get(scounts, sendto), sdtype, sendto,
MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD, comm));
if (MPI_SUCCESS != err) {
line = __LINE__;
Expand Down Expand Up @@ -280,9 +283,9 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
* differently and so will not have to duplicate code.
*/
int
ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts, const int *sdisps,
ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, ompi_count_array *scounts, ompi_disp_array *sdisps,
struct ompi_datatype_t *sdtype,
void *rbuf, const int *rcounts, const int *rdisps,
void *rbuf, ompi_count_array *rcounts, ompi_disp_array *rdisps,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
Expand Down Expand Up @@ -313,11 +316,11 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
ompi_datatype_type_extent(rdtype, &rext);

/* Simple optimization - handle send to self first */
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[rank] * sext;
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[rank] * rext;
if (0 < scounts[rank] && 0 < sdtype_size) {
err = ompi_datatype_sndrcv(psnd, scounts[rank], sdtype,
prcv, rcounts[rank], rdtype);
psnd = ((char *) sbuf) + ompi_disp_array_get(sdisps, rank) * sext;
prcv = ((char *) rbuf) + ompi_disp_array_get(rdisps, rank) * rext;
if (0 < ompi_count_array_get(scounts, rank) && 0 < sdtype_size) {
err = ompi_datatype_sndrcv(psnd, ompi_count_array_get(scounts, rank), sdtype,
prcv, ompi_count_array_get(rcounts, rank), rdtype);
if (MPI_SUCCESS != err) {
return err;
}
Expand All @@ -339,10 +342,10 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
continue;
}

if (0 < rcounts[i] && 0 < rdtype_size) {
if (0 < ompi_count_array_get(rcounts, i) && 0 < rdtype_size) {
++nreqs;
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[i] * rext;
err = MCA_PML_CALL(irecv_init(prcv, rcounts[i], rdtype,
prcv = ((char *) rbuf) + ompi_disp_array_get(rdisps, i) * rext;
err = MCA_PML_CALL(irecv_init(prcv, ompi_count_array_get(rcounts, i), rdtype,
i, MCA_COLL_BASE_TAG_ALLTOALLV, comm,
preq++));
if (MPI_SUCCESS != err) { goto err_hndl; }
Expand All @@ -355,10 +358,10 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
continue;
}

if (0 < scounts[i] && 0 < sdtype_size) {
if (0 < ompi_count_array_get(scounts, i) && 0 < sdtype_size) {
++nreqs;
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[i] * sext;
err = MCA_PML_CALL(isend_init(psnd, scounts[i], sdtype,
psnd = ((char *) sbuf) + ompi_disp_array_get(sdisps, i) * sext;
err = MCA_PML_CALL(isend_init(psnd, ompi_count_array_get(scounts, i), sdtype,
i, MCA_COLL_BASE_TAG_ALLTOALLV,
MCA_PML_BASE_SEND_STANDARD, comm,
preq++));
Expand Down
20 changes: 10 additions & 10 deletions ompi/mca/coll/base/coll_base_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,27 @@ typedef enum COLLTYPE {

/* defined arg lists to simply auto inclusion of user overriding decision functions */
#define ALLGATHER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLGATHERV_BASE_ARGS const void *sendbuf, int sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int displs[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLGATHERV_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array *recvcounts, ompi_disp_array *displs, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLREDUCE_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define ALLTOALL_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLTOALLV_BASE_ARGS const void *sendbuf, const int sendcounts[], const int sdispls[], struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int rdispls[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLTOALLW_BASE_ARGS const void *sendbuf, const int sendcounts[], const int sdispls[], struct ompi_datatype_t * const sendtypes[], void *recvbuf, const int recvcounts[], const int rdispls[], struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
#define ALLTOALLV_BASE_ARGS const void *sendbuf, ompi_count_array *sendcounts, ompi_disp_array *sdispls, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array *recvcounts, ompi_disp_array *rdispls, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLTOALLW_BASE_ARGS const void *sendbuf, ompi_count_array *sendcounts, ompi_disp_array *sdispls, struct ompi_datatype_t * const sendtypes[], void *recvbuf, ompi_count_array *recvcounts, ompi_disp_array *rdispls, struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
#define BARRIER_BASE_ARGS struct ompi_communicator_t *comm
#define BCAST_BASE_ARGS void *buffer, size_t count, struct ompi_datatype_t *datatype, int root, struct ompi_communicator_t *comm
#define EXSCAN_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define GATHER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define GATHERV_BASE_ARGS const void *sendbuf, int sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int displs[], struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define GATHERV_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array *recvcounts, ompi_disp_array *displs, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define REDUCE_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm
#define REDUCESCATTER_BASE_ARGS const void *sendbuf, void *recvbuf, const int recvcounts[], struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define REDUCESCATTER_BASE_ARGS const void *sendbuf, void *recvbuf, ompi_count_array *recvcounts, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define REDUCESCATTERBLOCK_BASE_ARGS const void *sendbuf, void *recvbuf, size_t recvcount, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define SCAN_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define SCATTER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define SCATTERV_BASE_ARGS const void *sendbuf, const int sendcounts[], const int displs[], struct ompi_datatype_t *sendtype, void *recvbuf, int recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define SCATTERV_BASE_ARGS const void *sendbuf, ompi_count_array *sendcounts, ompi_disp_array *displs, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLGATHER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLGATHERV_BASE_ARGS const void *sendbuf, int sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int displs[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLGATHERV_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array *recvcounts, ompi_disp_array *displs, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALL_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALLV_BASE_ARGS const void *sendbuf, const int sendcounts[], const int sdispls[], struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int rdispls[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALLW_BASE_ARGS const void *sendbuf, const int sendcounts[], const MPI_Aint sdispls[], struct ompi_datatype_t * const sendtypes[], void *recvbuf, const int recvcounts[], const MPI_Aint rdispls[], struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALLV_BASE_ARGS const void *sendbuf, ompi_count_array *sendcounts, ompi_disp_array *sdispls, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array *recvcounts, ompi_disp_array *rdispls, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALLW_BASE_ARGS const void *sendbuf, ompi_count_array *sendcounts, ompi_disp_array *sdispls, struct ompi_datatype_t * const sendtypes[], void *recvbuf, ompi_count_array *recvcounts, ompi_disp_array *rdispls, struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm

#define ALLGATHER_ARGS ALLGATHER_BASE_ARGS, mca_coll_base_module_t *module
#define ALLGATHERV_ARGS ALLGATHERV_BASE_ARGS, mca_coll_base_module_t *module
Expand Down Expand Up @@ -227,7 +227,7 @@ int mca_coll_base_alltoall_intra_basic_inplace(const void *rbuf, size_t rcount,
/* AlltoAllV */
int ompi_coll_base_alltoallv_intra_pairwise(ALLTOALLV_ARGS);
int ompi_coll_base_alltoallv_intra_basic_linear(ALLTOALLV_ARGS);
int mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts, const int *rdisps,
int mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, ompi_count_array *rcounts, ompi_disp_array *rdisps,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module); /* special version for INPLACE */
Expand Down
Loading

0 comments on commit 9aa2745

Please sign in to comment.