Skip to content

Commit

Permalink
[gpu][ds-fusion] Add handling for offset module in ds-fusion thunk
Browse files Browse the repository at this point in the history
This patch adds support for offset modules in ds fusion thunk. It also
moves the `ResourceRequests` structure from `gpu_executable.cc` to
`gpu_executable.h` because a valid implementation of the abstract class
`Thunk::ResourceRequests` is required for calling `Thunk::Prepare()`.
  • Loading branch information
shraiysh committed Dec 23, 2024
1 parent 7e03b71 commit 3edf4ab
Show file tree
Hide file tree
Showing 6 changed files with 441 additions and 176 deletions.
312 changes: 151 additions & 161 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,190 +197,180 @@ static PersistentCliquesMap& GetPersistentCliquesMap() {
return *persistent_cliques;
}

// Shared resources required for thunk initialization and execution.
class ResourceRequests : public Thunk::ResourceRequests {
public:
absl::Status AddClique(const GpuCliqueKey& clique_key,
int32_t num_local_participants) final {
VLOG(5) << "Add collective clique request: " << clique_key.ToString()
<< "; num_local_participants: " << num_local_participants;

// Check if there is already a clique request for this clique key.
if (auto it = cliques_.find(clique_key); it != cliques_.end()) {
// We can't have multiple requests for a same clique key with different
// number of local participants as we can acquire a clique only once and
// we have to know how many executables will join the rendezvous.
if (it->second.num_local_participants != num_local_participants) {
return absl::InternalError(absl::StrFormat(
"Clique request for a clique key %s has number of local "
"participants %d different from previously requested value of %d. "
"This will lead to deadlock at run time and is an XLA compiler "
"bug. Please report it to XLA team.",
clique_key.ToString(), num_local_participants,
it->second.num_local_participants));
}
return absl::OkStatus();
}
} // namespace

// XLA compiler guarantees that all collective operations have the same
// order on all replicas. We rely on this property to assign unique id to
// clique requests simply based on the number of already recored requests.
int64_t id = cliques_.size();
cliques_.try_emplace(clique_key,
CliqueRequest{clique_key, num_local_participants, id});
absl::Status ResourceRequests::AddClique(const GpuCliqueKey& clique_key,
int32_t num_local_participants) {
VLOG(5) << "Add collective clique request: " << clique_key.ToString()
<< "; num_local_participants: " << num_local_participants;

// Check if there is already a clique request for this clique key.
if (auto it = cliques_.find(clique_key); it != cliques_.end()) {
// We can't have multiple requests for a same clique key with different
// number of local participants as we can acquire a clique only once and
// we have to know how many executables will join the rendezvous.
if (it->second.num_local_participants != num_local_participants) {
return absl::InternalError(absl::StrFormat(
"Clique request for a clique key %s has number of local "
"participants %d different from previously requested value of %d. "
"This will lead to deadlock at run time and is an XLA compiler "
"bug. Please report it to XLA team.",
clique_key.ToString(), num_local_participants,
it->second.num_local_participants));
}
return absl::OkStatus();
}

absl::StatusOr<Thunk::CollectiveCliques> AcquireCollectiveCliques(
const Thunk::CollectiveExecuteParams& params,
bool use_persistent_cliques) {
if (cliques_.empty()) return Thunk::CollectiveCliques();

VLOG(2) << "Acquire " << cliques_.size()
<< " collective cliques for global device id "
<< params.global_device_id.value()
<< "; run_id=" << params.run_id.ToInt()
<< "; max number of channels for collectives "
<< params.collective_max_nchannels
<< "; max number of channels for p2p " << params.p2p_max_nchannels
<< "; use_persistent_cliques=" << use_persistent_cliques;

std::vector<CliqueRequest> ordered_cliques = GetOrderedCliqueRequests();
for (size_t i = 0; i < ordered_cliques.size(); ++i) {
const CliqueRequest& r = ordered_cliques[i];
VLOG(2) << " clique #" << i << " (for global device id "
<< params.global_device_id.value() << ")"
<< ": num_local_participants=" << r.num_local_participants
<< "; id=" << r.id << "; key=" << r.key.ToString();
}
// XLA compiler guarantees that all collective operations have the same
// order on all replicas. We rely on this property to assign unique id to
// clique requests simply based on the number of already recored requests.
int64_t id = cliques_.size();
cliques_.try_emplace(clique_key,
CliqueRequest{clique_key, num_local_participants, id});
return absl::OkStatus();
}

tsl::profiler::TraceMe trace([&] {
return tsl::profiler::TraceMeEncode(
"AcquireCollectiveCliques",
{{"num_cliques", cliques_.size()},
{"use_persistent_cliques", use_persistent_cliques}});
});
absl::StatusOr<Thunk::CollectiveCliques>
ResourceRequests::AcquireCollectiveCliques(
const Thunk::CollectiveExecuteParams& params, bool use_persistent_cliques) {
if (cliques_.empty()) return Thunk::CollectiveCliques();

VLOG(2) << "Acquire " << cliques_.size()
<< " collective cliques for global device id "
<< params.global_device_id.value()
<< "; run_id=" << params.run_id.ToInt()
<< "; max number of channels for collectives "
<< params.collective_max_nchannels
<< "; max number of channels for p2p " << params.p2p_max_nchannels
<< "; use_persistent_cliques=" << use_persistent_cliques;

std::vector<CliqueRequest> ordered_cliques = GetOrderedCliqueRequests();
for (size_t i = 0; i < ordered_cliques.size(); ++i) {
const CliqueRequest& r = ordered_cliques[i];
VLOG(2) << " clique #" << i << " (for global device id "
<< params.global_device_id.value() << ")"
<< ": num_local_participants=" << r.num_local_participants
<< "; id=" << r.id << "; key=" << r.key.ToString();
}

auto start_micros = tsl::Env::Default()->NowMicros();
tsl::profiler::TraceMe trace([&] {
return tsl::profiler::TraceMeEncode(
"AcquireCollectiveCliques",
{{"num_cliques", cliques_.size()},
{"use_persistent_cliques", use_persistent_cliques}});
});

AcquiredCliquesMap cliques_map;
int32_t num_transient_cliques = 0;
auto start_micros = tsl::Env::Default()->NowMicros();

for (const CliqueRequest& r : ordered_cliques) {
std::optional<RankId> rank = r.key.rank(params.global_device_id);
AcquiredCliquesMap cliques_map;
int32_t num_transient_cliques = 0;

if (!rank.has_value()) {
return absl::InternalError(absl::StrCat(
"Can't find global device id ", params.global_device_id.value(),
" in clique key ", r.key.ToString()));
}
for (const CliqueRequest& r : ordered_cliques) {
std::optional<RankId> rank = r.key.rank(params.global_device_id);

bool is_local = r.key.devices().size() == r.num_local_participants;
TF_ASSIGN_OR_RETURN(const CliqueIdCallback* clique_id_callback,
params.collectives->GetCliqueIdCallback(
params.nccl_clique_id_callback, is_local));
if (!rank.has_value()) {
return absl::InternalError(absl::StrCat(
"Can't find global device id ", params.global_device_id.value(),
" in clique key ", r.key.ToString()));
}

int64_t max_channels = r.key.stream_kind() == AsyncStreamKind::kCollective
? params.collective_max_nchannels
: params.p2p_max_nchannels;
bool is_local = r.key.devices().size() == r.num_local_participants;
TF_ASSIGN_OR_RETURN(const CliqueIdCallback* clique_id_callback,
params.collectives->GetCliqueIdCallback(
params.nccl_clique_id_callback, is_local));

// Check if we have a persistent clique for this key.
if (use_persistent_cliques) {
auto& pc = GetPersistentCliquesMap();
absl::MutexLock lock(&pc.mutex);
int64_t max_channels = r.key.stream_kind() == AsyncStreamKind::kCollective
? params.collective_max_nchannels
: params.p2p_max_nchannels;

if (auto it = pc.cliques_map.find(r.key); it != pc.cliques_map.end()) {
VLOG(2) << "Found persistent clique for key " << r.key.ToString();
cliques_map[r.key] = it->second;
continue;
}
}
// Check if we have a persistent clique for this key.
if (use_persistent_cliques) {
auto& pc = GetPersistentCliquesMap();
absl::MutexLock lock(&pc.mutex);

// If we don't have a persistent clique we have to acquire a transient
// one.
TF_ASSIGN_OR_RETURN(
std::shared_ptr<LockableGpuClique::Lock> clique,
AcquireGpuClique(params.collectives, params.executor, params.run_id,
r.key, *clique_id_callback, *rank,
r.num_local_participants, cliques_map,
max_channels));
++num_transient_cliques;

// Take a copy of the clique lock, so that we can reuse it. This is
// potentially unsafe in the case when we have multiple racing executions
// of XLA, as we might observe partial state and some of the replicas will
// use persistent clique, and others will try to acquire a new one.
//
// However given that persistent cliques is an unsafe escape hatch, any
// racing execution together with persistent cliques will lead to
// deadlocks anyway, so we don't bother to fix this. If anyone is doing
// it, it's 100% their fault and they will suffer.
if (use_persistent_cliques) {
auto& pc = GetPersistentCliquesMap();
absl::MutexLock lock(&pc.mutex);
pc.cliques_map[r.key] = clique;
if (auto it = pc.cliques_map.find(r.key); it != pc.cliques_map.end()) {
VLOG(2) << "Found persistent clique for key " << r.key.ToString();
cliques_map[r.key] = it->second;
continue;
}

cliques_map[r.key] = std::move(clique);
}

auto end_micros = tsl::Env::Default()->NowMicros();
VLOG(2) << "Acquired " << cliques_map.size()
<< " collective cliques for global device id "
<< params.global_device_id.value() << " in "
<< (end_micros - start_micros) << " μs"
<< "; run_id=" << params.run_id.ToInt()
<< "; num_transient_cliques=" << num_transient_cliques;
// If we don't have a persistent clique we have to acquire a transient
// one.
TF_ASSIGN_OR_RETURN(
std::shared_ptr<LockableGpuClique::Lock> clique,
AcquireGpuClique(params.collectives, params.executor, params.run_id,
r.key, *clique_id_callback, *rank,
r.num_local_participants, cliques_map, max_channels));
++num_transient_cliques;

// Take a copy of the clique lock, so that we can reuse it. This is
// potentially unsafe in the case when we have multiple racing executions
// of XLA, as we might observe partial state and some of the replicas will
// use persistent clique, and others will try to acquire a new one.
//
// However given that persistent cliques is an unsafe escape hatch, any
// racing execution together with persistent cliques will lead to
// deadlocks anyway, so we don't bother to fix this. If anyone is doing
// it, it's 100% their fault and they will suffer.
if (use_persistent_cliques) {
auto& pc = GetPersistentCliquesMap();
absl::MutexLock lock(&pc.mutex);
pc.cliques_map[r.key] = clique;
}

return Thunk::CollectiveCliques(std::move(cliques_map),
num_transient_cliques);
cliques_map[r.key] = std::move(clique);
}

private:
struct CliqueRequest {
GpuCliqueKey key;
int64_t num_local_participants;
int64_t id;
};
auto end_micros = tsl::Env::Default()->NowMicros();
VLOG(2) << "Acquired " << cliques_map.size()
<< " collective cliques for global device id "
<< params.global_device_id.value() << " in "
<< (end_micros - start_micros) << " μs"
<< "; run_id=" << params.run_id.ToInt()
<< "; num_transient_cliques=" << num_transient_cliques;

// Return clique requests deterministically ordered using a comparison
// function that produces identical ordering for all participating ranks.
//
// Example: 8 ranks splitted in different groups of communicators
//
// Group #0: [0,1], [2,3], [4,5], [6,7]
// Group #1: [0,4], [1,5], [2,6], [3,7]
//
// Both groups #0 and #1 can be acqured by splitting [0...7] clique. To avoid
// deadlocks all participants should acquire all cliques in a group #0 before
// acquiring any cliques in a group #1.
//
// We rely on clique request id to guarantee that the order is identical
// on all participating ranks (including ranks running on different hosts).
std::vector<CliqueRequest> GetOrderedCliqueRequests() {
std::vector<CliqueRequest> cliques;
cliques.reserve(cliques_.size());
for (const auto& [_, request] : cliques_) cliques.push_back(request);

absl::c_sort(cliques, [](const CliqueRequest& a, const CliqueRequest& b) {
// Acquire larger cliques first to be able to split them later.
if (a.key.devices().size() > b.key.devices().size()) return true;
if (b.key.devices().size() > a.key.devices().size()) return false;

// If cliques have the same size prefer cliques with smaller stream id.
if (a.key.stream_id().value() < b.key.stream_id().value()) return true;
if (b.key.stream_id().value() < a.key.stream_id().value()) return false;

// Prefer cliques with smaller id (comes earlier in execution order).
return a.id < b.id;
});

return cliques;
}
return Thunk::CollectiveCliques(std::move(cliques_map),
num_transient_cliques);
}

absl::flat_hash_map<GpuCliqueKey, CliqueRequest> cliques_;
};
// Return clique requests deterministically ordered using a comparison
// function that produces identical ordering for all participating ranks.
//
// Example: 8 ranks splitted in different groups of communicators
//
// Group #0: [0,1], [2,3], [4,5], [6,7]
// Group #1: [0,4], [1,5], [2,6], [3,7]
//
// Both groups #0 and #1 can be acqured by splitting [0...7] clique. To avoid
// deadlocks all participants should acquire all cliques in a group #0 before
// acquiring any cliques in a group #1.
//
// We rely on clique request id to guarantee that the order is identical
// on all participating ranks (including ranks running on different hosts).
std::vector<ResourceRequests::CliqueRequest>
ResourceRequests::GetOrderedCliqueRequests() {
std::vector<CliqueRequest> cliques;
cliques.reserve(cliques_.size());
for (const auto& [_, request] : cliques_) cliques.push_back(request);

absl::c_sort(cliques, [](const CliqueRequest& a, const CliqueRequest& b) {
// Acquire larger cliques first to be able to split them later.
if (a.key.devices().size() > b.key.devices().size()) return true;
if (b.key.devices().size() > a.key.devices().size()) return false;

// If cliques have the same size prefer cliques with smaller stream id.
if (a.key.stream_id().value() < b.key.stream_id().value()) return true;
if (b.key.stream_id().value() < a.key.stream_id().value()) return false;

// Prefer cliques with smaller id (comes earlier in execution order).
return a.id < b.id;
});

return cliques;
}

namespace {
absl::Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options,
se::EventBasedTimer* execution_timer,
se::Stream* stream_to_sync);
Expand Down
36 changes: 36 additions & 0 deletions xla/service/gpu/gpu_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,42 @@ limitations under the License.
namespace xla {
namespace gpu {

// Shared resources required for thunk initialization and execution.
class ResourceRequests : public Thunk::ResourceRequests {
public:
absl::Status AddClique(const GpuCliqueKey& clique_key,
int32_t num_local_participants) final;

absl::StatusOr<Thunk::CollectiveCliques> AcquireCollectiveCliques(
const Thunk::CollectiveExecuteParams& params,
bool use_persistent_cliques);

private:
struct CliqueRequest {
GpuCliqueKey key;
int64_t num_local_participants;
int64_t id;
};

// Return clique requests deterministically ordered using a comparison
// function that produces identical ordering for all participating ranks.
//
// Example: 8 ranks splitted in different groups of communicators
//
// Group #0: [0,1], [2,3], [4,5], [6,7]
// Group #1: [0,4], [1,5], [2,6], [3,7]
//
// Both groups #0 and #1 can be acqured by splitting [0...7] clique. To avoid
// deadlocks all participants should acquire all cliques in a group #0 before
// acquiring any cliques in a group #1.
//
// We rely on clique request id to guarantee that the order is identical
// on all participating ranks (including ranks running on different hosts).
std::vector<CliqueRequest> GetOrderedCliqueRequests();

absl::flat_hash_map<GpuCliqueKey, CliqueRequest> cliques_;
};

// GPU-targeting implementation of the XLA Executable interface.
//
// Launches the given GPU kernel via the StreamExecutor.
Expand Down
Loading

0 comments on commit 3edf4ab

Please sign in to comment.