Skip to content

Commit

Permalink
Sync IB changes with NCCL v2.21.5-1
Browse files Browse the repository at this point in the history
  • Loading branch information
bureddy committed Apr 15, 2024
1 parent 6aeb933 commit 3c80efc
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 21 deletions.
3 changes: 2 additions & 1 deletion include/p2p_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,14 @@ struct ncclIbRequest {
struct ncclIbGidInfo {
uint8_t link_layer;
union ibv_gid localGid;
int32_t localGidIndex;
};

typedef struct ncclIbNetCommDevBase {
int ibDevN;
struct ibv_pd* pd;
struct ibv_cq* cq;
uint64_t pad[1];
uint64_t pad[2];
struct ncclIbGidInfo gidInfo;
} ncclIbNetCommDevBase;

Expand Down
238 changes: 225 additions & 13 deletions src/ib_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ static int ncclNIbDevs = -1;
pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER;
int ncclIbRelaxedOrderingEnabled = 0;

NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", 0);
NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", -1);
NCCL_PARAM(IbRoceVersionNum, "IB_ROCE_VERSION_NUM", 2);
NCCL_PARAM(IbIsGlobal, "IB_IS_GLOBAL", 0);
NCCL_PARAM(IbTimeout, "IB_TIMEOUT", 18);
NCCL_PARAM(IbRetryCnt, "IB_RETRY_CNT", 7);
Expand Down Expand Up @@ -63,6 +64,211 @@ int ncclIbRelaxedOrderingCapable(void) {
return 1;
}

static sa_family_t envIbAddrFamily(void) {
sa_family_t family = AF_INET;
const char* env = ncclGetEnv("NCCL_IB_ADDR_FAMILY");
if (env == NULL || strlen(env) == 0) {
return family;
}

INFO(NCCL_ENV, "NCCL_IB_ADDR_FAMILY set by environment to %s", env);

if (strcmp(env, "AF_INET") == 0) {
family = AF_INET;
} else if (strcmp(env, "AF_INET6") == 0) {
family = AF_INET6;
}

return family;
}

static void* envIbAddrRange(sa_family_t af, int* mask) {
*mask = 0;
static struct in_addr addr;
static struct in6_addr addr6;
void *ret = (af == AF_INET) ? (void *)&addr : (void *)&addr6;

const char* env = ncclGetEnv("NCCL_IB_ADDR_RANGE");
if (NULL == env || strlen(env) == 0) {
return NULL;
}

INFO(NCCL_ENV, "NCCL_IB_ADDR_RANGE set by environment to %s", env);

char addrString[128] = { 0 };
snprintf(addrString, 128, "%s", env);
char *addrStrPtr = addrString;
char *maskStrPtr = strstr(addrString, "/") + 1;
if (NULL == maskStrPtr) {
return NULL;
}
*(maskStrPtr - 1) = '\0';

if (inet_pton(af, addrStrPtr, ret) == 0) {
WARN("NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6");
return NULL;
}

*mask = (int)strtol(maskStrPtr, NULL, 10);
if (af == AF_INET && *mask > 32) {
WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6");
*mask = 0;
ret = NULL;
} else if (af == AF_INET6 && *mask > 128) {
WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6");
*mask = 0;
ret = NULL;
}

return ret;
}

static sa_family_t getGidAddrFamily(union ibv_gid* gid) {
const struct in6_addr *a = (struct in6_addr *)gid->raw;
bool isIpV4Mapped = ((a->s6_addr32[0] | a->s6_addr32[1]) | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL;
bool isIpV4MappedMulticast = (a->s6_addr32[0] == htonl(0xff0e0000) && ((a->s6_addr32[1] | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL));
return (isIpV4Mapped || isIpV4MappedMulticast) ? AF_INET : AF_INET6;
}

static bool matchGidAddrPrefix(sa_family_t af, void* prefix, int prefixlen, union ibv_gid* gid) {
struct in_addr *base = NULL;
struct in6_addr *base6 = NULL;
struct in6_addr *addr6 = NULL;;
if (af == AF_INET) {
base = (struct in_addr *)prefix;
} else {
base6 = (struct in6_addr *)prefix;
}
addr6 = (struct in6_addr *)gid->raw;

#define NETMASK(bits) (htonl(0xffffffff ^ ((1 << (32 - bits)) - 1)))

int i = 0;
while (prefixlen > 0 && i < 4) {
if (af == AF_INET) {
int mask = NETMASK(prefixlen);
if ((base->s_addr & mask) ^ (addr6->s6_addr32[3] & mask)) {
break;
}
prefixlen = 0;
break;
} else {
if (prefixlen >= 32) {
if (base6->s6_addr32[i] ^ addr6->s6_addr32[i]) {
break;
}
prefixlen -= 32;
++i;
} else {
int mask = NETMASK(prefixlen);
if ((base6->s6_addr32[i] & mask) ^ (addr6->s6_addr32[i] & mask)) {
break;
}
prefixlen = 0;
}
}
}

return (prefixlen == 0) ? true : false;
}

static bool configuredGid(union ibv_gid* gid) {
const struct in6_addr *a = (struct in6_addr *)gid->raw;
int trailer = (a->s6_addr32[1] | a->s6_addr32[2] | a->s6_addr32[3]);
if (((a->s6_addr32[0] | trailer) == 0UL) || ((a->s6_addr32[0] == htonl(0xfe800000)) && (trailer == 0UL))) {
return false;
}
return true;
}

static bool linkLocalGid(union ibv_gid* gid) {
const struct in6_addr *a = (struct in6_addr *)gid->raw;
if (a->s6_addr32[0] == htonl(0xfe800000) && a->s6_addr32[1] == 0UL) {
return true;
}
return false;
}

static bool validGid(union ibv_gid* gid) {
return (configuredGid(gid) && !linkLocalGid(gid));
}

static ncclResult_t ncclIbRoceGetVersionNum(const char* deviceName, int portNum, int gidIndex, int* version) {
char gidRoceVerStr[16] = { 0 };
char roceTypePath[PATH_MAX] = { 0 };
sprintf(roceTypePath, "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex);

int fd = open(roceTypePath, O_RDONLY);
if (fd == -1) {
return ncclSystemError;
}
int ret = read(fd, gidRoceVerStr, 15);
close(fd);

if (ret == -1) {
return ncclSystemError;
}

if (strlen(gidRoceVerStr)) {
if (strncmp(gidRoceVerStr, "IB/RoCE v1", strlen("IB/RoCE v1")) == 0 || strncmp(gidRoceVerStr, "RoCE v1", strlen("RoCE v1")) == 0) {
*version = 1;
} else if (strncmp(gidRoceVerStr, "RoCE v2", strlen("RoCE v2")) == 0) {
*version = 2;
}
}

return ncclSuccess;
}

static ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t portNum, sa_family_t af, void* prefix, int prefixlen, int roceVer, int gidIndexCandidate, int* gidIndex) {
union ibv_gid gid, gidCandidate;
NCCLCHECK(wrap_ibv_query_gid(context, portNum, *gidIndex, &gid));
NCCLCHECK(wrap_ibv_query_gid(context, portNum, gidIndexCandidate, &gidCandidate));

sa_family_t usrFam = af;
sa_family_t gidFam = getGidAddrFamily(&gid);
sa_family_t gidCandidateFam = getGidAddrFamily(&gidCandidate);
bool gidCandidateMatchSubnet = matchGidAddrPrefix(usrFam, prefix, prefixlen, &gidCandidate);

if (gidCandidateFam != gidFam && gidCandidateFam == usrFam && gidCandidateMatchSubnet) {
*gidIndex = gidIndexCandidate;
} else {
if (gidCandidateFam != usrFam || !validGid(&gidCandidate) || !gidCandidateMatchSubnet) {
return ncclSuccess;
}
int usrRoceVer = roceVer;
int gidRoceVerNum, gidRoceVerNumCandidate;
const char* deviceName = wrap_ibv_get_device_name(context->device);
NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, *gidIndex, &gidRoceVerNum));
NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, gidIndexCandidate, &gidRoceVerNumCandidate));
if ((gidRoceVerNum != gidRoceVerNumCandidate || !validGid(&gid)) && gidRoceVerNumCandidate == usrRoceVer) {
*gidIndex = gidIndexCandidate;
}
}

return ncclSuccess;
}

static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, int gidTblLen, int *gidIndex) {
*gidIndex = ncclParamIbGidIndex();
if (*gidIndex >= 0) {
return ncclSuccess;
}

sa_family_t userAddrFamily = envIbAddrFamily();
int userRoceVersion = ncclParamIbRoceVersionNum();
int prefixlen;
void *prefix = envIbAddrRange(userAddrFamily, &prefixlen);

*gidIndex = 0;
for (int gidIndexNext = 1; gidIndexNext < gidTblLen; ++gidIndexNext) {
NCCLCHECK(ncclUpdateGidIndex(context, portNum, userAddrFamily, prefix, prefixlen, userRoceVersion, gidIndexNext, gidIndex));
}

return ncclSuccess;
}


NCCL_PARAM(IbDisable, "IBEXT_DISABLE", 0);
NCCL_PARAM(IbMergeVfs, "IB_MERGE_VFS", 1);
NCCL_PARAM(IbMergeNics, "IB_MERGE_NICS", 1);
Expand Down Expand Up @@ -373,7 +579,7 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
return ncclSuccess;
}

ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbDevInfo* info) {
ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint8_t sGidIndex, uint32_t dest_qp_num, struct ncclIbDevInfo* info) {
struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
qpAttr.qp_state = IBV_QPS_RTR;
Expand All @@ -392,7 +598,7 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbD
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->spn;
qpAttr.ah_attr.grh.dgid.global.interface_id = info->iid;
qpAttr.ah_attr.grh.flow_label = 0;
qpAttr.ah_attr.grh.sgid_index = ncclParamIbGidIndex();
qpAttr.ah_attr.grh.sgid_index = sGidIndex;
qpAttr.ah_attr.grh.hop_limit = 255;
qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc();
}
Expand Down Expand Up @@ -514,7 +720,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
);

if (devInfo->link_layer == IBV_LINK_LAYER_ETHERNET || devInfo->is_global) {
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &commDev->base.gidInfo.localGid));

NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &commDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid));
devInfo->spn = commDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo->iid = commDev->base.gidInfo.localGid.global.interface_id;
}
Expand All @@ -532,9 +740,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
// Print just the QPs for this dev
if (comm->base.qps[q].devIndex == i)
INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x",
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev,
commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, ncclParamIbGidIndex(),
devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey);
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev,
commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, (int64_t)commDev->base.gidInfo.localGidIndex,
devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey);
}
}
}
Expand Down Expand Up @@ -602,12 +810,15 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet

// Assign per-QP remDev
comm->base.qps[q].remDevIdx = remQpInfo->devIndex;
int devIndex = comm->base.qps[q].devIndex;
ncclIbSendCommDev* commDev = comm->devs + devIndex;
uint8_t gidIndex = commDev->base.gidInfo.localGidIndex;

struct ibv_qp* qp = comm->base.qps[q].qp;
if (remQpInfo->ece_supported && remQpInfo->ece_supported)
NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported));

NCCLCHECK(ncclIbRtrQp(qp, remQpInfo->qpn, remDevInfo));
NCCLCHECK(ncclIbRtrQp(qp, gidIndex, remQpInfo->qpn, remDevInfo));
NCCLCHECK(ncclIbRtsQp(qp));
}

Expand Down Expand Up @@ -707,7 +918,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
ibDevN = mergedDev->devs[i];
NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &rCommDev->base));
ibDev = ncclIbDevs + ibDevN;
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &rCommDev->base.gidInfo.localGid));
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &rCommDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, rCommDev->base.gidInfo.localGidIndex, &rCommDev->base.gidInfo.localGid));
}

// Copy remDevInfo for things like remGidInfo, remFifoAddr, etc.
Expand Down Expand Up @@ -745,7 +957,7 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
if (meta.qpInfo[q].ece_supported)
NCCLCHECK(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported));
}
NCCLCHECK(ncclIbRtrQp(qp->qp, remMeta.qpInfo[q].qpn, remDevInfo));
NCCLCHECK(ncclIbRtrQp(qp->qp, rCommDev->base.gidInfo.localGidIndex, remMeta.qpInfo[q].qpn, remDevInfo));
NCCLCHECK(ncclIbRtsQp(qp->qp));
}

Expand Down Expand Up @@ -783,8 +995,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
#endif
);
devInfo.mtu = ibDev->portAttr.active_mtu;
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo));
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->base.gidInfo.localGidIndex, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo));
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
}

// Fill Handle
Expand Down Expand Up @@ -1431,7 +1643,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
return ncclInternalError;
}
if (req->nreqs == 1) {
req->recv.sizes[0] += wc->imm_data;
req->recv.sizes[0] = wc->imm_data;
}
}
req->events[i]--;
Expand Down
17 changes: 10 additions & 7 deletions src/p2p_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ int ncclIbFindMatchingDev(int dev) {
ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction)
{
int ncclNIbDevs = *num_devs;

ncclResult_t ret;
pluginLogFunction = logFunction;
if (ncclNIbDevs == -1) {
pthread_mutex_lock(&nccl_p2p_lock);
Expand All @@ -287,7 +287,8 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
ncclNSharpDevs = 0;
if (ncclFindInterfaces(ncclIbIfName, ncclIbIfAddr, MAX_IF_NAME_SIZE, 1) != 1) {
WARN("NET/IB : No IP interface found.");
return ncclInternalError;
ret = ncclInternalError;
goto fail;
}

// Detect IB cards
Expand All @@ -302,7 +303,7 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
if (searchExact) userIbEnv++;
int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS);

if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) return ncclInternalError;
if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) { ret = ncclInternalError; goto fail; }

for (int d=0; d<nIbDevs; d++) {
struct ibv_context * context;
Expand All @@ -314,7 +315,7 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
struct ibv_device_attr devAttr;
if (ncclSuccess != wrap_ibv_query_device(context, &devAttr)) {
WARN("NET/IB : Unable to query device %s", devices[d]->name);
if (ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; }
if (ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; }
continue;
}
for (int port_num = 1; port_num <= devAttr.phys_port_cnt; port_num++) {
Expand Down Expand Up @@ -394,9 +395,9 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
ncclNIbDevs++;
nPorts++;
}
if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; }
if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; }
}
if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { return ncclInternalError; };
if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { ret = ncclInternalError; goto fail; };
}
if (ncclNIbDevs == 0) {
INFO(NCCL_INIT|NCCL_NET, "NET/IB : No device found.");
Expand Down Expand Up @@ -444,7 +445,9 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
pthread_mutex_unlock(&nccl_p2p_lock);
}
return ncclSuccess;

fail:
pthread_mutex_unlock(&nccl_p2p_lock);
return ret;
}

ncclResult_t nccl_p2p_ib_pci_path(ncclIbDev *devs, int num_devs, char* dev_name, char** path, int* real_port)
Expand Down

0 comments on commit 3c80efc

Please sign in to comment.