Skip to content

Commit

Permalink
Add result accuracy attribute to ExpOp in StableHlo.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685857742
  • Loading branch information
hanrach9 authored and Google-ML-Automation committed Jan 24, 2025
1 parent 6557723 commit d243224
Show file tree
Hide file tree
Showing 12 changed files with 4,827 additions and 149 deletions.
4,681 changes: 4,534 additions & 147 deletions third_party/stablehlo/temporary.patch

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,39 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
return output;
}

// Converts ResultAccuracyAttr to XLA ResultAccuracy proto.
static xla::ResultAccuracy Convert_result_accuracy(
std::optional<mlir::mhlo::ResultAccuracyAttr>
optional_result_accuracy_attr) {
if (!optional_result_accuracy_attr.has_value()) return xla::ResultAccuracy();

auto result_accuracy = xla::ResultAccuracy();
if (optional_result_accuracy_attr.value().getMode().getValue() ==
mlir::mhlo::ResultAccuracyMode::TOLERANCE) {
result_accuracy.mutable_tolerance()->set_atol(
optional_result_accuracy_attr.value().getAtol().convertToFloat());
result_accuracy.mutable_tolerance()->set_rtol(
optional_result_accuracy_attr.value().getRtol().convertToFloat());
result_accuracy.mutable_tolerance()->set_ulps(
optional_result_accuracy_attr.value().getUlps());
} else {
xla::ResultAccuracy::Mode mode;
auto result_accuracy_mode =
::mlir::mhlo::stringifyResultAccuracyMode(
optional_result_accuracy_attr.value().getMode().getValue())
.str();
if (xla::ResultAccuracy::Mode_Parse(result_accuracy_mode, &mode)) {
result_accuracy.set_mode(mode);
} else {
auto* context = optional_result_accuracy_attr.value().getContext();
mlir::emitError(mlir::UnknownLoc::get(context))
<< "unexpected result accuracy mode " << result_accuracy_mode;
return xla::ResultAccuracy();
}
}
return result_accuracy;
}

// Returns an OpSharding proto from the "sharding" attribute of the op. If the
// op doesn't have a sharding attribute or the sharding attribute is invalid,
// returns std::nullopt.
Expand Down
37 changes: 37 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,29 @@ LogicalResult SparseDotOp::verify() {
return success();
}

// ===----------------------------------------------------------------------===//
// ExpOp
//===----------------------------------------------------------------------===//

LogicalResult ResultAccuracyAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol,
APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) {
return hlo::verifyResultAccuracyAttr(
emitError, std::move(atol), std::move(rtol), ulps,
stringifyResultAccuracyMode(mode.getValue()));
}

LogicalResult ExpOp::verify() {
if (auto attr = getResultAccuracyAttr()) {
if (failed(ResultAccuracyAttr::verify([&] { return emitError(); },
attr.getAtol(), attr.getRtol(),
attr.getUlps(), attr.getMode()))) {
return failure();
}
}
return success();
}

//===----------------------------------------------------------------------===//
// FftOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -7261,6 +7284,20 @@ static LogicalResult verifyArgResultAliasAttr(StringAttr attrName,
return success();
}

//===----------------------------------------------------------------------===//
// Custom unary op
//===----------------------------------------------------------------------===//

void ResultAccuracyAttr::print(AsmPrinter& odsPrinter) const {
hlo::printResultAccuracyAttr(odsPrinter, getAtol(), getRtol(), getUlps(),
getMode());
}

Attribute ResultAccuracyAttr::parse(AsmParser& parser, Type type) {
return hlo::parseResultAccuracyAttr<ResultAccuracyAttr,
ResultAccuracyModeAttr>(parser, type);
}

// Each CrossProgramPrefetchAttr specifies a parameter and a ShapeIndex
// (1) the parameter must be valid
// (2) there must be a subshape at the given indices
Expand Down
4 changes: 4 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ def MHLO_ExpOp: MHLO_UnaryElementwiseOp<"exponential",
%result = mhlo.exponential %operand : tensor<2x2xf64>
```
}];
let arguments = (ins MHLO_FpComplexOrQuantizedIntTensor:$operand,
DefaultValuedOptionalAttr<MHLO_ResultAccuracyAttr, "::mlir::mhlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
let results = (outs MHLO_FpComplexOrQuantizedIntTensor:$result);
let hasVerifier = 1;
let hasFolder = 1;
}
def MHLO_Expm1Op: MHLO_UnaryElementwiseOp<"exponential_minus_one",
Expand Down
32 changes: 32 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,36 @@ def MHLO_BoolElementsAttr :
let convertFromStorage = "$_self";
}

//==----------------------------------------------------------------------===//
// Result Accuracy attributes
//===----------------------------------------------------------------------===//

def MHLO_APFloatV1 : APFloatParameter<""> {
let parser = [{
[&]() -> FailureOr<llvm::APFloat> {
double value;
if (failed($_parser.parseFloat(value))) {
return failure();
}
return APFloat(value);
}()
}];
let printer = "$_printer.printFloat($_self);";
}

def MHLO_ResultAccuracyAttr : AttrDef<MHLO_Dialect, "ResultAccuracy"> {
let mnemonic = "result_accuracy";
let summary = "The requested accuracy for unary ops.";
let parameters = (ins
"APFloat":$atol,
"APFloat":$rtol,
"int64_t":$ulps,
MHLO_ResultAccuracyModeAttr:$mode

);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
let constBuilderCall = "mlir::mhlo::ResultAccuracyAttr::get($_builder.getContext(), llvm::APFloat(0.0), llvm::APFloat(0.0), 0, mlir::mhlo::ResultAccuracyModeAttr::get($_builder.getContext(), $0))";
}

#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ATTRS
24 changes: 24 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,28 @@ def MHLO_RngAlgorithmAttr : EnumAttr<MHLO_Dialect, MHLO_RngAlgorithm, "rng_algor
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// Result Accuracy enum definitions.
//===----------------------------------------------------------------------===//

def MHLO_RESULT_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>;
def MHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>;
def MHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>;

def MHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode",
"XLA result accuracy mode.",
[
MHLO_RESULT_ACCURACY_DEFAULT,
MHLO_RESULT_ACCURACY_HIGHEST,
MHLO_RESULT_ACCURACY_TOLERANCE
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mhlo";
}

def MHLO_ResultAccuracyModeAttr : EnumAttr<MHLO_Dialect, MHLO_ResultAccuracyMode, "result_accuracy_mode"> {
let assemblyFormat = "`<` $value `>`";
}


#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS
78 changes: 76 additions & 2 deletions xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ limitations under the License.

#include "mhlo/IR/mhlo_bytecode.h"

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/Base.h"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -173,6 +176,18 @@ enum AttributeCode {
/// operandTupleIndices : svarint[]
/// }
kOutputOperandAlias = 16,

// ResultAccuracyModeAttr {
// mode: varint (encoded enum)
// }
kResultAccuracyModeAttr = 17,

// ResultAccuracyAttr {
// atol: APFloat
// rtol: APFloat
// ulps: svarint
// }
kResultAccuracyAttr = 18,
};

/// This enum contains marker codes used to indicate which type is
Expand Down Expand Up @@ -243,6 +258,10 @@ class MhloBytecodeInterface : public BytecodeDialectInterface {
OutputOperandAliasAttr readOutputOperandAliasAttr(
DialectBytecodeReader &reader) const;
PrecisionAttr readPrecisionAttr(DialectBytecodeReader &reader) const;
ResultAccuracyAttr readResultAccuracyAttr(
DialectBytecodeReader &reader) const;
ResultAccuracyModeAttr readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const;
RngAlgorithmAttr readRngAlgorithmAttr(DialectBytecodeReader &reader) const;
RngDistributionAttr readRngDistributionAttr(
DialectBytecodeReader &reader) const;
Expand All @@ -268,6 +287,8 @@ class MhloBytecodeInterface : public BytecodeDialectInterface {
DialectBytecodeWriter &writer) const;
void write(OutputOperandAliasAttr attr, DialectBytecodeWriter &writer) const;
void write(PrecisionAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const;
void write(RngAlgorithmAttr attr, DialectBytecodeWriter &writer) const;
void write(RngDistributionAttr attr, DialectBytecodeWriter &writer) const;
void write(ScatterDimensionNumbersAttr attr,
Expand Down Expand Up @@ -331,6 +352,10 @@ Attribute MhloBytecodeInterface::readAttribute(
return readOutputOperandAliasAttr(reader);
case mhlo_encoding::kPrecisionAttr:
return readPrecisionAttr(reader);
case mhlo_encoding::kResultAccuracyAttr:
return readResultAccuracyAttr(reader);
case mhlo_encoding::kResultAccuracyModeAttr:
return readResultAccuracyModeAttr(reader);
case mhlo_encoding::kRngAlgorithmAttr:
return readRngAlgorithmAttr(reader);
case mhlo_encoding::kRngDistributionAttr:
Expand Down Expand Up @@ -582,8 +607,9 @@ LogicalResult MhloBytecodeInterface::writeAttribute(
ConvDimensionNumbersAttr, ChannelHandleAttr, DomainKindAttr,
DotDimensionNumbersAttr, FftTypeAttr, FusionKindAttr,
GatherDimensionNumbersAttr, OutputOperandAliasAttr, PrecisionAttr,
RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr,
TransposeAttr, TypeExtensionsAttr>([&](auto attr) {
ResultAccuracyAttr, ResultAccuracyModeAttr, RngAlgorithmAttr,
RngDistributionAttr, ScatterDimensionNumbersAttr, TransposeAttr,
TypeExtensionsAttr>([&](auto attr) {
LOG_WRITE_CALL;
write(attr, writer);
return success();
Expand All @@ -594,6 +620,21 @@ LogicalResult MhloBytecodeInterface::writeAttribute(
});
}

void MhloBytecodeInterface::write(ResultAccuracyModeAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(mhlo_encoding::kResultAccuracyModeAttr);
hlo::bytecode::writeEnumAttribute<ResultAccuracyMode>(attr, writer);
}

void MhloBytecodeInterface::write(ResultAccuracyAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(mhlo_encoding::kResultAccuracyAttr);
writer.writeAPFloatWithKnownSemantics(attr.getAtol());
writer.writeAPFloatWithKnownSemantics(attr.getRtol());
writer.writeSignedVarInt(attr.getUlps());
writer.writeAttribute(attr.getMode());
}

void MhloBytecodeInterface::write(ArgResultAliasAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(mhlo_encoding::kArgResultAliasAttr);
Expand Down Expand Up @@ -790,6 +831,39 @@ void MhloBytecodeInterface::write(TokenType type,
writer.writeVarInt(mhlo_encoding::kTokenType);
}

//===----------------------------------------------------------------------===//
// ResultAccuracyModeAttr

ResultAccuracyModeAttr MhloBytecodeInterface::readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
return hlo::bytecode::readEnumAttribute<ResultAccuracyModeAttr>(
reader, getContext(),
[](uint32_t val) { return symbolizeResultAccuracyMode(val); });
}

//===----------------------------------------------------------------------===//
// ResultAccuracyAttr

ResultAccuracyAttr MhloBytecodeInterface::readResultAccuracyAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
FailureOr<APFloat> atol;
FailureOr<APFloat> rtol;
int64_t ulps = 0;
ResultAccuracyModeAttr mode;
if (failed(atol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(rtol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(reader.readSignedVarInt(ulps)) ||
failed(reader.readAttribute(mode))) {
reader.emitError() << "failed to read APFloat for atol";
return ResultAccuracyAttr();
}
return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode);
}

} // namespace

void addBytecodeInterface(MhloDialect *dialect) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,20 @@ Attribute convertDenseArray(mlir::StringAttr hloName, Attribute hloAttr) {
if (!stablehloValue.has_value()) return {}; \
return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value())

stablehlo::ResultAccuracyMode convertResultAccuracyMode(
mhlo::ResultAccuracyMode mode) {
switch (mode) {
case mhlo::ResultAccuracyMode::DEFAULT:
return stablehlo::ResultAccuracyMode::DEFAULT;
case mhlo::ResultAccuracyMode::HIGHEST:
return stablehlo::ResultAccuracyMode::HIGHEST;
case mhlo::ResultAccuracyMode::TOLERANCE:
return stablehlo::ResultAccuracyMode::TOLERANCE;
default:
return {};
}
}

Attribute convertAttr(Attribute hloAttr) {
// Handle MHLO attributes.
// The logic that handles attributes from other dialects (e.g. builtin
Expand Down Expand Up @@ -301,6 +315,19 @@ Attribute convertAttr(Attribute hloAttr) {
if (auto attr = mlir::dyn_cast<mhlo::TransposeAttr>(hloAttr)) {
RETURN_CONVERTED_ENUM_ATTR(Transpose);
}
if (auto attr = mlir::dyn_cast<mhlo::ResultAccuracyModeAttr>(hloAttr)) {
RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode);
}
if (auto attr = mlir::dyn_cast<mhlo::ResultAccuracyAttr>(hloAttr)) {
stablehlo::ResultAccuracyModeAttr modeAttr;
modeAttr = stablehlo::ResultAccuracyModeAttr::get(
attr.getContext(),
convertResultAccuracyMode(attr.getMode().getValue()));

return stablehlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(),
attr.getRtol(), attr.getUlps(),
modeAttr);
}
if (hloAttr.getDialect().getNamespace() ==
mhlo::MhloDialect::getDialectNamespace()) {
// Our guiding principle is to support all StableHLO functionality in MHLO.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ namespace {
if (!hloValue.has_value()) return {}; \
return mhlo::Name##Attr::get(attr.getContext(), hloValue.value())

mhlo::ResultAccuracyMode convertResultAccuracyMode(
stablehlo::ResultAccuracyMode mode) {
switch (mode) {
case stablehlo::ResultAccuracyMode::DEFAULT:
return mhlo::ResultAccuracyMode::DEFAULT;
case stablehlo::ResultAccuracyMode::HIGHEST:
return mhlo::ResultAccuracyMode::HIGHEST;
case stablehlo::ResultAccuracyMode::TOLERANCE:
return mhlo::ResultAccuracyMode::TOLERANCE;
default:
return {};
}
}
Attribute convertAttr(Attribute stablehloAttr) {
// StableHLO uses DenseArray for some attributes, MHLO is in the process
// of integrating this change. In the meantime, convert DenseArray to
Expand Down Expand Up @@ -140,6 +153,20 @@ Attribute convertAttr(Attribute stablehloAttr) {
if (auto attr = mlir::dyn_cast<stablehlo::TransposeAttr>(stablehloAttr)) {
RETURN_CONVERTED_ENUM_ATTR(Transpose);
}
if (auto attr =
mlir::dyn_cast<stablehlo::ResultAccuracyModeAttr>(stablehloAttr)) {
RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode);
}
if (auto attr =
mlir::dyn_cast<stablehlo::ResultAccuracyAttr>(stablehloAttr)) {
mhlo::ResultAccuracyModeAttr modeAttr = mhlo::ResultAccuracyModeAttr::get(
attr.getContext(),
convertResultAccuracyMode(attr.getMode().getValue()));

return mhlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(),
attr.getRtol(), attr.getUlps(),
modeAttr);
}
if (stablehloAttr.getDialect().getNamespace() ==
stablehlo::StablehloDialect::getDialectNamespace()) {
// Our guiding principle is to support all StableHLO functionality in MHLO.
Expand Down
Loading

0 comments on commit d243224

Please sign in to comment.