diff --git a/include/p2p_plugin.h b/include/p2p_plugin.h index b859777c..76a86d48 100644 --- a/include/p2p_plugin.h +++ b/include/p2p_plugin.h @@ -77,13 +77,14 @@ struct ncclIbRequest { struct ncclIbGidInfo { uint8_t link_layer; union ibv_gid localGid; + int32_t localGidIndex; }; typedef struct ncclIbNetCommDevBase { int ibDevN; struct ibv_pd* pd; struct ibv_cq* cq; - uint64_t pad[1]; + uint64_t pad[2]; struct ncclIbGidInfo gidInfo; } ncclIbNetCommDevBase; diff --git a/src/ib_plugin.c b/src/ib_plugin.c index 190932f8..8040f099 100644 --- a/src/ib_plugin.c +++ b/src/ib_plugin.c @@ -35,7 +35,8 @@ static int ncclNIbDevs = -1; pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER; int ncclIbRelaxedOrderingEnabled = 0; -NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", 0); +NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", -1); +NCCL_PARAM(IbRoceVersionNum, "IB_ROCE_VERSION_NUM", 2); NCCL_PARAM(IbIsGlobal, "IB_IS_GLOBAL", 0); NCCL_PARAM(IbTimeout, "IB_TIMEOUT", 18); NCCL_PARAM(IbRetryCnt, "IB_RETRY_CNT", 7); @@ -63,6 +64,211 @@ int ncclIbRelaxedOrderingCapable(void) { return 1; } +static sa_family_t envIbAddrFamily(void) { + sa_family_t family = AF_INET; + const char* env = ncclGetEnv("NCCL_IB_ADDR_FAMILY"); + if (env == NULL || strlen(env) == 0) { + return family; + } + + INFO(NCCL_ENV, "NCCL_IB_ADDR_FAMILY set by environment to %s", env); + + if (strcmp(env, "AF_INET") == 0) { + family = AF_INET; + } else if (strcmp(env, "AF_INET6") == 0) { + family = AF_INET6; + } + + return family; +} + +static void* envIbAddrRange(sa_family_t af, int* mask) { + *mask = 0; + static struct in_addr addr; + static struct in6_addr addr6; + void *ret = (af == AF_INET) ? (void *)&addr : (void *)&addr6; + + const char* env = ncclGetEnv("NCCL_IB_ADDR_RANGE"); + if (NULL == env || strlen(env) == 0) { + return NULL; + } + + INFO(NCCL_ENV, "NCCL_IB_ADDR_RANGE set by environment to %s", env); + + char addrString[128] = { 0 }; + snprintf(addrString, 128, "%s", env); + char *addrStrPtr = addrString; + char *maskStrPtr = strstr(addrString, "/") + 1; + if (NULL == maskStrPtr) { + return NULL; + } + *(maskStrPtr - 1) = '\0'; + + if (inet_pton(af, addrStrPtr, ret) == 0) { + WARN("NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + return NULL; + } + + *mask = (int)strtol(maskStrPtr, NULL, 10); + if (af == AF_INET && *mask > 32) { + WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + *mask = 0; + ret = NULL; + } else if (af == AF_INET6 && *mask > 128) { + WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + *mask = 0; + ret = NULL; + } + + return ret; +} + +static sa_family_t getGidAddrFamily(union ibv_gid* gid) { + const struct in6_addr *a = (struct in6_addr *)gid->raw; + bool isIpV4Mapped = ((a->s6_addr32[0] | a->s6_addr32[1]) | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL; + bool isIpV4MappedMulticast = (a->s6_addr32[0] == htonl(0xff0e0000) && ((a->s6_addr32[1] | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL)); + return (isIpV4Mapped || isIpV4MappedMulticast) ? AF_INET : AF_INET6; +} + +static bool matchGidAddrPrefix(sa_family_t af, void* prefix, int prefixlen, union ibv_gid* gid) { + struct in_addr *base = NULL; + struct in6_addr *base6 = NULL; + struct in6_addr *addr6 = NULL;; + if (af == AF_INET) { + base = (struct in_addr *)prefix; + } else { + base6 = (struct in6_addr *)prefix; + } + addr6 = (struct in6_addr *)gid->raw; + +#define NETMASK(bits) (htonl(0xffffffff ^ ((1 << (32 - bits)) - 1))) + + int i = 0; + while (prefixlen > 0 && i < 4) { + if (af == AF_INET) { + int mask = NETMASK(prefixlen); + if ((base->s_addr & mask) ^ (addr6->s6_addr32[3] & mask)) { + break; + } + prefixlen = 0; + break; + } else { + if (prefixlen >= 32) { + if (base6->s6_addr32[i] ^ addr6->s6_addr32[i]) { + break; + } + prefixlen -= 32; + ++i; + } else { + int mask = NETMASK(prefixlen); + if ((base6->s6_addr32[i] & mask) ^ (addr6->s6_addr32[i] & mask)) { + break; + } + prefixlen = 0; + } + } + } + + return (prefixlen == 0) ? true : false; +} + +static bool configuredGid(union ibv_gid* gid) { + const struct in6_addr *a = (struct in6_addr *)gid->raw; + int trailer = (a->s6_addr32[1] | a->s6_addr32[2] | a->s6_addr32[3]); + if (((a->s6_addr32[0] | trailer) == 0UL) || ((a->s6_addr32[0] == htonl(0xfe800000)) && (trailer == 0UL))) { + return false; + } + return true; +} + +static bool linkLocalGid(union ibv_gid* gid) { + const struct in6_addr *a = (struct in6_addr *)gid->raw; + if (a->s6_addr32[0] == htonl(0xfe800000) && a->s6_addr32[1] == 0UL) { + return true; + } + return false; +} + +static bool validGid(union ibv_gid* gid) { + return (configuredGid(gid) && !linkLocalGid(gid)); +} + +static ncclResult_t ncclIbRoceGetVersionNum(const char* deviceName, int portNum, int gidIndex, int* version) { + char gidRoceVerStr[16] = { 0 }; + char roceTypePath[PATH_MAX] = { 0 }; + sprintf(roceTypePath, "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex); + + int fd = open(roceTypePath, O_RDONLY); + if (fd == -1) { + return ncclSystemError; + } + int ret = read(fd, gidRoceVerStr, 15); + close(fd); + + if (ret == -1) { + return ncclSystemError; + } + + if (strlen(gidRoceVerStr)) { + if (strncmp(gidRoceVerStr, "IB/RoCE v1", strlen("IB/RoCE v1")) == 0 || strncmp(gidRoceVerStr, "RoCE v1", strlen("RoCE v1")) == 0) { + *version = 1; + } else if (strncmp(gidRoceVerStr, "RoCE v2", strlen("RoCE v2")) == 0) { + *version = 2; + } + } + + return ncclSuccess; +} + +static ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t portNum, sa_family_t af, void* prefix, int prefixlen, int roceVer, int gidIndexCandidate, int* gidIndex) { + union ibv_gid gid, gidCandidate; + NCCLCHECK(wrap_ibv_query_gid(context, portNum, *gidIndex, &gid)); + NCCLCHECK(wrap_ibv_query_gid(context, portNum, gidIndexCandidate, &gidCandidate)); + + sa_family_t usrFam = af; + sa_family_t gidFam = getGidAddrFamily(&gid); + sa_family_t gidCandidateFam = getGidAddrFamily(&gidCandidate); + bool gidCandidateMatchSubnet = matchGidAddrPrefix(usrFam, prefix, prefixlen, &gidCandidate); + + if (gidCandidateFam != gidFam && gidCandidateFam == usrFam && gidCandidateMatchSubnet) { + *gidIndex = gidIndexCandidate; + } else { + if (gidCandidateFam != usrFam || !validGid(&gidCandidate) || !gidCandidateMatchSubnet) { + return ncclSuccess; + } + int usrRoceVer = roceVer; + int gidRoceVerNum, gidRoceVerNumCandidate; + const char* deviceName = wrap_ibv_get_device_name(context->device); + NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, *gidIndex, &gidRoceVerNum)); + NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, gidIndexCandidate, &gidRoceVerNumCandidate)); + if ((gidRoceVerNum != gidRoceVerNumCandidate || !validGid(&gid)) && gidRoceVerNumCandidate == usrRoceVer) { + *gidIndex = gidIndexCandidate; + } + } + + return ncclSuccess; +} + +static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, int gidTblLen, int *gidIndex) { + *gidIndex = ncclParamIbGidIndex(); + if (*gidIndex >= 0) { + return ncclSuccess; + } + + sa_family_t userAddrFamily = envIbAddrFamily(); + int userRoceVersion = ncclParamIbRoceVersionNum(); + int prefixlen; + void *prefix = envIbAddrRange(userAddrFamily, &prefixlen); + + *gidIndex = 0; + for (int gidIndexNext = 1; gidIndexNext < gidTblLen; ++gidIndexNext) { + NCCLCHECK(ncclUpdateGidIndex(context, portNum, userAddrFamily, prefix, prefixlen, userRoceVersion, gidIndexNext, gidIndex)); + } + + return ncclSuccess; +} + + NCCL_PARAM(IbDisable, "IBEXT_DISABLE", 0); NCCL_PARAM(IbMergeVfs, "IB_MERGE_VFS", 1); NCCL_PARAM(IbMergeNics, "IB_MERGE_NICS", 1); @@ -373,7 +579,7 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, return ncclSuccess; } -ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbDevInfo* info) { +ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint8_t sGidIndex, uint32_t dest_qp_num, struct ncclIbDevInfo* info) { struct ibv_qp_attr qpAttr; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); qpAttr.qp_state = IBV_QPS_RTR; @@ -392,7 +598,7 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbD qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->spn; qpAttr.ah_attr.grh.dgid.global.interface_id = info->iid; qpAttr.ah_attr.grh.flow_label = 0; - qpAttr.ah_attr.grh.sgid_index = ncclParamIbGidIndex(); + qpAttr.ah_attr.grh.sgid_index = sGidIndex; qpAttr.ah_attr.grh.hop_limit = 255; qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc(); } @@ -514,7 +720,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet ); if (devInfo->link_layer == IBV_LINK_LAYER_ETHERNET || devInfo->is_global) { - NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &commDev->base.gidInfo.localGid)); + + NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &commDev->base.gidInfo.localGidIndex)); + NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid)); devInfo->spn = commDev->base.gidInfo.localGid.global.subnet_prefix; devInfo->iid = commDev->base.gidInfo.localGid.global.interface_id; } @@ -532,9 +740,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet // Print just the QPs for this dev if (comm->base.qps[q].devIndex == i) INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x", - comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, - commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, ncclParamIbGidIndex(), - devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey); + comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, + commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, (int64_t)commDev->base.gidInfo.localGidIndex, + devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey); } } } @@ -602,12 +810,15 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet // Assign per-QP remDev comm->base.qps[q].remDevIdx = remQpInfo->devIndex; + int devIndex = comm->base.qps[q].devIndex; + ncclIbSendCommDev* commDev = comm->devs + devIndex; + uint8_t gidIndex = commDev->base.gidInfo.localGidIndex; struct ibv_qp* qp = comm->base.qps[q].qp; if (remQpInfo->ece_supported && remQpInfo->ece_supported) NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported)); - NCCLCHECK(ncclIbRtrQp(qp, remQpInfo->qpn, remDevInfo)); + NCCLCHECK(ncclIbRtrQp(qp, gidIndex, remQpInfo->qpn, remDevInfo)); NCCLCHECK(ncclIbRtsQp(qp)); } @@ -707,7 +918,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl ibDevN = mergedDev->devs[i]; NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &rCommDev->base)); ibDev = ncclIbDevs + ibDevN; - NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &rCommDev->base.gidInfo.localGid)); + NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &rCommDev->base.gidInfo.localGidIndex)); + NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, rCommDev->base.gidInfo.localGidIndex, &rCommDev->base.gidInfo.localGid)); } // Copy remDevInfo for things like remGidInfo, remFifoAddr, etc. @@ -745,7 +957,7 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl if (meta.qpInfo[q].ece_supported) NCCLCHECK(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported)); } - NCCLCHECK(ncclIbRtrQp(qp->qp, remMeta.qpInfo[q].qpn, remDevInfo)); + NCCLCHECK(ncclIbRtrQp(qp->qp, rCommDev->base.gidInfo.localGidIndex, remMeta.qpInfo[q].qpn, remDevInfo)); NCCLCHECK(ncclIbRtsQp(qp->qp)); } @@ -783,8 +995,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl #endif ); devInfo.mtu = ibDev->portAttr.active_mtu; - NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo)); - NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp)); + NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->base.gidInfo.localGidIndex, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo)); + NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp)); } // Fill Handle @@ -1431,7 +1643,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { return ncclInternalError; } if (req->nreqs == 1) { - req->recv.sizes[0] += wc->imm_data; + req->recv.sizes[0] = wc->imm_data; } } req->events[i]--; diff --git a/src/p2p_plugin.c b/src/p2p_plugin.c index 1663d85e..f91a424a 100644 --- a/src/p2p_plugin.c +++ b/src/p2p_plugin.c @@ -276,7 +276,7 @@ int ncclIbFindMatchingDev(int dev) { ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction) { int ncclNIbDevs = *num_devs; - + ncclResult_t ret; pluginLogFunction = logFunction; if (ncclNIbDevs == -1) { pthread_mutex_lock(&nccl_p2p_lock); @@ -287,7 +287,8 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb ncclNSharpDevs = 0; if (ncclFindInterfaces(ncclIbIfName, ncclIbIfAddr, MAX_IF_NAME_SIZE, 1) != 1) { WARN("NET/IB : No IP interface found."); - return ncclInternalError; + ret = ncclInternalError; + goto fail; } // Detect IB cards @@ -302,7 +303,7 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb if (searchExact) userIbEnv++; int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS); - if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) return ncclInternalError; + if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) { ret = ncclInternalError; goto fail; } for (int d=0; dname); - if (ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; } + if (ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; } continue; } for (int port_num = 1; port_num <= devAttr.phys_port_cnt; port_num++) { @@ -394,9 +395,9 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb ncclNIbDevs++; nPorts++; } - if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; } + if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; } } - if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { return ncclInternalError; }; + if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { ret = ncclInternalError; goto fail; }; } if (ncclNIbDevs == 0) { INFO(NCCL_INIT|NCCL_NET, "NET/IB : No device found."); @@ -444,7 +445,9 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb pthread_mutex_unlock(&nccl_p2p_lock); } return ncclSuccess; - +fail: + pthread_mutex_unlock(&nccl_p2p_lock); + return ret; } ncclResult_t nccl_p2p_ib_pci_path(ncclIbDev *devs, int num_devs, char* dev_name, char** path, int* real_port)