From 2914231703afd52f84551676ce02b0d5fcae7110 Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Mon, 14 Oct 2024 15:37:20 -0700 Subject: [PATCH] Add result accuracy attribute to ExpOp in StableHlo. PiperOrigin-RevId: 685857742 --- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 34 ++++++ xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 113 ++++++++++++++++++ xla/mlir_hlo/mhlo/IR/hlo_ops.td | 4 + xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td | 32 +++++ xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td | 22 ++++ xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc | 76 +++++++++++- .../hlo_legalize_to_stablehlo.cc | 28 +++++ .../stablehlo_legalize_to_hlo.cc | 28 +++++ .../mhlo/hlo-legalize-to-stablehlo.mlir | 9 ++ xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 6 + .../mhlo/stablehlo-legalize-to-hlo.mlir | 9 ++ 11 files changed, 359 insertions(+), 2 deletions(-) diff --git a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index e837d47418a141..accc5864ab2005 100644 --- a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "mhlo/IR/hlo_ops.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -638,6 +639,39 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( return output; } +// Converts ResultAccuracyAttr to XLA ResultAccuracy proto. +static xla::ResultAccuracy Convert_result_accuracy( + std::optional + optional_result_accuracy_attr) { + if (!optional_result_accuracy_attr.has_value()) return xla::ResultAccuracy(); + + auto result_accuracy = xla::ResultAccuracy(); + if (optional_result_accuracy_attr.value().getMode().getValue() == + mlir::mhlo::ResultAccuracyMode::TOLERANCE) { + result_accuracy.mutable_tolerance()->set_atol( + optional_result_accuracy_attr.value().getAtol().convertToFloat()); + result_accuracy.mutable_tolerance()->set_rtol( + optional_result_accuracy_attr.value().getRtol().convertToFloat()); + result_accuracy.mutable_tolerance()->set_ulps( + optional_result_accuracy_attr.value().getUlps()); + } else { + xla::ResultAccuracy::Mode mode; + auto result_accuracy_mode = + ::mlir::mhlo::stringifyResultAccuracyMode( + optional_result_accuracy_attr.value().getMode().getValue()) + .str(); + if (xla::ResultAccuracy::Mode_Parse(result_accuracy_mode, &mode)) { + result_accuracy.set_mode(mode); + } else { + auto* context = optional_result_accuracy_attr.value().getContext(); + mlir::emitError(mlir::UnknownLoc::get(context)) + << "unexpected result accuracy mode " << result_accuracy_mode; + return xla::ResultAccuracy(); + } + } + return result_accuracy; +} + // Returns an OpSharding proto from the "sharding" attribute of the op. If the // op doesn't have a sharding attribute or the sharding attribute is invalid, // returns std::nullopt. diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index d3d6bb6c158963..fc05ff0f6b82be 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -1214,6 +1214,29 @@ LogicalResult SparseDotOp::verify() { return success(); } +// ===----------------------------------------------------------------------===// +// ExpOp +//===----------------------------------------------------------------------===// + +LogicalResult ResultAccuracyAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, + APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) { + return hlo::verifyResultAccuracyAttr( + emitError, atol, rtol, ulps, + stringifyResultAccuracyMode(mode.getValue())); +} + +LogicalResult ExpOp::verify() { + if (auto attr = getResultAccuracyAttr()) { + if (failed(ResultAccuracyAttr::verify([&] { return this->emitError(); }, + attr.getAtol(), attr.getRtol(), + attr.getUlps(), attr.getMode()))) { + return failure(); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // FftOp //===----------------------------------------------------------------------===// @@ -7212,6 +7235,96 @@ static LogicalResult verifyArgResultAliasAttr(StringAttr attrName, return success(); } +//===----------------------------------------------------------------------===// +// Custom unary op +//===----------------------------------------------------------------------===// + +void ResultAccuracyAttr::print(AsmPrinter& odsPrinter) const { + odsPrinter << ""; +} + +Attribute ResultAccuracyAttr::parse(AsmParser& parser, Type type) { + // ResultAccuractAttr ::= `<` AtolAccuracy `,` RtolAccuracy `, + // ` UlpAccuracy `,` ModeAccuracy `>` + // AtolAccuracy ::= `atol` `=` APFloat + // RtolAccuracy ::= `rtol` `=` APFloat + // UlpAccuracy ::= `ulps` `=` int64_t + // ModeAccuracy ::= `mode` `=` ResultAccuracyModeAttr + + // Parse literal '<' + if (parser.parseLess()) return {}; + // Parse AtolAccuracy + if (parser.parseKeyword("atol") || parser.parseEqual()) return {}; + FailureOr _result_atol = [&]() -> FailureOr { + double value; + if (failed(parser.parseFloat(value))) { + return failure(); + } + return APFloat(value); + }(); + if (failed(_result_atol)) { + parser.emitError(parser.getCurrentLocation(), + "failed to parse StableHLO_ResultAccuracyAttr parameter " + "'atol' which is to be a `::llvm::APFloat`"); + return {}; + } + // Parse literal ',' + if (parser.parseComma()) return {}; + // Parse RtolAccuracy + if (parser.parseKeyword("rtol") || parser.parseEqual()) return {}; + FailureOr _result_rtol = [&]() -> FailureOr { + double value; + if (failed(parser.parseFloat(value))) { + return failure(); + } + return APFloat(value); + }(); + if (failed(_result_rtol)) { + parser.emitError(parser.getCurrentLocation(), + "failed to parse StableHLO_ResultAccuracyAttr parameter " + "'rtol' which is to be a `::llvm::APFloat`"); + return {}; + } + // Parse literal ',' + if (parser.parseComma()) return {}; + // Parse UlpAccuracy + if (parser.parseKeyword("ulps") || parser.parseEqual()) return {}; + int64_t _result_ulps; + if (failed(parser.parseInteger(_result_ulps))) { + parser.emitError(parser.getCurrentLocation(), + "failed to parse StableHLO_ResultAccuracyAttr parameter " + "'ulps' which is to be a `int64_t`"); + return {}; + } + // Parse literal ',' + if (parser.parseComma()) return {}; + // Parse ModeAccuracy + if (parser.parseKeyword("mode") || parser.parseEqual()) return {}; + ResultAccuracyModeAttr mode_attr; + if (failed(parser.parseAttribute(mode_attr))) { + parser.emitError( + parser.getCurrentLocation(), + "failed to parse StableHLO_ResultAccuracyAttr parameter 'mode' which " + "is to be a `::mlir::mhlo::ResultAccuracyModeAttr`"); + return {}; + } + // Parse literal '>' + if (parser.parseGreater()) return {}; + return ResultAccuracyAttr::get(parser.getContext(), *_result_atol, + *_result_rtol, _result_ulps, mode_attr); +} + // Each CrossProgramPrefetchAttr specifies a parameter and a ShapeIndex // (1) the parameter must be valid // (2) there must be a subshape at the given indices diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 4eb95ef326a659..1f3f7d188bc968 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -343,6 +343,10 @@ def MHLO_ExpOp: MHLO_UnaryElementwiseOp<"exponential", %result = mhlo.exponential %operand : tensor<2x2xf64> ``` }]; + let arguments = (ins MHLO_FpComplexOrQuantizedIntTensor:$operand, + DefaultValuedOptionalAttr:$result_accuracy); + let results = (outs MHLO_FpComplexOrQuantizedIntTensor:$result); + let hasVerifier = 1; let hasFolder = 1; } def MHLO_Expm1Op: MHLO_UnaryElementwiseOp<"exponential_minus_one", diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td b/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td index c2c3b0aca31ff0..6b6118a44bc563 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td @@ -332,4 +332,36 @@ def MHLO_BoolElementsAttr : let convertFromStorage = "$_self"; } +//==----------------------------------------------------------------------===// +// Result Accuracy attributes +//===----------------------------------------------------------------------===// + +def MHLO_APFloatV1 : APFloatParameter<""> { + let parser = [{ + [&]() -> FailureOr { + double value; + if (failed($_parser.parseFloat(value))) { + return failure(); + } + return APFloat(value); + }() + }]; + let printer = "$_printer.printFloat($_self);"; +} + +def MHLO_ResultAccuracyAttr : AttrDef { + let mnemonic = "result_accuracy"; + let summary = "The requested accuracy for unary ops."; + let parameters = (ins + "APFloat":$atol, + "APFloat":$rtol, + "int64_t":$ulps, + MHLO_ResultAccuracyModeAttr:$mode + + ); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + let constBuilderCall = "ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ResultAccuracyModeAttr::get($_builder.getContext(), $0))"; +} + #endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ATTRS diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td b/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td index 3e4039ef9598ad..94011333340643 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td @@ -260,4 +260,26 @@ def MHLO_RngAlgorithmAttr : EnumAttr; +def MHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>; +def MHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>; + +def MHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode", + "XLA result accuracy mode.", + [ + MHLO_RESULT_ACCURACY_DEFAULT, + MHLO_RESULT_ACCURACY_HIGHEST, + MHLO_RESULT_ACCURACY_TOLERANCE + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_ResultAccuracyModeAttr : EnumAttr; + + #endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS diff --git a/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc b/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc index 0afccba2e587d6..f86bf93c51f88b 100644 --- a/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc +++ b/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mhlo/IR/mhlo_bytecode.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" @@ -173,6 +174,18 @@ enum AttributeCode { /// operandTupleIndices : svarint[] /// } kOutputOperandAlias = 16, + + // ResultAccuracyModeAttr { + // mode: varint (encoded enum) + // } + kResultAccuracyModeAttr = 17, + + // ResultAccuracyAttr { + // atol: APFloat + // rtol: APFloat + // ulps: svarint + // } + kResultAccuracyAttr = 18, }; /// This enum contains marker codes used to indicate which type is @@ -251,6 +264,10 @@ class MhloBytecodeInterface : public BytecodeDialectInterface { TransposeAttr readTransposeAttr(DialectBytecodeReader &reader) const; TypeExtensionsAttr readTypeExtensionsAttr( DialectBytecodeReader &reader) const; + ResultAccuracyModeAttr readResultAccuracyModeAttr( + DialectBytecodeReader &reader) const; + ResultAccuracyAttr readResultAccuracyAttr( + DialectBytecodeReader &reader) const; // TO ADD ATTRIBUTE: Include a write method for each attribute in StableHLO // Ex: void write(SomeAttr attr, DialectBytecodeWriter &writer) const; @@ -274,6 +291,8 @@ class MhloBytecodeInterface : public BytecodeDialectInterface { DialectBytecodeWriter &writer) const; void write(TransposeAttr attr, DialectBytecodeWriter &writer) const; void write(TypeExtensionsAttr attr, DialectBytecodeWriter &writer) const; + void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const; + void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const; //===--------------------------------------------------------------------===// // Types @@ -341,6 +360,10 @@ Attribute MhloBytecodeInterface::readAttribute( return readTransposeAttr(reader); case mhlo_encoding::kTypeExtensionsAttr: return readTypeExtensionsAttr(reader); + case mhlo_encoding::kResultAccuracyModeAttr: + return readResultAccuracyModeAttr(reader); + case mhlo_encoding::kResultAccuracyAttr: + return readResultAccuracyAttr(reader); default: reader.emitError() << "unknown mhlo attribute code: " << code; @@ -582,8 +605,9 @@ LogicalResult MhloBytecodeInterface::writeAttribute( ConvDimensionNumbersAttr, ChannelHandleAttr, DomainKindAttr, DotDimensionNumbersAttr, FftTypeAttr, FusionKindAttr, GatherDimensionNumbersAttr, OutputOperandAliasAttr, PrecisionAttr, - RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr, - TransposeAttr, TypeExtensionsAttr>([&](auto attr) { + ResultAccuracyAttr, ResultAccuracyModeAttr, RngAlgorithmAttr, + RngDistributionAttr, ScatterDimensionNumbersAttr, TransposeAttr, + TypeExtensionsAttr>([&](auto attr) { LOG_WRITE_CALL; write(attr, writer); return success(); @@ -594,6 +618,21 @@ LogicalResult MhloBytecodeInterface::writeAttribute( }); } +void MhloBytecodeInterface::write(ResultAccuracyModeAttr attr, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(mhlo_encoding::kResultAccuracyModeAttr); + hlo::bytecode::writeEnumAttribute(attr, writer); +} + +void MhloBytecodeInterface::write(ResultAccuracyAttr attr, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(mhlo_encoding::kResultAccuracyAttr); + writer.writeAPFloatWithKnownSemantics(attr.getAtol()); + writer.writeAPFloatWithKnownSemantics(attr.getRtol()); + writer.writeSignedVarInt(attr.getUlps()); + writer.writeAttribute(attr.getMode()); +} + void MhloBytecodeInterface::write(ArgResultAliasAttr attr, DialectBytecodeWriter &writer) const { writer.writeVarInt(mhlo_encoding::kArgResultAliasAttr); @@ -790,6 +829,39 @@ void MhloBytecodeInterface::write(TokenType type, writer.writeVarInt(mhlo_encoding::kTokenType); } +//===----------------------------------------------------------------------===// +// ResultAccuracyModeAttr + +ResultAccuracyModeAttr MhloBytecodeInterface::readResultAccuracyModeAttr( + DialectBytecodeReader &reader) const { + LOG_READ_CALL; + return hlo::bytecode::readEnumAttribute( + reader, getContext(), + [](uint32_t val) { return symbolizeResultAccuracyMode(val); }); +} + +//===----------------------------------------------------------------------===// +// ResultAccuracyAttr + +ResultAccuracyAttr MhloBytecodeInterface::readResultAccuracyAttr( + DialectBytecodeReader &reader) const { + LOG_READ_CALL; + FailureOr atol; + FailureOr rtol; + int64_t ulps; + ResultAccuracyModeAttr mode; + if (failed(atol = + reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || + failed(rtol = + reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || + failed(reader.readSignedVarInt(ulps)) || + failed(reader.readAttribute(mode))) { + reader.emitError() << "failed to read APFloat for atol"; + return ResultAccuracyAttr(); + } + return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode); +} + } // namespace void addBytecodeInterface(MhloDialect *dialect) { diff --git a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 3931db159ec81c..3ecc9c9ed5b6ba 100644 --- a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "stablehlo/dialect/StablehloOps.h" +#include "third_party/stablehlo/stablehlo/dialect/StablehloOps.h" namespace mlir { namespace stablehlo { @@ -227,6 +228,20 @@ Attribute convertDenseArray(mlir::StringAttr hloName, Attribute hloAttr) { if (!stablehloValue.has_value()) return {}; \ return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value()) +stablehlo::ResultAccuracyMode convertResultAccuracyMode( + mhlo::ResultAccuracyMode mode) { + switch (mode) { + case mhlo::ResultAccuracyMode::DEFAULT: + return stablehlo::ResultAccuracyMode::DEFAULT; + case mhlo::ResultAccuracyMode::HIGHEST: + return stablehlo::ResultAccuracyMode::HIGHEST; + case mhlo::ResultAccuracyMode::TOLERANCE: + return stablehlo::ResultAccuracyMode::TOLERANCE; + default: + return {}; + } +} + Attribute convertAttr(Attribute hloAttr) { // Handle MHLO attributes. // The logic that handles attributes from other dialects (e.g. builtin @@ -301,6 +316,19 @@ Attribute convertAttr(Attribute hloAttr) { if (auto attr = mlir::dyn_cast(hloAttr)) { RETURN_CONVERTED_ENUM_ATTR(Transpose); } + if (auto attr = mlir::dyn_cast(hloAttr)) { + RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode); + } + if (auto attr = mlir::dyn_cast(hloAttr)) { + stablehlo::ResultAccuracyModeAttr modeAttr; + modeAttr = stablehlo::ResultAccuracyModeAttr::get( + attr.getContext(), + convertResultAccuracyMode(attr.getMode().getValue())); + + return stablehlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(), + attr.getRtol(), attr.getUlps(), + modeAttr); + } if (hloAttr.getDialect().getNamespace() == mhlo::MhloDialect::getDialectNamespace()) { // Our guiding principle is to support all StableHLO functionality in MHLO. diff --git a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index 7570d34ace0bc1..1b79b19585e1e0 100644 --- a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" +// #include "third_party/stablehlo/stablehlo/dialect/StablehloOps.h" namespace mlir { namespace stablehlo { @@ -44,6 +45,19 @@ namespace { if (!hloValue.has_value()) return {}; \ return mhlo::Name##Attr::get(attr.getContext(), hloValue.value()) +mhlo::ResultAccuracyMode convertResultAccuracyMode( + stablehlo::ResultAccuracyMode mode) { + switch (mode) { + case stablehlo::ResultAccuracyMode::DEFAULT: + return mhlo::ResultAccuracyMode::DEFAULT; + case stablehlo::ResultAccuracyMode::HIGHEST: + return mhlo::ResultAccuracyMode::HIGHEST; + case stablehlo::ResultAccuracyMode::TOLERANCE: + return mhlo::ResultAccuracyMode::TOLERANCE; + default: + return {}; + } +} Attribute convertAttr(Attribute stablehloAttr) { // StableHLO uses DenseArray for some attributes, MHLO is in the process // of integrating this change. In the meantime, convert DenseArray to @@ -139,6 +153,20 @@ Attribute convertAttr(Attribute stablehloAttr) { if (auto attr = mlir::dyn_cast(stablehloAttr)) { RETURN_CONVERTED_ENUM_ATTR(Transpose); } + if (auto attr = + mlir::dyn_cast(stablehloAttr)) { + RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode); + } + if (auto attr = + mlir::dyn_cast(stablehloAttr)) { + mhlo::ResultAccuracyModeAttr modeAttr = mhlo::ResultAccuracyModeAttr::get( + attr.getContext(), + convertResultAccuracyMode(attr.getMode().getValue())); + + return mhlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(), + attr.getRtol(), attr.getUlps(), + modeAttr); + } if (stablehloAttr.getDialect().getNamespace() == stablehlo::StablehloDialect::getDialectNamespace()) { // Our guiding principle is to support all StableHLO functionality in MHLO. diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 92b59bda4c1c05..d4d9410556b1a4 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -335,6 +335,15 @@ func.func @attr_type_extensions_bounds( func.return %arg0 : tensor> } + +// CHECK-LABEL: "attr_result_accuracy_mode" +func.func @attr_result_accuracy_mode(%arg0: tensor) -> tensor { + %0 = "mhlo.exponential"(%arg0) { + result_accuracy = #mhlo.result_accuracy> + } : (tensor) -> tensor + func.return %0 : tensor +} + // ============ OPS ============ // CHECK-LABEL: "op_abs" diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index d07a178c6c4e7f..1317d4bb1fc886 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -1638,6 +1638,12 @@ func.func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi3 func.return %0: tensor<2x2xi32> } +// ----- +// CHECK-LABEL: func @exponential_result_accuracy +func.func @exponential_result_accuracy(%arg0: tensor) -> tensor { + %0 = "mhlo.exponential"(%arg0) {result_accuracy = #mhlo.result_accuracy>} : (tensor) -> tensor + func.return %0: tensor +} // ----- func.func @dot_more_dynamic_output_type(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor { diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index fdf12a56cefb08..1792619b4ff21b 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -222,6 +222,15 @@ func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor< func.return %0 : tensor<8x8xf32> } +// CHECK-LABEL: "attr_result_accuracy" +func.func @attr_result_accuracy(%arg0: tensor) -> tensor { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: result_accuracy = #mhlo.result_accuracy> + result_accuracy = #stablehlo.result_accuracy> + } : (tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "attr_rng_algorithm_default" func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { %0:2 = "stablehlo.rng_bit_generator"(%arg0) {