Skip to content

Commit

Permalink
Add result accuracy attribute to ExpOp in StableHlo.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685857742
  • Loading branch information
hanrach9 authored and Google-ML-Automation committed Dec 24, 2024
1 parent eaf4c75 commit 2914231
Show file tree
Hide file tree
Showing 11 changed files with 359 additions and 2 deletions.
34 changes: 34 additions & 0 deletions xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "mhlo/IR/hlo_ops.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -638,6 +639,39 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
return output;
}

// Converts ResultAccuracyAttr to XLA ResultAccuracy proto.
static xla::ResultAccuracy Convert_result_accuracy(
std::optional<mlir::mhlo::ResultAccuracyAttr>
optional_result_accuracy_attr) {
if (!optional_result_accuracy_attr.has_value()) return xla::ResultAccuracy();

auto result_accuracy = xla::ResultAccuracy();
if (optional_result_accuracy_attr.value().getMode().getValue() ==
mlir::mhlo::ResultAccuracyMode::TOLERANCE) {
result_accuracy.mutable_tolerance()->set_atol(
optional_result_accuracy_attr.value().getAtol().convertToFloat());
result_accuracy.mutable_tolerance()->set_rtol(
optional_result_accuracy_attr.value().getRtol().convertToFloat());
result_accuracy.mutable_tolerance()->set_ulps(
optional_result_accuracy_attr.value().getUlps());
} else {
xla::ResultAccuracy::Mode mode;
auto result_accuracy_mode =
::mlir::mhlo::stringifyResultAccuracyMode(
optional_result_accuracy_attr.value().getMode().getValue())
.str();
if (xla::ResultAccuracy::Mode_Parse(result_accuracy_mode, &mode)) {
result_accuracy.set_mode(mode);
} else {
auto* context = optional_result_accuracy_attr.value().getContext();
mlir::emitError(mlir::UnknownLoc::get(context))
<< "unexpected result accuracy mode " << result_accuracy_mode;
return xla::ResultAccuracy();
}
}
return result_accuracy;
}

// Returns an OpSharding proto from the "sharding" attribute of the op. If the
// op doesn't have a sharding attribute or the sharding attribute is invalid,
// returns std::nullopt.
Expand Down
113 changes: 113 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,29 @@ LogicalResult SparseDotOp::verify() {
return success();
}

// ===----------------------------------------------------------------------===//
// ExpOp
//===----------------------------------------------------------------------===//

LogicalResult ResultAccuracyAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol,
APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) {
return hlo::verifyResultAccuracyAttr(
emitError, atol, rtol, ulps,
stringifyResultAccuracyMode(mode.getValue()));
}

LogicalResult ExpOp::verify() {
if (auto attr = getResultAccuracyAttr()) {
if (failed(ResultAccuracyAttr::verify([&] { return this->emitError(); },
attr.getAtol(), attr.getRtol(),
attr.getUlps(), attr.getMode()))) {
return failure();
}
}
return success();
}

//===----------------------------------------------------------------------===//
// FftOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -7212,6 +7235,96 @@ static LogicalResult verifyArgResultAliasAttr(StringAttr attrName,
return success();
}

//===----------------------------------------------------------------------===//
// Custom unary op
//===----------------------------------------------------------------------===//

void ResultAccuracyAttr::print(AsmPrinter& odsPrinter) const {
odsPrinter << "<atol = ";
odsPrinter.printFloat(getAtol());
odsPrinter << ",";
odsPrinter << " rtol = ";
odsPrinter.printFloat(getRtol());
odsPrinter << ",";
odsPrinter << " ulps = ";
odsPrinter << getUlps();
odsPrinter << ",";
odsPrinter << " mode = ";
odsPrinter.printAttribute(getMode());
odsPrinter << ">";
}

Attribute ResultAccuracyAttr::parse(AsmParser& parser, Type type) {
// ResultAccuractAttr ::= `<` AtolAccuracy `,` RtolAccuracy `,
// ` UlpAccuracy `,` ModeAccuracy `>`
// AtolAccuracy ::= `atol` `=` APFloat
// RtolAccuracy ::= `rtol` `=` APFloat
// UlpAccuracy ::= `ulps` `=` int64_t
// ModeAccuracy ::= `mode` `=` ResultAccuracyModeAttr

// Parse literal '<'
if (parser.parseLess()) return {};
// Parse AtolAccuracy
if (parser.parseKeyword("atol") || parser.parseEqual()) return {};
FailureOr<APFloat> _result_atol = [&]() -> FailureOr<llvm::APFloat> {
double value;
if (failed(parser.parseFloat(value))) {
return failure();
}
return APFloat(value);
}();
if (failed(_result_atol)) {
parser.emitError(parser.getCurrentLocation(),
"failed to parse StableHLO_ResultAccuracyAttr parameter "
"'atol' which is to be a `::llvm::APFloat`");
return {};
}
// Parse literal ','
if (parser.parseComma()) return {};
// Parse RtolAccuracy
if (parser.parseKeyword("rtol") || parser.parseEqual()) return {};
FailureOr<APFloat> _result_rtol = [&]() -> FailureOr<llvm::APFloat> {
double value;
if (failed(parser.parseFloat(value))) {
return failure();
}
return APFloat(value);
}();
if (failed(_result_rtol)) {
parser.emitError(parser.getCurrentLocation(),
"failed to parse StableHLO_ResultAccuracyAttr parameter "
"'rtol' which is to be a `::llvm::APFloat`");
return {};
}
// Parse literal ','
if (parser.parseComma()) return {};
// Parse UlpAccuracy
if (parser.parseKeyword("ulps") || parser.parseEqual()) return {};
int64_t _result_ulps;
if (failed(parser.parseInteger(_result_ulps))) {
parser.emitError(parser.getCurrentLocation(),
"failed to parse StableHLO_ResultAccuracyAttr parameter "
"'ulps' which is to be a `int64_t`");
return {};
}
// Parse literal ','
if (parser.parseComma()) return {};
// Parse ModeAccuracy
if (parser.parseKeyword("mode") || parser.parseEqual()) return {};
ResultAccuracyModeAttr mode_attr;
if (failed(parser.parseAttribute(mode_attr))) {
parser.emitError(
parser.getCurrentLocation(),
"failed to parse StableHLO_ResultAccuracyAttr parameter 'mode' which "
"is to be a `::mlir::mhlo::ResultAccuracyModeAttr`");
return {};
}
// Parse literal '>'
if (parser.parseGreater()) return {};
return ResultAccuracyAttr::get(parser.getContext(), *_result_atol,
*_result_rtol, _result_ulps, mode_attr);
}

// Each CrossProgramPrefetchAttr specifies a parameter and a ShapeIndex
// (1) the parameter must be valid
// (2) there must be a subshape at the given indices
Expand Down
4 changes: 4 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ def MHLO_ExpOp: MHLO_UnaryElementwiseOp<"exponential",
%result = mhlo.exponential %operand : tensor<2x2xf64>
```
}];
let arguments = (ins MHLO_FpComplexOrQuantizedIntTensor:$operand,
DefaultValuedOptionalAttr<MHLO_ResultAccuracyAttr, "ResultAccuracyMode::DEFAULT">:$result_accuracy);
let results = (outs MHLO_FpComplexOrQuantizedIntTensor:$result);
let hasVerifier = 1;
let hasFolder = 1;
}
def MHLO_Expm1Op: MHLO_UnaryElementwiseOp<"exponential_minus_one",
Expand Down
32 changes: 32 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,36 @@ def MHLO_BoolElementsAttr :
let convertFromStorage = "$_self";
}

//==----------------------------------------------------------------------===//
// Result Accuracy attributes
//===----------------------------------------------------------------------===//

def MHLO_APFloatV1 : APFloatParameter<""> {
let parser = [{
[&]() -> FailureOr<llvm::APFloat> {
double value;
if (failed($_parser.parseFloat(value))) {
return failure();
}
return APFloat(value);
}()
}];
let printer = "$_printer.printFloat($_self);";
}

def MHLO_ResultAccuracyAttr : AttrDef<MHLO_Dialect, "ResultAccuracy"> {
let mnemonic = "result_accuracy";
let summary = "The requested accuracy for unary ops.";
let parameters = (ins
"APFloat":$atol,
"APFloat":$rtol,
"int64_t":$ulps,
MHLO_ResultAccuracyModeAttr:$mode

);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
let constBuilderCall = "ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ResultAccuracyModeAttr::get($_builder.getContext(), $0))";
}

#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ATTRS
22 changes: 22 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,26 @@ def MHLO_RngAlgorithmAttr : EnumAttr<MHLO_Dialect, MHLO_RngAlgorithm, "rng_algor
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// Result Accuracy enum definitions.
//===----------------------------------------------------------------------===//

def MHLO_RESULT_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>;
def MHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>;
def MHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>;

def MHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode",
"XLA result accuracy mode.",
[
MHLO_RESULT_ACCURACY_DEFAULT,
MHLO_RESULT_ACCURACY_HIGHEST,
MHLO_RESULT_ACCURACY_TOLERANCE
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mhlo";
}

def MHLO_ResultAccuracyModeAttr : EnumAttr<MHLO_Dialect, MHLO_ResultAccuracyMode, "result_accuracy_mode">;


#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS
76 changes: 74 additions & 2 deletions xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "mhlo/IR/mhlo_bytecode.h"

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -173,6 +174,18 @@ enum AttributeCode {
/// operandTupleIndices : svarint[]
/// }
kOutputOperandAlias = 16,

// ResultAccuracyModeAttr {
// mode: varint (encoded enum)
// }
kResultAccuracyModeAttr = 17,

// ResultAccuracyAttr {
// atol: APFloat
// rtol: APFloat
// ulps: svarint
// }
kResultAccuracyAttr = 18,
};

/// This enum contains marker codes used to indicate which type is
Expand Down Expand Up @@ -251,6 +264,10 @@ class MhloBytecodeInterface : public BytecodeDialectInterface {
TransposeAttr readTransposeAttr(DialectBytecodeReader &reader) const;
TypeExtensionsAttr readTypeExtensionsAttr(
DialectBytecodeReader &reader) const;
ResultAccuracyModeAttr readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const;
ResultAccuracyAttr readResultAccuracyAttr(
DialectBytecodeReader &reader) const;

// TO ADD ATTRIBUTE: Include a write method for each attribute in StableHLO
// Ex: void write(SomeAttr attr, DialectBytecodeWriter &writer) const;
Expand All @@ -274,6 +291,8 @@ class MhloBytecodeInterface : public BytecodeDialectInterface {
DialectBytecodeWriter &writer) const;
void write(TransposeAttr attr, DialectBytecodeWriter &writer) const;
void write(TypeExtensionsAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const;

//===--------------------------------------------------------------------===//
// Types
Expand Down Expand Up @@ -341,6 +360,10 @@ Attribute MhloBytecodeInterface::readAttribute(
return readTransposeAttr(reader);
case mhlo_encoding::kTypeExtensionsAttr:
return readTypeExtensionsAttr(reader);
case mhlo_encoding::kResultAccuracyModeAttr:
return readResultAccuracyModeAttr(reader);
case mhlo_encoding::kResultAccuracyAttr:
return readResultAccuracyAttr(reader);

default:
reader.emitError() << "unknown mhlo attribute code: " << code;
Expand Down Expand Up @@ -582,8 +605,9 @@ LogicalResult MhloBytecodeInterface::writeAttribute(
ConvDimensionNumbersAttr, ChannelHandleAttr, DomainKindAttr,
DotDimensionNumbersAttr, FftTypeAttr, FusionKindAttr,
GatherDimensionNumbersAttr, OutputOperandAliasAttr, PrecisionAttr,
RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr,
TransposeAttr, TypeExtensionsAttr>([&](auto attr) {
ResultAccuracyAttr, ResultAccuracyModeAttr, RngAlgorithmAttr,
RngDistributionAttr, ScatterDimensionNumbersAttr, TransposeAttr,
TypeExtensionsAttr>([&](auto attr) {
LOG_WRITE_CALL;
write(attr, writer);
return success();
Expand All @@ -594,6 +618,21 @@ LogicalResult MhloBytecodeInterface::writeAttribute(
});
}

void MhloBytecodeInterface::write(ResultAccuracyModeAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(mhlo_encoding::kResultAccuracyModeAttr);
hlo::bytecode::writeEnumAttribute<ResultAccuracyMode>(attr, writer);
}

void MhloBytecodeInterface::write(ResultAccuracyAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(mhlo_encoding::kResultAccuracyAttr);
writer.writeAPFloatWithKnownSemantics(attr.getAtol());
writer.writeAPFloatWithKnownSemantics(attr.getRtol());
writer.writeSignedVarInt(attr.getUlps());
writer.writeAttribute(attr.getMode());
}

void MhloBytecodeInterface::write(ArgResultAliasAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(mhlo_encoding::kArgResultAliasAttr);
Expand Down Expand Up @@ -790,6 +829,39 @@ void MhloBytecodeInterface::write(TokenType type,
writer.writeVarInt(mhlo_encoding::kTokenType);
}

//===----------------------------------------------------------------------===//
// ResultAccuracyModeAttr

ResultAccuracyModeAttr MhloBytecodeInterface::readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
return hlo::bytecode::readEnumAttribute<ResultAccuracyModeAttr>(
reader, getContext(),
[](uint32_t val) { return symbolizeResultAccuracyMode(val); });
}

//===----------------------------------------------------------------------===//
// ResultAccuracyAttr

ResultAccuracyAttr MhloBytecodeInterface::readResultAccuracyAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
FailureOr<APFloat> atol;
FailureOr<APFloat> rtol;
int64_t ulps;
ResultAccuracyModeAttr mode;
if (failed(atol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(rtol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(reader.readSignedVarInt(ulps)) ||
failed(reader.readAttribute(mode))) {
reader.emitError() << "failed to read APFloat for atol";
return ResultAccuracyAttr();
}
return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode);
}

} // namespace

void addBytecodeInterface(MhloDialect *dialect) {
Expand Down
Loading

0 comments on commit 2914231

Please sign in to comment.