Skip to content

Commit

Permalink
[MHLO] Handle dynamic dimensions in HLO<->MHLO
Browse files Browse the repository at this point in the history
- Fix creating constant zero for ConvertOp HLO->MHLO translation
- Fix broadcast in dim bounded lowering from MHLO->HLO
- Don't use StableHLO verification methods on MHLO ReshapeOp with bounded dynamic outputs

```
$ cat /tmp/t.hlo
HloModule main, entry_computation_layout={(pred[<=1801,1]{1,0})->pred[<=1801,1]{1,0}}

ENTRY %convert_with_predicate (Arg_0.1: pred[<=1801,1]) -> pred[<=1801,1] {
  %Arg_0.1 = pred[<=1801,1] parameter(0)
  ROOT %convert_pred = pred[<=1801,1] convert(%Arg_0.1)
}

$ xla-translate /tmp/t.hlo --hlo-text-to-mlir-hlo
func.func @main(%arg0: tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>) -> tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>> {
  %0 = mhlo.constant dense<false> : tensor<1801x1xi1>
  %1 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>) -> tensor<i32>
  %2 = "mhlo.set_dimension_size"(%0, %1) <{dimension = 0 : i64}> : (tensor<1801x1xi1>, tensor<i32>) -> tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>
  %3 = mhlo.compare  NE, %arg0, %2 : (tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>, tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>) -> tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>
  return %3 : tensor<?x1xi1, #mhlo.type_extensions<bounds = [1801, ?]>>
}
```

Currently this fails when trying to create the `mhlo.constant dense<false>` that gets fed into compare since constants cannot have a bounded size.

PiperOrigin-RevId: 718162448
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Jan 27, 2025
1 parent c96f21e commit 33037f1
Show file tree
Hide file tree
Showing 16 changed files with 474 additions and 89 deletions.
217 changes: 212 additions & 5 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,79 @@ diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/
} // namespace hlo
} // namespace mlir

diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp
--- stablehlo/stablehlo/dialect/Base.cpp
+++ stablehlo/stablehlo/dialect/Base.cpp
@@ -780,5 +780,22 @@
numScales == rankedType.getDimSize(quantDim));
}

+bool hasSingleBoundedDimension(Type type) {
+ RankedTensorType rankedType = dyn_cast<RankedTensorType>(type);
+ auto boundedAttr =
+ dyn_cast_or_null<BoundedAttrInterface>(rankedType.getEncoding());
+ if (!boundedAttr) return false;
+
+ // Count if bounded attr size is not kDynamic
+ int64_t numBoundedDims = llvm::count_if(
+ boundedAttr.getBounds(),
+ [](int64_t bound) { return !ShapedType::isDynamic(bound); });
+ // Also check that there are only bounded dims and no unbounded dims.
+ int64_t numDynamicDims = llvm::count_if(
+ rankedType.getShape(),
+ [](int64_t bound) { return ShapedType::isDynamic(bound); });
+ return numBoundedDims == 1 && numDynamicDims == 1;
+}
+
} // namespace hlo
} // namespace mlir
diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h
--- stablehlo/stablehlo/dialect/Base.h
+++ stablehlo/stablehlo/dialect/Base.h
@@ -101,6 +101,9 @@
// mentioned in the StableHLO specification.
bool isValidQuantizedDimension(Type type);

+// Returns true if the given type has a single bounded dimension.
+bool hasSingleBoundedDimension(Type type);
+
// TODO(zhouxin) Move type inference related methods to TypeInference.cpp

std::pair<int64_t, int64_t> inferConcatenatedDimAndBound(int64_t leftSize,
diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/Base.td
--- stablehlo/stablehlo/dialect/Base.td
+++ stablehlo/stablehlo/dialect/Base.td
@@ -29,6 +29,20 @@
def I32RankedTensor : RankedTensorOf<[I32]>;

def UI32RankedTensor : RankedTensorOf<[UI32]>;
+
+//===----------------------------------------------------------------------===//
+// HLO type constraints.
+//===----------------------------------------------------------------------===//
+
+// Note: Bounded dynamisms is largely unspecced and this feature needs more
+// thoguht as it is adopted to modern frameworks. The current support is
+// designed to allow existing TF programs to be representable in StableHLO and
+// is subject to change as a formal design for boudned dynamism is developed.
+def HLO_HasSingleBoundedDimensionPred
+ : CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">;
+
+def HLO_HasStaticOrSingleBoundedShapePred
+ : Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>;

//===----------------------------------------------------------------------===//
// HLO type definitions.
@@ -267,6 +281,9 @@
def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">;

+def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
+ [IsValidQuantizedDimension, HLO_HasStaticOrSingleBoundedShapePred], "statically shaped or single bounded dimension tensor">;
+
def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>;

def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>;
diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/dialect/CMakeLists.txt
--- stablehlo/stablehlo/dialect/CMakeLists.txt
+++ stablehlo/stablehlo/dialect/CMakeLists.txt
Expand Down Expand Up @@ -440,11 +513,10 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/
diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td
--- stablehlo/stablehlo/dialect/StablehloOps.td
+++ stablehlo/stablehlo/dialect/StablehloOps.td
@@ -327,6 +327,23 @@
```mlir
@@ -328,6 +328,23 @@
%result = stablehlo.exponential %operand : tensor<2x2xf64>
```
+ }];
}];
+ let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
+ DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr, "::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
+ let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
Expand All @@ -461,12 +533,57 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/d
+
+ let assemblyFormat = [{
+ $operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
}];
+ }];
}

def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one",
@@ -1963,7 +1980,7 @@
DenseI64ArrayAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/
);

- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor);
+ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor);

let hasVerifier = 1;

@@ -2715,7 +2732,7 @@

let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand);

- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor);
+ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor);
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp
--- stablehlo/stablehlo/dialect/TypeInference.cpp
+++ stablehlo/stablehlo/dialect/TypeInference.cpp
@@ -3724,9 +3724,8 @@
Value operand,
ArrayRef<int64_t> broadcastDimensions,
Value result) {
+ // broadcast_in_dim_c1
auto operandType = cast<RankedTensorType>(operand.getType());
-
- // broadcast_in_dim_c1
if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandType,
result.getType())))
return failure();
@@ -4658,11 +4657,12 @@
Value result) {
// If the operand type is dynamically shaped there is nothing to verify.
auto operandTy = cast<RankedTensorType>(operand.getType());
- if (!operandTy.hasStaticShape()) return success();
+ auto resultTy = cast<RankedTensorType>(result.getType());
+ if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
+ return success();

// If the operand type is statically shaped (not required) the number of
// elements must match that of the result type.
- auto resultTy = cast<RankedTensorType>(result.getType());
int64_t numResultElements = resultTy.getNumElements();
int64_t numOperandElements = operandTy.getNumElements();
if (numResultElements != numOperandElements)
@@ -5057,5 +5057,30 @@
return success();
}
Expand Down Expand Up @@ -1349,7 +1466,30 @@ diff --ruN a/stablehlo/stablehlo/tests/interpret/chlo/ragged_dot.mlir b/stablehl
diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir
--- stablehlo/stablehlo/tests/ops_stablehlo.mlir
+++ stablehlo/stablehlo/tests/ops_stablehlo.mlir
@@ -1775,6 +1775,30 @@
@@ -1274,6 +1274,22 @@

// -----

+// CHECK-LABEL: func @broadcast_in_dim_dynamic_i1
+func.func @broadcast_in_dim_dynamic_i1(%arg0: tensor<?xi32>) -> tensor<1x3xi32> {
+ %0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<?xi32>) -> tensor<1x3xi32>
+ return %0 : tensor<1x3xi32>
+}
+
+// -----
+
+func.func @broadcast_in_dim_dynamic_result(%arg0: tensor<3xi32>) -> tensor<?x3xi32> {
+ // expected-error@+1 {{must be statically shaped or single bounded dimension tensor}}
+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1>} : (tensor<3xi32>) -> tensor<?x3xi32>
+ func.return %0 : tensor<?x3xi32>
+}
+
+// -----
+
// Regression test for b/180052624, where this was improperly marked as an
// invalid stablehlo.broadcast_in_dim op.
// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand
@@ -1775,6 +1791,30 @@
// expected-error@+1 {{'precision_config' failed to satisfy constraint}}
%0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = ["FOO", #stablehlo<precision HIGHEST>]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
func.return %0: tensor<2x2xi32>
Expand Down Expand Up @@ -1380,6 +1520,73 @@ diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/
}

// -----
diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir
--- stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir
+++ stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir
@@ -0,0 +1,63 @@
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file captures some quirks to bounded dynamism in StableHLO that are
+// included to allow StableHLO to repersent existing TF programs.
+
+// CHECK-LABEL: reshape_with_single_bounded_dimension
+func.func @reshape_with_single_bounded_dimension(%arg0: tensor<?x2xf32, #stablehlo.bounds<5, ?>>) -> tensor<2x?xf32, #stablehlo.bounds<?, 5>> {
+ %0 = stablehlo.reshape %arg0 : (tensor<?x2xf32, #stablehlo.bounds<5, ?>>) -> tensor<2x?xf32, #stablehlo.bounds<?, 5>>
+ // CHECK: return {{.*}} #stablehlo.bounds<?, 5>
+ return %0 : tensor<2x?xf32, #stablehlo.bounds<?, 5>>
+}
+
+// -----
+
+// CHECK-LABEL: reshape_scalar_with_single_bounded_dimension
+func.func @reshape_scalar_with_single_bounded_dimension(%arg0: tensor<?xf32, #stablehlo.bounds<5>>) -> tensor<1x?xf32, #stablehlo.bounds<?, 5>> {
+ %0 = stablehlo.reshape %arg0 : (tensor<?xf32, #stablehlo.bounds<5>>) -> tensor<1x?xf32, #stablehlo.bounds<?, 5>>
+ // CHECK: return {{.*}} #stablehlo.bounds<?, 5>
+ return %0 : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
+}
+
+// -----
+
+func.func @reshape_with_multiple_bounded_dimensions(%arg0: tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<?x?xf32, #stablehlo.bounds<5, 5>> {
+ // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}}
+ %0 = stablehlo.reshape %arg0 : (tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<?x?xf32, #stablehlo.bounds<5, 5>>
+ return %0 : tensor<?x?xf32, #stablehlo.bounds<5, 5>>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_in_dim_with_single_bounded_dimension
+func.func @broadcast_in_dim_with_single_bounded_dimension(%arg0: tensor<1x?xf32, #stablehlo.bounds<?, 5>>) -> tensor<2x1x?xf32, #stablehlo.bounds<?, ?, 5>> {
+ %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x?xf32, #stablehlo.bounds<?, 5>>) -> tensor<2x1x?xf32, #stablehlo.bounds<?, ?, 5>>
+ // CHECK: return {{.*}} #stablehlo.bounds<?, ?, 5>
+ return %0 : tensor<2x1x?xf32, #stablehlo.bounds<?, ?, 5>>
+}
+
+// -----
+
+func.func @broadcast_in_dim_with_multiple_bounded_dimensions(%arg0: tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<2x?x?xf32, #stablehlo.bounds<?, 5, 5>> {
+ // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}}
+ %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<2x?x?xf32, #stablehlo.bounds<?, 5, 5>>
+ return %0 : tensor<2x?x?xf32, #stablehlo.bounds<?, 5, 5>>
+}
+
+// -----
+
+// CHECK-LABEL: constant_splat_broadcast
+func.func @constant_splat_broadcast() -> tensor<1x?xf32, #stablehlo.bounds<?, 5>> {
+ %0 = stablehlo.constant dense<1.0> : tensor<f32>
+ %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1x?xf32, #stablehlo.bounds<?, 5>>
+ // CHECK: tensor<1x?xf32, #stablehlo.bounds<?, 5>>
+ return %1 : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
+}
+
+// -----
+
+func.func @constant_with_dynamic_shape() -> tensor<1x?xf32, #stablehlo.bounds<?, 5>> {
+ // expected-error@+2 {{elements literal type must have static shape}}
+ %c = stablehlo.constant dense<1> : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
+ return %c : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
+}
diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir
--- stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir
+++ stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir
Expand Down
10 changes: 10 additions & 0 deletions xla/hlo/builder/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,16 @@ TEST(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) {
GmockMatch(m::Broadcast(m::Reshape(m::Broadcast()))));
}

TEST(XlaBuilderTest, BroadcastInDimWithBoundedDim) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[2, <=3]"));
auto x = Parameter(&b, 0, shape, "x");
BroadcastInDim(x, {1, 2, 3},
/*broadcast_dimensions=*/{1, 2});
TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
EXPECT_THAT(GetRoot(*module), GmockMatch(m::Broadcast()));
}

TEST(XlaBuilderTest, BroadcastInDimWithNegativeSize) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x");
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/translate/hlo_to_mhlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ cc_library(
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:SparseTensorDialect",
"@llvm-project//mlir:Support",
"@stablehlo//:base",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
Expand Down
71 changes: 65 additions & 6 deletions xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <optional>
#include <string>
#include <unordered_map>
Expand All @@ -44,6 +45,7 @@ limitations under the License.
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
Expand All @@ -54,6 +56,7 @@ limitations under the License.
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "stablehlo/dialect/Base.h"
#include "xla/comparison_util.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
Expand Down Expand Up @@ -170,6 +173,59 @@ Operation* createReturnOp(mlir::OpBuilder& builder, mlir::Location loc,
return builder.create<mlir::mhlo::ReturnOp>(loc, operands);
}

// Creates an array of zeros like the given MLIR type, if type has bounded
// dynamism, the constant is padded and set to the dimenison size of the
// operand.
//
// Example ZeroLike([<=5] operand):
// %c = constant dense<0> : tensor<5xf32>
// %0 = get_dimension_size %operand
// %1 = set_dimension_size %c, %0, bounded_dim={0}
//
// Note: Currently this only supports a single bounded dimension.
absl::StatusOr<mlir::Value> createConstantZeroLike(mlir::Value operand,
Shape input_shape,
mlir::OpBuilder* builder,
mlir::Location loc) {
TF_ASSIGN_OR_RETURN(
mlir::RankedTensorType type,
ConvertTensorShapeToType<mlir::RankedTensorType>(input_shape, *builder));

LLVM_DEBUG(llvm::dbgs() << "CreateConstantZeroLike: " << operand << ", "
<< type << '\n');
if (type.hasStaticShape())
return builder
->create<mlir::mhlo::ConstantOp>(loc, builder->getZeroAttr(type))
->getResult(0);

// Note: Currently this only supports a single bounded dimension.
if (!mlir::hlo::hasSingleBoundedDimension(type))
return Internal(
"Currently HLO to MHLO only supports a single bounded dimension.");

auto bounded_dim = std::distance(type.getShape().begin(),
llvm::find_if(type.getShape(), [](auto dim) {
return mlir::ShapedType::isDynamic(dim);
}));

// Create a constant with no bounded dynamism, drop tensor encoding.
ArrayRef<int64_t> padded_dims(input_shape.dimensions().begin(),
input_shape.dimensions().end());
auto padded_type =
mlir::RankedTensorType::get(padded_dims, type.getElementType());
auto padded_constant = builder->create<mlir::mhlo::ConstantOp>(
loc, builder->getZeroAttr(padded_type));

// Get or Set the dimensions size based on the operand type.
auto dim_size = builder->create<mlir::mhlo::GetDimensionSizeOp>(
loc, operand, builder->getI64IntegerAttr(bounded_dim));
std::vector<mlir::Value> operands = {padded_constant->getResult(0), dim_size};
std::vector<mlir::NamedAttribute> attributes{builder->getNamedAttr(
"dimension", builder->getI64IntegerAttr(bounded_dim))};
return builder->create<mlir::mhlo::SetDimensionSizeOp>(loc, type, operands,
attributes);
}

} // namespace

void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands(
Expand Down Expand Up @@ -1871,13 +1927,16 @@ absl::StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(

// Return type is boolean, let's use `operand != 0` instead of Convert.
Shape input_shape = instruction->operand(0)->shape();
TF_ASSIGN_OR_RETURN(mlir::Type type,
ConvertTensorShapeToType<mlir::RankedTensorType>(
input_shape, *func_builder));
auto zero = func_builder->create<mlir::mhlo::ConstantOp>(
loc, func_builder->getZeroAttr(type));
TF_ASSIGN_OR_RETURN(
mlir::Value zero,
createConstantZeroLike(operands[0], input_shape, func_builder, loc));
std::vector<mlir::Value> compare_operands = {operands[0], zero};
std::vector<mlir::NamedAttribute> attributes = {builder_->getNamedAttr(
"comparison_direction", mlir::mhlo::ComparisonDirectionAttr::get(
func_builder->getContext(),
mlir::mhlo::ComparisonDirection::NE))};
return {func_builder->create<mlir::mhlo::CompareOp>(
loc, operands[0], zero, mlir::mhlo::ComparisonDirection::NE)};
loc, result_type, compare_operands, attributes)};
}
case HloOpcode::kOptimizationBarrier: {
llvm::SmallVector<Value> flattened_operands;
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/translate/hlo_to_mhlo/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ lit_test_suite(
"if_conditional.hlo",
"import.hlo",
"import_async.hlo",
"import_bounded_dynamism.hlo",
"import_entry_computation_layout.hlo",
"layouts_and_names.hlo",
"location.hlo",
Expand Down
Loading

0 comments on commit 33037f1

Please sign in to comment.