From 33037f173b6be51296c755e83a76317215966978 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 21 Jan 2025 18:58:18 -0800 Subject: [PATCH] [MHLO] Handle dynamic dimensions in HLO<->MHLO - Fix creating constant zero for ConvertOp HLO->MHLO translation - Fix broadcast in dim bounded lowering from MHLO->HLO - Don't use StableHLO verification methods on MHLO ReshapeOp with bounded dynamic outputs ``` $ cat /tmp/t.hlo HloModule main, entry_computation_layout={(pred[<=1801,1]{1,0})->pred[<=1801,1]{1,0}} ENTRY %convert_with_predicate (Arg_0.1: pred[<=1801,1]) -> pred[<=1801,1] { %Arg_0.1 = pred[<=1801,1] parameter(0) ROOT %convert_pred = pred[<=1801,1] convert(%Arg_0.1) } $ xla-translate /tmp/t.hlo --hlo-text-to-mlir-hlo func.func @main(%arg0: tensor>) -> tensor> { %0 = mhlo.constant dense : tensor<1801x1xi1> %1 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor>) -> tensor %2 = "mhlo.set_dimension_size"(%0, %1) <{dimension = 0 : i64}> : (tensor<1801x1xi1>, tensor) -> tensor> %3 = mhlo.compare NE, %arg0, %2 : (tensor>, tensor>) -> tensor> return %3 : tensor> } ``` Currently this fails when trying to create the `mhlo.constant dense` that gets fed into compare since constants cannot have a bounded size. PiperOrigin-RevId: 718162448 --- third_party/stablehlo/temporary.patch | 217 +++++++++++++++++- xla/hlo/builder/xla_builder_test.cc | 10 + xla/hlo/translate/hlo_to_mhlo/BUILD | 1 + .../hlo_to_mhlo/hlo_function_importer.cc | 71 +++++- xla/hlo/translate/hlo_to_mhlo/tests/BUILD | 1 + .../tests/import_bounded_dynamism.hlo | 66 ++++++ xla/hlo/translate/hlo_to_mhlo/translate.cc | 20 +- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 5 +- xla/hlo/translate/mhlo_to_hlo/tests/BUILD | 1 + .../tests/export_bounded_dynamism.mlir | 35 +++ xla/mlir_hlo/mhlo/IR/hlo_base.td | 2 + xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 47 ---- xla/mlir_hlo/mhlo/IR/hlo_ops.td | 4 +- .../tests/Dialect/mhlo/bounded_dynamism.mlir | 31 ++- .../mhlo/hlo-legalize-to-stablehlo.mlir | 34 +-- .../mhlo/stablehlo-legalize-to-hlo.mlir | 18 ++ 16 files changed, 474 insertions(+), 89 deletions(-) create mode 100644 xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism.hlo create mode 100644 xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index dcea9111ece40a..9d58a83ebdd9f0 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -170,6 +170,79 @@ diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/ } // namespace hlo } // namespace mlir +diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp +--- stablehlo/stablehlo/dialect/Base.cpp ++++ stablehlo/stablehlo/dialect/Base.cpp +@@ -780,5 +780,22 @@ + numScales == rankedType.getDimSize(quantDim)); + } + ++bool hasSingleBoundedDimension(Type type) { ++ RankedTensorType rankedType = dyn_cast(type); ++ auto boundedAttr = ++ dyn_cast_or_null(rankedType.getEncoding()); ++ if (!boundedAttr) return false; ++ ++ // Count if bounded attr size is not kDynamic ++ int64_t numBoundedDims = llvm::count_if( ++ boundedAttr.getBounds(), ++ [](int64_t bound) { return !ShapedType::isDynamic(bound); }); ++ // Also check that there are only bounded dims and no unbounded dims. ++ int64_t numDynamicDims = llvm::count_if( ++ rankedType.getShape(), ++ [](int64_t bound) { return ShapedType::isDynamic(bound); }); ++ return numBoundedDims == 1 && numDynamicDims == 1; ++} ++ + } // namespace hlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h +--- stablehlo/stablehlo/dialect/Base.h ++++ stablehlo/stablehlo/dialect/Base.h +@@ -101,6 +101,9 @@ + // mentioned in the StableHLO specification. + bool isValidQuantizedDimension(Type type); + ++// Returns true if the given type has a single bounded dimension. ++bool hasSingleBoundedDimension(Type type); ++ + // TODO(zhouxin) Move type inference related methods to TypeInference.cpp + + std::pair inferConcatenatedDimAndBound(int64_t leftSize, +diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/Base.td +--- stablehlo/stablehlo/dialect/Base.td ++++ stablehlo/stablehlo/dialect/Base.td +@@ -29,6 +29,20 @@ + def I32RankedTensor : RankedTensorOf<[I32]>; + + def UI32RankedTensor : RankedTensorOf<[UI32]>; ++ ++//===----------------------------------------------------------------------===// ++// HLO type constraints. ++//===----------------------------------------------------------------------===// ++ ++// Note: Bounded dynamisms is largely unspecced and this feature needs more ++// thoguht as it is adopted to modern frameworks. The current support is ++// designed to allow existing TF programs to be representable in StableHLO and ++// is subject to change as a formal design for boudned dynamism is developed. ++def HLO_HasSingleBoundedDimensionPred ++ : CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">; ++ ++def HLO_HasStaticOrSingleBoundedShapePred ++ : Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>; + + //===----------------------------------------------------------------------===// + // HLO type definitions. +@@ -267,6 +281,9 @@ + def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], + [IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">; + ++def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], ++ [IsValidQuantizedDimension, HLO_HasStaticOrSingleBoundedShapePred], "statically shaped or single bounded dimension tensor">; ++ + def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>; + + def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>; diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/dialect/CMakeLists.txt --- stablehlo/stablehlo/dialect/CMakeLists.txt +++ stablehlo/stablehlo/dialect/CMakeLists.txt @@ -440,11 +513,10 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td --- stablehlo/stablehlo/dialect/StablehloOps.td +++ stablehlo/stablehlo/dialect/StablehloOps.td -@@ -327,6 +327,23 @@ - ```mlir +@@ -328,6 +328,23 @@ %result = stablehlo.exponential %operand : tensor<2x2xf64> ``` -+ }]; + }]; + let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand, + DefaultValuedOptionalAttr:$result_accuracy); + let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result); @@ -461,12 +533,57 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/d + + let assemblyFormat = [{ + $operand attr-dict `:` custom(type($operand), type($result)) - }]; ++ }]; } + def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one", +@@ -1963,7 +1980,7 @@ + DenseI64ArrayAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/ + ); + +- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); ++ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor); + + let hasVerifier = 1; + +@@ -2715,7 +2732,7 @@ + + let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand); + +- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); ++ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor); + let hasVerifier = 1; + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp --- stablehlo/stablehlo/dialect/TypeInference.cpp +++ stablehlo/stablehlo/dialect/TypeInference.cpp +@@ -3724,9 +3724,8 @@ + Value operand, + ArrayRef broadcastDimensions, + Value result) { ++ // broadcast_in_dim_c1 + auto operandType = cast(operand.getType()); +- +- // broadcast_in_dim_c1 + if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandType, + result.getType()))) + return failure(); +@@ -4658,11 +4657,12 @@ + Value result) { + // If the operand type is dynamically shaped there is nothing to verify. + auto operandTy = cast(operand.getType()); +- if (!operandTy.hasStaticShape()) return success(); ++ auto resultTy = cast(result.getType()); ++ if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) ++ return success(); + + // If the operand type is statically shaped (not required) the number of + // elements must match that of the result type. +- auto resultTy = cast(result.getType()); + int64_t numResultElements = resultTy.getNumElements(); + int64_t numOperandElements = operandTy.getNumElements(); + if (numResultElements != numOperandElements) @@ -5057,5 +5057,30 @@ return success(); } @@ -1349,7 +1466,30 @@ diff --ruN a/stablehlo/stablehlo/tests/interpret/chlo/ragged_dot.mlir b/stablehl diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir --- stablehlo/stablehlo/tests/ops_stablehlo.mlir +++ stablehlo/stablehlo/tests/ops_stablehlo.mlir -@@ -1775,6 +1775,30 @@ +@@ -1274,6 +1274,22 @@ + + // ----- + ++// CHECK-LABEL: func @broadcast_in_dim_dynamic_i1 ++func.func @broadcast_in_dim_dynamic_i1(%arg0: tensor) -> tensor<1x3xi32> { ++ %0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor) -> tensor<1x3xi32> ++ return %0 : tensor<1x3xi32> ++} ++ ++// ----- ++ ++func.func @broadcast_in_dim_dynamic_result(%arg0: tensor<3xi32>) -> tensor { ++ // expected-error@+1 {{must be statically shaped or single bounded dimension tensor}} ++ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array} : (tensor<3xi32>) -> tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ + // Regression test for b/180052624, where this was improperly marked as an + // invalid stablehlo.broadcast_in_dim op. + // CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand +@@ -1775,6 +1791,30 @@ // expected-error@+1 {{'precision_config' failed to satisfy constraint}} %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = ["FOO", #stablehlo]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> func.return %0: tensor<2x2xi32> @@ -1380,6 +1520,73 @@ diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/ } // ----- +diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir +--- stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir ++++ stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir +@@ -0,0 +1,63 @@ ++// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s ++ ++// This file captures some quirks to bounded dynamism in StableHLO that are ++// included to allow StableHLO to repersent existing TF programs. ++ ++// CHECK-LABEL: reshape_with_single_bounded_dimension ++func.func @reshape_with_single_bounded_dimension(%arg0: tensor>) -> tensor<2x?xf32, #stablehlo.bounds> { ++ %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor<2x?xf32, #stablehlo.bounds> ++ // CHECK: return {{.*}} #stablehlo.bounds ++ return %0 : tensor<2x?xf32, #stablehlo.bounds> ++} ++ ++// ----- ++ ++// CHECK-LABEL: reshape_scalar_with_single_bounded_dimension ++func.func @reshape_scalar_with_single_bounded_dimension(%arg0: tensor>) -> tensor<1x?xf32, #stablehlo.bounds> { ++ %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor<1x?xf32, #stablehlo.bounds> ++ // CHECK: return {{.*}} #stablehlo.bounds ++ return %0 : tensor<1x?xf32, #stablehlo.bounds> ++} ++ ++// ----- ++ ++func.func @reshape_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor> { ++ // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} ++ %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor> ++ return %0 : tensor> ++} ++ ++// ----- ++ ++// CHECK-LABEL: broadcast_in_dim_with_single_bounded_dimension ++func.func @broadcast_in_dim_with_single_bounded_dimension(%arg0: tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> { ++ %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> ++ // CHECK: return {{.*}} #stablehlo.bounds ++ return %0 : tensor<2x1x?xf32, #stablehlo.bounds> ++} ++ ++// ----- ++ ++func.func @broadcast_in_dim_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor<2x?x?xf32, #stablehlo.bounds> { ++ // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} ++ %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor>) -> tensor<2x?x?xf32, #stablehlo.bounds> ++ return %0 : tensor<2x?x?xf32, #stablehlo.bounds> ++} ++ ++// ----- ++ ++// CHECK-LABEL: constant_splat_broadcast ++func.func @constant_splat_broadcast() -> tensor<1x?xf32, #stablehlo.bounds> { ++ %0 = stablehlo.constant dense<1.0> : tensor ++ %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1x?xf32, #stablehlo.bounds> ++ // CHECK: tensor<1x?xf32, #stablehlo.bounds> ++ return %1 : tensor<1x?xf32, #stablehlo.bounds> ++} ++ ++// ----- ++ ++func.func @constant_with_dynamic_shape() -> tensor<1x?xf32, #stablehlo.bounds> { ++ // expected-error@+2 {{elements literal type must have static shape}} ++ %c = stablehlo.constant dense<1> : tensor<1x?xf32, #stablehlo.bounds> ++ return %c : tensor<1x?xf32, #stablehlo.bounds> ++} diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir --- stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir +++ stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir diff --git a/xla/hlo/builder/xla_builder_test.cc b/xla/hlo/builder/xla_builder_test.cc index e38125c7cd3b27..b8c72f95d8147c 100644 --- a/xla/hlo/builder/xla_builder_test.cc +++ b/xla/hlo/builder/xla_builder_test.cc @@ -566,6 +566,16 @@ TEST(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { GmockMatch(m::Broadcast(m::Reshape(m::Broadcast())))); } +TEST(XlaBuilderTest, BroadcastInDimWithBoundedDim) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[2, <=3]")); + auto x = Parameter(&b, 0, shape, "x"); + BroadcastInDim(x, {1, 2, 3}, + /*broadcast_dimensions=*/{1, 2}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), GmockMatch(m::Broadcast())); +} + TEST(XlaBuilderTest, BroadcastInDimWithNegativeSize) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x"); diff --git a/xla/hlo/translate/hlo_to_mhlo/BUILD b/xla/hlo/translate/hlo_to_mhlo/BUILD index 8d97947d4cae1c..c96ea08b26563e 100644 --- a/xla/hlo/translate/hlo_to_mhlo/BUILD +++ b/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -115,6 +115,7 @@ cc_library( "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", + "@stablehlo//:base", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], diff --git a/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 09bc748fea7cd1..19eaa1ae6692c0 100644 --- a/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -44,6 +45,7 @@ limitations under the License. #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" @@ -54,6 +56,7 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/Base.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -170,6 +173,59 @@ Operation* createReturnOp(mlir::OpBuilder& builder, mlir::Location loc, return builder.create(loc, operands); } +// Creates an array of zeros like the given MLIR type, if type has bounded +// dynamism, the constant is padded and set to the dimenison size of the +// operand. +// +// Example ZeroLike([<=5] operand): +// %c = constant dense<0> : tensor<5xf32> +// %0 = get_dimension_size %operand +// %1 = set_dimension_size %c, %0, bounded_dim={0} +// +// Note: Currently this only supports a single bounded dimension. +absl::StatusOr createConstantZeroLike(mlir::Value operand, + Shape input_shape, + mlir::OpBuilder* builder, + mlir::Location loc) { + TF_ASSIGN_OR_RETURN( + mlir::RankedTensorType type, + ConvertTensorShapeToType(input_shape, *builder)); + + LLVM_DEBUG(llvm::dbgs() << "CreateConstantZeroLike: " << operand << ", " + << type << '\n'); + if (type.hasStaticShape()) + return builder + ->create(loc, builder->getZeroAttr(type)) + ->getResult(0); + + // Note: Currently this only supports a single bounded dimension. + if (!mlir::hlo::hasSingleBoundedDimension(type)) + return Internal( + "Currently HLO to MHLO only supports a single bounded dimension."); + + auto bounded_dim = std::distance(type.getShape().begin(), + llvm::find_if(type.getShape(), [](auto dim) { + return mlir::ShapedType::isDynamic(dim); + })); + + // Create a constant with no bounded dynamism, drop tensor encoding. + ArrayRef padded_dims(input_shape.dimensions().begin(), + input_shape.dimensions().end()); + auto padded_type = + mlir::RankedTensorType::get(padded_dims, type.getElementType()); + auto padded_constant = builder->create( + loc, builder->getZeroAttr(padded_type)); + + // Get or Set the dimensions size based on the operand type. + auto dim_size = builder->create( + loc, operand, builder->getI64IntegerAttr(bounded_dim)); + std::vector operands = {padded_constant->getResult(0), dim_size}; + std::vector attributes{builder->getNamedAttr( + "dimension", builder->getI64IntegerAttr(bounded_dim))}; + return builder->create(loc, type, operands, + attributes); +} + } // namespace void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands( @@ -1871,13 +1927,16 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( // Return type is boolean, let's use `operand != 0` instead of Convert. Shape input_shape = instruction->operand(0)->shape(); - TF_ASSIGN_OR_RETURN(mlir::Type type, - ConvertTensorShapeToType( - input_shape, *func_builder)); - auto zero = func_builder->create( - loc, func_builder->getZeroAttr(type)); + TF_ASSIGN_OR_RETURN( + mlir::Value zero, + createConstantZeroLike(operands[0], input_shape, func_builder, loc)); + std::vector compare_operands = {operands[0], zero}; + std::vector attributes = {builder_->getNamedAttr( + "comparison_direction", mlir::mhlo::ComparisonDirectionAttr::get( + func_builder->getContext(), + mlir::mhlo::ComparisonDirection::NE))}; return {func_builder->create( - loc, operands[0], zero, mlir::mhlo::ComparisonDirection::NE)}; + loc, result_type, compare_operands, attributes)}; } case HloOpcode::kOptimizationBarrier: { llvm::SmallVector flattened_operands; diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/BUILD b/xla/hlo/translate/hlo_to_mhlo/tests/BUILD index 19b7a2314245b8..6451a45f20fe47 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/BUILD +++ b/xla/hlo/translate/hlo_to_mhlo/tests/BUILD @@ -22,6 +22,7 @@ lit_test_suite( "if_conditional.hlo", "import.hlo", "import_async.hlo", + "import_bounded_dynamism.hlo", "import_entry_computation_layout.hlo", "layouts_and_names.hlo", "location.hlo", diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism.hlo new file mode 100644 index 00000000000000..dbf4ffcd4db990 --- /dev/null +++ b/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism.hlo @@ -0,0 +1,66 @@ +// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations -split-input-file %s -o - | FileCheck %s + +HloModule main, entry_computation_layout={(f32[1,1,<=1801]{2,1,0})->f32[1,16,1,<=1801]{3,2,1,0}} + +ENTRY %main.3 (Arg_0.1: f32[1,1,<=1801]) -> f32[1,16,1,<=1801] { + %Arg_0.1 = f32[1,1,<=1801] parameter(0) + // CHECK: [[BROADCAST:%.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>}> : (tensor<1x1x?xf32, #mhlo.type_extensions>) -> tensor<1x16x1x?xf32, #mhlo.type_extensions> + // CHECK-NEXT: return [[BROADCAST]] + ROOT %broadcast.2 = f32[1,16,1,<=1801] broadcast(f32[1,1,<=1801] %Arg_0.1), dimensions={0,2,3} +} + +// ----- + +HloModule main, entry_computation_layout={(pred[<=1801,1]{1,0})->pred[<=1801,1]{1,0}} + +ENTRY %convert_with_predicate (Arg_0.1: pred[<=1801,1]) -> pred[<=1801,1] { + // CHECK: [[CST:%.*]] = mhlo.constant dense : tensor<1801x1xi1> + // CHECK-NEXT: [[GDS:%.*]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor>) -> tensor + // CHECK-NEXT: [[SDS:%.*]] = "mhlo.set_dimension_size"([[CST]], [[GDS]]) <{dimension = 0 : i64}> : (tensor<1801x1xi1>, tensor) -> tensor> + // CHECK-NEXT: [[CMP:%.*]] = mhlo.compare NE, %arg0, [[SDS]] : (tensor>, tensor>) -> tensor> + // CHECK-NEXT: return [[CMP]] : tensor> + %Arg_0.1 = pred[<=1801,1] parameter(0) + ROOT %convert_pred = pred[<=1801,1] convert(%Arg_0.1) +} + +// ----- + +// TODO: Figure out how to write negative tests. This likely is an issue due to the custom split-input-file handling +// +// // expected-error {{HLO Module import failed: Currently HLO to MHLO only supports a single bounded dimension.}} +// HloModule main, entry_computation_layout={(pred[<=1801,<=1801]{1,0})->pred[<=1801,<=1801]{1,0}} + +// ENTRY %convert_with_predicate_multiple_bounded_dimensions (Arg_0.1: pred[<=1801,<=1801]) -> pred[<=1801,<=1801] { +// %Arg_0.1 = pred[<=1801,<=1801] parameter(0) +// ROOT %convert_pred = pred[<=1801,<=1801] convert(%Arg_0.1) +// } + +HloModule main, entry_computation_layout={(f32[<=1801,1]{1,0})->f32[<=1801,1]{1,0}} + +ENTRY %convert_with_f32icate (Arg_0.1: f32[<=1801,1]) -> f32[<=1801,1] { + // CHECK: [[CVT:%.*]] = mhlo.convert %arg0 : tensor> + // CHECK-NEXT: return [[CVT]] : tensor> + %Arg_0.1 = f32[<=1801,1] parameter(0) + ROOT %convert_f32 = f32[<=1801,1] convert(%Arg_0.1) +} + +// ----- + +HloModule main, entry_computation_layout={(f32[<=1801,1]{1,0})->f32[<=1801,1]{1,0}} + +ENTRY %convert_with_f32icate (Arg_0.1: f32[<=1801,1]) -> f32[<=1801,1] { + // CHECK: [[CVT:%.*]] = mhlo.convert %arg0 : tensor> + // CHECK-NEXT: return [[CVT]] : tensor> + %Arg_0.1 = f32[<=1801,1] parameter(0) + ROOT %convert_f32 = f32[<=1801,1] convert(%Arg_0.1) +} + +// ----- + +HloModule main, entry_computation_layout={(f32[<=5]{0})->f32[1,<=5]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[<=5]) -> f32[1,<=5] { + %Arg_0.1 = f32[<=5] parameter(0) + // CHECK: mhlo.reshape %arg0 : (tensor>) -> tensor<1x?xf32, #mhlo.type_extensions> + ROOT %reshape.2 = f32[1,<=5] reshape(f32[<=5] %Arg_0.1) +} diff --git a/xla/hlo/translate/hlo_to_mhlo/translate.cc b/xla/hlo/translate/hlo_to_mhlo/translate.cc index b6ae10b87f13de..e203008f7187d7 100644 --- a/xla/hlo/translate/hlo_to_mhlo/translate.cc +++ b/xla/hlo/translate/hlo_to_mhlo/translate.cc @@ -55,20 +55,21 @@ bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) { mlir::OwningOpRef HloToMlirHloTranslateFunction( llvm::StringRef input, mlir::MLIRContext* context, bool import_all_computations, bool flatten_computation_args_result) { + mlir::OwningOpRef module = + llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); + HloProto hlo_proto; std::string content(input.data(), input.size()); if (!LoadHloProto(content, &hlo_proto)) { - LOG(ERROR) << "Failed to load proto"; + module->emitError("Failed to load proto"); return nullptr; } - mlir::OwningOpRef module = - llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); auto status = ConvertHloToMlirHlo( module.get(), hlo_proto.mutable_hlo_module(), import_all_computations, flatten_computation_args_result); if (!status.ok()) { - LOG(ERROR) << "Hlo module import failed: " << status; + module->emitError("Hlo module import failed: ") << status.message(); return nullptr; } @@ -78,22 +79,23 @@ mlir::OwningOpRef HloToMlirHloTranslateFunction( mlir::OwningOpRef HloTextToMlirHloTranslateFunction( llvm::StringRef input, mlir::MLIRContext* context, bool import_all_computations, bool flatten_computation_args_result) { - std::string content(input.data(), input.size()); + mlir::OwningOpRef module = + llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); + std::string content(input.data(), input.size()); auto hlo_module_error = ParseAndReturnUnverifiedModule(content); if (!hlo_module_error.ok()) { - LOG(ERROR) << "HLO Module loading failed: " << hlo_module_error.status(); + module->emitError("HLO Module loading failed: ") + << hlo_module_error.status().message(); return nullptr; } auto hlo_module = std::move(hlo_module_error.value()); - mlir::OwningOpRef module = - llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); auto status = ConvertHloToMlirHlo(*module, hlo_module.get(), import_all_computations, flatten_computation_args_result); if (!status.ok()) { - LOG(ERROR) << "HLO Module import failed: " << status; + module->emitError("HLO Module import failed: ") << status.message(); return nullptr; } diff --git a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 8d143132452d4f..17c20a098390dd 100644 --- a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -1628,8 +1628,11 @@ LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) return failure(); + // Use TypeToShape to handle bounded dynamism. + // HLO expects broadcast sizes to use the bound's value, not kDynamic. + xla::Shape shape = xla::TypeToShape(type); value_map[op] = - BroadcastInDim(operand, Convert_ArrayRef(type.getShape()), + BroadcastInDim(operand, shape.dimensions(), Convert_broadcast_dimensions(op.getBroadcastDimensions())); return success(); } diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/BUILD b/xla/hlo/translate/mhlo_to_hlo/tests/BUILD index ff7078f4bdda61..f803689c11aa97 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/BUILD +++ b/xla/hlo/translate/mhlo_to_hlo/tests/BUILD @@ -20,6 +20,7 @@ lit_test_suite( "export.mlir", "export_async.mlir", "export_and_check_layouts.mlir", + "export_bounded_dynamism.mlir", "export_entry_computation_layout.mlir", "export_large_constants.mlir", "export_replicas.mlir", diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir new file mode 100644 index 00000000000000..811bc4a44f8284 --- /dev/null +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir @@ -0,0 +1,35 @@ +// RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s + +// CHECK-LITERAL: HloModule main +func.func @main(%arg0: tensor<1x1x?xf32, #mhlo.type_extensions>) -> tensor<1x16x1x?xf32, #mhlo.type_extensions> { + // CHECK: ROOT {{.*}} = f32[1,16,1,<=1801] broadcast + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>}> : (tensor<1x1x?xf32, #mhlo.type_extensions>) -> tensor<1x16x1x?xf32, #mhlo.type_extensions> + return %0 : tensor<1x16x1x?xf32, #mhlo.type_extensions> +} + +// ----- + +// CHECK-LITERAL: HloModule main +func.func @main(%arg0: tensor>) -> tensor> { + // CHECK: ROOT{{.*}} = f32[<=1801,1] convert + %0 = mhlo.convert %arg0 : tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LITERAL: HloModule main +func.func @main(%arg0: tensor<1x?x512xf32, #mhlo.type_extensions>, %arg1: tensor) -> tensor<1x?x512xf32, #mhlo.type_extensions> { + // CHECK: ROOT {{.*}} = f32[1,<=1800,512] set-dimension-size + %0 = "mhlo.set_dimension_size"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor<1x?x512xf32, #mhlo.type_extensions>, tensor) -> tensor<1x?x512xf32, #mhlo.type_extensions> + return %0 : tensor<1x?x512xf32, #mhlo.type_extensions> +} + +// ----- + +// CHECK-LITERAL: HloModule main +func.func @main(%arg0: tensor>) -> tensor<1x?xf32, #mhlo.type_extensions> { + %0 = mhlo.reshape %arg0 : (tensor>) -> tensor<1x?xf32, #mhlo.type_extensions> + // CHECK: f32[1,<=5] reshape(f32[<=5] + return %0 : tensor<1x?xf32, #mhlo.type_extensions> +} diff --git a/xla/mlir_hlo/mhlo/IR/hlo_base.td b/xla/mlir_hlo/mhlo/IR/hlo_base.td index cfab8f3857ebc8..1f90057883629f 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_base.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_base.td @@ -128,6 +128,8 @@ defvar MHLO_PredIntFpOrQuantizedTensor = HLO_PredIntFpOrQuantizedTensor; // it is for a legacy op which is only correct with static shapes. defvar MHLO_StaticShapeTensor = HLO_StaticShapeTensorOrPerAxisQuantizedTensor; +defvar MHLO_StaticShapeOrBoundedDimTensor = HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor; + defvar MHLO_StaticShapeTensorOrToken = HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken; defvar MHLO_StaticShapeIntOrFpTensor = HLO_StaticShapeIntOrFpTensor; diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index aa9c8fda99b6d0..a3df74b710cb59 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -2549,45 +2549,7 @@ LogicalResult BroadcastOp::reifyReturnTypeShapes( // BroadcastInDimOp //===----------------------------------------------------------------------===// -namespace { -LogicalResult verifySingleBoundedDynamicDimension(Operation* op, - RankedTensorType type) { - auto errFn = [&]() { - return op->emitOpError() << "load bearing ops with dynamic dimensions must " - "have a single bounded dimension"; - }; - - // Get bounded dims - TypeExtensionsAttr encoding = - dyn_cast_or_null(type.getEncoding()); - if (!encoding) return errFn(); - - // Check that all dynamic dims are bounded - ArrayRef bounds = encoding.getBounds(); - ArrayRef shape = type.getShape(); - for (auto [dim, bound] : llvm::zip(shape, bounds)) { - if (ShapedType::isDynamic(dim) && ShapedType::isDynamic(bound)) - return errFn(); - } - - // Check single bounded dimension - if (llvm::count_if(bounds, [](int64_t dim) { - return !ShapedType::isDynamic(dim); - }) > 1) { - return errFn(); - } - return success(); -} -} // namespace - LogicalResult BroadcastInDimOp::verify() { - ShapedType resultType = getResult().getType(); - if (!resultType.hasStaticShape()) { - RankedTensorType rankedResultType = cast(resultType); - if (failed(verifySingleBoundedDynamicDimension(getOperation(), - rankedResultType))) - return failure(); - } return hlo::verifyBroadcastInDimOp( getLoc(), getOperand(), llvm::to_vector(getBroadcastDimensions().getValues()), @@ -4629,15 +4591,6 @@ LogicalResult ReshapeOp::verify() { if (!operandType.hasRank() || !resultType.hasRank()) { return success(); } - - // Verify that dynamic outputs are all bounded - // HLO allows this, and it is unclear if MHLO/StableHLO should, but it is - // needed to raise some TF programs from HLO to MHLO. - if (!resultType.hasStaticShape()) { - RankedTensorType resultType = cast(getResult().getType()); - if (failed(verifySingleBoundedDynamicDimension(getOperation(), resultType))) - return failure(); - } return hlo::verifyReshapeOp(getLoc(), getOperand(), getResult()); } diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 9fa2054f3bc069..25647b02b17ae7 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2122,7 +2122,7 @@ def MHLO_BroadcastInDimOp : MHLO_Op<"broadcast_in_dim", MHLO_BroadcastDimAttr:$broadcast_dimensions ); - let results = (outs MHLO_AnyTensor); + let results = (outs MHLO_StaticShapeOrBoundedDimTensor); let hasFolder = 1; let hasCanonicalizer = 1; @@ -2881,7 +2881,7 @@ def MHLO_ReshapeOp: MHLO_Op<"reshape", let arguments = (ins MHLO_AnyTensor:$operand); - let results = (outs MHLO_AnyTensor); + let results = (outs MHLO_StaticShapeOrBoundedDimTensor); let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/bounded_dynamism.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/bounded_dynamism.mlir index c674f734082e8a..9b9d8754f90721 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/bounded_dynamism.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/bounded_dynamism.mlir @@ -13,8 +13,17 @@ func.func @reshape_with_single_bounded_dimension(%arg0: tensor>) -> tensor<1x?xf32, #mhlo.type_extensions> { + %0 = mhlo.reshape %arg0 : (tensor>) -> tensor<1x?xf32, #mhlo.type_extensions> + // CHECK: return {{.*}} #mhlo.type_extensions> +} + +// ----- + func.func @reshape_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor> { - // expected-error@+1 {{load bearing ops with dynamic dimensions must have a single bounded dimension}} + // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} %0 = mhlo.reshape %arg0 : (tensor>) -> tensor> return %0 : tensor> } @@ -31,7 +40,25 @@ func.func @broadcast_in_dim_with_single_bounded_dimension(%arg0: tensor<1x?xf32, // ----- func.func @broadcast_in_dim_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor<2x?x?xf32, #mhlo.type_extensions> { - // expected-error@+1 {{load bearing ops with dynamic dimensions must have a single bounded dimension}} + // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor>) -> tensor<2x?x?xf32, #mhlo.type_extensions> return %0 : tensor<2x?x?xf32, #mhlo.type_extensions> } + +// ----- + +// CHECK-LABEL: constant_splat_broadcast +func.func @constant_splat_broadcast() -> tensor<1x?xf32, #mhlo.type_extensions> { + %0 = mhlo.constant dense<1.0> : tensor + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<1x?xf32, #mhlo.type_extensions> + // CHECK: tensor<1x?xf32, #mhlo.type_extensions> + return %1 : tensor<1x?xf32, #mhlo.type_extensions> +} + +// ----- + +func.func @constant_with_dynamic_shape() -> tensor<1x?xf32, #mhlo.type_extensions> { + // expected-error@+2 {{elements literal type must have static shape}} + %c = mhlo.constant dense<1> : tensor<1x?xf32, #mhlo.type_extensions> + return %c : tensor<1x?xf32, #mhlo.type_extensions> +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 9a8f58d8e763e2..a69aac61d4649f 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1739,6 +1739,23 @@ func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// ============ BOUNDED DYNAMISM ============ + +// CHECK-LABEL: bounded_dynamism_reshape +func.func @bounded_dynamism_reshape(%arg0: tensor>) -> tensor> { + // CHECK: stablehlo.reshape{{.*}}tensor> + %0 = "mhlo.reshape"(%arg0) : (tensor>) + -> tensor> + return %0 : tensor> +} + +// CHECK-LABEL: bounded_dynamism_broadcast_in_dim +func.func @bounded_dynamism_broadcast_in_dim(%arg0: tensor<1x?xf32, #mhlo.type_extensions>) -> tensor<2x1x?xf32, #mhlo.type_extensions> { + // CHECK: stablehlo.broadcast_in_dim{{.*}}tensor<2x1x?xf32, #stablehlo.bounds> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x?xf32, #mhlo.type_extensions>) -> tensor<2x1x?xf32, #mhlo.type_extensions> + return %0 : tensor<2x1x?xf32, #mhlo.type_extensions> +} + // ============ TYPES ============ // CHECK-LABEL: "type_i1" @@ -2156,23 +2173,6 @@ func.func @op_fusion(%arg0: tensor) -> tensor { // ----- -func.func @bounded_dynamism_reshape(%arg0: tensor>) -> tensor> { - // expected-error@+1 {{'stablehlo.reshape' op result #0 must be statically shaped tensor}} - %0 = "mhlo.reshape"(%arg0) : (tensor>) - -> tensor> - return %0 : tensor> -} - -// ----- - -func.func @bounded_dynamism_broadcast_in_dim(%arg0: tensor<1x?xf32, #mhlo.type_extensions>) -> tensor<2x1x?xf32, #mhlo.type_extensions> { - // expected-error@+1 {{'stablehlo.broadcast_in_dim' op result #0 must be statically shaped tensor}} - %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x?xf32, #mhlo.type_extensions>) -> tensor<2x1x?xf32, #mhlo.type_extensions> - return %0 : tensor<2x1x?xf32, #mhlo.type_extensions> -} - -// ----- - func.func @op_stochastic_convert(%arg0: tensor, %arg1: tensor) -> tensor { // expected-error@+1 {{failed to legalize operation 'mhlo.stochastic_convert' that was explicitly marked illegal}} %0 = "mhlo.stochastic_convert"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 9549615c457dea..f194eb0ffa9550 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -2018,6 +2018,24 @@ func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { // ----- +// ============ BOUNDED DYNAMISM ============ + +// CHECK-LABEL: bounded_dynamism_reshape +func.func @bounded_dynamism_reshape(%arg0: tensor>) -> tensor> { + // CHECK: mhlo.reshape{{.*}}tensor> + %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: bounded_dynamism_broadcast_in_dim +func.func @bounded_dynamism_broadcast_in_dim(%arg0: tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> { + // CHECK: mhlo.broadcast_in_dim{{.*}}tensor<2x1x?xf32, #mhlo.type_extensions> + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> + return %0 : tensor<2x1x?xf32, #stablehlo.bounds> +} + // ============ TYPES ============ // CHECK-LABEL: "type_i1"