Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add result accuracy attribute to ExpOp in StableHlo. #20388

Merged
merged 1 commit into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,681 changes: 4,534 additions & 147 deletions third_party/stablehlo/temporary.patch

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,39 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
return output;
}

// Converts ResultAccuracyAttr to XLA ResultAccuracy proto.
static xla::ResultAccuracy Convert_result_accuracy(
std::optional<mlir::mhlo::ResultAccuracyAttr>
optional_result_accuracy_attr) {
if (!optional_result_accuracy_attr.has_value()) return xla::ResultAccuracy();

auto result_accuracy = xla::ResultAccuracy();
if (optional_result_accuracy_attr.value().getMode().getValue() ==
mlir::mhlo::ResultAccuracyMode::TOLERANCE) {
result_accuracy.mutable_tolerance()->set_atol(
optional_result_accuracy_attr.value().getAtol().convertToFloat());
result_accuracy.mutable_tolerance()->set_rtol(
optional_result_accuracy_attr.value().getRtol().convertToFloat());
result_accuracy.mutable_tolerance()->set_ulps(
optional_result_accuracy_attr.value().getUlps());
} else {
xla::ResultAccuracy::Mode mode;
auto result_accuracy_mode =
::mlir::mhlo::stringifyResultAccuracyMode(
optional_result_accuracy_attr.value().getMode().getValue())
.str();
if (xla::ResultAccuracy::Mode_Parse(result_accuracy_mode, &mode)) {
result_accuracy.set_mode(mode);
} else {
auto* context = optional_result_accuracy_attr.value().getContext();
mlir::emitError(mlir::UnknownLoc::get(context))
<< "unexpected result accuracy mode " << result_accuracy_mode;
return xla::ResultAccuracy();
}
}
return result_accuracy;
}

// Returns an OpSharding proto from the "sharding" attribute of the op. If the
// op doesn't have a sharding attribute or the sharding attribute is invalid,
// returns std::nullopt.
Expand Down
37 changes: 37 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,29 @@ LogicalResult SparseDotOp::verify() {
return success();
}

// ===----------------------------------------------------------------------===//
// ExpOp
//===----------------------------------------------------------------------===//

LogicalResult ResultAccuracyAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol,
APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) {
return hlo::verifyResultAccuracyAttr(
emitError, std::move(atol), std::move(rtol), ulps,
stringifyResultAccuracyMode(mode.getValue()));
}

LogicalResult ExpOp::verify() {
if (auto attr = getResultAccuracyAttr()) {
if (failed(ResultAccuracyAttr::verify([&] { return emitError(); },
attr.getAtol(), attr.getRtol(),
attr.getUlps(), attr.getMode()))) {
return failure();
}
}
return success();
}

//===----------------------------------------------------------------------===//
// FftOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -7261,6 +7284,20 @@ static LogicalResult verifyArgResultAliasAttr(StringAttr attrName,
return success();
}

//===----------------------------------------------------------------------===//
// Custom unary op
//===----------------------------------------------------------------------===//

void ResultAccuracyAttr::print(AsmPrinter& odsPrinter) const {
hlo::printResultAccuracyAttr(odsPrinter, getAtol(), getRtol(), getUlps(),
getMode());
}

Attribute ResultAccuracyAttr::parse(AsmParser& parser, Type type) {
return hlo::parseResultAccuracyAttr<ResultAccuracyAttr,
ResultAccuracyModeAttr>(parser, type);
}

// Each CrossProgramPrefetchAttr specifies a parameter and a ShapeIndex
// (1) the parameter must be valid
// (2) there must be a subshape at the given indices
Expand Down
4 changes: 4 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ def MHLO_ExpOp: MHLO_UnaryElementwiseOp<"exponential",
%result = mhlo.exponential %operand : tensor<2x2xf64>
```
}];
let arguments = (ins MHLO_FpComplexOrQuantizedIntTensor:$operand,
DefaultValuedOptionalAttr<MHLO_ResultAccuracyAttr, "::mlir::mhlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
let results = (outs MHLO_FpComplexOrQuantizedIntTensor:$result);
let hasVerifier = 1;
let hasFolder = 1;
}
def MHLO_Expm1Op: MHLO_UnaryElementwiseOp<"exponential_minus_one",
Expand Down
32 changes: 32 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,36 @@ def MHLO_BoolElementsAttr :
let convertFromStorage = "$_self";
}

//==----------------------------------------------------------------------===//
// Result Accuracy attributes
//===----------------------------------------------------------------------===//

def MHLO_APFloatV1 : APFloatParameter<""> {
let parser = [{
[&]() -> FailureOr<llvm::APFloat> {
double value;
if (failed($_parser.parseFloat(value))) {
return failure();
}
return APFloat(value);
}()
}];
let printer = "$_printer.printFloat($_self);";
}

def MHLO_ResultAccuracyAttr : AttrDef<MHLO_Dialect, "ResultAccuracy"> {
let mnemonic = "result_accuracy";
let summary = "The requested accuracy for unary ops.";
let parameters = (ins
"APFloat":$atol,
"APFloat":$rtol,
"int64_t":$ulps,
MHLO_ResultAccuracyModeAttr:$mode

);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
let constBuilderCall = "mlir::mhlo::ResultAccuracyAttr::get($_builder.getContext(), llvm::APFloat(0.0), llvm::APFloat(0.0), 0, mlir::mhlo::ResultAccuracyModeAttr::get($_builder.getContext(), $0))";
}

#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ATTRS
24 changes: 24 additions & 0 deletions xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,28 @@ def MHLO_RngAlgorithmAttr : EnumAttr<MHLO_Dialect, MHLO_RngAlgorithm, "rng_algor
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// Result Accuracy enum definitions.
//===----------------------------------------------------------------------===//

def MHLO_RESULT_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>;
def MHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>;
def MHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>;

def MHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode",
"XLA result accuracy mode.",
[
MHLO_RESULT_ACCURACY_DEFAULT,
MHLO_RESULT_ACCURACY_HIGHEST,
MHLO_RESULT_ACCURACY_TOLERANCE
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mhlo";
}

def MHLO_ResultAccuracyModeAttr : EnumAttr<MHLO_Dialect, MHLO_ResultAccuracyMode, "result_accuracy_mode"> {
let assemblyFormat = "`<` $value `>`";
}


#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS
78 changes: 76 additions & 2 deletions xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ limitations under the License.

#include "mhlo/IR/mhlo_bytecode.h"

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/Base.h"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -173,6 +176,18 @@ enum AttributeCode {
/// operandTupleIndices : svarint[]
/// }
kOutputOperandAlias = 16,

// ResultAccuracyModeAttr {
// mode: varint (encoded enum)
// }
kResultAccuracyModeAttr = 17,

// ResultAccuracyAttr {
// atol: APFloat
// rtol: APFloat
// ulps: svarint
// }
kResultAccuracyAttr = 18,
};

/// This enum contains marker codes used to indicate which type is
Expand Down Expand Up @@ -243,6 +258,10 @@ class MhloBytecodeInterface : public BytecodeDialectInterface {
OutputOperandAliasAttr readOutputOperandAliasAttr(
DialectBytecodeReader &reader) const;
PrecisionAttr readPrecisionAttr(DialectBytecodeReader &reader) const;
ResultAccuracyAttr readResultAccuracyAttr(
DialectBytecodeReader &reader) const;
ResultAccuracyModeAttr readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const;
RngAlgorithmAttr readRngAlgorithmAttr(DialectBytecodeReader &reader) const;
RngDistributionAttr readRngDistributionAttr(
DialectBytecodeReader &reader) const;
Expand All @@ -268,6 +287,8 @@ class MhloBytecodeInterface : public BytecodeDialectInterface {
DialectBytecodeWriter &writer) const;
void write(OutputOperandAliasAttr attr, DialectBytecodeWriter &writer) const;
void write(PrecisionAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const;
void write(RngAlgorithmAttr attr, DialectBytecodeWriter &writer) const;
void write(RngDistributionAttr attr, DialectBytecodeWriter &writer) const;
void write(ScatterDimensionNumbersAttr attr,
Expand Down Expand Up @@ -331,6 +352,10 @@ Attribute MhloBytecodeInterface::readAttribute(
return readOutputOperandAliasAttr(reader);
case mhlo_encoding::kPrecisionAttr:
return readPrecisionAttr(reader);
case mhlo_encoding::kResultAccuracyAttr:
return readResultAccuracyAttr(reader);
case mhlo_encoding::kResultAccuracyModeAttr:
return readResultAccuracyModeAttr(reader);
case mhlo_encoding::kRngAlgorithmAttr:
return readRngAlgorithmAttr(reader);
case mhlo_encoding::kRngDistributionAttr:
Expand Down Expand Up @@ -582,8 +607,9 @@ LogicalResult MhloBytecodeInterface::writeAttribute(
ConvDimensionNumbersAttr, ChannelHandleAttr, DomainKindAttr,
DotDimensionNumbersAttr, FftTypeAttr, FusionKindAttr,
GatherDimensionNumbersAttr, OutputOperandAliasAttr, PrecisionAttr,
RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr,
TransposeAttr, TypeExtensionsAttr>([&](auto attr) {
ResultAccuracyAttr, ResultAccuracyModeAttr, RngAlgorithmAttr,
RngDistributionAttr, ScatterDimensionNumbersAttr, TransposeAttr,
TypeExtensionsAttr>([&](auto attr) {
LOG_WRITE_CALL;
write(attr, writer);
return success();
Expand All @@ -594,6 +620,21 @@ LogicalResult MhloBytecodeInterface::writeAttribute(
});
}

void MhloBytecodeInterface::write(ResultAccuracyModeAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(mhlo_encoding::kResultAccuracyModeAttr);
hlo::bytecode::writeEnumAttribute<ResultAccuracyMode>(attr, writer);
}

void MhloBytecodeInterface::write(ResultAccuracyAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(mhlo_encoding::kResultAccuracyAttr);
writer.writeAPFloatWithKnownSemantics(attr.getAtol());
writer.writeAPFloatWithKnownSemantics(attr.getRtol());
writer.writeSignedVarInt(attr.getUlps());
writer.writeAttribute(attr.getMode());
}

void MhloBytecodeInterface::write(ArgResultAliasAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(mhlo_encoding::kArgResultAliasAttr);
Expand Down Expand Up @@ -790,6 +831,39 @@ void MhloBytecodeInterface::write(TokenType type,
writer.writeVarInt(mhlo_encoding::kTokenType);
}

//===----------------------------------------------------------------------===//
// ResultAccuracyModeAttr

ResultAccuracyModeAttr MhloBytecodeInterface::readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
return hlo::bytecode::readEnumAttribute<ResultAccuracyModeAttr>(
reader, getContext(),
[](uint32_t val) { return symbolizeResultAccuracyMode(val); });
}

//===----------------------------------------------------------------------===//
// ResultAccuracyAttr

ResultAccuracyAttr MhloBytecodeInterface::readResultAccuracyAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
FailureOr<APFloat> atol;
FailureOr<APFloat> rtol;
int64_t ulps = 0;
ResultAccuracyModeAttr mode;
if (failed(atol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(rtol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(reader.readSignedVarInt(ulps)) ||
failed(reader.readAttribute(mode))) {
reader.emitError() << "failed to read APFloat for atol";
return ResultAccuracyAttr();
}
return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode);
}

} // namespace

void addBytecodeInterface(MhloDialect *dialect) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,20 @@ Attribute convertDenseArray(mlir::StringAttr hloName, Attribute hloAttr) {
if (!stablehloValue.has_value()) return {}; \
return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value())

stablehlo::ResultAccuracyMode convertResultAccuracyMode(
mhlo::ResultAccuracyMode mode) {
switch (mode) {
case mhlo::ResultAccuracyMode::DEFAULT:
return stablehlo::ResultAccuracyMode::DEFAULT;
case mhlo::ResultAccuracyMode::HIGHEST:
return stablehlo::ResultAccuracyMode::HIGHEST;
case mhlo::ResultAccuracyMode::TOLERANCE:
return stablehlo::ResultAccuracyMode::TOLERANCE;
default:
return {};
}
}

Attribute convertAttr(Attribute hloAttr) {
// Handle MHLO attributes.
// The logic that handles attributes from other dialects (e.g. builtin
Expand Down Expand Up @@ -301,6 +315,19 @@ Attribute convertAttr(Attribute hloAttr) {
if (auto attr = mlir::dyn_cast<mhlo::TransposeAttr>(hloAttr)) {
RETURN_CONVERTED_ENUM_ATTR(Transpose);
}
if (auto attr = mlir::dyn_cast<mhlo::ResultAccuracyModeAttr>(hloAttr)) {
RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode);
}
if (auto attr = mlir::dyn_cast<mhlo::ResultAccuracyAttr>(hloAttr)) {
stablehlo::ResultAccuracyModeAttr modeAttr;
modeAttr = stablehlo::ResultAccuracyModeAttr::get(
attr.getContext(),
convertResultAccuracyMode(attr.getMode().getValue()));

return stablehlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(),
attr.getRtol(), attr.getUlps(),
modeAttr);
}
if (hloAttr.getDialect().getNamespace() ==
mhlo::MhloDialect::getDialectNamespace()) {
// Our guiding principle is to support all StableHLO functionality in MHLO.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ namespace {
if (!hloValue.has_value()) return {}; \
return mhlo::Name##Attr::get(attr.getContext(), hloValue.value())

mhlo::ResultAccuracyMode convertResultAccuracyMode(
stablehlo::ResultAccuracyMode mode) {
switch (mode) {
case stablehlo::ResultAccuracyMode::DEFAULT:
return mhlo::ResultAccuracyMode::DEFAULT;
case stablehlo::ResultAccuracyMode::HIGHEST:
return mhlo::ResultAccuracyMode::HIGHEST;
case stablehlo::ResultAccuracyMode::TOLERANCE:
return mhlo::ResultAccuracyMode::TOLERANCE;
default:
return {};
}
}
Attribute convertAttr(Attribute stablehloAttr) {
// StableHLO uses DenseArray for some attributes, MHLO is in the process
// of integrating this change. In the meantime, convert DenseArray to
Expand Down Expand Up @@ -140,6 +153,20 @@ Attribute convertAttr(Attribute stablehloAttr) {
if (auto attr = mlir::dyn_cast<stablehlo::TransposeAttr>(stablehloAttr)) {
RETURN_CONVERTED_ENUM_ATTR(Transpose);
}
if (auto attr =
mlir::dyn_cast<stablehlo::ResultAccuracyModeAttr>(stablehloAttr)) {
RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode);
}
if (auto attr =
mlir::dyn_cast<stablehlo::ResultAccuracyAttr>(stablehloAttr)) {
mhlo::ResultAccuracyModeAttr modeAttr = mhlo::ResultAccuracyModeAttr::get(
attr.getContext(),
convertResultAccuracyMode(attr.getMode().getValue()));

return mhlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(),
attr.getRtol(), attr.getUlps(),
modeAttr);
}
if (stablehloAttr.getDialect().getNamespace() ==
stablehlo::StablehloDialect::getDialectNamespace()) {
// Our guiding principle is to support all StableHLO functionality in MHLO.
Expand Down
Loading
Loading