diff --git a/src/nccl_ofi_interface_neuron.c b/src/nccl_ofi_interface_neuron.c index 8ab2f83d5..151a20efc 100644 --- a/src/nccl_ofi_interface_neuron.c +++ b/src/nccl_ofi_interface_neuron.c @@ -6,6 +6,23 @@ #include "nccl_ofi.h" #include "nccl_ofi_api.h" +#include "nccl_ofi_param.h" + +static ncclResult_t init_v4(ncclDebugLogger_t logFunction) +{ + /* + * RDMA protocol `connect()` returns a valid send communicator only + * after a connect response message is received from peer. Because the + * v4 net-plugin `connect()` API is expected to synchronously return a + * valid send communicator (a behaviour that was changed since v5+), + * this RDMA protocol behaviour is incompatible with v4 `connect()` + * API. + */ + if (ofi_nccl_protocol() == NULL) { + setenv("OFI_NCCL_PROTOCOL", "SENDRECV", 0); + } + return nccl_net_ofi_init(logFunction); +} static ncclResult_t getProperties_v4(int dev_id, ncclNetProperties_v4_t *props) { @@ -29,10 +46,9 @@ static ncclResult_t getProperties_v4(int dev_id, ncclNetProperties_v4_t *props) return ncclSuccess; } - NCCL_OFI_EXPORT_SYMBOL const ncclNet_v4_t ncclNetPlugin_v4 = { .name = "AWS Libfabric", - .init = nccl_net_ofi_init, + .init = init_v4, .devices = nccl_net_ofi_devices, .getProperties = getProperties_v4, .listen = nccl_net_ofi_listen_v4,