From 42eb40e2cd69247b81ddd8fc3792968bab78c5c5 Mon Sep 17 00:00:00 2001 From: Yael Yacobovich Date: Mon, 10 Feb 2025 17:45:49 +0200 Subject: [PATCH] TL/UCP: Allow self copy in allgather using network loopback --- src/components/tl/ucp/allgather/allgather.c | 33 +++++++++ src/components/tl/ucp/allgather/allgather.h | 11 +++ .../tl/ucp/allgather/allgather_knomial.c | 74 ++++++++++++------- .../tl/ucp/allgather/allgather_neighbor.c | 34 ++++++--- .../tl/ucp/allgather/allgather_ring.c | 36 +++++---- .../tl/ucp/allgather/allgather_sparbit.c | 8 +- src/components/tl/ucp/tl_ucp.c | 14 ++-- src/components/tl/ucp/tl_ucp.h | 1 + src/components/tl/ucp/tl_ucp_coll.h | 4 + test/gtest/coll/test_allgather.cc | 22 +++--- 10 files changed, 167 insertions(+), 70 deletions(-) diff --git a/src/components/tl/ucp/allgather/allgather.c b/src/components/tl/ucp/allgather/allgather.c index 769c4fb981..cc790a1ab4 100644 --- a/src/components/tl/ucp/allgather/allgather.c +++ b/src/components/tl/ucp/allgather/allgather.c @@ -58,3 +58,36 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team) UCC_TL_UCP_ALLGATHER_DEFAULT_ALG_SELECT_STR, algo_num); return str; } + +ucc_status_t loopback_self_copy(void *rbuf, void *sbuf, size_t data_size, + ucc_memory_type_t rmem, ucc_memory_type_t smem, + ucc_rank_t rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task) +{ + ucc_status_t status; + status = ucc_tl_ucp_send_nb(sbuf, data_size, smem, rank, team, task); + if (UCC_OK != status) { + task->super.status = status; + return task->super.status; + } + status = ucc_tl_ucp_recv_nb(rbuf, data_size, rmem, rank, team, task); + if (UCC_OK != status) { + task->super.status = status; + return task->super.status; + } + return UCC_OK; +} +ucc_status_t allgather_copy(void *rbuf, void *sbuf, size_t data_size, + ucc_memory_type_t rmem, ucc_memory_type_t smem, + ucc_rank_t rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task) +{ + ucc_status_t status; + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; + if (use_loopback) { + status = loopback_self_copy(rbuf, sbuf, data_size, rmem, smem, rank, team, task); + } else { + status = ucc_mc_memcpy(rbuf, sbuf, data_size, rmem, smem); + } + return status; +} diff --git a/src/components/tl/ucp/allgather/allgather.h b/src/components/tl/ucp/allgather/allgather.h index 61733a4ab7..58255a5891 100644 --- a/src/components/tl/ucp/allgather/allgather.h +++ b/src/components/tl/ucp/allgather/allgather.h @@ -7,6 +7,7 @@ #define ALLGATHER_H_ #include "../tl_ucp.h" #include "../tl_ucp_coll.h" +#include "tl_ucp_sendrecv.h" enum { UCC_TL_UCP_ALLGATHER_ALG_KNOMIAL, @@ -38,6 +39,16 @@ static inline int ucc_tl_ucp_allgather_alg_from_str(const char *str) ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task); +ucc_status_t loopback_self_copy(void *rbuf, void *sbuf, size_t data_size, + ucc_memory_type_t rmem, ucc_memory_type_t smem, + ucc_rank_t rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task); + +ucc_status_t allgather_copy(void *rbuf, void *sbuf, size_t data_size, + ucc_memory_type_t rmem, ucc_memory_type_t smem, + ucc_rank_t rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task); + /* Ring */ ucc_status_t ucc_tl_ucp_allgather_ring_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, diff --git a/src/components/tl/ucp/allgather/allgather_knomial.c b/src/components/tl/ucp/allgather/allgather_knomial.c index 1fbcf773cc..1d93735fb3 100644 --- a/src/components/tl/ucp/allgather/allgather_knomial.c +++ b/src/components/tl/ucp/allgather/allgather_knomial.c @@ -13,6 +13,7 @@ #include "coll_patterns/sra_knomial.h" #include "utils/ucc_math.h" #include "utils/ucc_coll_utils.h" +#include "allgather.h" #define SAVE_STATE(_phase) \ do { \ @@ -54,22 +55,21 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, - ucc_tl_ucp_task_t); - ucc_coll_args_t *args = &TASK_ARGS(task); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_kn_radix_t radix = task->allgather_kn.p.radix; + ucc_tl_ucp_task_t * task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_coll_args_t * args = &TASK_ARGS(task); + ucc_tl_ucp_team_t * team = TASK_TEAM(task); + ucc_kn_radix_t radix = task->allgather_kn.p.radix; uint8_t node_type = task->allgather_kn.p.node_type; ucc_knomial_pattern_t *p = &task->allgather_kn.p; - void *rbuf = GET_DST(args); + void * rbuf = GET_DST(args); ucc_memory_type_t mem_type = GET_MT(args); size_t dt_size = ucc_dt_size(GET_DT(args)); ucc_rank_t size = task->subset.map.ep_num; size_t data_size = GET_TOTAL_COUNT(args, size); - ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ? - args->root : 0; - ucc_rank_t rank = VRANK(task->subset.myrank, broot, size); - size_t local = GET_LOCAL_COUNT(args, size, rank); + ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ? args->root : 0; + ucc_rank_t rank = VRANK(task->subset.myrank, broot, size); + size_t local = GET_LOCAL_COUNT(args, size, rank); + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; void *sbuf; ptrdiff_t peer_seg_offset, local_seg_offset; ucc_rank_t peer, peer_dist; @@ -78,8 +78,14 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) ucc_status_t status; size_t extra_count; - EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test", - task->allgather_kn.etask); + if (use_loopback) { + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + return; + } + } else { + EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test", + task->allgather_kn.etask); + } task->allgather_kn.etask = NULL; UCC_KN_GOTO_PHASE(task->allgather_kn.phase); if (KN_NODE_EXTRA == node_type) { @@ -209,6 +215,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task) ct == UCC_COLL_TYPE_BCAST ? args->root : 0, size); ucc_ee_executor_task_args_t eargs = {0}; + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; ucc_status_t status; ptrdiff_t offset; ucc_ee_executor_t *exec; @@ -225,21 +232,34 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task) ucc_dt_size(args->dst.info.datatype); rbuf = args->dst.info.buffer; if (!UCC_IS_INPLACE(*args)) { - status = ucc_coll_task_get_executor(&task->super, &exec); - if (ucc_unlikely(status != UCC_OK)) { - task->super.status = status; - return status; - } - eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; - eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset); - eargs.copy.src = args->src.info.buffer; - eargs.copy.len = args->src.info.count * - ucc_dt_size(args->src.info.datatype); - status = ucc_ee_executor_task_post(exec, &eargs, - &task->allgather_kn.etask); - if (ucc_unlikely(status != UCC_OK)) { - task->super.status = status; - return status; + if (use_loopback) { + status = loopback_self_copy( + PTR_OFFSET(args->dst.info.buffer, offset), + args->src.info.buffer, + args->src.info.count * ucc_dt_size(args->src.info.datatype), + args->dst.info.mem_type, args->src.info.mem_type, rank, + team, task); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + } else { + /* Executer */ + status = ucc_coll_task_get_executor(&task->super, &exec); + if (ucc_unlikely(status != UCC_OK)) { + task->super.status = status; + return status; + } + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset); + eargs.copy.src = args->src.info.buffer; + eargs.copy.len = + args->src.info.count * ucc_dt_size(args->src.info.datatype); + status = ucc_ee_executor_task_post(exec, &eargs, + &task->allgather_kn.etask); + if (ucc_unlikely(status != UCC_OK)) { + task->super.status = status; + return status; + } } } } else if (ct == UCC_COLL_TYPE_ALLGATHERV) { diff --git a/src/components/tl/ucp/allgather/allgather_neighbor.c b/src/components/tl/ucp/allgather/allgather_neighbor.c index 534c197e4e..bcc88f113e 100644 --- a/src/components/tl/ucp/allgather/allgather_neighbor.c +++ b/src/components/tl/ucp/allgather/allgather_neighbor.c @@ -12,6 +12,7 @@ #include "utils/ucc_coll_utils.h" #include "components/mc/ucc_mc.h" + static ucc_rank_t get_recv_from_rank(ucc_rank_t rank, ucc_rank_t size, int i) { const int i_parity = i % 2; @@ -81,9 +82,11 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *coll_task) ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; size_t count = TASK_ARGS(task).dst.info.count; size_t data_size = (count / tsize) * ucc_dt_size(dt); + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; ucc_rank_t neighbors[2], i; int i_parity, even_rank; void *tmprecv, *tmpsend; + int counter; if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; @@ -98,8 +101,13 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *coll_task) neighbors[1] = (trank + 1) % tsize; } - while (task->tagged.send_posted < (tsize / 2)) { - i = task->tagged.send_posted; + if ((!UCC_IS_INPLACE(TASK_ARGS(task))) && use_loopback) { + counter = task->tagged.send_posted - 1; + } else { + counter = task->tagged.send_posted; + } + while (counter < (tsize / 2)) { + i = counter; i_parity = i % 2; tmprecv = @@ -114,10 +122,15 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *coll_task) UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(tmprecv, 2 * data_size, rmem, neighbors[i_parity], team, task), task, out); - + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } + if ((!UCC_IS_INPLACE(TASK_ARGS(task))) && use_loopback) { + counter = task->tagged.send_posted - 1; + } else { + counter = task->tagged.send_posted; + } } ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); @@ -133,8 +146,8 @@ ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *coll_task) ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_tl_ucp_team_t *team = TASK_TEAM(task); size_t count = TASK_ARGS(task).dst.info.count; - void *sbuf = TASK_ARGS(task).src.info.buffer; - void *rbuf = TASK_ARGS(task).dst.info.buffer; + void * sbuf = TASK_ARGS(task).src.info.buffer; + void * rbuf = TASK_ARGS(task).dst.info.buffer; ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type; ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; @@ -144,19 +157,20 @@ ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *coll_task) ucc_status_t status; ucc_rank_t neighbor; void *tmprecv, *tmpsend; - + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_neighbor_start", 0); ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); if (!UCC_IS_INPLACE(TASK_ARGS(task))) { - status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * trank), sbuf, - data_size, rmem, smem); + status = allgather_copy(PTR_OFFSET(rbuf, data_size * trank), sbuf, + data_size, rmem, smem, trank, team, task); if (ucc_unlikely(UCC_OK != status)) { return status; } } - + while (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + } if (trank % 2) { neighbor = (trank - 1 + tsize) % tsize; } else { @@ -173,6 +187,8 @@ ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *coll_task) UCPCHECK_GOTO( ucc_tl_ucp_recv_nb(tmprecv, data_size, rmem, neighbor, team, task), task, out); + + out: return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } diff --git a/src/components/tl/ucp/allgather/allgather_ring.c b/src/components/tl/ucp/allgather/allgather_ring.c index 07178aea25..059d9a1bd4 100644 --- a/src/components/tl/ucp/allgather/allgather_ring.c +++ b/src/components/tl/ucp/allgather/allgather_ring.c @@ -31,15 +31,16 @@ static ucc_rank_t ucc_tl_ucp_allgather_ring_get_recv_block(ucc_subset_t *subset, void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t trank = task->subset.myrank; - ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; - void *rbuf = TASK_ARGS(task).dst.info.buffer; - ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; - size_t count = TASK_ARGS(task).dst.info.count; - ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; - size_t data_size = (count / tsize) * ucc_dt_size(dt); + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = task->subset.myrank; + ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; + void * rbuf = TASK_ARGS(task).dst.info.buffer; + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + size_t count = TASK_ARGS(task).dst.info.count; + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + size_t data_size = (count / tsize) * ucc_dt_size(dt); + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; ucc_rank_t sendto, recvfrom, sblock, rblock; int step; void *buf; @@ -49,9 +50,10 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) } sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize); recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize); + step = + use_loopback ? task->tagged.send_posted - 1 : task->tagged.send_posted; - while (task->tagged.send_posted < tsize - 1) { - step = task->tagged.send_posted; + while (step < tsize - 1) { sblock = task->allgather_ring.get_send_block(&task->subset, trank, tsize, step); rblock = task->allgather_ring.get_recv_block(&task->subset, trank, @@ -67,6 +69,8 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } + step = use_loopback ? task->tagged.send_posted - 1 + : task->tagged.send_posted; } ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); task->super.status = UCC_OK; @@ -79,13 +83,14 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task) ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_tl_ucp_team_t *team = TASK_TEAM(task); size_t count = TASK_ARGS(task).dst.info.count; - void *sbuf = TASK_ARGS(task).src.info.buffer; - void *rbuf = TASK_ARGS(task).dst.info.buffer; + void * sbuf = TASK_ARGS(task).src.info.buffer; + void * rbuf = TASK_ARGS(task).dst.info.buffer; ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type; ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; ucc_rank_t trank = task->subset.myrank; ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; + ucc_rank_t rank = ucc_ep_map_eval(task->subset.map, trank); size_t data_size = (count / tsize) * ucc_dt_size(dt); ucc_status_t status; ucc_rank_t block; @@ -96,13 +101,12 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task) if (!UCC_IS_INPLACE(TASK_ARGS(task))) { block = task->allgather_ring.get_send_block(&task->subset, trank, tsize, 0); - status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block), - sbuf, data_size, rmem, smem); + status = allgather_copy(PTR_OFFSET(rbuf, data_size * block), sbuf, + data_size, rmem, smem, rank, team, task); if (ucc_unlikely(UCC_OK != status)) { return status; } } - return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } diff --git a/src/components/tl/ucp/allgather/allgather_sparbit.c b/src/components/tl/ucp/allgather/allgather_sparbit.c index 0edfc4d4a3..f453d5ab5b 100644 --- a/src/components/tl/ucp/allgather/allgather_sparbit.c +++ b/src/components/tl/ucp/allgather/allgather_sparbit.c @@ -114,8 +114,8 @@ ucc_status_t ucc_tl_ucp_allgather_sparbit_start(ucc_coll_task_t *coll_task) ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_tl_ucp_team_t *team = TASK_TEAM(task); size_t count = TASK_ARGS(task).dst.info.count; - void *sbuf = TASK_ARGS(task).src.info.buffer; - void *rbuf = TASK_ARGS(task).dst.info.buffer; + void * sbuf = TASK_ARGS(task).src.info.buffer; + void * rbuf = TASK_ARGS(task).dst.info.buffer; ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type; ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; @@ -131,8 +131,8 @@ ucc_status_t ucc_tl_ucp_allgather_sparbit_start(ucc_coll_task_t *coll_task) task->allgather_sparbit.data_expected = 1; if (!UCC_IS_INPLACE(TASK_ARGS(task))) { - status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * trank), sbuf, - data_size, rmem, smem); + status = allgather_copy(PTR_OFFSET(rbuf, data_size * trank), sbuf, + data_size, rmem, smem, trank, team, task); if (ucc_unlikely(UCC_OK != status)) { return status; } diff --git a/src/components/tl/ucp/tl_ucp.c b/src/components/tl/ucp/tl_ucp.c index 7db99bdaf2..a0e49de3c9 100644 --- a/src/components/tl/ucp/tl_ucp.c +++ b/src/components/tl/ucp/tl_ucp.c @@ -48,7 +48,7 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = { ucc_offsetof(ucc_tl_ucp_lib_config_t, alltoallv_pairwise_num_posts), UCC_CONFIG_TYPE_ULUNITS}, -/* TODO: add radix to config once it's fully supported by the algorithm + /* TODO: add radix to config once it's fully supported by the algorithm {"ALLTOALLV_HYBRID_RADIX", "2", "Radix of the Hybrid Alltoallv algorithm", ucc_offsetof(ucc_tl_ucp_lib_config_t, alltoallv_hybrid_radix), @@ -140,6 +140,12 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = { ucc_offsetof(ucc_tl_ucp_lib_config_t, allgather_kn_radix), UCC_CONFIG_TYPE_UINT}, + {"ALLGATHER_USE_LOOPBACK", "0", + "If set to 1 performs network loopback for self copy, otherwise uses mc " + "cuda copy", + ucc_offsetof(ucc_tl_ucp_lib_config_t, allgather_use_loopback), + UCC_CONFIG_TYPE_BOOL}, + {"BCAST_KN_RADIX", "4", "Radix of the recursive-knomial bcast algorithm", ucc_offsetof(ucc_tl_ucp_lib_config_t, bcast_kn_radix), UCC_CONFIG_TYPE_UINT}, @@ -196,10 +202,8 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = { ucc_offsetof(ucc_tl_ucp_lib_config_t, reduce_scatterv_ring_bidirectional), UCC_CONFIG_TYPE_BOOL}, - {"USE_TOPO", "try", - "Allow usage of tl ucp topo", - ucc_offsetof(ucc_tl_ucp_lib_config_t, use_topo), - UCC_CONFIG_TYPE_TERNARY}, + {"USE_TOPO", "try", "Allow usage of tl ucp topo", + ucc_offsetof(ucc_tl_ucp_lib_config_t, use_topo), UCC_CONFIG_TYPE_TERNARY}, {"RANKS_REORDERING", "y", "Use topology information in TL UCP to reorder ranks. Requires topo info", diff --git a/src/components/tl/ucp/tl_ucp.h b/src/components/tl/ucp/tl_ucp.h index 3c439f4ae5..6d31c5aead 100644 --- a/src/components/tl/ucp/tl_ucp.h +++ b/src/components/tl/ucp/tl_ucp.h @@ -55,6 +55,7 @@ typedef struct ucc_tl_ucp_lib_config { ucc_mrange_uint_t allreduce_sra_kn_radix; uint32_t reduce_scatter_kn_radix; uint32_t allgather_kn_radix; + int allgather_use_loopback; uint32_t bcast_kn_radix; ucc_mrange_uint_t bcast_sag_kn_radix; uint32_t reduce_kn_radix; diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 2769244d39..93e511ca0a 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -201,11 +201,15 @@ typedef struct ucc_tl_ucp_task { struct { ucc_mc_buffer_header_t *scratch_header; size_t scratch_size; + int phase; } allgather_bruck; struct { uint32_t i; int data_expected; } allgather_sparbit; + struct { + int phase; + } allgather_neighbor; struct { ucc_rank_t dist; uint32_t radix; diff --git a/test/gtest/coll/test_allgather.cc b/test/gtest/coll/test_allgather.cc index c48bb8303d..0c0a3dba6c 100644 --- a/test/gtest/coll/test_allgather.cc +++ b/test/gtest/coll/test_allgather.cc @@ -9,7 +9,7 @@ using Param_0 = std::tuple; using Param_1 = std::tuple; -using Param_2 = std::tuple; +using Param_2 = std::tuple; class test_allgather : public UccCollArgs, public ucc::test { @@ -259,16 +259,18 @@ class test_allgather_alg : public test_allgather, UCC_TEST_P(test_allgather_alg, alg) { - const ucc_datatype_t dtype = std::get<0>(GetParam()); - const ucc_memory_type_t mem_type = std::get<1>(GetParam()); - const int count = std::get<2>(GetParam()); - const gtest_ucc_inplace_t inplace = std::get<3>(GetParam()); - int n_procs = 5; + const ucc_datatype_t dtype = std::get<0>(GetParam()); + const ucc_memory_type_t mem_type = std::get<1>(GetParam()); + const int count = std::get<2>(GetParam()); + const gtest_ucc_inplace_t inplace = std::get<3>(GetParam()); + int n_procs = 5; char tune[32]; + std::string use_loopback = std::get<5>(GetParam()); sprintf(tune, "allgather:@%s:inf", std::get<4>(GetParam()).c_str()); ucc_job_env_t env = {{"UCC_CL_BASIC_TUNE", "inf"}, - {"UCC_TL_UCP_TUNE", tune}}; + {"UCC_TL_UCP_TUNE", tune}, + {"UCC_TL_UCP_ALLGATHER_USE_LOOPBACK", use_loopback}}; UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); UccTeam_h team = job.create_team(n_procs); UccCollCtxVec ctxs; @@ -296,7 +298,8 @@ INSTANTIATE_TEST_CASE_P( #endif ::testing::Values(1,3,8192), // count ::testing::Values(TEST_INPLACE, TEST_NO_INPLACE), - ::testing::Values("knomial", "ring", "neighbor", "bruck", "sparbit")), + ::testing::Values("knomial", "ring", "neighbor", "bruck", "sparbit"), + ::testing::Values("1", "0")), [](const testing::TestParamInfo& info) { std::string name; name += ucc_datatype_str(std::get<0>(info.param)); @@ -304,5 +307,6 @@ INSTANTIATE_TEST_CASE_P( name += std::string("_count_")+std::to_string(std::get<2>(info.param)); name += std::string("_inplace_")+std::to_string(std::get<3>(info.param)); name += std::string("_")+std::get<4>(info.param); + name += std::string("_use_loopback_")+std::get<5>(info.param); return name; - }); + }); \ No newline at end of file