diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h index 7f507927b..216bf1bd8 100644 --- a/apps/nccl/include/nccl.h +++ b/apps/nccl/include/nccl.h @@ -344,6 +344,14 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); + +/* + * Register + */ +ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle); +ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle); + + /* * Send * diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index cb0e7d56e..d95bda450 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -214,7 +214,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, channelKey recvKey{(void*)recvBasePtr, recvBytes}; mscclpp::DeviceHandle* smChannels = nullptr; mscclpp::DeviceHandle* smOutChannels = nullptr; - + // Creating the channels if (count * ncclTypeSize(datatype) <= comm->largeMessageSizeBoundary) { auto sendIt = comm->channelScratchInfos.find(sendKey); @@ -566,6 +566,58 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t return ncclSuccess; } +NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) { + size_t buffBytes; + CUdeviceptr buffBasePtr; + + MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); + + size_t offsetIn = (char*)buff - (char*)buffBasePtr; + uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff; + size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx; + int rank = comm->comm->bootstrap()->getRank(); + channelKey buffKey{(void*)buffBasePtr, buffBytes}; + mscclpp::DeviceHandle* smChannels = nullptr; + mscclpp::DeviceHandle* smOutChannels = nullptr; + std::vector remoteMemories; + + // Creating the channels + auto buffIt = comm->channelScratchInfos.find(buffKey); + if (buffIt == comm->channelScratchInfos.end()) { + std::vector channels = + setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)buffBasePtr)); + ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first; + } + + auto sendIt = comm->channelInInfos.find(buffKey); + if (sendIt == comm->channelInInfos.end()) { + std::vector channels = + setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)buffBasePtr)); + ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first; + } + + auto recvIt = comm->channelOutInfos.find(buffKey); + if (recvIt == comm->channelOutInfos.end()) { + remoteMemories = + setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); + std::vector outChannels = + setupSmChannels(comm, remoteMemories, const_cast((void*)buffBasePtr)); + ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)}; + recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first; + } + *handle = (void*) buffBasePtr; + + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) { + handle = nullptr; + return ncclSuccess; +} + + NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) { // TODO: implement this function return ncclInternalError;