Skip to content

Commit

Permalink
[xla:gpu] Synchronize device activity before initializing NCCL clique
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652551858
  • Loading branch information
ezhulenev authored and copybara-github committed Jul 15, 2024
1 parent 32600b1 commit 8d2a60d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ cc_library(
"//xla:debug_options_flags",
"//xla:executable_run_options",
"//xla:status_macros",
"//xla:util",
"//xla/service:global_device_id",
"//xla/service:lockable",
"//xla/service:rendezvous",
Expand Down
31 changes: 27 additions & 4 deletions xla/service/gpu/runtime/nccl_clique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include "absl/container/btree_map.h"
#include "absl/container/node_hash_map.h"
#include "absl/functional/function_ref.h"
#include "absl/hash/hash.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand All @@ -48,6 +47,7 @@ limitations under the License.
#include "xla/service/rendezvous.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/hash.h"
Expand Down Expand Up @@ -259,15 +259,28 @@ static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
// Start NCCL clique heart beat monitor when create a first clique.
StartNcclCliqueHeartBeatMonitor();

using RendezvousArg = std::pair<NcclApi::DeviceRank, /*synchronized=*/bool>;

// Initializes a NcclClique for given device ranks and returns a lock that
// gives access to clique communicators.
auto initialize = [&](absl::Span<const NcclApi::DeviceRank* const> args)
auto initialize = [&](absl::Span<const RendezvousArg* const> args)
-> absl::StatusOr<NcclClique::Lock> {
TF_ASSIGN_OR_RETURN(auto clique_id, clique_id_callback(clique_key));

// Check that all ranks successfully synchronized device activity before
// trying to instantiate NCCL communicators.
for (const RendezvousArg* arg : args) {
if (auto& [device_rank, synchronized] = *arg; !synchronized) {
return Internal(
"Failed to synchronize device activity on rank %d. Do not attempt "
"to initialize NCCL clique.",
device_rank.rank);
}
}

std::vector<NcclApi::DeviceRank> ranks;
ranks.reserve(args.size());
for (auto* arg : args) ranks.emplace_back(*arg);
for (auto* arg : args) ranks.emplace_back(arg->first);

// Sort device ranks, mainly to get more readable logs below, NCCL does
// not care in what order ranks are initialized.
Expand Down Expand Up @@ -322,9 +335,19 @@ static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
rank, clique_key.ToString(), run_id.ToInt());

NcclApi::DeviceRank device_rank = {device, rank};
bool synchronized = device->SynchronizeAllActivity();

// We choose not to exit early on failed synchronization, because it will lead
// to a deadlock, as not all participants will arrive to a rendezvous point,
// instead we check synchronization result in the initialization callback.
//
// Unfortunately we can't share synchronization result across different
// processes, so we still might end up in a deadlock situation when some
// processes are not able to synchronize device activity.
RendezvousArg rendezvous_arg = std::make_pair(device_rank, synchronized);

return RendezvousSingle<absl::StatusOr<NcclClique::Lock>>(
initialization_rendezvous_name, rendezvous_key, device_rank,
initialization_rendezvous_name, rendezvous_key, rendezvous_arg,
num_local_participants, initialize, WarnStuckTimeout(),
TerminateTimeout());
}
Expand Down

0 comments on commit 8d2a60d

Please sign in to comment.