Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move the sharding axes from dimensions that need replication to batch…
… dimensions, such that we replace an `all-gather` with an `all-to-all`. Given the following input ``` ENTRY entry { %param0 = f32[14,257] parameter(0), sharding={devices=[1,2]0,1} %param1 = f32[14,116] parameter(1), sharding={devices=[1,2]0,1} ROOT %concatenate = f32[14,373] concatenate(%param0, %param1), dimensions={1}, sharding={devices=[1,2]0,1} } ``` The partitioner generates all-gather before this change ``` ENTRY %entry_spmd (param: f32[14,129], param.1: f32[14,58]) -> f32[14,187] { %param = f32[14,129]{1,0} parameter(0), sharding={devices=[1,2]<=[2]} %all-gather = f32[14,258]{1,0} all-gather(f32[14,129]{1,0} %param), channel_id=1, replica_groups=[1,2]<=[2], dimensions={1}, use_global_device_ids=true %slice = f32[14,257]{1,0} slice(f32[14,258]{1,0} %all-gather), slice={[0:14], [0:257]} %param.1 = f32[14,58]{1,0} parameter(1), sharding={devices=[1,2]<=[2]} %all-gather.1 = f32[14,116]{1,0} all-gather(f32[14,58]{1,0} %param.1), channel_id=2, replica_groups=[1,2]<=[2], dimensions={1}, use_global_device_ids=true %concatenate.1 = f32[14,373]{1,0} concatenate(f32[14,257]{1,0} %slice, f32[14,116]{1,0} %all-gather.1), dimensions={1} %constant = f32[] constant(0) %pad = f32[14,374]{1,0} pad(f32[14,373]{1,0} %concatenate.1, f32[] %constant), padding=0_0x0_1 %constant.1 = s32[] constant(0) %constant.2 = s32[2]{0} constant({0, 187}) %partition-id = u32[] partition-id() %dynamic-slice = s32[1]{0} dynamic-slice(s32[2]{0} %constant.2, u32[] %partition-id), dynamic_slice_sizes={1} %reshape = s32[] reshape(s32[1]{0} %dynamic-slice) ROOT %dynamic-slice.1 = f32[14,187]{1,0} dynamic-slice(f32[14,374]{1,0} %pad, s32[] %constant.1, s32[] %reshape), dynamic_slice_sizes={14,187} } ``` With this change, all-gather is replaced by all-to-all ``` ENTRY %entry_spmd (param: f32[14,129], param.1: f32[14,58]) -> f32[14,187] { %param = f32[14,129]{1,0} parameter(0), sharding={devices=[1,2]<=[2]} %reshape.1 = f32[2,7,129]{2,1,0} reshape(f32[14,129]{1,0} %param) %all-to-all = f32[2,7,129]{2,1,0} all-to-all(f32[2,7,129]{2,1,0} %reshape.1), channel_id=1, replica_groups={{0,1}}, dimensions={0} %transpose = f32[7,2,129]{2,0,1} transpose(f32[2,7,129]{2,1,0} %all-to-all), dimensions={1,0,2} %reshape.2 = f32[7,258]{1,0} reshape(f32[7,2,129]{2,0,1} %transpose) %slice = f32[7,257]{1,0} slice(f32[7,258]{1,0} %reshape.2), slice={[0:7], [0:257]} %param.1 = f32[14,58]{1,0} parameter(1), sharding={devices=[1,2]<=[2]} %reshape.5 = f32[2,7,58]{2,1,0} reshape(f32[14,58]{1,0} %param.1) %all-to-all.1 = f32[2,7,58]{2,1,0} all-to-all(f32[2,7,58]{2,1,0} %reshape.5), channel_id=2, replica_groups={{0,1}}, dimensions={0} %transpose.1 = f32[7,2,58]{2,0,1} transpose(f32[2,7,58]{2,1,0} %all-to-all.1), dimensions={1,0,2} %reshape.6 = f32[7,116]{1,0} reshape(f32[7,2,58]{2,0,1} %transpose.1) %concatenate.1 = f32[7,373]{1,0} concatenate(f32[7,257]{1,0} %slice, f32[7,116]{1,0} %reshape.6), dimensions={1} %constant.20 = f32[] constant(0) %pad = f32[7,374]{1,0} pad(f32[7,373]{1,0} %concatenate.1, f32[] %constant.20), padding=0_0x0_1 %reshape.9 = f32[7,2,187]{2,1,0} reshape(f32[7,374]{1,0} %pad) %all-to-all.2 = f32[7,2,187]{2,1,0} all-to-all(f32[7,2,187]{2,1,0} %reshape.9), channel_id=3, replica_groups={{0,1}}, dimensions={1} %transpose.2 = f32[2,7,187]{2,0,1} transpose(f32[7,2,187]{2,1,0} %all-to-all.2), dimensions={1,0,2} ROOT %reshape.10 = f32[14,187]{1,0} reshape(f32[2,7,187]{2,0,1} %transpose.2) } ``` PiperOrigin-RevId: 718546009
- Loading branch information