-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MSCCLPP user buffer registration APIs and integrate with RCCL #1477
Changes from 3 commits
6c32a32
a552427
184cea2
8617629
8bfaf30
4f2dff1
8568024
de5fa5c
46c0438
ad036d4
5afa3e5
4a9aa09
65dacd5
373b9f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,89 @@ | ||||||||||||||||||
diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h | ||||||||||||||||||
index 7f50792..b8b146d 100644 | ||||||||||||||||||
--- a/apps/nccl/include/nccl.h | ||||||||||||||||||
+++ b/apps/nccl/include/nccl.h | ||||||||||||||||||
@@ -344,6 +344,12 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun | ||||||||||||||||||
ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, | ||||||||||||||||||
ncclComm_t comm, cudaStream_t stream); | ||||||||||||||||||
|
||||||||||||||||||
+/* | ||||||||||||||||||
+ * Register/Deregister | ||||||||||||||||||
+ */ | ||||||||||||||||||
+ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle); | ||||||||||||||||||
+ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle); | ||||||||||||||||||
+ncclResult_t ncclBuffIsRegistered(ncclComm_t comm, const void* buff, size_t count, bool* registered); | ||||||||||||||||||
/* | ||||||||||||||||||
* Send | ||||||||||||||||||
* | ||||||||||||||||||
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu | ||||||||||||||||||
index a697be2..d8497e7 100644 | ||||||||||||||||||
--- a/apps/nccl/src/nccl.cu | ||||||||||||||||||
+++ b/apps/nccl/src/nccl.cu | ||||||||||||||||||
@@ -577,6 +577,67 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t | ||||||||||||||||||
return ncclSuccess; | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
+NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) { | ||||||||||||||||||
+ size_t buffBytes = size; | ||||||||||||||||||
+ CUdeviceptr buffBasePtr; | ||||||||||||||||||
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); | ||||||||||||||||||
+ | ||||||||||||||||||
+ int rank = comm->comm->bootstrap()->getRank(); | ||||||||||||||||||
+ channelKey buffKey{(void*)buffBasePtr, buffBytes}; | ||||||||||||||||||
+ | ||||||||||||||||||
+ std::vector<mscclpp::RegisteredMemory> remoteMemories; | ||||||||||||||||||
+ | ||||||||||||||||||
+ // Creating the channels | ||||||||||||||||||
+ auto buffIt = comm->channelScratchInfos.find(buffKey); | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there's no guarantee this will be called from a single thread, we probably need to lock access to channelScratchInfos and the rest of the maps. I suggest using one lock if this function doesn't fall on the critical path. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This map is supposed to be checked by the host process and MSCCLPP is invoked only when using one GPU per process case. |
||||||||||||||||||
+ if (buffIt == comm->channelScratchInfos.end()) { | ||||||||||||||||||
+ std::vector<mscclpp::SmChannel> channels = | ||||||||||||||||||
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr)); | ||||||||||||||||||
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; | ||||||||||||||||||
+ buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first; | ||||||||||||||||||
+ } | ||||||||||||||||||
+ auto sendIt = comm->channelInInfos.find(buffKey); | ||||||||||||||||||
+ if (sendIt == comm->channelInInfos.end()) { | ||||||||||||||||||
+ std::vector<mscclpp::SmChannel> channels = | ||||||||||||||||||
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr)); | ||||||||||||||||||
+ | ||||||||||||||||||
+ remoteMemories = | ||||||||||||||||||
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); | ||||||||||||||||||
+ std::vector<mscclpp::SmChannel> channels1 = | ||||||||||||||||||
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr)); | ||||||||||||||||||
+ | ||||||||||||||||||
+ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)}; | ||||||||||||||||||
+ sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first; | ||||||||||||||||||
+ } | ||||||||||||||||||
+ auto recvIt = comm->channelOutInfos.find(buffKey); | ||||||||||||||||||
+ if (recvIt == comm->channelOutInfos.end()) { | ||||||||||||||||||
+ remoteMemories = | ||||||||||||||||||
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); | ||||||||||||||||||
+ std::vector<mscclpp::SmChannel> outChannels = | ||||||||||||||||||
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr)); | ||||||||||||||||||
+ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; | ||||||||||||||||||
+ recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first; | ||||||||||||||||||
+ } | ||||||||||||||||||
+ *handle = (void*) buffBasePtr; | ||||||||||||||||||
+ | ||||||||||||||||||
+ return ncclSuccess; | ||||||||||||||||||
+} | ||||||||||||||||||
+ | ||||||||||||||||||
+NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) { | ||||||||||||||||||
+ if (comm && handle) { | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we free any resources associated with
Same thing can be done for channelOut and channelIn |
||||||||||||||||||
+ handle = nullptr; | ||||||||||||||||||
+ } | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation nullifies the
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a NCCL API and cannot be changed. I am still working on some of these fixes. |
||||||||||||||||||
+ return ncclSuccess; | ||||||||||||||||||
+} | ||||||||||||||||||
+ | ||||||||||||||||||
+NCCL_API ncclResult_t ncclBuffIsRegistered(ncclComm_t comm, const void* buff, size_t count, bool* registered){ | ||||||||||||||||||
+ size_t buffBytes; | ||||||||||||||||||
+ CUdeviceptr buffBasePtr; | ||||||||||||||||||
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); | ||||||||||||||||||
+ channelKey buffKey{(void*)buffBasePtr, buffBytes}; | ||||||||||||||||||
+ auto buffIt = comm->channelScratchInfos.find(buffKey); | ||||||||||||||||||
+ *registered = buffIt != comm->channelScratchInfos.end(); | ||||||||||||||||||
+ return ncclSuccess; | ||||||||||||||||||
+} | ||||||||||||||||||
NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) { | ||||||||||||||||||
// TODO: implement this function | ||||||||||||||||||
return ncclInternalError; |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
#include "net.h" | ||
#include "register.h" | ||
#include "api_trace.h" | ||
#include "mscclpp/mscclpp_nccl.h" | ||
|
||
ncclResult_t ncclNetDeregister(struct ncclComm* comm, struct ncclReg* reg) { | ||
struct ncclRegCache* cache = &comm->regCache; | ||
|
@@ -155,32 +156,59 @@ NCCL_API(ncclResult_t, ncclCommRegister, const ncclComm_t comm, void* buff, size | |
ncclResult_t ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** handle) { | ||
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); | ||
if (comm->checkPointers) NCCLCHECK(CudaPtrCheck(buff, comm, "buff", "ncclCommRegister")); | ||
NCCLCHECK(ncclRegister(comm, buff, size, handle)); | ||
|
||
#ifdef ENABLE_MSCCLPP | ||
if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold){ | ||
bool isManagedBuffer = false; | ||
CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(buff))); | ||
if(!isManagedBuffer){ | ||
INFO(NCCL_INIT, "MSCCL++: ncclCommRegister"); | ||
NCCLCHECK(mscclpp_ncclCommRegister(comm->mscclpp_comm, buff, size, handle)); | ||
} | ||
else{ | ||
WARN("MSCCL++: Cannot register user-buffers on managed memory"); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will do nothing, issue a warning, then fall through and return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the expectation is that if no registration happens in MSCCL++ due to the buffer being in managed memory, no MSCCL++ or RCCL registration will happen and execution will fall back to RCCL kernel without UBR. I actually need to do a small adjustment to not cause a RCCL deregistration error when MSCCL++ detects no registered buffer because of this. I just tested a simple fix for this. But since we always fallback to RCCL kernel when MSCCL++ is not enabled, I think a warning without return an error is OK. Also, I haven't seen any errors returned in any UBR related routines, except when RCCL cannot find the registration handle during deregistration. Of course, this definitely makes it harder to see if anything goes wrong. |
||
} | ||
else | ||
#endif | ||
{ | ||
INFO(NCCL_INIT, "RCCL: ncclCommRegister"); | ||
NCCLCHECK(ncclRegister(comm, buff, size, handle)); | ||
} | ||
isaki001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return ncclSuccess; | ||
} | ||
|
||
NCCL_API(ncclResult_t, ncclCommDeregister, const ncclComm_t comm, void* handle); | ||
ncclResult_t ncclCommDeregister_impl(const ncclComm_t comm, void* handle) { | ||
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); | ||
struct ncclReg* reg = (struct ncclReg*)handle; | ||
struct ncclRegCache* cache = &comm->regCache; | ||
int slot; | ||
for (slot=0; slot<cache->population && cache->slots[slot] != reg; slot++); | ||
if (slot == cache->population) { | ||
WARN("Deregister: Could not find handle"); | ||
return ncclInvalidUsage; | ||
} | ||
if (--reg->refs) return ncclSuccess; | ||
NCCLCHECK(ncclNetDeregister(comm, reg)); | ||
if (reg->state & NVLS_REG_COMPLETE) { | ||
NCCLCHECK(ncclNvlsDeregBuffer(®->mcHandle, reg->regAddr, reg->dev, reg->regSize)); | ||
reg->regAddr = (CUdeviceptr)NULL; | ||
} | ||
if (reg->state & COLLNET_REG_COMPLETE) { | ||
NCCLCHECK(ncclCollnetDeregBuffer(comm, reg->proxyconn, reg->collnetHandle)); | ||
} | ||
free(reg); | ||
memmove(cache->slots+slot, cache->slots+slot+1, (cache->population-slot-1)*sizeof(struct ncclReg*)); | ||
cache->population -= 1; | ||
|
||
#ifdef ENABLE_MSCCLPP | ||
if (comm->mscclCompatible){ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here you should do the same checks that you did before registering the buffer. |
||
NCCLCHECK(mscclpp_ncclCommDeregister(comm->mscclpp_comm, handle)); | ||
} | ||
else | ||
#endif | ||
{ | ||
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); | ||
struct ncclReg* reg = (struct ncclReg*)handle; | ||
struct ncclRegCache* cache = &comm->regCache; | ||
int slot; | ||
for (slot=0; slot<cache->population && cache->slots[slot] != reg; slot++); | ||
if (slot == cache->population) { | ||
WARN("Deregister: Could not find handle"); | ||
return ncclInvalidUsage; | ||
} | ||
if (--reg->refs) return ncclSuccess; | ||
NCCLCHECK(ncclNetDeregister(comm, reg)); | ||
if (reg->state & NVLS_REG_COMPLETE) { | ||
NCCLCHECK(ncclNvlsDeregBuffer(®->mcHandle, reg->regAddr, reg->dev, reg->regSize)); | ||
reg->regAddr = (CUdeviceptr)NULL; | ||
} | ||
if (reg->state & COLLNET_REG_COMPLETE) { | ||
NCCLCHECK(ncclCollnetDeregBuffer(comm, reg->proxyconn, reg->collnetHandle)); | ||
} | ||
free(reg); | ||
memmove(cache->slots+slot, cache->slots+slot+1, (cache->population-slot-1)*sizeof(struct ncclReg*)); | ||
cache->population -= 1; | ||
} | ||
return ncclSuccess; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change does not seem to be necessary