Skip to content

Commit

Permalink
Move the sharding axes from dimensions that need replication to batch…
Browse files Browse the repository at this point in the history
… dimensions, such that we replace an `all-gather` with an `all-to-all`.

Given the following input
```
ENTRY entry {
  %param0 = f32[14,257] parameter(0), sharding={devices=[1,2]0,1}
  %param1 = f32[14,116] parameter(1), sharding={devices=[1,2]0,1}
  ROOT %concatenate = f32[14,373] concatenate(%param0, %param1),
    dimensions={1}, sharding={devices=[1,2]0,1}
}
```

The partitioner generates all-gather before this change
```
ENTRY %entry_spmd (param: f32[14,129], param.1: f32[14,58]) -> f32[14,187] {
  %param = f32[14,129]{1,0} parameter(0), sharding={devices=[1,2]<=[2]}
  %all-gather = f32[14,258]{1,0} all-gather(f32[14,129]{1,0} %param), channel_id=1, replica_groups=[1,2]<=[2], dimensions={1}, use_global_device_ids=true
  %slice = f32[14,257]{1,0} slice(f32[14,258]{1,0} %all-gather), slice={[0:14], [0:257]}
  %param.1 = f32[14,58]{1,0} parameter(1), sharding={devices=[1,2]<=[2]}
  %all-gather.1 = f32[14,116]{1,0} all-gather(f32[14,58]{1,0} %param.1), channel_id=2, replica_groups=[1,2]<=[2], dimensions={1}, use_global_device_ids=true
  %concatenate.1 = f32[14,373]{1,0} concatenate(f32[14,257]{1,0} %slice, f32[14,116]{1,0} %all-gather.1), dimensions={1}
  %constant = f32[] constant(0)
  %pad = f32[14,374]{1,0} pad(f32[14,373]{1,0} %concatenate.1, f32[] %constant), padding=0_0x0_1
  %constant.1 = s32[] constant(0)
  %constant.2 = s32[2]{0} constant({0, 187})
  %partition-id = u32[] partition-id()
  %dynamic-slice = s32[1]{0} dynamic-slice(s32[2]{0} %constant.2, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape = s32[] reshape(s32[1]{0} %dynamic-slice)
  ROOT %dynamic-slice.1 = f32[14,187]{1,0} dynamic-slice(f32[14,374]{1,0} %pad, s32[] %constant.1, s32[] %reshape), dynamic_slice_sizes={14,187}
}
```

With this change, all-gather is replaced by all-to-all
```
ENTRY %entry_spmd (param: f32[14,129], param.1: f32[14,58]) -> f32[14,187] {
  %param = f32[14,129]{1,0} parameter(0), sharding={devices=[1,2]<=[2]}
  %reshape.1 = f32[2,7,129]{2,1,0} reshape(f32[14,129]{1,0} %param)
  %all-to-all = f32[2,7,129]{2,1,0} all-to-all(f32[2,7,129]{2,1,0} %reshape.1), channel_id=1, replica_groups={{0,1}}, dimensions={0}
  %transpose = f32[7,2,129]{2,0,1} transpose(f32[2,7,129]{2,1,0} %all-to-all), dimensions={1,0,2}
  %reshape.2 = f32[7,258]{1,0} reshape(f32[7,2,129]{2,0,1} %transpose)
  %slice = f32[7,257]{1,0} slice(f32[7,258]{1,0} %reshape.2), slice={[0:7], [0:257]}
  %param.1 = f32[14,58]{1,0} parameter(1), sharding={devices=[1,2]<=[2]}
  %reshape.5 = f32[2,7,58]{2,1,0} reshape(f32[14,58]{1,0} %param.1)
  %all-to-all.1 = f32[2,7,58]{2,1,0} all-to-all(f32[2,7,58]{2,1,0} %reshape.5), channel_id=2, replica_groups={{0,1}}, dimensions={0}
  %transpose.1 = f32[7,2,58]{2,0,1} transpose(f32[2,7,58]{2,1,0} %all-to-all.1), dimensions={1,0,2}
  %reshape.6 = f32[7,116]{1,0} reshape(f32[7,2,58]{2,0,1} %transpose.1)
  %concatenate.1 = f32[7,373]{1,0} concatenate(f32[7,257]{1,0} %slice, f32[7,116]{1,0} %reshape.6), dimensions={1}
  %constant.20 = f32[] constant(0)
  %pad = f32[7,374]{1,0} pad(f32[7,373]{1,0} %concatenate.1, f32[] %constant.20), padding=0_0x0_1
  %reshape.9 = f32[7,2,187]{2,1,0} reshape(f32[7,374]{1,0} %pad)
  %all-to-all.2 = f32[7,2,187]{2,1,0} all-to-all(f32[7,2,187]{2,1,0} %reshape.9), channel_id=3, replica_groups={{0,1}}, dimensions={1}
  %transpose.2 = f32[2,7,187]{2,0,1} transpose(f32[7,2,187]{2,1,0} %all-to-all.2), dimensions={1,0,2}
  ROOT %reshape.10 = f32[14,187]{1,0} reshape(f32[2,7,187]{2,0,1} %transpose.2)
}
```

PiperOrigin-RevId: 718546009
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Jan 23, 2025
1 parent fcad14e commit 0b2047b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 30 deletions.
21 changes: 16 additions & 5 deletions xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2773,11 +2773,22 @@ absl::Status SpmdPartitioningVisitor::HandleElementwiseWithDimsToReplicate(
return DefaultAction(hlo);
}

// 1. Replicate the final sharding along `dims_to_replicate` to get
// temp_sharding.
const HloSharding temp_sharding =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
sharding, dims_to_replicate);
// 1. Obtain the temp_sharding by moving or replicating the sharding tiles.
HloSharding temp_sharding = sharding;
std::function<bool(int64_t)> not_in_dims_to_replicate = [&](int64_t dim) {
return !absl::c_linear_search(dims_to_replicate, dim);
};
for (int64_t dim : dims_to_replicate) {
if (std::optional<int64_t> target_dim =
hlo_sharding_util::GetFirstTargetDimToMoveShardingTiles(
hlo->shape(), temp_sharding, dim, not_in_dims_to_replicate)) {
temp_sharding = hlo_sharding_util::MoveAndMergeShardingTiles(
temp_sharding, dim, *target_dim);
} else {
temp_sharding = hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
temp_sharding, {dim});
}
}

// 2. Reshard the operands to temp_sharding.
std::vector<HloInstruction*> new_operands;
Expand Down
49 changes: 24 additions & 25 deletions xla/service/spmd/spmd_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2323,27 +2323,22 @@ ENTRY entry {
VLOG(1) << module->ToString();

auto param0 = AllOf(op::Parameter(0), op::Shape("f32[14,129]"));
auto param0_adjusted =
AllOf(op::Select(op::Compare(op::Add(), op::Broadcast(op::Constant())),
param0, op::Broadcast(op::Constant())),
op::Shape("f32[14,129]"));
auto param0_replicated = AllOf(op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(), param0_adjusted, _, _)),
op::Shape("f32[14,257]"));
auto param0_resharded = AllOf(
op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(param0))))),
op::Shape("f32[7,257]"));

auto param1 = AllOf(op::Parameter(1), op::Shape("f32[14,58]"));
auto param1_replicated = AllOf(
op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), param1, _, _)),
op::Shape("f32[14,116]"));
auto param1_resharded =
AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(param1)))),
op::Shape("f32[7,116]"));

auto concatenate =
AllOf(op::Concatenate(param0_replicated, param1_replicated),
op::Shape("f32[14,373]"));
auto concatenate = AllOf(op::Concatenate(param0_resharded, param1_resharded),
op::Shape("f32[7,373]"));

const auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root, AllOf(op::DynamicSlice(op::Pad(concatenate, op::Constant()), _, _),
op::Shape("f32[14,187]")));
EXPECT_THAT(module->entry_computation()->root_instruction(),
AllOf(op::Reshape(op::Transpose(
op::AllToAll(op::Reshape(op::Pad(concatenate, _))))),
op::Shape("f32[14,187]")));
}

TEST_P(SpmdPartitioningTest, ConcatenateAlongBothDimensions) {
Expand Down Expand Up @@ -15552,16 +15547,20 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));

// TODO(b/353990256). Involuntary full rematerialization between shardings
// {devices=[2,2,2]<=[8]} to {devices=[8,1,1]<=[8]}.
auto param0 = AllOf(op::Parameter(0), op::Shape("f32[16,16,16]"));
auto param0_reshard =
AllOf(op::Shape("f32[16,32,32]"),
op::AllReduce(op::AllReduce(
op::DynamicUpdateSlice(op::Broadcast(), param0, _, _, _))));
auto param0_replicated = op::AllReduce(op::AllReduce(
op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), param0, _, _, _))));
auto param0_reshard = AllOf(op::Shape("f32[4,32,32]"),
op::DynamicSlice(param0_replicated, _, _, _));
auto cholesky =
AllOf(op::Cholesky(param0_reshard), op::Shape("f32[16,32,32]"));
EXPECT_THAT(
module->entry_computation()->root_instruction(),
AllOf(op::DynamicSlice(cholesky, _, _, _), op::Shape("f32[16,16,16]")));
AllOf(op::Cholesky(param0_reshard), op::Shape("f32[4,32,32]"));
auto cholesky_replicated =
op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), cholesky, _, _, _));
EXPECT_THAT(module->entry_computation()->root_instruction(),
AllOf(op::DynamicSlice(cholesky_replicated, _, _, _),
op::Shape("f32[16,16,16]")));
}

TEST_P(SpmdPartitioningTest, TriangularSolve) {
Expand Down

0 comments on commit 0b2047b

Please sign in to comment.