-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MSCCLPP user buffer registration APIs and integrate with RCCL (#1477
) * ext-src: add MSCCLPP memory registration APIs * update mem-reg patch with mscclpp helper routine to check if buffer is registered * RCCL integration of MSCCL++ user-buffer registration APIs * only include mscclpp_nccl header if ENABLE_MSCCLPP is defined * ext-src: update mscclpp mem-reg patch * add helper routine to patch * check handle before MSCCL++ deregister * fix typo to replace send buff with recv buff * in case of no mscclpp registration, dduring deRegister call, ont fall back to rccl deRegister which will return an error * Apply suggestions from code review Whitespace suggestions and reducing diffs to avoid future merge conflicts Co-authored-by: corey-derochie-amd <[email protected]> * rename helper functions and change their return type * set RCCL user-buffer registration to occur if attempting MSCCL++ registration with a buffer in managed memory --------- Co-authored-by: isaki001 <[email protected]> Co-authored-by: isaki001 <[email protected]> Co-authored-by: corey-derochie-amd <[email protected]>
- Loading branch information
1 parent
3fee623
commit e9b6bbc
Showing
6 changed files
with
215 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h | ||
index 7f50792..b8b146d 100644 | ||
--- a/apps/nccl/include/nccl.h | ||
+++ b/apps/nccl/include/nccl.h | ||
@@ -344,6 +344,13 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun | ||
ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, | ||
ncclComm_t comm, cudaStream_t stream); | ||
|
||
+/* | ||
+ * Register/Deregister | ||
+ */ | ||
+ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle); | ||
+ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle); | ||
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count); | ||
+size_t mscclpp_BufferSize(ncclComm_t comm, void* handle); | ||
/* | ||
* Send | ||
* | ||
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu | ||
index a697be2..1d4af61 100644 | ||
--- a/apps/nccl/src/nccl.cu | ||
+++ b/apps/nccl/src/nccl.cu | ||
@@ -65,6 +65,7 @@ struct ncclComm { | ||
std::unordered_map<channelKey, ChannelInfo> channelInInfos; | ||
std::unordered_map<channelKey, ChannelInfo> channelOutInfos; | ||
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos; | ||
+ std::unordered_map<void*, channelKey> handleKeys; | ||
std::shared_ptr<char> scratchBuff; | ||
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories; | ||
|
||
@@ -73,6 +74,11 @@ struct ncclComm { | ||
uint32_t buffFlag; | ||
}; | ||
|
||
+struct handleInfo { | ||
+ void * buff; | ||
+ cudaIpcMemHandle_t ipcHandle; | ||
+}; | ||
+ | ||
static size_t ncclTypeSize(ncclDataType_t type) { | ||
switch (type) { | ||
case ncclInt8: | ||
@@ -577,6 +583,104 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t | ||
return ncclSuccess; | ||
} | ||
|
||
+NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) { | ||
+ size_t buffBytes = size; | ||
+ CUdeviceptr buffBasePtr; | ||
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); | ||
+ | ||
+ int rank = comm->comm->bootstrap()->getRank(); | ||
+ channelKey buffKey{(void*)buffBasePtr, buffBytes}; | ||
+ | ||
+ std::vector<mscclpp::RegisteredMemory> remoteMemories; | ||
+ | ||
+ // Creating the channels | ||
+ auto buffIt = comm->channelScratchInfos.find(buffKey); | ||
+ if (buffIt == comm->channelScratchInfos.end()) { | ||
+ std::vector<mscclpp::SmChannel> channels = | ||
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr)); | ||
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; | ||
+ buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first; | ||
+ } | ||
+ auto sendIt = comm->channelInInfos.find(buffKey); | ||
+ if (sendIt == comm->channelInInfos.end()) { | ||
+ std::vector<mscclpp::SmChannel> channels = | ||
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr)); | ||
+ | ||
+ remoteMemories = | ||
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); | ||
+ std::vector<mscclpp::SmChannel> channels1 = | ||
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr)); | ||
+ | ||
+ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)}; | ||
+ sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first; | ||
+ } | ||
+ auto recvIt = comm->channelOutInfos.find(buffKey); | ||
+ if (recvIt == comm->channelOutInfos.end()) { | ||
+ remoteMemories = | ||
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); | ||
+ std::vector<mscclpp::SmChannel> outChannels = | ||
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr)); | ||
+ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; | ||
+ recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first; | ||
+ } | ||
+ | ||
+ cudaIpcMemHandle_t ipcHandle; | ||
+ MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&ipcHandle, buffBasePtr)); | ||
+ | ||
+ struct handleInfo *p = (struct handleInfo *) malloc(sizeof(struct handleInfo)); | ||
+ p->buff = buffBasePtr; | ||
+ p->ipcHandle = ipcHandle; | ||
+ *handle = p; | ||
+ | ||
+ auto it = comm->handleKeys.find(*handle); | ||
+ if (it == comm->handleKeys.end()) { | ||
+ comm->handleKeys[*handle] = buffKey; | ||
+ } | ||
+ | ||
+ return ncclSuccess; | ||
+} | ||
+ | ||
+NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) { | ||
+ if (comm && handle) { | ||
+ channelKey buffKey = comm->handleKeys[handle]; | ||
+ | ||
+ auto scratchIt = comm->channelScratchInfos.find(buffKey); | ||
+ if (scratchIt != comm->channelScratchInfos.end()) { | ||
+ comm->channelScratchInfos.erase(scratchIt); | ||
+ } | ||
+ | ||
+ auto inIt = comm->channelInInfos.find(buffKey); | ||
+ if (inIt != comm->channelInInfos.end()) { | ||
+ comm->channelInInfos.erase(inIt); | ||
+ } | ||
+ | ||
+ auto outIt = comm->channelOutInfos.find(buffKey); | ||
+ if (outIt != comm->channelOutInfos.end()) { | ||
+ comm->channelOutInfos.erase(outIt); | ||
+ } | ||
+ | ||
+ free(handle); | ||
+ } | ||
+ return ncclSuccess; | ||
+} | ||
+ | ||
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count){ | ||
+ size_t buffBytes; | ||
+ CUdeviceptr buffBasePtr; | ||
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); | ||
+ channelKey buffKey{(void*)buffBasePtr, buffBytes}; | ||
+ auto buffIt = comm->channelScratchInfos.find(buffKey); | ||
+ bool registered = buffIt != comm->channelScratchInfos.end(); | ||
+ return registered; | ||
+} | ||
+size_t | ||
+mscclpp_BufferSize(ncclComm_t comm, void* handle){ | ||
+ if (!(comm && handle)){ | ||
+ return 0; | ||
+ } | ||
+ auto buffKeyIt = comm->handleKeys.find(handle); | ||
+ return buffKeyIt != comm->handleKeys.end() ? buffKeyIt->second.bytes : 0; | ||
+} | ||
NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) { | ||
// TODO: implement this function | ||
return ncclInternalError; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters