Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync with NCCL v2.21.5-1 #153

Merged
merged 2 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/p2p_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,14 @@ struct ncclIbRequest {
struct ncclIbGidInfo {
uint8_t link_layer;
union ibv_gid localGid;
int32_t localGidIndex;
};

typedef struct ncclIbNetCommDevBase {
int ibDevN;
struct ibv_pd* pd;
struct ibv_cq* cq;
uint64_t pad[1];
uint64_t pad[2];
struct ncclIbGidInfo gidInfo;
} ncclIbNetCommDevBase;

Expand Down
249 changes: 230 additions & 19 deletions src/ib_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ static int ncclNIbDevs = -1;
pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER;
int ncclIbRelaxedOrderingEnabled = 0;

NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", 0);
NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", -1);
NCCL_PARAM(IbRoceVersionNum, "IB_ROCE_VERSION_NUM", 2);
NCCL_PARAM(IbIsGlobal, "IB_IS_GLOBAL", 0);
NCCL_PARAM(IbTimeout, "IB_TIMEOUT", 18);
NCCL_PARAM(IbRetryCnt, "IB_RETRY_CNT", 7);
Expand Down Expand Up @@ -63,6 +64,211 @@ int ncclIbRelaxedOrderingCapable(void) {
return 1;
}

static sa_family_t envIbAddrFamily(void) {
sa_family_t family = AF_INET;
const char* env = ncclGetEnv("NCCL_IB_ADDR_FAMILY");
if (env == NULL || strlen(env) == 0) {
return family;
}

INFO(NCCL_ENV, "NCCL_IB_ADDR_FAMILY set by environment to %s", env);

if (strcmp(env, "AF_INET") == 0) {
family = AF_INET;
} else if (strcmp(env, "AF_INET6") == 0) {
family = AF_INET6;
}

return family;
}

static void* envIbAddrRange(sa_family_t af, int* mask) {
*mask = 0;
static struct in_addr addr;
static struct in6_addr addr6;
void *ret = (af == AF_INET) ? (void *)&addr : (void *)&addr6;

const char* env = ncclGetEnv("NCCL_IB_ADDR_RANGE");
if (NULL == env || strlen(env) == 0) {
return NULL;
}

INFO(NCCL_ENV, "NCCL_IB_ADDR_RANGE set by environment to %s", env);

char addrString[128] = { 0 };
snprintf(addrString, 128, "%s", env);
char *addrStrPtr = addrString;
char *maskStrPtr = strstr(addrString, "/") + 1;
if (NULL == maskStrPtr) {
return NULL;
}
*(maskStrPtr - 1) = '\0';

if (inet_pton(af, addrStrPtr, ret) == 0) {
WARN("NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6");
return NULL;
}

*mask = (int)strtol(maskStrPtr, NULL, 10);
if (af == AF_INET && *mask > 32) {
WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6");
*mask = 0;
ret = NULL;
} else if (af == AF_INET6 && *mask > 128) {
WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6");
*mask = 0;
ret = NULL;
}

return ret;
}

static sa_family_t getGidAddrFamily(union ibv_gid* gid) {
const struct in6_addr *a = (struct in6_addr *)gid->raw;
bool isIpV4Mapped = ((a->s6_addr32[0] | a->s6_addr32[1]) | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL;
bool isIpV4MappedMulticast = (a->s6_addr32[0] == htonl(0xff0e0000) && ((a->s6_addr32[1] | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL));
return (isIpV4Mapped || isIpV4MappedMulticast) ? AF_INET : AF_INET6;
}

static bool matchGidAddrPrefix(sa_family_t af, void* prefix, int prefixlen, union ibv_gid* gid) {
struct in_addr *base = NULL;
struct in6_addr *base6 = NULL;
struct in6_addr *addr6 = NULL;;
if (af == AF_INET) {
base = (struct in_addr *)prefix;
} else {
base6 = (struct in6_addr *)prefix;
}
addr6 = (struct in6_addr *)gid->raw;

#define NETMASK(bits) (htonl(0xffffffff ^ ((1 << (32 - bits)) - 1)))

int i = 0;
while (prefixlen > 0 && i < 4) {
if (af == AF_INET) {
int mask = NETMASK(prefixlen);
if ((base->s_addr & mask) ^ (addr6->s6_addr32[3] & mask)) {
break;
}
prefixlen = 0;
break;
} else {
if (prefixlen >= 32) {
if (base6->s6_addr32[i] ^ addr6->s6_addr32[i]) {
break;
}
prefixlen -= 32;
++i;
} else {
int mask = NETMASK(prefixlen);
if ((base6->s6_addr32[i] & mask) ^ (addr6->s6_addr32[i] & mask)) {
break;
}
prefixlen = 0;
}
}
}

return (prefixlen == 0) ? true : false;
}

static bool configuredGid(union ibv_gid* gid) {
const struct in6_addr *a = (struct in6_addr *)gid->raw;
int trailer = (a->s6_addr32[1] | a->s6_addr32[2] | a->s6_addr32[3]);
if (((a->s6_addr32[0] | trailer) == 0UL) || ((a->s6_addr32[0] == htonl(0xfe800000)) && (trailer == 0UL))) {
return false;
}
return true;
}

static bool linkLocalGid(union ibv_gid* gid) {
const struct in6_addr *a = (struct in6_addr *)gid->raw;
if (a->s6_addr32[0] == htonl(0xfe800000) && a->s6_addr32[1] == 0UL) {
return true;
}
return false;
}

static bool validGid(union ibv_gid* gid) {
return (configuredGid(gid) && !linkLocalGid(gid));
}

static ncclResult_t ncclIbRoceGetVersionNum(const char* deviceName, int portNum, int gidIndex, int* version) {
char gidRoceVerStr[16] = { 0 };
char roceTypePath[PATH_MAX] = { 0 };
sprintf(roceTypePath, "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex);

int fd = open(roceTypePath, O_RDONLY);
if (fd == -1) {
return ncclSystemError;
}
int ret = read(fd, gidRoceVerStr, 15);
close(fd);

if (ret == -1) {
return ncclSystemError;
}

if (strlen(gidRoceVerStr)) {
if (strncmp(gidRoceVerStr, "IB/RoCE v1", strlen("IB/RoCE v1")) == 0 || strncmp(gidRoceVerStr, "RoCE v1", strlen("RoCE v1")) == 0) {
*version = 1;
} else if (strncmp(gidRoceVerStr, "RoCE v2", strlen("RoCE v2")) == 0) {
*version = 2;
}
}

return ncclSuccess;
}

static ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t portNum, sa_family_t af, void* prefix, int prefixlen, int roceVer, int gidIndexCandidate, int* gidIndex) {
union ibv_gid gid, gidCandidate;
NCCLCHECK(wrap_ibv_query_gid(context, portNum, *gidIndex, &gid));
NCCLCHECK(wrap_ibv_query_gid(context, portNum, gidIndexCandidate, &gidCandidate));

sa_family_t usrFam = af;
sa_family_t gidFam = getGidAddrFamily(&gid);
sa_family_t gidCandidateFam = getGidAddrFamily(&gidCandidate);
bool gidCandidateMatchSubnet = matchGidAddrPrefix(usrFam, prefix, prefixlen, &gidCandidate);

if (gidCandidateFam != gidFam && gidCandidateFam == usrFam && gidCandidateMatchSubnet) {
*gidIndex = gidIndexCandidate;
} else {
if (gidCandidateFam != usrFam || !validGid(&gidCandidate) || !gidCandidateMatchSubnet) {
return ncclSuccess;
}
int usrRoceVer = roceVer;
int gidRoceVerNum, gidRoceVerNumCandidate;
const char* deviceName = wrap_ibv_get_device_name(context->device);
NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, *gidIndex, &gidRoceVerNum));
NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, gidIndexCandidate, &gidRoceVerNumCandidate));
if ((gidRoceVerNum != gidRoceVerNumCandidate || !validGid(&gid)) && gidRoceVerNumCandidate == usrRoceVer) {
*gidIndex = gidIndexCandidate;
}
}

return ncclSuccess;
}

static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, int gidTblLen, int *gidIndex) {
*gidIndex = ncclParamIbGidIndex();
if (*gidIndex >= 0) {
return ncclSuccess;
}

sa_family_t userAddrFamily = envIbAddrFamily();
int userRoceVersion = ncclParamIbRoceVersionNum();
int prefixlen;
void *prefix = envIbAddrRange(userAddrFamily, &prefixlen);

*gidIndex = 0;
for (int gidIndexNext = 1; gidIndexNext < gidTblLen; ++gidIndexNext) {
NCCLCHECK(ncclUpdateGidIndex(context, portNum, userAddrFamily, prefix, prefixlen, userRoceVersion, gidIndexNext, gidIndex));
}

return ncclSuccess;
}


NCCL_PARAM(IbDisable, "IBEXT_DISABLE", 0);
NCCL_PARAM(IbMergeVfs, "IB_MERGE_VFS", 1);
NCCL_PARAM(IbMergeNics, "IB_MERGE_NICS", 1);
Expand Down Expand Up @@ -373,7 +579,7 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
return ncclSuccess;
}

ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbDevInfo* info) {
ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint8_t sGidIndex, uint32_t dest_qp_num, struct ncclIbDevInfo* info) {
struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
qpAttr.qp_state = IBV_QPS_RTR;
Expand All @@ -382,21 +588,20 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbD
qpAttr.rq_psn = 0;
qpAttr.max_dest_rd_atomic = 1;
qpAttr.min_rnr_timer = 12;
qpAttr.ah_attr.is_global = 0;
qpAttr.ah_attr.dlid = info->lid;
qpAttr.ah_attr.sl = ncclParamIbSl();
qpAttr.ah_attr.src_path_bits = 0;
qpAttr.ah_attr.port_num = info->ib_port;
if (info->link_layer == IBV_LINK_LAYER_ETHERNET || info->is_global) {
qpAttr.ah_attr.is_global = 1;
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->spn;
qpAttr.ah_attr.grh.dgid.global.interface_id = info->iid;
qpAttr.ah_attr.grh.flow_label = 0;
qpAttr.ah_attr.grh.sgid_index = ncclParamIbGidIndex();
qpAttr.ah_attr.grh.sgid_index = sGidIndex;
qpAttr.ah_attr.grh.hop_limit = 255;
qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc();
} else {
qpAttr.ah_attr.is_global = 0;
qpAttr.ah_attr.dlid = info->lid;
}
qpAttr.ah_attr.sl = ncclParamIbSl();
qpAttr.ah_attr.src_path_bits = 0;
qpAttr.ah_attr.port_num = info->ib_port;
NCCLCHECK(wrap_ibv_modify_qp(qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER));
return ncclSuccess;
}
Expand Down Expand Up @@ -515,7 +720,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
);

if (devInfo->link_layer == IBV_LINK_LAYER_ETHERNET || devInfo->is_global) {
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &commDev->base.gidInfo.localGid));

NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &commDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid));
devInfo->spn = commDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo->iid = commDev->base.gidInfo.localGid.global.interface_id;
}
Expand All @@ -533,9 +740,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
// Print just the QPs for this dev
if (comm->base.qps[q].devIndex == i)
INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x",
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev,
commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, ncclParamIbGidIndex(),
devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey);
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev,
commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, (int64_t)commDev->base.gidInfo.localGidIndex,
devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey);
}
}
}
Expand Down Expand Up @@ -603,12 +810,15 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet

// Assign per-QP remDev
comm->base.qps[q].remDevIdx = remQpInfo->devIndex;
int devIndex = comm->base.qps[q].devIndex;
ncclIbSendCommDev* commDev = comm->devs + devIndex;
uint8_t gidIndex = commDev->base.gidInfo.localGidIndex;

struct ibv_qp* qp = comm->base.qps[q].qp;
if (remQpInfo->ece_supported && remQpInfo->ece_supported)
NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported));

NCCLCHECK(ncclIbRtrQp(qp, remQpInfo->qpn, remDevInfo));
NCCLCHECK(ncclIbRtrQp(qp, gidIndex, remQpInfo->qpn, remDevInfo));
NCCLCHECK(ncclIbRtsQp(qp));
}

Expand Down Expand Up @@ -708,7 +918,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
ibDevN = mergedDev->devs[i];
NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &rCommDev->base));
ibDev = ncclIbDevs + ibDevN;
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &rCommDev->base.gidInfo.localGid));
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &rCommDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, rCommDev->base.gidInfo.localGidIndex, &rCommDev->base.gidInfo.localGid));
}

// Copy remDevInfo for things like remGidInfo, remFifoAddr, etc.
Expand Down Expand Up @@ -746,7 +957,7 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
if (meta.qpInfo[q].ece_supported)
NCCLCHECK(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported));
}
NCCLCHECK(ncclIbRtrQp(qp->qp, remMeta.qpInfo[q].qpn, remDevInfo));
NCCLCHECK(ncclIbRtrQp(qp->qp, rCommDev->base.gidInfo.localGidIndex, remMeta.qpInfo[q].qpn, remDevInfo));
NCCLCHECK(ncclIbRtsQp(qp->qp));
}

Expand Down Expand Up @@ -784,8 +995,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
#endif
);
devInfo.mtu = ibDev->portAttr.active_mtu;
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo));
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->base.gidInfo.localGidIndex, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo));
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
}

// Fill Handle
Expand Down Expand Up @@ -1432,7 +1643,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
return ncclInternalError;
}
if (req->nreqs == 1) {
req->recv.sizes[0] += wc->imm_data;
req->recv.sizes[0] = wc->imm_data;
}
}
req->events[i]--;
Expand Down
Loading
Loading