Skip to content

Commit

Permalink
Extend unit tests for insert explicit reshards with some interesting …
Browse files Browse the repository at this point in the history
…examples of fully replicated operations and fully pointwise operations.

PiperOrigin-RevId: 718869283
  • Loading branch information
Google-ML-Automation authored and copybara-github committed Jan 23, 2025
1 parent 08c7c56 commit 57dd7cd
Showing 1 changed file with 132 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -544,14 +544,28 @@ func.func @reverse(%arg0: tensor<4x32x8x2xf32> {sdy.sharding = #sdy.sharding<@me
// CHECK-LABEL: func @bitcast_convert_upcast
func.func @bitcast_convert_upcast(%arg0: tensor<4x2x2xui32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}) -> tensor<4x2xui64> {
// CHECK-NOT: sdy.reshard
%0 = stablehlo.bitcast_convert %arg0 : (tensor<4x2x2xui32>) -> tensor<4x2xui64>
%0 = stablehlo.bitcast_convert %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<4x2x2xui32>) -> tensor<4x2xui64>
return %0 : tensor<4x2xui64>
}

// CHECK-LABEL: func @bitcast_convert_upcast_casting_dim_is_sharded
func.func @bitcast_convert_upcast_casting_dim_is_sharded(%arg0: tensor<4x2x2xui32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {"y"}]>}) -> tensor<4x2xui64> {
// CHECK-NOT: sdy.reshard
%0 = stablehlo.bitcast_convert %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<4x2x2xui32>) -> tensor<4x2xui64>
return %0 : tensor<4x2xui64>
}

// CHECK-LABEL: func @bitcast_convert_downcast
func.func @bitcast_convert_downcast(%arg0: tensor<4x2xui64> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> tensor<4x2x2xui32> {
// CHECK-NOT: sdy.reshard
%0 = stablehlo.bitcast_convert %arg0 : (tensor<4x2xui64>) -> tensor<4x2x2xui32>
%0 = stablehlo.bitcast_convert %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}, {}]>]>}: (tensor<4x2xui64>) -> tensor<4x2x2xui32>
return %0 : tensor<4x2x2xui32>
}

// CHECK-LABEL: func @bitcast_convert_downcast_casting_dim_is_sharded
func.func @bitcast_convert_downcast_casting_dim_is_sharded(%arg0: tensor<4x2xui64> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<4x2x2xui32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {"y"}]>}){
// CHECK-NOT: sdy.reshard
%0 = stablehlo.bitcast_convert %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}, {"y"}]>]>} : (tensor<4x2xui64>) -> tensor<4x2x2xui32>
return %0 : tensor<4x2x2xui32>
}

Expand All @@ -562,13 +576,94 @@ func.func @broadcast_in_dim(%arg0: tensor<2x3x5x1x7xf32> {sdy.sharding = #sdy.sh
return %0 : tensor<2x5x3x11x7x13xf32>
}

// CHECK-LABEL: func @concatenate_single_input
func.func @concatenate_single_input(%arg0: tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}, {}]>}) -> (tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}) {
// CHECK-NOT: sdy.reshard
%0 = stablehlo.concatenate %arg0, dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}, {}]>]>} : (tensor<4x32x256xf32>) -> tensor<4x32x256xf32>
return %0 : tensor<4x32x256xf32>
}

// CHECK-LABEL: func @concatenate
func.func @concatenate(%arg0: tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}, %arg1: tensor<4x48x256xf32>) -> tensor<4x80x256xf32> {
func.func @concatenate(%arg0: tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}, %arg1: tensor<4x48x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}, {}]>}) -> tensor<4x80x256xf32> {
// CHECK-NOT: sdy.reshard
%0 = stablehlo.concatenate %arg0, %arg1, dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}, {}]>]>} : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32>
return %0 : tensor<4x80x256xf32>
}

// CHECK-LABEL: func @concatenate_replicated_dim_is_sharded
func.func @concatenate_replicated_dim_is_sharded(%arg0: tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}, {}]>}, %arg1: tensor<4x48x256xf32>) -> tensor<4x80x256xf32> {
// CHECK-NOT: sdy.reshard
%0 = stablehlo.concatenate %arg0, %arg1, dim = 1 : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32>
return %0 : tensor<4x80x256xf32>
}

// CHECK-LABEL: func @add
func.func @add(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}, %arg1: tensor<4x32xf32>) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg1 <@mesh, [{"x"}, {}]> : tensor<4x32xf32>
// CHECK-NEXT: stablehlo.add %arg0, %[[RESHARD]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32>
%0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32>
return %0 : tensor<4x32xf32>
}

// CHECK-LABEL: func @add_input_sharding_is_larger
func.func @add_input_sharding_is_larger(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}, %arg1: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<4x32xf32>) {
// CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ADD]] <@mesh, [{"y"}, {}]> : tensor<4x32xf32>
%0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {}]>]>} : tensor<4x32xf32>
return %0 : tensor<4x32xf32>
}

// CHECK-LABEL: func @add_output_sharding_is_larger
func.func @add_output_sharding_is_larger(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg1: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
// CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {}]>]>} : tensor<4x32xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ADD]] <@mesh, [{"x"}, {}]> : tensor<4x32xf32>
%0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32>
return %0 : tensor<4x32xf32>
}

// CHECK-LABEL: func @add_input_and_output_sharded_on_separate_dims
func.func @add_input_and_output_sharded_on_separate_dims(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}, %arg1: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
// CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {"y"}]> : tensor<4x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"x"}, {"y"}]> : tensor<4x32xf32>
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : tensor<4x32xf32>
// CHECK-NEXT: %[[RESHARD3:.*]] = sdy.reshard %[[ADD]] <@mesh, [{}, {"y"}]> : tensor<4x32xf32>
%0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}]>]>} : tensor<4x32xf32>
return %0 : tensor<4x32xf32>
}

// CHECK-LABEL: func @negate
func.func @negate(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {}]> : tensor<4x32xf32>
// CHECK-NEXT: stablehlo.negate %[[RESHARD]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32>
%0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32>
return %0 : tensor<4x32xf32>
}

// CHECK-LABEL: func @negate_input_sharding_is_larger
func.func @negate_input_sharding_is_larger(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) {
// CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[NEGATE]] <@mesh, [{"y"}, {}]> : tensor<4x32xf32>
%0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {}]>]>} : tensor<4x32xf32>
return %0 : tensor<4x32xf32>
}

// CHECK-LABEL: func @negate_output_sharding_is_larger
func.func @negate_output_sharding_is_larger(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {}]> : tensor<4x32xf32>
// CHECK-NEXT: stablehlo.negate %[[RESHARD]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32>
%0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32>
return %0 : tensor<4x32xf32>
}

// CHECK-LABEL: func @negate_input_and_output_sharded_on_separate_dims
func.func @negate_input_and_output_sharded_on_separate_dims(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) ->( tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) {
// CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {"y"}]> : tensor<4x32xf32>
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[RESHARD1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : tensor<4x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[NEGATE]] <@mesh, [{}, {"y"}]> : tensor<4x32xf32>
%0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}]>]>} : tensor<4x32xf32>
return %0 : tensor<4x32xf32>
}

// CHECK-LABEL: func @dynamic_slice
func.func @dynamic_slice(%arg0: tensor<32x4x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<32x1x2xf32> {
// CHECK-NOT: sdy.reshard
Expand Down Expand Up @@ -666,6 +761,28 @@ func.func @sort_input_and_output_shardings_are_different_on_sorting_dimension(%a
return %0 : tensor<4x32x8xi32>
}

// CHECK-LABEL: func @sort_incompatible_on_nonsort_dimensions
func.func @sort_incompatible_on_nonsort_dimensions(%arg0: tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}, {}]>}) -> (tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) {
// CHECK-NOT: sdy.reshard
%0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = true}> ({
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = stablehlo.compare GT, %arg2, %arg3 : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
}) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {}]>]>} : (tensor<4x32x8xi32>) -> (tensor<4x32x8xi32>)
return %0 : tensor<4x32x8xi32>
}

// CHECK-LABEL: func @sort_compatible_on_nonsort_dimension
func.func @sort_compatible_on_nonsort_dimension(%arg0: tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) -> (tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) {
// CHECK-NOT: sdy.reshard
%0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = true}> ({
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = stablehlo.compare GT, %arg2, %arg3 : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
}) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {}]>]>} : (tensor<4x32x8xi32>) -> (tensor<4x32x8xi32>)
return %0 : tensor<4x32x8xi32>
}

// CHECK-LABEL: func @transpose
func.func @transpose(%arg0: tensor<256x32x64x100xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}, {}, {}]>}) -> tensor<100x32x256x64xf32> {
// CHECK-NOT: sdy.reshard
Expand All @@ -685,6 +802,18 @@ func.func @triangular_solve(%arg0: tensor<8x3x3xf32> {sdy.sharding = #sdy.shardi
return %0 : tensor<8x3x5xf32>
}

// CHECK-LABEL: func @triangular_solve_replicated_dim_is_sharded
func.func @triangular_solve_replicated_dim_is_sharded(%arg0: tensor<8x3x3xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}, {}]>}, %arg1: tensor<8x3x5xf32>) -> tensor<8x3x5xf32> {
// CHECK-NOT: sdy.reshard
%0 = "stablehlo.triangular_solve"(%arg0, %arg1) <{
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
}> {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}, {}]>]>} : (tensor<8x3x3xf32>, tensor<8x3x5xf32>) -> tensor<8x3x5xf32>
return %0 : tensor<8x3x5xf32>
}

// CHECK-LABEL: func @fft_complex
func.func @fft_complex(%arg0: tensor<8x32x64xcomplex<f32>> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}) -> tensor<8x32x64xcomplex<f32>> {
// CHECK-NOT: sdy.reshard
Expand Down

0 comments on commit 57dd7cd

Please sign in to comment.