-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MSCCLPP user buffer registration APIs and integrate with RCCL #1477
Changes from all commits
6c32a32
a552427
184cea2
8617629
8bfaf30
4f2dff1
8568024
de5fa5c
46c0438
ad036d4
5afa3e5
4a9aa09
65dacd5
373b9f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h | ||
index 7f50792..b8b146d 100644 | ||
--- a/apps/nccl/include/nccl.h | ||
+++ b/apps/nccl/include/nccl.h | ||
@@ -344,6 +344,13 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun | ||
ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, | ||
ncclComm_t comm, cudaStream_t stream); | ||
|
||
+/* | ||
+ * Register/Deregister | ||
+ */ | ||
+ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle); | ||
+ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle); | ||
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count); | ||
+size_t mscclpp_BufferSize(ncclComm_t comm, void* handle); | ||
/* | ||
* Send | ||
* | ||
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu | ||
index a697be2..1d4af61 100644 | ||
--- a/apps/nccl/src/nccl.cu | ||
+++ b/apps/nccl/src/nccl.cu | ||
@@ -65,6 +65,7 @@ struct ncclComm { | ||
std::unordered_map<channelKey, ChannelInfo> channelInInfos; | ||
std::unordered_map<channelKey, ChannelInfo> channelOutInfos; | ||
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos; | ||
+ std::unordered_map<void*, channelKey> handleKeys; | ||
std::shared_ptr<char> scratchBuff; | ||
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories; | ||
|
||
@@ -73,6 +74,11 @@ struct ncclComm { | ||
uint32_t buffFlag; | ||
}; | ||
|
||
+struct handleInfo { | ||
+ void * buff; | ||
+ cudaIpcMemHandle_t ipcHandle; | ||
+}; | ||
+ | ||
static size_t ncclTypeSize(ncclDataType_t type) { | ||
switch (type) { | ||
case ncclInt8: | ||
@@ -577,6 +583,104 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t | ||
return ncclSuccess; | ||
} | ||
|
||
+NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) { | ||
+ size_t buffBytes = size; | ||
+ CUdeviceptr buffBasePtr; | ||
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); | ||
+ | ||
+ int rank = comm->comm->bootstrap()->getRank(); | ||
+ channelKey buffKey{(void*)buffBasePtr, buffBytes}; | ||
+ | ||
+ std::vector<mscclpp::RegisteredMemory> remoteMemories; | ||
+ | ||
+ // Creating the channels | ||
+ auto buffIt = comm->channelScratchInfos.find(buffKey); | ||
+ if (buffIt == comm->channelScratchInfos.end()) { | ||
+ std::vector<mscclpp::SmChannel> channels = | ||
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr)); | ||
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; | ||
+ buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first; | ||
+ } | ||
+ auto sendIt = comm->channelInInfos.find(buffKey); | ||
+ if (sendIt == comm->channelInInfos.end()) { | ||
+ std::vector<mscclpp::SmChannel> channels = | ||
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr)); | ||
+ | ||
+ remoteMemories = | ||
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); | ||
+ std::vector<mscclpp::SmChannel> channels1 = | ||
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr)); | ||
+ | ||
+ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)}; | ||
+ sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first; | ||
+ } | ||
+ auto recvIt = comm->channelOutInfos.find(buffKey); | ||
+ if (recvIt == comm->channelOutInfos.end()) { | ||
+ remoteMemories = | ||
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); | ||
+ std::vector<mscclpp::SmChannel> outChannels = | ||
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr)); | ||
+ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; | ||
+ recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first; | ||
+ } | ||
+ | ||
+ cudaIpcMemHandle_t ipcHandle; | ||
+ MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&ipcHandle, buffBasePtr)); | ||
+ | ||
+ struct handleInfo *p = (struct handleInfo *) malloc(sizeof(struct handleInfo)); | ||
+ p->buff = buffBasePtr; | ||
+ p->ipcHandle = ipcHandle; | ||
+ *handle = p; | ||
+ | ||
+ auto it = comm->handleKeys.find(*handle); | ||
+ if (it == comm->handleKeys.end()) { | ||
+ comm->handleKeys[*handle] = buffKey; | ||
+ } | ||
+ | ||
+ return ncclSuccess; | ||
+} | ||
+ | ||
+NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) { | ||
+ if (comm && handle) { | ||
+ channelKey buffKey = comm->handleKeys[handle]; | ||
+ | ||
+ auto scratchIt = comm->channelScratchInfos.find(buffKey); | ||
+ if (scratchIt != comm->channelScratchInfos.end()) { | ||
+ comm->channelScratchInfos.erase(scratchIt); | ||
+ } | ||
+ | ||
+ auto inIt = comm->channelInInfos.find(buffKey); | ||
+ if (inIt != comm->channelInInfos.end()) { | ||
+ comm->channelInInfos.erase(inIt); | ||
+ } | ||
+ | ||
+ auto outIt = comm->channelOutInfos.find(buffKey); | ||
+ if (outIt != comm->channelOutInfos.end()) { | ||
+ comm->channelOutInfos.erase(outIt); | ||
+ } | ||
+ | ||
+ free(handle); | ||
+ } | ||
+ return ncclSuccess; | ||
+} | ||
+ | ||
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count){ | ||
+ size_t buffBytes; | ||
+ CUdeviceptr buffBasePtr; | ||
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); | ||
+ channelKey buffKey{(void*)buffBasePtr, buffBytes}; | ||
+ auto buffIt = comm->channelScratchInfos.find(buffKey); | ||
+ bool registered = buffIt != comm->channelScratchInfos.end(); | ||
+ return registered; | ||
+} | ||
+size_t | ||
+mscclpp_BufferSize(ncclComm_t comm, void* handle){ | ||
+ if (!(comm && handle)){ | ||
+ return 0; | ||
+ } | ||
+ auto buffKeyIt = comm->handleKeys.find(handle); | ||
+ return buffKeyIt != comm->handleKeys.end() ? buffKeyIt->second.bytes : 0; | ||
+} | ||
NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) { | ||
// TODO: implement this function | ||
return ncclInternalError; |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,9 @@ | |
#include "net.h" | ||
#include "register.h" | ||
#include "api_trace.h" | ||
#ifdef ENABLE_MSCCLPP | ||
#include "mscclpp/mscclpp_nccl.h" | ||
#endif | ||
|
||
ncclResult_t ncclNetDeregister(struct ncclComm* comm, struct ncclReg* reg) { | ||
struct ncclRegCache* cache = &comm->regCache; | ||
|
@@ -155,12 +158,36 @@ NCCL_API(ncclResult_t, ncclCommRegister, const ncclComm_t comm, void* buff, size | |
ncclResult_t ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** handle) { | ||
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); | ||
if (comm->checkPointers) NCCLCHECK(CudaPtrCheck(buff, comm, "buff", "ncclCommRegister")); | ||
#ifdef ENABLE_MSCCLPP | ||
if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold){ | ||
bool isManagedBuffer = false; | ||
CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(buff))); | ||
if(!isManagedBuffer){ | ||
INFO(NCCL_INIT, "MSCCL++: ncclCommRegister"); | ||
NCCLCHECK(mscclpp_ncclCommRegister(comm->mscclpp_comm, buff, size, handle)); | ||
return ncclSuccess; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm debating whether this is the right behaviour... Are these buffers reused in ways that might call MSCCL++ sometimes and RCCL other times? Are buffers always matched to message size, or could they use a buffer that's larger than the given message size? If so, the safest thing to do might be to register for BOTH RCCL and MSCCL++. Thoughts, @nusislam ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cases when |
||
} | ||
else{ | ||
WARN("MSCCL++: Cannot register user-buffers on managed memory. RCCL user-buffer registration will occur."); | ||
} | ||
} | ||
#endif | ||
INFO(NCCL_INIT, "RCCL: ncclCommRegister"); | ||
NCCLCHECK(ncclRegister(comm, buff, size, handle)); | ||
return ncclSuccess; | ||
} | ||
|
||
NCCL_API(ncclResult_t, ncclCommDeregister, const ncclComm_t comm, void* handle); | ||
ncclResult_t ncclCommDeregister_impl(const ncclComm_t comm, void* handle) { | ||
|
||
#ifdef ENABLE_MSCCLPP | ||
const size_t size = mscclpp_BufferSize(comm->mscclpp_comm, handle); | ||
if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold) { | ||
NCCLCHECK(mscclpp_ncclCommDeregister(comm->mscclpp_comm, handle)); | ||
return ncclSuccess; | ||
} | ||
#endif | ||
|
||
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); | ||
struct ncclReg* reg = (struct ncclReg*)handle; | ||
struct ncclRegCache* cache = &comm->regCache; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there's no guarantee this will be called from a single thread, we probably need to lock access to channelScratchInfos and the rest of the maps. I suggest using one lock if this function doesn't fall on the critical path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This map is supposed to be checked by the host process and MSCCLPP is invoked only when using one GPU per process case.