diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index e3368b5228fc0..81030079835a3 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -56,6 +56,109 @@ diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeTo return success(); } }; +diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp +--- stablehlo/stablehlo/dialect/Base.cpp ++++ stablehlo/stablehlo/dialect/Base.cpp +@@ -780,5 +780,18 @@ + numScales == rankedType.getDimSize(quantDim)); + } + ++bool hasSingleBoundedDimension(Type type) { ++ RankedTensorType rankedType = dyn_cast(type); ++ auto boundedAttr = ++ dyn_cast_or_null(rankedType.getEncoding()); ++ if (!boundedAttr) return false; ++ ++ // count if bounded attr size is not kDynamic ++ int64_t numBoundedDims = llvm::count_if( ++ boundedAttr.getBounds(), ++ [](int64_t bound) { return !ShapedType::isDynamic(bound); }); ++ return numBoundedDims == 1; ++} ++ + } // namespace hlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h +--- stablehlo/stablehlo/dialect/Base.h ++++ stablehlo/stablehlo/dialect/Base.h +@@ -101,6 +101,9 @@ + // mentioned in the StableHLO specification. + bool isValidQuantizedDimension(Type type); + ++// Returns true if the given type has a single bounded dimension. ++bool hasSingleBoundedDimension(Type type); ++ + // TODO(zhouxin) Move type inference related methods to TypeInference.cpp + + std::pair inferConcatenatedDimAndBound(int64_t leftSize, +diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/Base.td +--- stablehlo/stablehlo/dialect/Base.td ++++ stablehlo/stablehlo/dialect/Base.td +@@ -29,6 +29,20 @@ + def I32RankedTensor : RankedTensorOf<[I32]>; + + def UI32RankedTensor : RankedTensorOf<[UI32]>; ++ ++//===----------------------------------------------------------------------===// ++// HLO type constraints. ++//===----------------------------------------------------------------------===// ++ ++// Note: Bounded dynamisms is largely unspecced and this feature needs more ++// thoguht as it is adopted to modern frameworks. The current support is ++// designed to allow existing TF programs to be representable in StableHLO and ++// is subject to change as a formal design for boudned dynamism is developed. ++def HLO_HasSingleBoundedDimensionPred ++ : CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">; ++ ++def HLO_HasStaticOrBoundedShapePred ++ : Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>; + + //===----------------------------------------------------------------------===// + // HLO type definitions. +@@ -267,6 +281,9 @@ + def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], + [IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">; + ++def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], ++ [IsValidQuantizedDimension, HLO_HasStaticOrBoundedShapePred], "statically shaped or single bounded dimension tensor">; ++ + def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>; + + def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>; +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td +--- stablehlo/stablehlo/dialect/StablehloOps.td ++++ stablehlo/stablehlo/dialect/StablehloOps.td +@@ -1963,7 +1963,7 @@ + DenseI64ArrayAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/ + ); + +- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); ++ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor); + + let hasVerifier = 1; + +@@ -2715,7 +2715,7 @@ + + let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand); + +- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); ++ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor); + let hasVerifier = 1; + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; +diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp +--- stablehlo/stablehlo/dialect/TypeInference.cpp ++++ stablehlo/stablehlo/dialect/TypeInference.cpp +@@ -3726,6 +3726,9 @@ + Value result) { + auto operandType = cast(operand.getType()); + ++ // No additional verification on shapes with single bounded dimension. ++ if (!operandType.hasStaticShape()) return success(); ++ + // broadcast_in_dim_c1 + if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandType, + result.getType()))) diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp --- stablehlo/stablehlo/dialect/Version.cpp +++ stablehlo/stablehlo/dialect/Version.cpp diff --git a/xla/hlo/builder/xla_builder_test.cc b/xla/hlo/builder/xla_builder_test.cc index e38125c7cd3b2..b8c72f95d8147 100644 --- a/xla/hlo/builder/xla_builder_test.cc +++ b/xla/hlo/builder/xla_builder_test.cc @@ -566,6 +566,16 @@ TEST(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { GmockMatch(m::Broadcast(m::Reshape(m::Broadcast())))); } +TEST(XlaBuilderTest, BroadcastInDimWithBoundedDim) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[2, <=3]")); + auto x = Parameter(&b, 0, shape, "x"); + BroadcastInDim(x, {1, 2, 3}, + /*broadcast_dimensions=*/{1, 2}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), GmockMatch(m::Broadcast())); +} + TEST(XlaBuilderTest, BroadcastInDimWithNegativeSize) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x"); diff --git a/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index a40e1d571a339..8e292193404a0 100644 --- a/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -44,6 +45,7 @@ limitations under the License. #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" @@ -170,6 +172,61 @@ Operation* createReturnOp(mlir::OpBuilder& builder, mlir::Location loc, return builder.create(loc, operands); } +// Creates an array of zeros like the given MLIR type, if type has bounded +// dynamism, the constant is padded and set to the dimenison size of the +// operand. +// +// Example ZeroLike([<=5] operand): +// %c = constant dense<0> : tensor<5xf32> +// %0 = get_dimension_size %operand +// %1 = set_dimension_size %c, %0, bounded_dim={0} +// +// Note: Currently this only supports a single bounded dimension. +absl::StatusOr createConstantZeroLike(mlir::Value operand, + Shape input_shape, + mlir::OpBuilder* builder, + mlir::Location loc) { + TF_ASSIGN_OR_RETURN( + mlir::RankedTensorType type, + ConvertTensorShapeToType(input_shape, *builder)); + + LLVM_DEBUG(llvm::dbgs() << "CreateConstantZeroLike: " << operand << ", " + << type << '\n'); + if (type.hasStaticShape()) + return builder + ->create(loc, builder->getZeroAttr(type)) + ->getResult(0); + + // Note: Currently this only supports a single bounded dimension. + if (llvm::count_if(type.getShape(), [](auto dim) { + return mlir::ShapedType::isDynamic(dim); + }) != 1) + return Internal( + "Currently HLO to MHLO only supports a single bounded dimension."); + + auto bounded_dim = std::distance(type.getShape().begin(), + llvm::find_if(type.getShape(), [](auto dim) { + return mlir::ShapedType::isDynamic(dim); + })); + + // Create a constant with no bounded dynamism, drop tensor encoding. + ArrayRef padded_dims(input_shape.dimensions().begin(), + input_shape.dimensions().end()); + auto padded_type = + mlir::RankedTensorType::get(padded_dims, type.getElementType()); + auto padded_constant = builder->create( + loc, builder->getZeroAttr(padded_type)); + + // Get or Set the dimensions size based on the operand type. + auto dim_size = builder->create( + loc, operand, builder->getI64IntegerAttr(bounded_dim)); + std::vector operands = {padded_constant->getResult(0), dim_size}; + std::vector attributes{builder->getNamedAttr( + "dimension", builder->getI64IntegerAttr(bounded_dim))}; + return builder->create(loc, type, operands, + attributes); +} + } // namespace void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands( @@ -1847,13 +1904,16 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( // Return type is boolean, let's use `operand != 0` instead of Convert. Shape input_shape = instruction->operand(0)->shape(); - TF_ASSIGN_OR_RETURN(mlir::Type type, - ConvertTensorShapeToType( - input_shape, *func_builder)); - auto zero = func_builder->create( - loc, func_builder->getZeroAttr(type)); + TF_ASSIGN_OR_RETURN( + mlir::Value zero, + createConstantZeroLike(operands[0], input_shape, func_builder, loc)); + std::vector compare_operands = {operands[0], zero}; + std::vector attributes = {builder_->getNamedAttr( + "comparison_direction", mlir::mhlo::ComparisonDirectionAttr::get( + func_builder->getContext(), + mlir::mhlo::ComparisonDirection::NE))}; return {func_builder->create( - loc, operands[0], zero, mlir::mhlo::ComparisonDirection::NE)}; + loc, result_type, compare_operands, attributes)}; } case HloOpcode::kOptimizationBarrier: { llvm::SmallVector flattened_operands; diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/BUILD b/xla/hlo/translate/hlo_to_mhlo/tests/BUILD index 19b7a2314245b..6451a45f20fe4 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/BUILD +++ b/xla/hlo/translate/hlo_to_mhlo/tests/BUILD @@ -22,6 +22,7 @@ lit_test_suite( "if_conditional.hlo", "import.hlo", "import_async.hlo", + "import_bounded_dynamism.hlo", "import_entry_computation_layout.hlo", "layouts_and_names.hlo", "location.hlo", diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism.hlo new file mode 100644 index 0000000000000..76996887f4481 --- /dev/null +++ b/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism.hlo @@ -0,0 +1,26 @@ +// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations -split-input-file %s -o - | FileCheck %s + + +HloModule main, entry_computation_layout={(pred[<=1801,1]{1,0})->pred[<=1801,1]{1,0}} + +ENTRY %convert_with_predicate (Arg_0.1: pred[<=1801,1]) -> pred[<=1801,1] { + // CHECK: [[CST:%.*]] = mhlo.constant dense : tensor<1801x1xi1> + // CHECK-NEXT: [[GDS:%.*]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor>) -> tensor + // CHECK-NEXT: [[SDS:%.*]] = "mhlo.set_dimension_size"([[CST]], [[GDS]]) <{dimension = 0 : i64}> : (tensor<1801x1xi1>, tensor) -> tensor> + // CHECK-NEXT: [[CMP:%.*]] = mhlo.compare NE, %arg0, [[SDS]] : (tensor>, tensor>) -> tensor> + // CHECK-NEXT: return [[CMP]] : tensor> + %Arg_0.1 = pred[<=1801,1] parameter(0) + ROOT %convert_pred = pred[<=1801,1] convert(%Arg_0.1) +} + +// ----- + +HloModule main, entry_computation_layout={(f32[<=1801,1]{1,0})->f32[<=1801,1]{1,0}} + +ENTRY %convert_with_f32icate (Arg_0.1: f32[<=1801,1]) -> f32[<=1801,1] { + // CHECK: [[CVT:%.*]] = mhlo.convert %arg0 : tensor> + // CHECK-NEXT: return [[CVT]] : tensor> + %Arg_0.1 = f32[<=1801,1] parameter(0) + ROOT %convert_f32 = f32[<=1801,1] convert(%Arg_0.1) +} + diff --git a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index c67092d2c9585..3534a00efca20 100644 --- a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -1595,8 +1595,11 @@ LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) return failure(); + // Use TypeToShape to handle bounded dynamism. + // HLO expects broadcast sizes to use the bound's value, not kDynamic. + xla::Shape shape = xla::TypeToShape(type); value_map[op] = - BroadcastInDim(operand, Convert_ArrayRef(type.getShape()), + BroadcastInDim(operand, shape.dimensions(), Convert_broadcast_dimensions(op.getBroadcastDimensions())); return success(); } diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/BUILD b/xla/hlo/translate/mhlo_to_hlo/tests/BUILD index ff7078f4bdda6..f803689c11aa9 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/BUILD +++ b/xla/hlo/translate/mhlo_to_hlo/tests/BUILD @@ -20,6 +20,7 @@ lit_test_suite( "export.mlir", "export_async.mlir", "export_and_check_layouts.mlir", + "export_bounded_dynamism.mlir", "export_entry_computation_layout.mlir", "export_large_constants.mlir", "export_replicas.mlir", diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir new file mode 100644 index 0000000000000..e116f78a70b05 --- /dev/null +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir @@ -0,0 +1,17 @@ +// RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s + +// CHECK-LITERAL: HloModule main +func.func @main(%arg0: tensor<1x1x?xf32, #mhlo.type_extensions>) -> tensor<1x16x1x?xf32, #mhlo.type_extensions> { + // CHECK: ROOT {{.*}} = f32[1,16,1,<=1801] broadcast + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>}> : (tensor<1x1x?xf32, #mhlo.type_extensions>) -> tensor<1x16x1x?xf32, #mhlo.type_extensions> + return %0 : tensor<1x16x1x?xf32, #mhlo.type_extensions> +} + +// ----- + +// CHECK-LITERAL: HloModule main +func.func @main(%arg0: tensor<1x?x512xf32, #mhlo.type_extensions>, %arg1: tensor) -> tensor<1x?x512xf32, #mhlo.type_extensions> { + // CHECK: ROOT {{.*}} = f32[1,<=1800,512] set-dimension-size + %0 = "mhlo.set_dimension_size"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor<1x?x512xf32, #mhlo.type_extensions>, tensor) -> tensor<1x?x512xf32, #mhlo.type_extensions> + return %0 : tensor<1x?x512xf32, #mhlo.type_extensions> +} \ No newline at end of file diff --git a/xla/mlir_hlo/mhlo/IR/hlo_base.td b/xla/mlir_hlo/mhlo/IR/hlo_base.td index cfab8f3857ebc..1f90057883629 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_base.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_base.td @@ -128,6 +128,8 @@ defvar MHLO_PredIntFpOrQuantizedTensor = HLO_PredIntFpOrQuantizedTensor; // it is for a legacy op which is only correct with static shapes. defvar MHLO_StaticShapeTensor = HLO_StaticShapeTensorOrPerAxisQuantizedTensor; +defvar MHLO_StaticShapeOrBoundedDimTensor = HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor; + defvar MHLO_StaticShapeTensorOrToken = HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken; defvar MHLO_StaticShapeIntOrFpTensor = HLO_StaticShapeIntOrFpTensor; diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 925ed003c7e6f..b66ec5f0402d6 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -2526,45 +2526,7 @@ LogicalResult BroadcastOp::reifyReturnTypeShapes( // BroadcastInDimOp //===----------------------------------------------------------------------===// -namespace { -LogicalResult verifySingleBoundedDynamicDimension(Operation* op, - RankedTensorType type) { - auto errFn = [&]() { - return op->emitOpError() << "load bearing ops with dynamic dimensions must " - "have a single bounded dimension"; - }; - - // Get bounded dims - TypeExtensionsAttr encoding = - dyn_cast_or_null(type.getEncoding()); - if (!encoding) return errFn(); - - // Check that all dynamic dims are bounded - ArrayRef bounds = encoding.getBounds(); - ArrayRef shape = type.getShape(); - for (auto [dim, bound] : llvm::zip(shape, bounds)) { - if (ShapedType::isDynamic(dim) && ShapedType::isDynamic(bound)) - return errFn(); - } - - // Check single bounded dimension - if (llvm::count_if(bounds, [](int64_t dim) { - return !ShapedType::isDynamic(dim); - }) > 1) { - return errFn(); - } - return success(); -} -} // namespace - LogicalResult BroadcastInDimOp::verify() { - ShapedType resultType = getResult().getType(); - if (!resultType.hasStaticShape()) { - RankedTensorType rankedResultType = cast(resultType); - if (failed(verifySingleBoundedDynamicDimension(getOperation(), - rankedResultType))) - return failure(); - } return hlo::verifyBroadcastInDimOp( getLoc(), getOperand(), llvm::to_vector(getBroadcastDimensions().getValues()), @@ -4606,15 +4568,6 @@ LogicalResult ReshapeOp::verify() { if (!operandType.hasRank() || !resultType.hasRank()) { return success(); } - - // Verify that dynamic outputs are all bounded - // HLO allows this, and it is unclear if MHLO/StableHLO should, but it is - // needed to raise some TF programs from HLO to MHLO. - if (!resultType.hasStaticShape()) { - RankedTensorType resultType = cast(getResult().getType()); - if (failed(verifySingleBoundedDynamicDimension(getOperation(), resultType))) - return failure(); - } return hlo::verifyReshapeOp(getLoc(), getOperand(), getResult()); } diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 99fbf3e26cace..eca363748b099 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2118,7 +2118,7 @@ def MHLO_BroadcastInDimOp : MHLO_Op<"broadcast_in_dim", MHLO_BroadcastDimAttr:$broadcast_dimensions ); - let results = (outs MHLO_AnyTensor); + let results = (outs MHLO_StaticShapeOrBoundedDimTensor); let hasFolder = 1; let hasCanonicalizer = 1; @@ -2877,7 +2877,7 @@ def MHLO_ReshapeOp: MHLO_Op<"reshape", let arguments = (ins MHLO_AnyTensor:$operand); - let results = (outs MHLO_AnyTensor); + let results = (outs MHLO_StaticShapeOrBoundedDimTensor); let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/bounded_dynamism.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/bounded_dynamism.mlir index c674f734082e8..9b9d8754f9072 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/bounded_dynamism.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/bounded_dynamism.mlir @@ -13,8 +13,17 @@ func.func @reshape_with_single_bounded_dimension(%arg0: tensor>) -> tensor<1x?xf32, #mhlo.type_extensions> { + %0 = mhlo.reshape %arg0 : (tensor>) -> tensor<1x?xf32, #mhlo.type_extensions> + // CHECK: return {{.*}} #mhlo.type_extensions> +} + +// ----- + func.func @reshape_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor> { - // expected-error@+1 {{load bearing ops with dynamic dimensions must have a single bounded dimension}} + // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} %0 = mhlo.reshape %arg0 : (tensor>) -> tensor> return %0 : tensor> } @@ -31,7 +40,25 @@ func.func @broadcast_in_dim_with_single_bounded_dimension(%arg0: tensor<1x?xf32, // ----- func.func @broadcast_in_dim_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor<2x?x?xf32, #mhlo.type_extensions> { - // expected-error@+1 {{load bearing ops with dynamic dimensions must have a single bounded dimension}} + // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor>) -> tensor<2x?x?xf32, #mhlo.type_extensions> return %0 : tensor<2x?x?xf32, #mhlo.type_extensions> } + +// ----- + +// CHECK-LABEL: constant_splat_broadcast +func.func @constant_splat_broadcast() -> tensor<1x?xf32, #mhlo.type_extensions> { + %0 = mhlo.constant dense<1.0> : tensor + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<1x?xf32, #mhlo.type_extensions> + // CHECK: tensor<1x?xf32, #mhlo.type_extensions> + return %1 : tensor<1x?xf32, #mhlo.type_extensions> +} + +// ----- + +func.func @constant_with_dynamic_shape() -> tensor<1x?xf32, #mhlo.type_extensions> { + // expected-error@+2 {{elements literal type must have static shape}} + %c = mhlo.constant dense<1> : tensor<1x?xf32, #mhlo.type_extensions> + return %c : tensor<1x?xf32, #mhlo.type_extensions> +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index e3a50d613cb94..b061d6b3beb9c 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1726,6 +1726,23 @@ func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// ============ BOUNDED DYNAMISM ============ + +// CHECK-LABEL: bounded_dynamism_reshape +func.func @bounded_dynamism_reshape(%arg0: tensor>) -> tensor> { + // CHECK: stablehlo.reshape{{.*}}tensor> + %0 = "mhlo.reshape"(%arg0) : (tensor>) + -> tensor> + return %0 : tensor> +} + +// CHECK-LABEL: bounded_dynamism_broadcast_in_dim +func.func @bounded_dynamism_broadcast_in_dim(%arg0: tensor<1x?xf32, #mhlo.type_extensions>) -> tensor<2x1x?xf32, #mhlo.type_extensions> { + // CHECK: stablehlo.broadcast_in_dim{{.*}}tensor<2x1x?xf32, #stablehlo.bounds> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x?xf32, #mhlo.type_extensions>) -> tensor<2x1x?xf32, #mhlo.type_extensions> + return %0 : tensor<2x1x?xf32, #mhlo.type_extensions> +} + // ============ TYPES ============ // CHECK-LABEL: "type_i1" @@ -2143,23 +2160,6 @@ func.func @op_fusion(%arg0: tensor) -> tensor { // ----- -func.func @bounded_dynamism_reshape(%arg0: tensor>) -> tensor> { - // expected-error@+1 {{'stablehlo.reshape' op result #0 must be statically shaped tensor}} - %0 = "mhlo.reshape"(%arg0) : (tensor>) - -> tensor> - return %0 : tensor> -} - -// ----- - -func.func @bounded_dynamism_broadcast_in_dim(%arg0: tensor<1x?xf32, #mhlo.type_extensions>) -> tensor<2x1x?xf32, #mhlo.type_extensions> { - // expected-error@+1 {{'stablehlo.broadcast_in_dim' op result #0 must be statically shaped tensor}} - %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x?xf32, #mhlo.type_extensions>) -> tensor<2x1x?xf32, #mhlo.type_extensions> - return %0 : tensor<2x1x?xf32, #mhlo.type_extensions> -} - -// ----- - func.func @op_stochastic_convert(%arg0: tensor, %arg1: tensor) -> tensor { // expected-error@+1 {{failed to legalize operation 'mhlo.stochastic_convert' that was explicitly marked illegal}} %0 = "mhlo.stochastic_convert"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index fdf12a56cefb0..5eaac811d1942 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1715,6 +1715,22 @@ func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// ============ BOUNDED DYNAMISM ============ + +// CHECK-LABEL: bounded_dynamism_reshape +func.func @bounded_dynamism_reshape(%arg0: tensor>) -> tensor> { + // CHECK: mhlo.reshape{{.*}}tensor> + %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor> + return %0 : tensor> +} + +// CHECK-LABEL: bounded_dynamism_broadcast_in_dim +func.func @bounded_dynamism_broadcast_in_dim(%arg0: tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> { + // CHECK: mhlo.broadcast_in_dim{{.*}}tensor<2x1x?xf32, #mhlo.type_extensions> + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> + return %0 : tensor<2x1x?xf32, #stablehlo.bounds> +} + // ============ TYPES ============ // CHECK-LABEL: "type_i1"