Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MSCCLPP user buffer registration APIs and integrate with RCCL #1477

Merged
merged 14 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions cmake/MSCCLPP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ if(ENABLE_MSCCLPP)
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)

message(STATUS "Building mscclpp only for gfx942.")

Expand Down Expand Up @@ -98,13 +102,18 @@ if(ENABLE_MSCCLPP)

find_package(mscclpp_nccl REQUIRED)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)

endif()

execute_process(COMMAND objcopy
Expand Down
89 changes: 89 additions & 0 deletions ext-src/mem-reg.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h
index 7f50792..b8b146d 100644
--- a/apps/nccl/include/nccl.h
+++ b/apps/nccl/include/nccl.h
@@ -344,6 +344,12 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun
ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
ncclComm_t comm, cudaStream_t stream);

+/*
+ * Register/Deregister
+ */
+ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle);
+ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle);
+ncclResult_t ncclBuffIsRegistered(ncclComm_t comm, const void* buff, size_t count, bool* registered);
/*
* Send
*
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
index a697be2..d8497e7 100644
--- a/apps/nccl/src/nccl.cu
+++ b/apps/nccl/src/nccl.cu
@@ -577,6 +577,67 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change does not seem to be necessary

return ncclSuccess;
}

+NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) {
+ size_t buffBytes = size;
+ CUdeviceptr buffBasePtr;
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
+
+ int rank = comm->comm->bootstrap()->getRank();
+ channelKey buffKey{(void*)buffBasePtr, buffBytes};
+
+ std::vector<mscclpp::RegisteredMemory> remoteMemories;
+
+ // Creating the channels
+ auto buffIt = comm->channelScratchInfos.find(buffKey);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's no guarantee this will be called from a single thread, we probably need to lock access to channelScratchInfos and the rest of the maps. I suggest using one lock if this function doesn't fall on the critical path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This map is supposed to be checked by the host process and MSCCLPP is invoked only when using one GPU per process case.

+ if (buffIt == comm->channelScratchInfos.end()) {
+ std::vector<mscclpp::SmChannel> channels =
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
+ buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first;
+ }
+ auto sendIt = comm->channelInInfos.find(buffKey);
+ if (sendIt == comm->channelInInfos.end()) {
+ std::vector<mscclpp::SmChannel> channels =
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
+
+ remoteMemories =
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc);
+ std::vector<mscclpp::SmChannel> channels1 =
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr));
+
+ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)};
+ sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first;
+ }
+ auto recvIt = comm->channelOutInfos.find(buffKey);
+ if (recvIt == comm->channelOutInfos.end()) {
+ remoteMemories =
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc);
+ std::vector<mscclpp::SmChannel> outChannels =
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr));
+ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)};
+ recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first;
+ }
+ *handle = (void*) buffBasePtr;
+
+ return ncclSuccess;
+}
+
+NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
+ if (comm && handle) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we free any resources associated with handle here?
Example:

auto it = comm->channelScratchInfos.find(buffKey);
if (it != comm->channelScratchInfos.end()) {
  comm->channelScratchInfos.erase(it);
}

Same thing can be done for channelOut and channelIn

+ handle = nullptr;
+ }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation nullifies the handle inside the function, but since it's passed as void*, the change doesn't reflect in the caller's context. This can lead to dangling pointers. I suggest passing handle as a void** to resolve this issue, if API change is possible

Suggested change
+NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
+ if (comm && handle) {
+ handle = nullptr;
+ }
+NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void** handle) {
+ if (comm && handle && *handle) {
+ *handle = nullptr;
+ }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a NCCL API and cannot be changed. I am still working on some of these fixes.

+ return ncclSuccess;
+}
+
+NCCL_API ncclResult_t ncclBuffIsRegistered(ncclComm_t comm, const void* buff, size_t count, bool* registered){
+ size_t buffBytes;
+ CUdeviceptr buffBasePtr;
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
+ channelKey buffKey{(void*)buffBasePtr, buffBytes};
+ auto buffIt = comm->channelScratchInfos.find(buffKey);
+ *registered = buffIt != comm->channelScratchInfos.end();
+ return ncclSuccess;
+}
NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
// TODO: implement this function
return ncclInternalError;
6 changes: 6 additions & 0 deletions src/include/mscclpp/mscclpp_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ extern "C" {
/* See ncclAllGather. */
ncclResult_t mscclpp_ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, mscclppComm_t comm, hipStream_t stream);

ncclResult_t mscclpp_ncclCommRegister(mscclppComm_t comm, void* buff, size_t size, void** handle);

ncclResult_t mscclpp_ncclCommDeregister(mscclppComm_t comm, void* handle);

ncclResult_t mscclpp_ncclBuffIsRegistered(mscclppComm_t comm, const void* buff, size_t count, bool* registered);
}

namespace std {
Expand Down
18 changes: 16 additions & 2 deletions src/misc/msccl/msccl_lifecycle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,15 @@ ncclResult_t mscclEnqueueCheck(
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
}

bool sendBuffRegistered = false;
bool recvBuffRegistered = false;
mscclpp_ncclBuffIsRegistered(comm->mscclpp_comm, sendBuff, count, &sendBuffRegistered);
mscclpp_ncclBuffIsRegistered(comm->mscclpp_comm, sendBuff, count, &recvBuffRegistered);
const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture;
const bool buffsRegistedNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered;

/* check if one rank per GPU and graph mode is enabled */
if ((threadLocalStatus.captureStatus != mscclNoCapture) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
if ((graphMode || buffsRegistedNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
bool isManagedBuffer = false;
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
Expand Down Expand Up @@ -565,8 +572,15 @@ ncclResult_t mscclEnqueueCheck(
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
}

bool sendBuffRegistered = false;
bool recvBuffRegistered = false;
mscclpp_ncclBuffIsRegistered(comm->mscclpp_comm, sendBuff, count, &sendBuffRegistered);
mscclpp_ncclBuffIsRegistered(comm->mscclpp_comm, sendBuff, count, &recvBuffRegistered);
const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture;
const bool buffsRegistedNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered;

/* check if one rank per GPU and graph mode is enabled */
if ((threadLocalStatus.captureStatus != mscclNoCapture) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
if ((graphMode || buffsRegistedNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
bool isManagedBuffer = false;
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
Expand Down
3 changes: 3 additions & 0 deletions src/misc/mscclpp/mscclpp_nccl_syms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ ncclRedOpDestroy mscclpp_ncclRedOpDestroy
ncclReduce mscclpp_ncclReduce
ncclReduceScatter mscclpp_ncclReduceScatter
ncclSend mscclpp_ncclSend
ncclCommRegister mscclpp_ncclCommRegister
ncclCommDeregister mscclpp_ncclCommDeregister
ncclBuffIsRegistered mscclpp_ncclBuffIsRegistered
72 changes: 50 additions & 22 deletions src/register.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "net.h"
#include "register.h"
#include "api_trace.h"
#include "mscclpp/mscclpp_nccl.h"

ncclResult_t ncclNetDeregister(struct ncclComm* comm, struct ncclReg* reg) {
struct ncclRegCache* cache = &comm->regCache;
Expand Down Expand Up @@ -155,32 +156,59 @@ NCCL_API(ncclResult_t, ncclCommRegister, const ncclComm_t comm, void* buff, size
ncclResult_t ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** handle) {
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm"));
if (comm->checkPointers) NCCLCHECK(CudaPtrCheck(buff, comm, "buff", "ncclCommRegister"));
NCCLCHECK(ncclRegister(comm, buff, size, handle));

#ifdef ENABLE_MSCCLPP
if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold){
bool isManagedBuffer = false;
CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(buff)));
if(!isManagedBuffer){
INFO(NCCL_INIT, "MSCCL++: ncclCommRegister");
NCCLCHECK(mscclpp_ncclCommRegister(comm->mscclpp_comm, buff, size, handle));
}
else{
WARN("MSCCL++: Cannot register user-buffers on managed memory");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will do nothing, issue a warning, then fall through and return ncclSuccess. Is that the right behaviour? Should it not return an error?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the expectation is that if no registration happens in MSCCL++ due to the buffer being in managed memory, no MSCCL++ or RCCL registration will happen and execution will fall back to RCCL kernel without UBR.

I actually need to do a small adjustment to not cause a RCCL deregistration error when MSCCL++ detects no registered buffer because of this. I just tested a simple fix for this.

But since we always fallback to RCCL kernel when MSCCL++ is not enabled, I think a warning without return an error is OK. Also, I haven't seen any errors returned in any UBR related routines, except when RCCL cannot find the registration handle during deregistration. Of course, this definitely makes it harder to see if anything goes wrong.

}
else
#endif
{
INFO(NCCL_INIT, "RCCL: ncclCommRegister");
NCCLCHECK(ncclRegister(comm, buff, size, handle));
}
isaki001 marked this conversation as resolved.
Show resolved Hide resolved
return ncclSuccess;
}

NCCL_API(ncclResult_t, ncclCommDeregister, const ncclComm_t comm, void* handle);
ncclResult_t ncclCommDeregister_impl(const ncclComm_t comm, void* handle) {
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm"));
struct ncclReg* reg = (struct ncclReg*)handle;
struct ncclRegCache* cache = &comm->regCache;
int slot;
for (slot=0; slot<cache->population && cache->slots[slot] != reg; slot++);
if (slot == cache->population) {
WARN("Deregister: Could not find handle");
return ncclInvalidUsage;
}
if (--reg->refs) return ncclSuccess;
NCCLCHECK(ncclNetDeregister(comm, reg));
if (reg->state & NVLS_REG_COMPLETE) {
NCCLCHECK(ncclNvlsDeregBuffer(&reg->mcHandle, reg->regAddr, reg->dev, reg->regSize));
reg->regAddr = (CUdeviceptr)NULL;
}
if (reg->state & COLLNET_REG_COMPLETE) {
NCCLCHECK(ncclCollnetDeregBuffer(comm, reg->proxyconn, reg->collnetHandle));
}
free(reg);
memmove(cache->slots+slot, cache->slots+slot+1, (cache->population-slot-1)*sizeof(struct ncclReg*));
cache->population -= 1;

#ifdef ENABLE_MSCCLPP
if (comm->mscclCompatible){
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you should do the same checks that you did before registering the buffer.
if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold){

NCCLCHECK(mscclpp_ncclCommDeregister(comm->mscclpp_comm, handle));
}
else
#endif
{
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm"));
struct ncclReg* reg = (struct ncclReg*)handle;
struct ncclRegCache* cache = &comm->regCache;
int slot;
for (slot=0; slot<cache->population && cache->slots[slot] != reg; slot++);
if (slot == cache->population) {
WARN("Deregister: Could not find handle");
return ncclInvalidUsage;
}
if (--reg->refs) return ncclSuccess;
NCCLCHECK(ncclNetDeregister(comm, reg));
if (reg->state & NVLS_REG_COMPLETE) {
NCCLCHECK(ncclNvlsDeregBuffer(&reg->mcHandle, reg->regAddr, reg->dev, reg->regSize));
reg->regAddr = (CUdeviceptr)NULL;
}
if (reg->state & COLLNET_REG_COMPLETE) {
NCCLCHECK(ncclCollnetDeregBuffer(comm, reg->proxyconn, reg->collnetHandle));
}
free(reg);
memmove(cache->slots+slot, cache->slots+slot+1, (cache->population-slot-1)*sizeof(struct ncclReg*));
cache->population -= 1;
}
return ncclSuccess;
}