From 3bc119b4a8b08604f97c4c8833e10ee4c134f27b Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 20 Dec 2024 18:33:40 +0000 Subject: [PATCH] Cleaning up tests in chlo_ops.mlir backported twice --- stablehlo/tests/ops_chlo.mlir | 216 ---------------------------------- 1 file changed, 216 deletions(-) diff --git a/stablehlo/tests/ops_chlo.mlir b/stablehlo/tests/ops_chlo.mlir index f3e7f77736..b5a021043b 100644 --- a/stablehlo/tests/ops_chlo.mlir +++ b/stablehlo/tests/ops_chlo.mlir @@ -289,222 +289,6 @@ func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tenso // ----- -// ragged_dot mode 1: [b,m,k], [g,b,k,n], [g] -> [b,m,n] -func.func @ragged_dot_non_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> { - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [1], - lhs_contracting_dimensions = [2], - rhs_contracting_dimensions = [2], - lhs_ragged_dimensions = [1], - rhs_group_dimensions = [0] - >, - precision_config = [#chlo, #chlo] - } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> - func.return %0 : tensor<2x11x7xf32> -} - -// ----- - -// ragged_dot mode 2: [m,k], [k,n], [g] -> [g,m,n] -func.func @ragged_dot_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x2x11x7xf32> { - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [0], - lhs_contracting_dimensions = [2], - rhs_contracting_dimensions = [1], - lhs_ragged_dimensions = [2], - rhs_group_dimensions = [] - >, - precision_config = [#chlo, #chlo] - } : (tensor<2x11x5xf32>, tensor<2x5x7xf32>, tensor<3xi64>) -> tensor<3x2x11x7xf32> - func.return %0 : tensor<3x2x11x7xf32> -} - -// ----- - -// ragged_dot mode 3: [b,m,k], [b,k,n], [g] -> [b,m,n] -func.func @ragged_dot_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> { - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [0], - lhs_contracting_dimensions = [2], - rhs_contracting_dimensions = [1], - lhs_ragged_dimensions = [0], - rhs_group_dimensions = [] - >, - precision_config = [#chlo, #chlo] - } : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32> - func.return %0 : tensor<3x11x7xf32> -} - -// ----- - -func.func @ragged_dot_incompatible_contracting_dims(%lhs : tensor<11x5xf32>, %rhs : tensor<3x2x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { - // @expected-error@+1 {{contracting dimension sizes must match}} - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [], - rhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1], - lhs_ragged_dimensions = [0], - rhs_group_dimensions = [0] - >, - precision_config = [#chlo, #chlo] - } : (tensor<11x5xf32>, tensor<3x2x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> - func.return %0 : tensor<11x7xf32> -} - -// ----- - -func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3x2xi64>) -> tensor<11x7xf32> { - // @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}} - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [], - rhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1], - lhs_ragged_dimensions = [0], - rhs_group_dimensions = [0] - >, - precision_config = [#chlo, #chlo] - } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3x2xi64>) -> tensor<11x7xf32> - func.return %0 : tensor<11x7xf32> -} - -// ----- - -func.func @ragged_dot_group_sizes_incorrect_shape(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> { - // @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}} - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [], - rhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1], - lhs_ragged_dimensions = [0], - rhs_group_dimensions = [0] - >, - precision_config = [#chlo, #chlo] - } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<2xi64>) -> tensor<11x7xf32> - func.return %0 : tensor<11x7xf32> -} - -// ----- - -func.func @ragged_dot_incorrect_number_of_lhs_ragged_dimensions(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { - // @expected-error@+1 {{There must be exactly one ragged dimension in the lhs}} - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [], - rhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1], - lhs_ragged_dimensions = [0, 1], - rhs_group_dimensions = [0] - >, - precision_config = [#chlo, #chlo] - } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> - func.return %0 : tensor<11x7xf32> -} - -// ----- - -func.func @ragged_dot_rhs_group_dim_is_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> { - // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_batching_dimensions: 0}} - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [0], - lhs_contracting_dimensions = [2], - rhs_contracting_dimensions = [1], - lhs_ragged_dimensions = [1], - rhs_group_dimensions = [0] - >, - precision_config = [#chlo, #chlo] - } : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32> - func.return %0 : tensor<3x11x7xf32> -} - -// ----- - -func.func @ragged_dot_rhs_group_dim_is_contracting(%lhs : tensor<11x3xf32>, %rhs : tensor<3x3x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { - // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_contracting_dimensions: 1}} - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [], - rhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1], - lhs_ragged_dimensions = [0], - rhs_group_dimensions = [1] - >, - precision_config = [#chlo, #chlo] - } : (tensor<11x3xf32>, tensor<3x3x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> - func.return %0 : tensor<11x7xf32> -} - -// ----- - -func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_batch(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> { - // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}} - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [1], - lhs_contracting_dimensions = [2], - rhs_contracting_dimensions = [2], - lhs_ragged_dimensions = [0], - rhs_group_dimensions = [0] - >, - precision_config = [#chlo, #chlo] - } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> - func.return %0 : tensor<2x11x7xf32> -} - -// ----- - -func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_contracting(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { - // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}} - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [], - rhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1], - lhs_ragged_dimensions = [1], - rhs_group_dimensions = [0] - >, - precision_config = [#chlo, #chlo] - } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> - func.return %0 : tensor<11x7xf32> -} - -// ----- - -func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tensor<11x5xf32>, %rhs : tensor<5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { - // @expected-error@+1 {{There must be exactly one group dimension in the rhs when the lhs ragged dimension is non-contracting}} - %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { - ragged_dot_dimension_numbers = #chlo.ragged_dot< - lhs_batching_dimensions = [], - rhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [0], - lhs_ragged_dimensions = [0], - rhs_group_dimensions = [] - >, - precision_config = [#chlo, #chlo] - } : (tensor<11x5xf32>, tensor<5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> - func.return %0 : tensor<11x7xf32> -} - -// ----- - func.func @top_k(%arg0 : tensor) { // expected-error @+2 {{failed to infer returned types}} // @expected-error @+1{{operand's rank must be at least 1}}