Skip to content

Commit

Permalink
[StableHLO] Port TF bounded dynamism MHLO fixes to StableHLO.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718425491
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Jan 23, 2025
1 parent e795171 commit e9db0f5
Show file tree
Hide file tree
Showing 14 changed files with 294 additions and 75 deletions.
103 changes: 103 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,109 @@ diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeTo
return success();
}
};
diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp
--- stablehlo/stablehlo/dialect/Base.cpp
+++ stablehlo/stablehlo/dialect/Base.cpp
@@ -780,5 +780,18 @@
numScales == rankedType.getDimSize(quantDim));
}

+bool hasSingleBoundedDimension(Type type) {
+ RankedTensorType rankedType = dyn_cast<RankedTensorType>(type);
+ auto boundedAttr =
+ dyn_cast_or_null<BoundedAttrInterface>(rankedType.getEncoding());
+ if (!boundedAttr) return false;
+
+ // count if bounded attr size is not kDynamic
+ int64_t numBoundedDims = llvm::count_if(
+ boundedAttr.getBounds(),
+ [](int64_t bound) { return !ShapedType::isDynamic(bound); });
+ return numBoundedDims == 1;
+}
+
} // namespace hlo
} // namespace mlir
diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h
--- stablehlo/stablehlo/dialect/Base.h
+++ stablehlo/stablehlo/dialect/Base.h
@@ -101,6 +101,9 @@
// mentioned in the StableHLO specification.
bool isValidQuantizedDimension(Type type);

+// Returns true if the given type has a single bounded dimension.
+bool hasSingleBoundedDimension(Type type);
+
// TODO(zhouxin) Move type inference related methods to TypeInference.cpp

std::pair<int64_t, int64_t> inferConcatenatedDimAndBound(int64_t leftSize,
diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/Base.td
--- stablehlo/stablehlo/dialect/Base.td
+++ stablehlo/stablehlo/dialect/Base.td
@@ -29,6 +29,20 @@
def I32RankedTensor : RankedTensorOf<[I32]>;

def UI32RankedTensor : RankedTensorOf<[UI32]>;
+
+//===----------------------------------------------------------------------===//
+// HLO type constraints.
+//===----------------------------------------------------------------------===//
+
+// Note: Bounded dynamisms is largely unspecced and this feature needs more
+// thoguht as it is adopted to modern frameworks. The current support is
+// designed to allow existing TF programs to be representable in StableHLO and
+// is subject to change as a formal design for boudned dynamism is developed.
+def HLO_HasSingleBoundedDimensionPred
+ : CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">;
+
+def HLO_HasStaticOrBoundedShapePred
+ : Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>;

//===----------------------------------------------------------------------===//
// HLO type definitions.
@@ -267,6 +281,9 @@
def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">;

+def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
+ [IsValidQuantizedDimension, HLO_HasStaticOrBoundedShapePred], "statically shaped or single bounded dimension tensor">;
+
def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>;

def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>;
diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td
--- stablehlo/stablehlo/dialect/StablehloOps.td
+++ stablehlo/stablehlo/dialect/StablehloOps.td
@@ -1963,7 +1963,7 @@
DenseI64ArrayAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/
);

- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor);
+ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor);

let hasVerifier = 1;

@@ -2715,7 +2715,7 @@

let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand);

- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor);
+ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor);
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp
--- stablehlo/stablehlo/dialect/TypeInference.cpp
+++ stablehlo/stablehlo/dialect/TypeInference.cpp
@@ -3726,6 +3726,9 @@
Value result) {
auto operandType = cast<RankedTensorType>(operand.getType());

+ // No additional verification on shapes with single bounded dimension.
+ if (!operandType.hasStaticShape()) return success();
+
// broadcast_in_dim_c1
if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandType,
result.getType())))
diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp
--- stablehlo/stablehlo/dialect/Version.cpp
+++ stablehlo/stablehlo/dialect/Version.cpp
Expand Down
10 changes: 10 additions & 0 deletions xla/hlo/builder/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,16 @@ TEST(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) {
GmockMatch(m::Broadcast(m::Reshape(m::Broadcast()))));
}

TEST(XlaBuilderTest, BroadcastInDimWithBoundedDim) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[2, <=3]"));
auto x = Parameter(&b, 0, shape, "x");
BroadcastInDim(x, {1, 2, 3},
/*broadcast_dimensions=*/{1, 2});
TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
EXPECT_THAT(GetRoot(*module), GmockMatch(m::Broadcast()));
}

TEST(XlaBuilderTest, BroadcastInDimWithNegativeSize) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x");
Expand Down
72 changes: 66 additions & 6 deletions xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <optional>
#include <string>
#include <unordered_map>
Expand All @@ -44,6 +45,7 @@ limitations under the License.
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
Expand Down Expand Up @@ -170,6 +172,61 @@ Operation* createReturnOp(mlir::OpBuilder& builder, mlir::Location loc,
return builder.create<mlir::mhlo::ReturnOp>(loc, operands);
}

// Creates an array of zeros like the given MLIR type, if type has bounded
// dynamism, the constant is padded and set to the dimenison size of the
// operand.
//
// Example ZeroLike([<=5] operand):
// %c = constant dense<0> : tensor<5xf32>
// %0 = get_dimension_size %operand
// %1 = set_dimension_size %c, %0, bounded_dim={0}
//
// Note: Currently this only supports a single bounded dimension.
absl::StatusOr<mlir::Value> createConstantZeroLike(mlir::Value operand,
Shape input_shape,
mlir::OpBuilder* builder,
mlir::Location loc) {
TF_ASSIGN_OR_RETURN(
mlir::RankedTensorType type,
ConvertTensorShapeToType<mlir::RankedTensorType>(input_shape, *builder));

LLVM_DEBUG(llvm::dbgs() << "CreateConstantZeroLike: " << operand << ", "
<< type << '\n');
if (type.hasStaticShape())
return builder
->create<mlir::mhlo::ConstantOp>(loc, builder->getZeroAttr(type))
->getResult(0);

// Note: Currently this only supports a single bounded dimension.
if (llvm::count_if(type.getShape(), [](auto dim) {
return mlir::ShapedType::isDynamic(dim);
}) != 1)
return Internal(
"Currently HLO to MHLO only supports a single bounded dimension.");

auto bounded_dim = std::distance(type.getShape().begin(),
llvm::find_if(type.getShape(), [](auto dim) {
return mlir::ShapedType::isDynamic(dim);
}));

// Create a constant with no bounded dynamism, drop tensor encoding.
ArrayRef<int64_t> padded_dims(input_shape.dimensions().begin(),
input_shape.dimensions().end());
auto padded_type =
mlir::RankedTensorType::get(padded_dims, type.getElementType());
auto padded_constant = builder->create<mlir::mhlo::ConstantOp>(
loc, builder->getZeroAttr(padded_type));

// Get or Set the dimensions size based on the operand type.
auto dim_size = builder->create<mlir::mhlo::GetDimensionSizeOp>(
loc, operand, builder->getI64IntegerAttr(bounded_dim));
std::vector<mlir::Value> operands = {padded_constant->getResult(0), dim_size};
std::vector<mlir::NamedAttribute> attributes{builder->getNamedAttr(
"dimension", builder->getI64IntegerAttr(bounded_dim))};
return builder->create<mlir::mhlo::SetDimensionSizeOp>(loc, type, operands,
attributes);
}

} // namespace

void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands(
Expand Down Expand Up @@ -1847,13 +1904,16 @@ absl::StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(

// Return type is boolean, let's use `operand != 0` instead of Convert.
Shape input_shape = instruction->operand(0)->shape();
TF_ASSIGN_OR_RETURN(mlir::Type type,
ConvertTensorShapeToType<mlir::RankedTensorType>(
input_shape, *func_builder));
auto zero = func_builder->create<mlir::mhlo::ConstantOp>(
loc, func_builder->getZeroAttr(type));
TF_ASSIGN_OR_RETURN(
mlir::Value zero,
createConstantZeroLike(operands[0], input_shape, func_builder, loc));
std::vector<mlir::Value> compare_operands = {operands[0], zero};
std::vector<mlir::NamedAttribute> attributes = {builder_->getNamedAttr(
"comparison_direction", mlir::mhlo::ComparisonDirectionAttr::get(
func_builder->getContext(),
mlir::mhlo::ComparisonDirection::NE))};
return {func_builder->create<mlir::mhlo::CompareOp>(
loc, operands[0], zero, mlir::mhlo::ComparisonDirection::NE)};
loc, result_type, compare_operands, attributes)};
}
case HloOpcode::kOptimizationBarrier: {
llvm::SmallVector<Value> flattened_operands;
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/translate/hlo_to_mhlo/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ lit_test_suite(
"if_conditional.hlo",
"import.hlo",
"import_async.hlo",
"import_bounded_dynamism.hlo",
"import_entry_computation_layout.hlo",
"layouts_and_names.hlo",
"location.hlo",
Expand Down
26 changes: 26 additions & 0 deletions xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations -split-input-file %s -o - | FileCheck %s


HloModule main, entry_computation_layout={(pred[<=1801,1]{1,0})->pred[<=1801,1]{1,0}}

ENTRY %convert_with_predicate (Arg_0.1: pred[<=1801,1]) -> pred[<=1801,1] {
// CHECK: [[CST:%.*]] = mhlo.constant dense<false> : tensor<1801x1xi1>
// CHECK-NEXT: [[GDS:%.*]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>) -> tensor<i32>
// CHECK-NEXT: [[SDS:%.*]] = "mhlo.set_dimension_size"([[CST]], [[GDS]]) <{dimension = 0 : i64}> : (tensor<1801x1xi1>, tensor<i32>) -> tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>
// CHECK-NEXT: [[CMP:%.*]] = mhlo.compare NE, %arg0, [[SDS]] : (tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>, tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>) -> tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>
// CHECK-NEXT: return [[CMP]] : tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>
%Arg_0.1 = pred[<=1801,1] parameter(0)
ROOT %convert_pred = pred[<=1801,1] convert(%Arg_0.1)
}

// -----

HloModule main, entry_computation_layout={(f32[<=1801,1]{1,0})->f32[<=1801,1]{1,0}}

ENTRY %convert_with_f32icate (Arg_0.1: f32[<=1801,1]) -> f32[<=1801,1] {
// CHECK: [[CVT:%.*]] = mhlo.convert %arg0 : tensor<?x1xf32, #mhlo.type_extensions<bounds = [1801, ?]>>
// CHECK-NEXT: return [[CVT]] : tensor<?x1xf32, #mhlo.type_extensions<bounds = [1801, ?]>>
%Arg_0.1 = f32[<=1801,1] parameter(0)
ROOT %convert_f32 = f32[<=1801,1] convert(%Arg_0.1)
}

5 changes: 4 additions & 1 deletion xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1595,8 +1595,11 @@ LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op)))
return failure();

// Use TypeToShape to handle bounded dynamism.
// HLO expects broadcast sizes to use the bound's value, not kDynamic.
xla::Shape shape = xla::TypeToShape(type);
value_map[op] =
BroadcastInDim(operand, Convert_ArrayRef(type.getShape()),
BroadcastInDim(operand, shape.dimensions(),
Convert_broadcast_dimensions(op.getBroadcastDimensions()));
return success();
}
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/translate/mhlo_to_hlo/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ lit_test_suite(
"export.mlir",
"export_async.mlir",
"export_and_check_layouts.mlir",
"export_bounded_dynamism.mlir",
"export_entry_computation_layout.mlir",
"export_large_constants.mlir",
"export_replicas.mlir",
Expand Down
17 changes: 17 additions & 0 deletions xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s

// CHECK-LITERAL: HloModule main
func.func @main(%arg0: tensor<1x1x?xf32, #mhlo.type_extensions<bounds = [?, ?, 1801]>>) -> tensor<1x16x1x?xf32, #mhlo.type_extensions<bounds = [?, ?, ?, 1801]>> {
// CHECK: ROOT {{.*}} = f32[1,16,1,<=1801] broadcast
%0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>}> : (tensor<1x1x?xf32, #mhlo.type_extensions<bounds = [?, ?, 1801]>>) -> tensor<1x16x1x?xf32, #mhlo.type_extensions<bounds = [?, ?, ?, 1801]>>
return %0 : tensor<1x16x1x?xf32, #mhlo.type_extensions<bounds = [?, ?, ?, 1801]>>
}

// -----

// CHECK-LITERAL: HloModule main
func.func @main(%arg0: tensor<1x?x512xf32, #mhlo.type_extensions<bounds = [?, 1800, ?]>>, %arg1: tensor<i32>) -> tensor<1x?x512xf32, #mhlo.type_extensions<bounds = [?, 1800, ?]>> {
// CHECK: ROOT {{.*}} = f32[1,<=1800,512] set-dimension-size
%0 = "mhlo.set_dimension_size"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor<1x?x512xf32, #mhlo.type_extensions<bounds = [?, 1800, ?]>>, tensor<i32>) -> tensor<1x?x512xf32, #mhlo.type_extensions<bounds = [?, 1800, ?]>>
return %0 : tensor<1x?x512xf32, #mhlo.type_extensions<bounds = [?, 1800, ?]>>
}
2 changes: 2 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_base.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ defvar MHLO_PredIntFpOrQuantizedTensor = HLO_PredIntFpOrQuantizedTensor;
// it is for a legacy op which is only correct with static shapes.
defvar MHLO_StaticShapeTensor = HLO_StaticShapeTensorOrPerAxisQuantizedTensor;

defvar MHLO_StaticShapeOrBoundedDimTensor = HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor;

defvar MHLO_StaticShapeTensorOrToken = HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken;

defvar MHLO_StaticShapeIntOrFpTensor = HLO_StaticShapeIntOrFpTensor;
Expand Down
47 changes: 0 additions & 47 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2526,45 +2526,7 @@ LogicalResult BroadcastOp::reifyReturnTypeShapes(
// BroadcastInDimOp
//===----------------------------------------------------------------------===//

namespace {
LogicalResult verifySingleBoundedDynamicDimension(Operation* op,
RankedTensorType type) {
auto errFn = [&]() {
return op->emitOpError() << "load bearing ops with dynamic dimensions must "
"have a single bounded dimension";
};

// Get bounded dims
TypeExtensionsAttr encoding =
dyn_cast_or_null<TypeExtensionsAttr>(type.getEncoding());
if (!encoding) return errFn();

// Check that all dynamic dims are bounded
ArrayRef<int64_t> bounds = encoding.getBounds();
ArrayRef<int64_t> shape = type.getShape();
for (auto [dim, bound] : llvm::zip(shape, bounds)) {
if (ShapedType::isDynamic(dim) && ShapedType::isDynamic(bound))
return errFn();
}

// Check single bounded dimension
if (llvm::count_if(bounds, [](int64_t dim) {
return !ShapedType::isDynamic(dim);
}) > 1) {
return errFn();
}
return success();
}
} // namespace

LogicalResult BroadcastInDimOp::verify() {
ShapedType resultType = getResult().getType();
if (!resultType.hasStaticShape()) {
RankedTensorType rankedResultType = cast<RankedTensorType>(resultType);
if (failed(verifySingleBoundedDynamicDimension(getOperation(),
rankedResultType)))
return failure();
}
return hlo::verifyBroadcastInDimOp(
getLoc(), getOperand(),
llvm::to_vector(getBroadcastDimensions().getValues<int64_t>()),
Expand Down Expand Up @@ -4606,15 +4568,6 @@ LogicalResult ReshapeOp::verify() {
if (!operandType.hasRank() || !resultType.hasRank()) {
return success();
}

// Verify that dynamic outputs are all bounded
// HLO allows this, and it is unclear if MHLO/StableHLO should, but it is
// needed to raise some TF programs from HLO to MHLO.
if (!resultType.hasStaticShape()) {
RankedTensorType resultType = cast<RankedTensorType>(getResult().getType());
if (failed(verifySingleBoundedDynamicDimension(getOperation(), resultType)))
return failure();
}
return hlo::verifyReshapeOp(getLoc(), getOperand(), getResult());
}

Expand Down
4 changes: 2 additions & 2 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -2118,7 +2118,7 @@ def MHLO_BroadcastInDimOp : MHLO_Op<"broadcast_in_dim",
MHLO_BroadcastDimAttr:$broadcast_dimensions
);

let results = (outs MHLO_AnyTensor);
let results = (outs MHLO_StaticShapeOrBoundedDimTensor);

let hasFolder = 1;
let hasCanonicalizer = 1;
Expand Down Expand Up @@ -2877,7 +2877,7 @@ def MHLO_ReshapeOp: MHLO_Op<"reshape",

let arguments = (ins MHLO_AnyTensor:$operand);

let results = (outs MHLO_AnyTensor);
let results = (outs MHLO_StaticShapeOrBoundedDimTensor);
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
Expand Down
Loading

0 comments on commit e9db0f5

Please sign in to comment.