Skip to content

Commit

Permalink
Wrap HLO strings in collective permute decomposer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720302446
  • Loading branch information
frgossen authored and Google-ML-Automation committed Jan 27, 2025
1 parent 285edf7 commit c977652
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 124 deletions.
25 changes: 19 additions & 6 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_collective_permute_decomposer_threshold(
std::numeric_limits<int64_t>::max());
opts.set_xla_gpu_experimental_enable_pipeline_parallelism_opt(false);
opts.set_xla_gpu_experimental_pipeline_parallelism_opt_level(
DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE);

opts.set_xla_partitioning_algorithm(
DebugOptions::PARTITIONING_ALGORITHM_NOOP);
Expand Down Expand Up @@ -494,6 +495,19 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
return true;
};

// Custom "sub-parser" lambda for
// xla_gpu_experimental_pipeline_parallelism_opt_level.
auto setter_for_xla_gpu_experimental_pipeline_parallelism_opt_level =
[debug_options](const std::string& value) {
DebugOptions::PipelineParallelismOptLevel level;
if (!DebugOptions::PipelineParallelismOptLevel_Parse(value, &level)) {
return false;
}
debug_options->set_xla_gpu_experimental_pipeline_parallelism_opt_level(
level);
return true;
};

// Custom "sub-parser" lambda for xla_partitioning_algorithm.
auto setter_for_xla_partitioning_algorithm =
[debug_options](const std::string& value) {
Expand Down Expand Up @@ -1667,11 +1681,10 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_collective_permute_decomposer_threshold(),
"Collective permute decomposer threshold."));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_pipeline_parallelism_opt",
bool_setter_for(
&DebugOptions::
set_xla_gpu_experimental_enable_pipeline_parallelism_opt),
debug_options->xla_gpu_experimental_enable_pipeline_parallelism_opt(),
"xla_gpu_experimental_pipeline_parallelism_opt_level",
setter_for_xla_gpu_experimental_pipeline_parallelism_opt_level,
DebugOptions::PipelineParallelismOptLevel_Name(
debug_options->xla_gpu_experimental_pipeline_parallelism_opt_level()),
"Experimental optimizations for SPMD-based pipeline parallelism on "
"GPU."));
flag_list->push_back(tsl::Flag(
Expand Down
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ cc_library(
srcs = ["collective_permute_decomposer.cc"],
hdrs = ["collective_permute_decomposer.h"],
deps = [
":call_graph",
":collective_ops_utils",
":source_target_pairs",
"//xla:shape_util",
Expand Down
23 changes: 19 additions & 4 deletions xla/service/collective_permute_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/call_graph.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/source_target_pairs.h"
Expand All @@ -49,7 +50,7 @@ namespace {
// operations without any cycle in their (source, target) relationship,
// with only one input and without any context data.
bool ShouldDecompose(const HloCollectivePermuteInstruction& collective_permute,
int64_t threshold_in_bytes) {
int64_t threshold_in_bytes, const CallGraph& call_graph) {
const Shape& result_shape = collective_permute.shape();

// Skip the transformation if result is not an array, such as containing
Expand All @@ -58,11 +59,23 @@ bool ShouldDecompose(const HloCollectivePermuteInstruction& collective_permute,
return false;
}

// Respect threshold to limit this pass.
if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) {
return false;
}
return !SourceTargetPairs(collective_permute.source_target_pairs())
.HasCycles();

// Do not decompose cycles as this leads to deadlocks in NCCL.
if (SourceTargetPairs(collective_permute.source_target_pairs()).HasCycles()) {
return false;
}

// Only decompose in loop body to allow for pipelining.
auto callers = call_graph.GetComputationCallers(collective_permute.parent());
if (callers.size() != 1 || callers.front()->opcode() != HloOpcode::kWhile) {
return false;
}

return true;
}

// Returns true for a pipelineable collective-permute. As a simple heuristic,
Expand Down Expand Up @@ -204,6 +217,8 @@ absl::Status EnforceOrderOfSendRecvChains(
absl::StatusOr<bool> CollectivePermuteDecomposer::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);

bool changed = false;
std::vector<HloComputation*> all_computations =
module->MakeComputationPostOrder(execution_threads);
Expand Down Expand Up @@ -242,7 +257,7 @@ absl::StatusOr<bool> CollectivePermuteDecomposer::Run(

HloCollectivePermuteInstruction* cp =
Cast<HloCollectivePermuteInstruction>(instr);
if (!ShouldDecompose(*cp, threshold_in_bytes_)) {
if (!ShouldDecompose(*cp, threshold_in_bytes_, *call_graph)) {
continue;
}
// Record collective-permute to be decomposed.
Expand Down
Loading

0 comments on commit c977652

Please sign in to comment.