forked from NVIDIA/TransformerEngine
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NVLS-MC based UB kernels (NVIDIA#721)
Fix license, and sign off everything Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
- Loading branch information
Showing
5 changed files
with
536 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
255 changes: 255 additions & 0 deletions
255
transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
/************************************************************************* | ||
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See LICENSE for license information. | ||
************************************************************************/ | ||
|
||
#include "ipcsocket.h" | ||
#include <errno.h> | ||
#include <stdlib.h> | ||
#include <string.h> | ||
|
||
#define WARN(...) \ | ||
{} | ||
#define TRACE(...) \ | ||
{} | ||
#define SYSCHECK(...) \ | ||
{} | ||
#define EQCHECK(...) \ | ||
{} | ||
|
||
// Enable Linux abstract socket naming | ||
#define USE_ABSTRACT_SOCKET | ||
|
||
#define NCCL_IPC_SOCKNAME_STR "/tmp/nccl-socket-%d-%lx" | ||
|
||
/* | ||
* Create a Unix Domain Socket | ||
*/ | ||
ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash, | ||
volatile uint32_t *abortFlag) { | ||
int fd = -1; | ||
struct sockaddr_un cliaddr; | ||
char temp[NCCL_IPC_SOCKNAME_LEN] = ""; | ||
|
||
if (handle == NULL) { | ||
return ncclInternalError; | ||
} | ||
|
||
handle->fd = -1; | ||
handle->socketName[0] = '\0'; | ||
if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) { | ||
WARN("UDS: Socket creation error : %s (%d)", strerror(errno), errno); | ||
return ncclSystemError; | ||
} | ||
|
||
bzero(&cliaddr, sizeof(cliaddr)); | ||
cliaddr.sun_family = AF_UNIX; | ||
|
||
// Create unique name for the socket. | ||
int len = | ||
snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); | ||
if (len > (sizeof(cliaddr.sun_path) - 1)) { | ||
WARN("UDS: Cannot bind provided name to socket. Name too large"); | ||
return ncclInternalError; | ||
} | ||
#ifndef USE_ABSTRACT_SOCKET | ||
unlink(temp); | ||
#endif | ||
|
||
TRACE(NCCL_INIT, "UDS: Creating socket %s", temp); | ||
|
||
strncpy(cliaddr.sun_path, temp, len); | ||
#ifdef USE_ABSTRACT_SOCKET | ||
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick | ||
#endif | ||
if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) { | ||
WARN("UDS: Binding to socket %s failed : %s (%d)", temp, strerror(errno), | ||
errno); | ||
close(fd); | ||
return ncclSystemError; | ||
} | ||
|
||
handle->fd = fd; | ||
strcpy(handle->socketName, temp); // NOLINT(*) | ||
|
||
handle->abortFlag = abortFlag; | ||
// Mark socket as non-blocking | ||
if (handle->abortFlag) { | ||
int flags; | ||
EQCHECK(flags = fcntl(fd, F_GETFL), -1); | ||
SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl"); | ||
} | ||
|
||
return ncclSuccess; | ||
} | ||
|
||
ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd) { | ||
if (handle == NULL) { | ||
WARN("ncclSocketGetFd: pass NULL socket"); | ||
return ncclInvalidArgument; | ||
} | ||
if (fd) | ||
*fd = handle->fd; | ||
return ncclSuccess; | ||
} | ||
|
||
ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) { | ||
if (handle == NULL) { | ||
return ncclInternalError; | ||
} | ||
if (handle->fd <= 0) { | ||
return ncclSuccess; | ||
} | ||
#ifndef USE_ABSTRACT_SOCKET | ||
if (handle->socketName[0] != '\0') { | ||
unlink(handle->socketName); | ||
} | ||
#endif | ||
close(handle->fd); | ||
|
||
return ncclSuccess; | ||
} | ||
|
||
ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, | ||
int *recvFd) { | ||
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; | ||
struct iovec iov[1]; | ||
|
||
// Union to guarantee alignment requirements for control array | ||
union { | ||
struct cmsghdr cm; | ||
char control[CMSG_SPACE(sizeof(int))]; | ||
} control_un; | ||
|
||
struct cmsghdr *cmptr; | ||
char dummy_buffer[1]; | ||
int ret; | ||
|
||
msg.msg_control = control_un.control; | ||
msg.msg_controllen = sizeof(control_un.control); | ||
|
||
if (hdr == NULL) { | ||
iov[0].iov_base = reinterpret_cast<void *>(dummy_buffer); | ||
iov[0].iov_len = sizeof(dummy_buffer); | ||
} else { | ||
iov[0].iov_base = hdr; | ||
iov[0].iov_len = hdrLen; | ||
} | ||
|
||
msg.msg_iov = iov; | ||
msg.msg_iovlen = 1; | ||
|
||
while ((ret = recvmsg(handle->fd, &msg, 0)) <= 0) { | ||
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { | ||
WARN("UDS: Receiving data over socket failed : %d", errno); | ||
return ncclSystemError; | ||
} | ||
if (handle->abortFlag && *handle->abortFlag) | ||
return ncclInternalError; | ||
} | ||
|
||
if (recvFd != NULL) { | ||
if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && | ||
(cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) { | ||
if ((cmptr->cmsg_level != SOL_SOCKET) || | ||
(cmptr->cmsg_type != SCM_RIGHTS)) { | ||
WARN("UDS: Receiving data over socket failed"); | ||
return ncclSystemError; | ||
} | ||
|
||
memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd)); | ||
} else { | ||
WARN("UDS: Receiving data over socket %s failed", handle->socketName); | ||
return ncclSystemError; | ||
} | ||
TRACE(NCCL_INIT | NCCL_P2P, "UDS: Got recvFd %d from socket %s", *recvFd, | ||
handle->socketName); | ||
} | ||
|
||
return ncclSuccess; | ||
} | ||
|
||
ncclResult_t ncclIpcSocketRecvFd(ncclIpcSocket *handle, int *recvFd) { | ||
return ncclIpcSocketRecvMsg(handle, NULL, 0, recvFd); | ||
} | ||
|
||
ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, | ||
const int sendFd, int rank, uint64_t hash) { | ||
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; | ||
struct iovec iov[1]; | ||
char temp[NCCL_IPC_SOCKNAME_LEN]; | ||
|
||
union { | ||
struct cmsghdr cm; | ||
char control[CMSG_SPACE(sizeof(int))]; | ||
} control_un; | ||
|
||
struct cmsghdr *cmptr; | ||
char dummy_buffer[1]; | ||
struct sockaddr_un cliaddr; | ||
|
||
// Construct client address to send this shareable handle to | ||
bzero(&cliaddr, sizeof(cliaddr)); | ||
cliaddr.sun_family = AF_UNIX; | ||
|
||
int len = | ||
snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); | ||
if (len > (sizeof(cliaddr.sun_path) - 1)) { | ||
WARN("UDS: Cannot connect to provided name for socket. Name too large"); | ||
return ncclInternalError; | ||
} | ||
(void)strncpy(cliaddr.sun_path, temp, len); | ||
|
||
#ifdef USE_ABSTRACT_SOCKET | ||
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick | ||
#endif | ||
|
||
TRACE(NCCL_INIT, "UDS: Sending hdr %p len %d to UDS socket %s", hdr, hdrLen, | ||
temp); | ||
|
||
if (sendFd != -1) { | ||
TRACE(NCCL_INIT, "UDS: Sending fd %d to UDS socket %s", sendFd, temp); | ||
|
||
msg.msg_control = control_un.control; | ||
msg.msg_controllen = sizeof(control_un.control); | ||
|
||
cmptr = CMSG_FIRSTHDR(&msg); | ||
cmptr->cmsg_len = CMSG_LEN(sizeof(int)); | ||
cmptr->cmsg_level = SOL_SOCKET; | ||
cmptr->cmsg_type = SCM_RIGHTS; | ||
memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd)); | ||
} | ||
|
||
msg.msg_name = reinterpret_cast<void *>(&cliaddr); | ||
msg.msg_namelen = sizeof(struct sockaddr_un); | ||
|
||
if (hdr == NULL) { | ||
iov[0].iov_base = reinterpret_cast<void *>(dummy_buffer); | ||
iov[0].iov_len = sizeof(dummy_buffer); | ||
} else { | ||
iov[0].iov_base = hdr; | ||
iov[0].iov_len = hdrLen; | ||
} | ||
msg.msg_iov = iov; | ||
msg.msg_iovlen = 1; | ||
msg.msg_flags = 0; | ||
|
||
ssize_t sendResult; | ||
while ((sendResult = sendmsg(handle->fd, &msg, 0)) < 0) { | ||
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { | ||
WARN("UDS: Sending data over socket %s failed : %s (%d)", temp, | ||
strerror(errno), errno); | ||
return ncclSystemError; | ||
} | ||
if (handle->abortFlag && *handle->abortFlag) | ||
return ncclInternalError; | ||
} | ||
|
||
return ncclSuccess; | ||
} | ||
|
||
ncclResult_t ncclIpcSocketSendFd(ncclIpcSocket *handle, const int sendFd, | ||
int rank, uint64_t hash) { | ||
return ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/************************************************************************* | ||
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See LICENSE for license information. | ||
************************************************************************/ | ||
|
||
#ifndef NCCL_IPCSOCKET_H | ||
#define NCCL_IPCSOCKET_H | ||
|
||
// #include "nccl.h" | ||
#include <errno.h> | ||
#include <fcntl.h> | ||
#include <inttypes.h> | ||
#include <memory.h> | ||
#include <stdio.h> | ||
#include <sys/mman.h> | ||
#include <sys/socket.h> | ||
#include <sys/types.h> | ||
#include <sys/un.h> | ||
#include <sys/wait.h> | ||
#include <unistd.h> | ||
|
||
typedef enum { | ||
ncclSuccess = 0, | ||
ncclUnhandledCudaError = 1, | ||
ncclSystemError = 2, | ||
ncclInternalError = 3, | ||
ncclInvalidArgument = 4, | ||
ncclInvalidUsage = 5, | ||
ncclRemoteError = 6, | ||
ncclInProgress = 7, | ||
ncclNumResults = 8 | ||
} ncclResult_t; | ||
|
||
#define NCCL_IPC_SOCKNAME_LEN 64 | ||
|
||
struct ncclIpcSocket { | ||
int fd; | ||
char socketName[NCCL_IPC_SOCKNAME_LEN]; | ||
volatile uint32_t *abortFlag; | ||
}; | ||
|
||
ncclResult_t ncclIpcSocketInit(struct ncclIpcSocket *handle, int rank, | ||
uint64_t hash, volatile uint32_t *abortFlag); | ||
ncclResult_t ncclIpcSocketClose(struct ncclIpcSocket *handle); | ||
ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd); | ||
|
||
ncclResult_t ncclIpcSocketRecvFd(struct ncclIpcSocket *handle, int *fd); | ||
ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, | ||
int rank, uint64_t hash); | ||
|
||
#endif /* NCCL_IPCSOCKET_H */ |
Oops, something went wrong.