Skip to content

Commit

Permalink
Differentiate blocking unblocking collectives in UCC
Browse files Browse the repository at this point in the history
Signed-off-by: Geoffroy Vallee <[email protected]>
Signed-off-by: George Bosilca <[email protected]>
  • Loading branch information
gvallee authored and bosilca committed Jun 13, 2024
1 parent 1438a79 commit 5c2dedb
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 20 deletions.
10 changes: 8 additions & 2 deletions ompi/mca/coll/ucc/coll_ucc_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
bool blocking,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
Expand All @@ -34,6 +35,7 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t

ucc_coll_args_t coll = {
.mask = 0,
.flags = 0,
.coll_type = UCC_COLL_TYPE_ALLGATHER,
.src.info = {
.buffer = (void*)sbuf,
Expand All @@ -53,6 +55,10 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
if (blocking) {
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags |= UCC_COLL_ARGS_HINT_OPTIMIZE_LATENCY;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -70,7 +76,7 @@ int mca_coll_ucc_allgather(const void *sbuf, size_t scount, struct ompi_datatype
UCC_VERBOSE(3, "running ucc allgather");
COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype,
rbuf, rcount, rdtype,
ucc_module, &req, NULL));
true, ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
Expand All @@ -94,7 +100,7 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype,
rbuf, rcount, rdtype,
ucc_module, &req, coll_req));
false, ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
Expand Down
10 changes: 8 additions & 2 deletions ompi/mca/coll/ucc/coll_ucc_allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc
struct ompi_datatype_t *sdtype,
void* rbuf, const int *rcounts, const int *rdisps,
struct ompi_datatype_t *rdtype,
bool blocking,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
Expand All @@ -31,6 +32,7 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc

ucc_coll_args_t coll = {
.mask = 0,
.flags = 0,
.coll_type = UCC_COLL_TYPE_ALLGATHERV,
.src.info = {
.buffer = (void*)sbuf,
Expand All @@ -51,6 +53,10 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
if (blocking) {
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags |= UCC_COLL_ARGS_HINT_OPTIMIZE_LATENCY;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -71,7 +77,7 @@ int mca_coll_ucc_allgatherv(const void *sbuf, int scount,

COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
ucc_module, &req, NULL));
true, ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
Expand All @@ -98,7 +104,7 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, int scount,
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
ucc_module, &req, coll_req));
false, ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
Expand Down
13 changes: 10 additions & 3 deletions ompi/mca/coll/ucc/coll_ucc_allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

static inline ucc_status_t mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op, mca_coll_ucc_module_t *ucc_module,
struct ompi_op_t *op,
bool blocking,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
{
Expand All @@ -32,6 +34,7 @@ static inline ucc_status_t mca_coll_ucc_allreduce_init(const void *sbuf, void *r
}
ucc_coll_args_t coll = {
.mask = 0,
.flags = 0,
.coll_type = UCC_COLL_TYPE_ALLREDUCE,
.src.info = {
.buffer = (void*)sbuf,
Expand All @@ -51,6 +54,10 @@ static inline ucc_status_t mca_coll_ucc_allreduce_init(const void *sbuf, void *r
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
if (blocking) {
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags |= UCC_COLL_ARGS_HINT_OPTIMIZE_LATENCY;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -67,7 +74,7 @@ int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, size_t count,

UCC_VERBOSE(3, "running ucc allreduce");
COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op,
ucc_module, &req, NULL));
true, ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
Expand All @@ -90,7 +97,7 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, size_t count,
UCC_VERBOSE(3, "running ucc iallreduce");
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op,
ucc_module, &req, coll_req));
false, ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
Expand Down
8 changes: 8 additions & 0 deletions ompi/mca/coll/ucc/coll_ucc_alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
bool blocking,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
Expand All @@ -34,6 +35,7 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s

ucc_coll_args_t coll = {
.mask = 0,
.flags = 0,
.coll_type = UCC_COLL_TYPE_ALLTOALL,
.src.info = {
.buffer = (void*)sbuf,
Expand All @@ -53,6 +55,10 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
if (blocking) {
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags |= UCC_COLL_ARGS_HINT_OPTIMIZE_LATENCY;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -70,6 +76,7 @@ int mca_coll_ucc_alltoall(const void *sbuf, size_t scount, struct ompi_datatype_
UCC_VERBOSE(3, "running ucc alltoall");
COLL_UCC_CHECK(mca_coll_ucc_alltoall_init(sbuf, scount, sdtype,
rbuf, rcount, rdtype,
true,
ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
Expand All @@ -94,6 +101,7 @@ int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_alltoall_init(sbuf, scount, sdtype,
rbuf, rcount, rdtype,
false,
ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
Expand Down
10 changes: 8 additions & 2 deletions ompi/mca/coll/ucc/coll_ucc_alltoallv.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, const i
const int *sdisps, struct ompi_datatype_t *sdtype,
void* rbuf, const int *rcounts, const int *rdisps,
struct ompi_datatype_t *rdtype,
bool blocking,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
Expand All @@ -31,6 +32,7 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, const i

ucc_coll_args_t coll = {
.mask = 0,
.flags = 0,
.coll_type = UCC_COLL_TYPE_ALLTOALLV,
.src.info_v = {
.buffer = (void*)sbuf,
Expand All @@ -52,6 +54,10 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, const i
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
if (blocking) {
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags |= UCC_COLL_ARGS_HINT_OPTIMIZE_LATENCY;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -72,7 +78,7 @@ int mca_coll_ucc_alltoallv(const void *sbuf, const int *scounts,

COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init(sbuf, scounts, sdisps, sdtype,
rbuf, rcounts, rdisps, rdtype,
ucc_module, &req, NULL));
true, ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
Expand All @@ -99,7 +105,7 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, const int *scounts,
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init(sbuf, scounts, sdisps, sdtype,
rbuf, rcounts, rdisps, rdtype,
ucc_module, &req, coll_req));
false, ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
Expand Down
11 changes: 8 additions & 3 deletions ompi/mca/coll/ucc/coll_ucc_barrier.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@

#include "coll_ucc_common.h"

static inline ucc_status_t mca_coll_ucc_barrier_init(mca_coll_ucc_module_t *ucc_module,
static inline ucc_status_t mca_coll_ucc_barrier_init(bool blocking,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
{
ucc_coll_args_t coll = {
.mask = 0,
.coll_type = UCC_COLL_TYPE_BARRIER
};
if (blocking) {
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_HINT_OPTIMIZE_LATENCY;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -29,7 +34,7 @@ int mca_coll_ucc_barrier(struct ompi_communicator_t *comm,
ucc_coll_req_h req;

UCC_VERBOSE(3, "running ucc barrier");
COLL_UCC_CHECK(mca_coll_ucc_barrier_init(ucc_module, &req, NULL));
COLL_UCC_CHECK(mca_coll_ucc_barrier_init(true, ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
Expand All @@ -48,7 +53,7 @@ int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm,

UCC_VERBOSE(3, "running ucc ibarrier");
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_barrier_init(ucc_module, &req, coll_req));
COLL_UCC_CHECK(mca_coll_ucc_barrier_init(false, ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
Expand Down
12 changes: 9 additions & 3 deletions ompi/mca/coll/ucc/coll_ucc_bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include "coll_ucc_common.h"

static inline ucc_status_t mca_coll_ucc_bcast_init(void *buf, size_t count, struct ompi_datatype_t *dtype,
int root, mca_coll_ucc_module_t *ucc_module,
int root,
bool blocking,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
{
Expand All @@ -30,6 +32,10 @@ static inline ucc_status_t mca_coll_ucc_bcast_init(void *buf, size_t count, stru
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
}
};
if (blocking) {
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_HINT_OPTIMIZE_LATENCY;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -43,7 +49,7 @@ int mca_coll_ucc_bcast(void *buf, size_t count, struct ompi_datatype_t *dtype,
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
ucc_coll_req_h req;
UCC_VERBOSE(3, "running ucc bcast");
COLL_UCC_CHECK(mca_coll_ucc_bcast_init(buf, count, dtype, root,
COLL_UCC_CHECK(mca_coll_ucc_bcast_init(buf, count, dtype, root, true,
ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
Expand All @@ -65,7 +71,7 @@ int mca_coll_ucc_ibcast(void *buf, size_t count, struct ompi_datatype_t *dtype,

UCC_VERBOSE(3, "running ucc ibcast");
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_bcast_init(buf, count, dtype, root,
COLL_UCC_CHECK(mca_coll_ucc_bcast_init(buf, count, dtype, root, false,
ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
Expand Down
1 change: 0 additions & 1 deletion ompi/mca/coll/ucc/coll_ucc_gatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, int scoun
{
ucc_datatype_t ucc_sdt, ucc_rdt;
int comm_rank = ompi_comm_rank(ucc_module->comm);
int comm_size = ompi_comm_size(ucc_module->comm);

ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
if (comm_rank == root) {
Expand Down
2 changes: 1 addition & 1 deletion ompi/mca/coll/ucc/coll_ucc_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module)
mca_coll_ucc_module_clear(ucc_module);
}

int mca_coll_ucc_progress(void)
static int mca_coll_ucc_progress(void)
{
ucc_context_progress(mca_coll_ucc_component.ucc_context);
return OPAL_SUCCESS;
Expand Down
12 changes: 9 additions & 3 deletions ompi/mca/coll/ucc/coll_ucc_reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
static inline ucc_status_t mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op, int root,
bool blocking,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
Expand All @@ -31,7 +32,8 @@ static inline ucc_status_t mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf
goto fallback;
}
ucc_coll_args_t coll = {
.mask = 0,
.mask = 0,
.flags = 0,
.coll_type = UCC_COLL_TYPE_REDUCE,
.root = root,
.src.info = {
Expand All @@ -52,6 +54,10 @@ static inline ucc_status_t mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
if (blocking) {
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags |= UCC_COLL_ARGS_HINT_OPTIMIZE_LATENCY;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -69,7 +75,7 @@ int mca_coll_ucc_reduce(const void *sbuf, void* rbuf, size_t count,

UCC_VERBOSE(3, "running ucc reduce");
COLL_UCC_CHECK(mca_coll_ucc_reduce_init(sbuf, rbuf, count, dtype, op,
root, ucc_module, &req, NULL));
root, true, ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
Expand All @@ -93,7 +99,7 @@ int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, size_t count,
UCC_VERBOSE(3, "running ucc ireduce");
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_reduce_init(sbuf, rbuf, count, dtype, op, root,
ucc_module, &req, coll_req));
false, ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
Expand Down

0 comments on commit 5c2dedb

Please sign in to comment.