From 57dd7cdd3766455c11abac55137ef43525ed910f Mon Sep 17 00:00:00 2001 From: shardy authors Date: Thu, 23 Jan 2025 08:03:24 -0800 Subject: [PATCH] Extend unit tests for insert explicit reshards with some interesting examples of fully replicated operations and fully pointwise operations. PiperOrigin-RevId: 718869283 --- .../export/test/insert_explicit_reshards.mlir | 135 +++++++++++++++++- 1 file changed, 132 insertions(+), 3 deletions(-) diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir index 570fc0e..54cefde 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir @@ -544,14 +544,28 @@ func.func @reverse(%arg0: tensor<4x32x8x2xf32> {sdy.sharding = #sdy.sharding<@me // CHECK-LABEL: func @bitcast_convert_upcast func.func @bitcast_convert_upcast(%arg0: tensor<4x2x2xui32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}) -> tensor<4x2xui64> { // CHECK-NOT: sdy.reshard - %0 = stablehlo.bitcast_convert %arg0 : (tensor<4x2x2xui32>) -> tensor<4x2xui64> + %0 = stablehlo.bitcast_convert %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<4x2x2xui32>) -> tensor<4x2xui64> + return %0 : tensor<4x2xui64> +} + +// CHECK-LABEL: func @bitcast_convert_upcast_casting_dim_is_sharded +func.func @bitcast_convert_upcast_casting_dim_is_sharded(%arg0: tensor<4x2x2xui32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {"y"}]>}) -> tensor<4x2xui64> { + // CHECK-NOT: sdy.reshard + %0 = stablehlo.bitcast_convert %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<4x2x2xui32>) -> tensor<4x2xui64> return %0 : tensor<4x2xui64> } // CHECK-LABEL: func @bitcast_convert_downcast func.func @bitcast_convert_downcast(%arg0: tensor<4x2xui64> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> tensor<4x2x2xui32> { // CHECK-NOT: sdy.reshard - %0 = stablehlo.bitcast_convert %arg0 : (tensor<4x2xui64>) -> tensor<4x2x2xui32> + %0 = stablehlo.bitcast_convert %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}, {}]>]>}: (tensor<4x2xui64>) -> tensor<4x2x2xui32> + return %0 : tensor<4x2x2xui32> +} + +// CHECK-LABEL: func @bitcast_convert_downcast_casting_dim_is_sharded +func.func @bitcast_convert_downcast_casting_dim_is_sharded(%arg0: tensor<4x2xui64> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<4x2x2xui32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {"y"}]>}){ + // CHECK-NOT: sdy.reshard + %0 = stablehlo.bitcast_convert %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}, {"y"}]>]>} : (tensor<4x2xui64>) -> tensor<4x2x2xui32> return %0 : tensor<4x2x2xui32> } @@ -562,13 +576,94 @@ func.func @broadcast_in_dim(%arg0: tensor<2x3x5x1x7xf32> {sdy.sharding = #sdy.sh return %0 : tensor<2x5x3x11x7x13xf32> } +// CHECK-LABEL: func @concatenate_single_input +func.func @concatenate_single_input(%arg0: tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}, {}]>}) -> (tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}) { + // CHECK-NOT: sdy.reshard + %0 = stablehlo.concatenate %arg0, dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}, {}]>]>} : (tensor<4x32x256xf32>) -> tensor<4x32x256xf32> + return %0 : tensor<4x32x256xf32> +} + // CHECK-LABEL: func @concatenate -func.func @concatenate(%arg0: tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}, %arg1: tensor<4x48x256xf32>) -> tensor<4x80x256xf32> { +func.func @concatenate(%arg0: tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}, %arg1: tensor<4x48x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}, {}]>}) -> tensor<4x80x256xf32> { + // CHECK-NOT: sdy.reshard + %0 = stablehlo.concatenate %arg0, %arg1, dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}, {}]>]>} : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32> + return %0 : tensor<4x80x256xf32> +} + +// CHECK-LABEL: func @concatenate_replicated_dim_is_sharded +func.func @concatenate_replicated_dim_is_sharded(%arg0: tensor<4x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}, {}]>}, %arg1: tensor<4x48x256xf32>) -> tensor<4x80x256xf32> { // CHECK-NOT: sdy.reshard %0 = stablehlo.concatenate %arg0, %arg1, dim = 1 : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32> return %0 : tensor<4x80x256xf32> } +// CHECK-LABEL: func @add +func.func @add(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}, %arg1: tensor<4x32xf32>) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg1 <@mesh, [{"x"}, {}]> : tensor<4x32xf32> + // CHECK-NEXT: stablehlo.add %arg0, %[[RESHARD]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// CHECK-LABEL: func @add_input_sharding_is_larger +func.func @add_input_sharding_is_larger(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}, %arg1: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<4x32xf32>) { + // CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32> + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ADD]] <@mesh, [{"y"}, {}]> : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {}]>]>} : tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// CHECK-LABEL: func @add_output_sharding_is_larger +func.func @add_output_sharding_is_larger(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg1: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {}]>]>} : tensor<4x32xf32> + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ADD]] <@mesh, [{"x"}, {}]> : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// CHECK-LABEL: func @add_input_and_output_sharded_on_separate_dims +func.func @add_input_and_output_sharded_on_separate_dims(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}, %arg1: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {"y"}]> : tensor<4x32xf32> + // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"x"}, {"y"}]> : tensor<4x32xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : tensor<4x32xf32> + // CHECK-NEXT: %[[RESHARD3:.*]] = sdy.reshard %[[ADD]] <@mesh, [{}, {"y"}]> : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}]>]>} : tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// CHECK-LABEL: func @negate +func.func @negate(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {}]> : tensor<4x32xf32> + // CHECK-NEXT: stablehlo.negate %[[RESHARD]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32> + %0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// CHECK-LABEL: func @negate_input_sharding_is_larger +func.func @negate_input_sharding_is_larger(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) { + // CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32> + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[NEGATE]] <@mesh, [{"y"}, {}]> : tensor<4x32xf32> + %0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {}]>]>} : tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// CHECK-LABEL: func @negate_output_sharding_is_larger +func.func @negate_output_sharding_is_larger(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {}]> : tensor<4x32xf32> + // CHECK-NEXT: stablehlo.negate %[[RESHARD]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32> + %0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// CHECK-LABEL: func @negate_input_and_output_sharded_on_separate_dims +func.func @negate_input_and_output_sharded_on_separate_dims(%arg0: tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) ->( tensor<4x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) { + // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {"y"}]> : tensor<4x32xf32> + // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[RESHARD1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : tensor<4x32xf32> + // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[NEGATE]] <@mesh, [{}, {"y"}]> : tensor<4x32xf32> + %0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}]>]>} : tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + // CHECK-LABEL: func @dynamic_slice func.func @dynamic_slice(%arg0: tensor<32x4x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<32x1x2xf32> { // CHECK-NOT: sdy.reshard @@ -666,6 +761,28 @@ func.func @sort_input_and_output_shardings_are_different_on_sorting_dimension(%a return %0 : tensor<4x32x8xi32> } +// CHECK-LABEL: func @sort_incompatible_on_nonsort_dimensions +func.func @sort_incompatible_on_nonsort_dimensions(%arg0: tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}, {}]>}) -> (tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) { + // CHECK-NOT: sdy.reshard + %0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = true}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.compare GT, %arg2, %arg3 : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + }) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {}]>]>} : (tensor<4x32x8xi32>) -> (tensor<4x32x8xi32>) + return %0 : tensor<4x32x8xi32> +} + +// CHECK-LABEL: func @sort_compatible_on_nonsort_dimension +func.func @sort_compatible_on_nonsort_dimension(%arg0: tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) -> (tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) { + // CHECK-NOT: sdy.reshard + %0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = true}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.compare GT, %arg2, %arg3 : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + }) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {}]>]>} : (tensor<4x32x8xi32>) -> (tensor<4x32x8xi32>) + return %0 : tensor<4x32x8xi32> +} + // CHECK-LABEL: func @transpose func.func @transpose(%arg0: tensor<256x32x64x100xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}, {}, {}]>}) -> tensor<100x32x256x64xf32> { // CHECK-NOT: sdy.reshard @@ -685,6 +802,18 @@ func.func @triangular_solve(%arg0: tensor<8x3x3xf32> {sdy.sharding = #sdy.shardi return %0 : tensor<8x3x5xf32> } +// CHECK-LABEL: func @triangular_solve_replicated_dim_is_sharded +func.func @triangular_solve_replicated_dim_is_sharded(%arg0: tensor<8x3x3xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}, {}]>}, %arg1: tensor<8x3x5xf32>) -> tensor<8x3x5xf32> { + // CHECK-NOT: sdy.reshard + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) <{ + left_side = true, + lower = true, + unit_diagonal = false, + transpose_a = #stablehlo + }> {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}, {}]>]>} : (tensor<8x3x3xf32>, tensor<8x3x5xf32>) -> tensor<8x3x5xf32> + return %0 : tensor<8x3x5xf32> +} + // CHECK-LABEL: func @fft_complex func.func @fft_complex(%arg0: tensor<8x32x64xcomplex> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}) -> tensor<8x32x64xcomplex> { // CHECK-NOT: sdy.reshard