Skip to content

Commit

Permalink
apps/nccl: add memory registartion APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
nusislam committed Dec 18, 2024
1 parent 1e82dd4 commit 1884809
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
8 changes: 8 additions & 0 deletions apps/nccl/include/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,14 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun
ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
ncclComm_t comm, cudaStream_t stream);


/*
* Register
*/
ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle);
ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle);


/*
* Send
*
Expand Down
54 changes: 53 additions & 1 deletion apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
channelKey recvKey{(void*)recvBasePtr, recvBytes};
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;

// Creating the channels
if (count * ncclTypeSize(datatype) <= comm->largeMessageSizeBoundary) {
auto sendIt = comm->channelScratchInfos.find(sendKey);
Expand Down Expand Up @@ -566,6 +566,58 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) {
size_t buffBytes;
CUdeviceptr buffBasePtr;

MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));

size_t offsetIn = (char*)buff - (char*)buffBasePtr;
uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff;
size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx;
int rank = comm->comm->bootstrap()->getRank();
channelKey buffKey{(void*)buffBasePtr, buffBytes};
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;
std::vector<mscclpp::RegisteredMemory> remoteMemories;

// Creating the channels
auto buffIt = comm->channelScratchInfos.find(buffKey);
if (buffIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first;
}

auto sendIt = comm->channelInInfos.find(buffKey);
if (sendIt == comm->channelInInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first;
}

auto recvIt = comm->channelOutInfos.find(buffKey);
if (recvIt == comm->channelOutInfos.end()) {
remoteMemories =
setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> outChannels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr));
ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)};
recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first;
}
*handle = (void*) buffBasePtr;

return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
handle = nullptr;
return ncclSuccess;
}


NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
// TODO: implement this function
return ncclInternalError;
Expand Down

0 comments on commit 1884809

Please sign in to comment.