From b3fc9ae301c0c7ae39e9f1f58e314fbbe8cc7559 Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Fri, 24 Jan 2025 16:10:28 -0800 Subject: [PATCH] Add result accuracy attribute to ExpOp in StableHlo. PiperOrigin-RevId: 719465588 --- third_party/stablehlo/temporary.patch | 4681 ++++++++++++++++- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 33 + xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 37 + xla/mlir_hlo/mhlo/IR/hlo_ops.td | 4 + xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td | 32 + xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td | 24 + xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc | 78 +- .../hlo_legalize_to_stablehlo.cc | 27 + .../stablehlo_legalize_to_hlo.cc | 27 + .../mhlo/hlo-legalize-to-stablehlo.mlir | 13 + xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 8 + .../mhlo/stablehlo-legalize-to-hlo.mlir | 12 + 12 files changed, 4827 insertions(+), 149 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index e9cad1d6ecfa0..dcea9111ece40 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,3 +1,15 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -1547,7 +1547,7 @@ + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", +- td_file = "stablehlo/dialect/VhloAttrs.td", ++ td_file = "stablehlo/dialect/VhloEnums.td", + deps = [ + ":vhlo_ops_td_files", + ], diff --ruN a/stablehlo/examples/c++/ExampleAdd.cpp b/stablehlo/examples/c++/ExampleAdd.cpp --- stablehlo/examples/c++/ExampleAdd.cpp +++ stablehlo/examples/c++/ExampleAdd.cpp @@ -56,6 +68,462 @@ diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeTo return success(); } }; +diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp +--- stablehlo/stablehlo/dialect/AssemblyFormat.cpp ++++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp +@@ -860,6 +860,29 @@ + return parser.parseSymbolName(target); + } + ++void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol, ++ int64_t ulps, Attribute mode) { ++ odsPrinter << "<"; ++ if (!atol.isZero()) { ++ odsPrinter << "atol = "; ++ odsPrinter.printFloat(atol); ++ odsPrinter << ", "; ++ } ++ if (!rtol.isZero()) { ++ odsPrinter << "rtol = "; ++ odsPrinter.printFloat(rtol); ++ odsPrinter << ", "; ++ } ++ if (ulps != 0) { ++ odsPrinter << "ulps = "; ++ odsPrinter << ulps; ++ odsPrinter << ", "; ++ } ++ odsPrinter << "mode = "; ++ odsPrinter.printAttribute(mode); ++ odsPrinter << ">"; ++} ++ + void printTypeExtensions(BoundedAttrInterface attr, DialectAsmPrinter& os) { + os << "bounds<"; + llvm::interleaveComma(attr.getBounds(), os, +diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/dialect/AssemblyFormat.h +--- stablehlo/stablehlo/dialect/AssemblyFormat.h ++++ stablehlo/stablehlo/dialect/AssemblyFormat.h +@@ -378,6 +378,65 @@ + return success(); + } + ++// ResultAccuracyAttr - Custom printing and parsing for ResultAccuracyAttr. ++// ++// ResultAccuractAttr ::= `<` OptAtolAccuracy OptRtolAccuracy ++// OptUlpAccuracy ModeAccuracy `>` ++// OptAtolAccuracy ::= `atol` `=` APFloat `, ` | eps ++// OptRtolAccuracy ::= `rtol` `=` APFloat `, ` | eps ++// OptUlpAccuracy ::= `ulps` `=` int64_t `, ` | eps ++// ModeAccuracy ::= `mode` `=` ResultAccuracyModeAttr ++void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol, ++ int64_t ulps, Attribute mode); ++ ++template ++Attribute parseResultAccuracyAttr(AsmParser& parser, Type type) { ++ APFloat resultAtol = APFloat::getZero(APFloat::IEEEdouble()); ++ APFloat resultRtol = APFloat::getZero(APFloat::IEEEdouble()); ++ int64_t resultUlps = 0; ++ ++ // Parse literal '<' ++ if (parser.parseLess()) return {}; ++ ++ // OptAtolAccuracy ++ if (succeeded(parser.parseOptionalKeyword("atol"))) { ++ double value; ++ if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma()) ++ return {}; ++ resultAtol = APFloat(value); ++ } ++ ++ // OptRtolAccuracy ++ if (succeeded(parser.parseOptionalKeyword("rtol"))) { ++ double value; ++ if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma()) ++ return {}; ++ resultRtol = APFloat(value); ++ } ++ ++ // OptUlpAccuracy ++ if (succeeded(parser.parseOptionalKeyword("ulps"))) { ++ int64_t value; ++ if (parser.parseEqual() || parser.parseInteger(value) || ++ parser.parseComma()) ++ return {}; ++ resultUlps = value; ++ } ++ ++ // ModeAccuracy ++ ModeTy modeAttr; ++ if (parser.parseKeyword("mode") || parser.parseEqual() || ++ parser.parseAttribute(modeAttr)) { ++ return {}; ++ } ++ ++ // Parse literal '>' ++ if (parser.parseGreater()) return {}; ++ return parser.getChecked( ++ parser.getCurrentLocation(), parser.getContext(), resultAtol, resultRtol, ++ resultUlps, modeAttr); ++} ++ + } // namespace hlo + } // namespace mlir + +diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/dialect/CMakeLists.txt +--- stablehlo/stablehlo/dialect/CMakeLists.txt ++++ stablehlo/stablehlo/dialect/CMakeLists.txt +@@ -190,7 +190,7 @@ + set(LLVM_TARGET_DEFINITIONS VhloOps.td) + mlir_tablegen(VhloAttrs.h.inc -gen-attrdef-decls) + mlir_tablegen(VhloAttrs.cpp.inc -gen-attrdef-defs) +-set(LLVM_TARGET_DEFINITIONS VhloAttrs.td) ++set(LLVM_TARGET_DEFINITIONS VhloEnums.td) + mlir_tablegen(VhloAttrInterfaces.h.inc -gen-attr-interface-decls) + mlir_tablegen(VhloAttrInterfaces.cpp.inc -gen-attr-interface-defs) + set(LLVM_TARGET_DEFINITIONS VhloTypes.td) +diff --ruN a/stablehlo/stablehlo/dialect/StablehloAttrs.td b/stablehlo/stablehlo/dialect/StablehloAttrs.td +--- stablehlo/stablehlo/dialect/StablehloAttrs.td ++++ stablehlo/stablehlo/dialect/StablehloAttrs.td +@@ -19,6 +19,7 @@ + + include "mlir/IR/OpBase.td" + include "mlir/IR/TensorEncoding.td" ++include "stablehlo/dialect/StablehloTypes.td" + + def StableHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> { + let parser = "parseDimSizes($_parser)"; +@@ -209,4 +210,18 @@ + let hasCustomAssemblyFormat = 1; + } + ++def StableHLO_ResultAccuracyAttr : AttrDef { ++ let mnemonic = "result_accuracy"; ++ let summary = "The requested accuracy for transcendental unary ops."; ++ let parameters = (ins ++ "APFloat":$atol, ++ "APFloat":$rtol, ++ "int64_t":$ulps, ++ StableHLO_ResultAccuracyModeAttr:$mode ++ ); ++ let hasCustomAssemblyFormat = 1; ++ let genVerifyDecl = 1; ++ let constBuilderCall = "ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ResultAccuracyModeAttr::get($_builder.getContext(), $0))"; ++} ++ + #endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS +diff --ruN a/stablehlo/stablehlo/dialect/StablehloBytecode.cpp b/stablehlo/stablehlo/dialect/StablehloBytecode.cpp +--- stablehlo/stablehlo/dialect/StablehloBytecode.cpp ++++ stablehlo/stablehlo/dialect/StablehloBytecode.cpp +@@ -18,6 +18,7 @@ + #include + #include + ++#include "llvm/ADT/APFloat.h" + #include "llvm/ADT/SmallVector.h" + #include "llvm/ADT/StringRef.h" + #include "llvm/ADT/TypeSwitch.h" +@@ -180,6 +181,18 @@ + /// allowImpreciseAccumulation : svarint + /// } + kDotAlgorithmAttr = 15, ++ ++ // ResultAccuracyModeAttr { ++ // mode: varint (encoded enum) ++ // } ++ kResultAccuracyModeAttr = 16, ++ ++ // ResultAccuracyAttr { ++ // atol: APFloat ++ // rtol: APFloat ++ // ulps: svarint ++ // } ++ kResultAccuracyAttr = 17, + }; + + /// This enum contains marker codes used to indicate which type is +@@ -241,6 +254,10 @@ + OutputOperandAliasAttr readOutputOperandAliasAttr( + DialectBytecodeReader &reader) const; + PrecisionAttr readPrecisionAttr(DialectBytecodeReader &reader) const; ++ ResultAccuracyAttr readResultAccuracyAttr( ++ DialectBytecodeReader &reader) const; ++ ResultAccuracyModeAttr readResultAccuracyModeAttr( ++ DialectBytecodeReader &reader) const; + RngAlgorithmAttr readRngAlgorithmAttr(DialectBytecodeReader &reader) const; + RngDistributionAttr readRngDistributionAttr( + DialectBytecodeReader &reader) const; +@@ -264,6 +281,8 @@ + DialectBytecodeWriter &writer) const; + void write(OutputOperandAliasAttr attr, DialectBytecodeWriter &writer) const; + void write(PrecisionAttr attr, DialectBytecodeWriter &writer) const; ++ void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const; ++ void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const; + void write(RngAlgorithmAttr attr, DialectBytecodeWriter &writer) const; + void write(RngDistributionAttr attr, DialectBytecodeWriter &writer) const; + void write(ScatterDimensionNumbersAttr attr, +@@ -327,6 +346,10 @@ + return readOutputOperandAliasAttr(reader); + case stablehlo_encoding::kPrecisionAttr: + return readPrecisionAttr(reader); ++ case stablehlo_encoding::kResultAccuracyAttr: ++ return readResultAccuracyAttr(reader); ++ case stablehlo_encoding::kResultAccuracyModeAttr: ++ return readResultAccuracyModeAttr(reader); + case stablehlo_encoding::kRngAlgorithmAttr: + return readRngAlgorithmAttr(reader); + case stablehlo_encoding::kRngDistributionAttr: +@@ -352,13 +375,13 @@ + .Case( +- [&](auto attr) { +- LOG_WRITE_CALL; +- write(attr, writer); +- return success(); +- }) ++ PrecisionAttr, ResultAccuracyAttr, ResultAccuracyModeAttr, ++ RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr, ++ TransposeAttr, TypeExtensionsAttr>([&](auto attr) { ++ LOG_WRITE_CALL; ++ write(attr, writer); ++ return success(); ++ }) + .Default([&](Attribute) { + LOG_NOT_IMPLEMENTED; + return failure(); +@@ -806,6 +829,55 @@ + } + } + ++//===----------------------------------------------------------------------===// ++// ResultAccuracyModeAttr ++ ++ResultAccuracyModeAttr StablehloBytecodeInterface::readResultAccuracyModeAttr( ++ DialectBytecodeReader &reader) const { ++ LOG_READ_CALL; ++ return hlo::bytecode::readEnumAttribute( ++ reader, getContext(), ++ [](uint32_t val) { return symbolizeResultAccuracyMode(val); }); ++} ++ ++void StablehloBytecodeInterface::write(ResultAccuracyModeAttr attr, ++ DialectBytecodeWriter &writer) const { ++ writer.writeVarInt(stablehlo_encoding::kResultAccuracyModeAttr); ++ hlo::bytecode::writeEnumAttribute(attr, writer); ++} ++ ++//===----------------------------------------------------------------------===// ++// ResultAccuracyAttr ++ ++ResultAccuracyAttr StablehloBytecodeInterface::readResultAccuracyAttr( ++ DialectBytecodeReader &reader) const { ++ LOG_READ_CALL; ++ FailureOr atol; ++ FailureOr rtol; ++ int64_t ulps; ++ ResultAccuracyModeAttr mode; ++ if (failed(atol = ++ reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || ++ failed(rtol = ++ reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || ++ failed(reader.readSignedVarInt(ulps)) || ++ failed(reader.readAttribute(mode))) { ++ mlir::emitWarning(mlir::UnknownLoc::get(getContext())) ++ << "failed to read APFloat for atol"; ++ return ResultAccuracyAttr(); ++ } ++ return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode); ++} ++ ++void StablehloBytecodeInterface::write(ResultAccuracyAttr attr, ++ DialectBytecodeWriter &writer) const { ++ writer.writeVarInt(stablehlo_encoding::kResultAccuracyAttr); ++ writer.writeAPFloatWithKnownSemantics(attr.getAtol()); ++ writer.writeAPFloatWithKnownSemantics(attr.getRtol()); ++ writer.writeSignedVarInt(attr.getUlps()); ++ writer.writeAttribute(attr.getMode()); ++} ++ + } // namespace + + void addBytecodeInterface(StablehloDialect *dialect) { +diff --ruN a/stablehlo/stablehlo/dialect/StablehloEnums.td b/stablehlo/stablehlo/dialect/StablehloEnums.td +--- stablehlo/stablehlo/dialect/StablehloEnums.td ++++ stablehlo/stablehlo/dialect/StablehloEnums.td +@@ -45,6 +45,29 @@ + // TODO(b/129153247) See if it's possible to also validate the size. + def StableHLO_PrecisionConfigAttr: + TypedArrayAttrBase; ++ ++//===----------------------------------------------------------------------===// ++// Result Accuracy enum definitions. ++//===----------------------------------------------------------------------===// ++ ++def STABLEHLO_RESULT_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; ++def STABLEHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>; ++def STABLEHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>; ++ ++def StableHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode", ++ "XLA result accuracy mode.", ++ [ ++ STABLEHLO_RESULT_ACCURACY_DEFAULT, ++ STABLEHLO_RESULT_ACCURACY_HIGHEST, ++ STABLEHLO_RESULT_ACCURACY_TOLERANCE ++ ]> { ++ let genSpecializedAttr = 0; ++ let cppNamespace = "::mlir::stablehlo"; ++} ++ ++def StableHLO_ResultAccuracyModeAttr : EnumAttr { ++ let assemblyFormat = "`<` $value `>`"; ++} + + //===----------------------------------------------------------------------===// + // Fast Fourier Transform Type enum definitions. +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/dialect/StablehloOps.cpp +@@ -792,6 +792,29 @@ + allowImpreciseAccumulation); + } + ++// ===----------------------------------------------------------------------===// ++// ExpOp ++//===----------------------------------------------------------------------===// ++ ++LogicalResult ResultAccuracyAttr::verify( ++ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, ++ APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) { ++ return hlo::verifyResultAccuracyAttr( ++ emitError, atol, rtol, ulps, ++ stringifyResultAccuracyMode(mode.getValue())); ++} ++ ++LogicalResult ExpOp::verify() { ++ if (auto attr = getResultAccuracyAttr()) { ++ if (failed(ResultAccuracyAttr::verify([&] { return emitError(); }, ++ attr.getAtol(), attr.getRtol(), ++ attr.getUlps(), attr.getMode()))) { ++ return failure(); ++ } ++ } ++ return success(); ++} ++ + //===----------------------------------------------------------------------===// + // FftOp + //===----------------------------------------------------------------------===// +@@ -3127,6 +3150,20 @@ + lhsContractingDimensions, rhsContractingDimensions); + } + ++// ===----------------------------------------------------------------------===// ++// Custom unary op ++// ===----------------------------------------------------------------------===// ++ ++void ResultAccuracyAttr::print(AsmPrinter& odsPrinter) const { ++ hlo::printResultAccuracyAttr(odsPrinter, getAtol(), getRtol(), getUlps(), ++ getMode()); ++} ++ ++Attribute ResultAccuracyAttr::parse(AsmParser& parser, Type type) { ++ return hlo::parseResultAccuracyAttr(parser, type); ++} ++ + namespace { + enum NonSpatialDim : int64_t { + IOBatch = -1, // Input or output batch dimension +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td +--- stablehlo/stablehlo/dialect/StablehloOps.td ++++ stablehlo/stablehlo/dialect/StablehloOps.td +@@ -327,6 +327,23 @@ + ```mlir + %result = stablehlo.exponential %operand : tensor<2x2xf64> + ``` ++ }]; ++ let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand, ++ DefaultValuedOptionalAttr:$result_accuracy); ++ let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result); ++ let extraClassDeclaration = commonClassDeclaration # [{ ++ LogicalResult reifyReturnTypeShapes( ++ OpBuilder& builder, ValueRange operands, ++ SmallVectorImpl& reifiedReturnShapes) { ++ return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), ++ operands.front(), ++ &reifiedReturnShapes); ++ } ++ }]; ++ let hasVerifier = 1; ++ ++ let assemblyFormat = [{ ++ $operand attr-dict `:` custom(type($operand), type($result)) + }]; + } + +diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp +--- stablehlo/stablehlo/dialect/TypeInference.cpp ++++ stablehlo/stablehlo/dialect/TypeInference.cpp +@@ -5057,5 +5057,30 @@ + return success(); + } + ++LogicalResult verifyResultAccuracyCombination( ++ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, ++ APFloat rtol, int64_t ulps, StringRef mode) { ++ if (mode == "DEFAULT" || mode == "HIGHEST") { ++ bool all_zero = atol.isZero() && rtol.isZero() && ulps == 0; ++ if (!all_zero) { ++ return emitError() ++ << "Invalid tolerances for ResultAccuracyAttr with mode " << mode ++ << ", must be all zero."; ++ } ++ } ++ return success(); ++} ++ ++LogicalResult verifyResultAccuracyAttr( ++ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, ++ APFloat rtol, int64_t ulps, StringRef mode) { ++ if (atol.isNegative() || rtol.isNegative() || ulps < 0) ++ return emitError() << "Negative tolerance"; ++ if (failed( ++ verifyResultAccuracyCombination(emitError, atol, rtol, ulps, mode))) ++ return failure(); ++ return success(); ++} ++ + } // end namespace hlo + } // end namespace mlir +diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.h b/stablehlo/stablehlo/dialect/TypeInference.h +--- stablehlo/stablehlo/dialect/TypeInference.h ++++ stablehlo/stablehlo/dialect/TypeInference.h +@@ -26,6 +26,7 @@ + #include "mlir/IR/SymbolTable.h" + #include "mlir/IR/Types.h" + #include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LLVM.h" + #include "mlir/Support/LogicalResult.h" + #include "stablehlo/dialect/Base.h" + +@@ -596,6 +597,14 @@ + + LogicalResult verifyWhileOp(std::optional location, + ValueRange operand, Region& cond, Region& body); ++ ++LogicalResult verifyResultAccuracyCombination( ++ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, ++ APFloat rtol, int64_t ulps, StringRef mode); ++ ++LogicalResult verifyResultAccuracyAttr( ++ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, ++ APFloat rtol, int64_t ulps, StringRef mode); + } // end namespace hlo + } // end namespace mlir + diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp --- stablehlo/stablehlo/dialect/Version.cpp +++ stablehlo/stablehlo/dialect/Version.cpp @@ -68,6 +536,250 @@ diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/diale // The time frames used are from the date that the release was tagged on, not // merged. The tag date is when the version has been verified and exported to // XLA. See: https://github.com/openxla/stablehlo/tags +diff --ruN a/stablehlo/stablehlo/dialect/Version.h b/stablehlo/stablehlo/dialect/Version.h +--- stablehlo/stablehlo/dialect/Version.h ++++ stablehlo/stablehlo/dialect/Version.h +@@ -38,7 +38,7 @@ + static FailureOr fromString(llvm::StringRef versionRef); + + /// Return a Version representing the current VHLO dialect version. +- static Version getCurrentVersion() { return Version(1, 8, 10); } ++ static Version getCurrentVersion() { return Version(1, 9, 0); } + + /// Return a Version representing the minimum supported VHLO dialect version. + static Version getMinimumVersion() { return Version(0, 9, 0); } +diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dialect/VhloAttrs.td +--- stablehlo/stablehlo/dialect/VhloAttrs.td ++++ stablehlo/stablehlo/dialect/VhloAttrs.td +@@ -21,18 +21,8 @@ + include "stablehlo/dialect/VhloBase.td" + include "stablehlo/dialect/VhloDialect.td" + include "stablehlo/dialect/VhloTypes.td" +- +-def VHLO_VersionedAttrInterface : AttrInterface<"VersionedAttrInterface"> { +- let cppNamespace = "::mlir::vhlo"; +- let methods = [ +- InterfaceMethod< +- "Returns the minimum version of the VHLO dialect an attribute is supported in.", +- "mlir::vhlo::Version", "getMinVersion">, +- InterfaceMethod< +- "Returns the maximum version (inclusive) of the VHLO dialect an attribute is supported in.", +- "mlir::vhlo::Version", "getMaxVersion">, +- ]; +-} ++include "stablehlo/dialect/VhloEnums.td" ++ + + class VHLO_AttrDef + : AttrDef { +@@ -190,4 +180,27 @@ + let assemblyFormat = "`<` struct(params) `>`"; + } + ++ ++def VHLO_ResultAccuracyAttrV1 : VHLO_AttrDef<"ResultAccuracyV1", "1.9.0", "current"> { ++ let mnemonic = "result_accuracy_v1"; ++ let summary = "The requested accuracy for transcendental unary ops."; ++ let parameters = (ins ++ VHLO_APFloatV1:$atol, ++ VHLO_APFloatV1:$rtol, ++ "int64_t":$ulps, ++ "mlir::Attribute":$mode ++ ); ++ let assemblyFormat = "`<` struct(params) `>`"; ++ let genVerifyDecl = 1; ++ let extraClassDefinition = [{ ++ LogicalResult ResultAccuracyV1Attr::verify( ++ llvm::function_ref errFn, ++ APFloat atol, APFloat rtol, int64_t ulps, ++ mlir::Attribute mode) { ++ if (!isFromVhlo(mode)) return errFn() << "expected VHLO result accuracy mode"; ++ return success(); ++ } ++ }]; ++} ++ + #endif // STABLEHLO_DIALECT_VHLO_ATTRS +diff --ruN a/stablehlo/stablehlo/dialect/VhloBytecode.cpp b/stablehlo/stablehlo/dialect/VhloBytecode.cpp +--- stablehlo/stablehlo/dialect/VhloBytecode.cpp ++++ stablehlo/stablehlo/dialect/VhloBytecode.cpp +@@ -178,6 +178,18 @@ + /// bounds : svarint[] + /// } + kTypeExtensionsV1Attr = 18, ++ ++ // ResultAccuracyModeV1Attr { ++ // mode: varint (encoded enum) ++ // } ++ kResultAccuracyModeV1Attr = 19, ++ ++ // ResultAccuracyV1Attr { ++ // atol: APFloat ++ // rtol: APFloat ++ // ulps: svarint ++ // } ++ kResultAccuracyV1Attr = 20, + }; + + /// This enum contains marker codes used to indicate which type is +@@ -433,6 +445,10 @@ + TypeV1Attr readTypeV1Attr(DialectBytecodeReader &reader) const; + TypeExtensionsV1Attr readTypeExtensionsV1Attr( + DialectBytecodeReader &reader) const; ++ ResultAccuracyModeV1Attr readResultAccuracyModeV1Attr( ++ DialectBytecodeReader &reader) const; ++ ResultAccuracyV1Attr readResultAccuracyV1Attr( ++ DialectBytecodeReader &reader) const; + + // TO ADD ATTRIBUTE: Include a write method for each attribute in VHLO + // Ex: void write(SomeAttr attr, DialectBytecodeWriter &writer) const; +@@ -457,6 +473,9 @@ + void write(TransposeV1Attr attr, DialectBytecodeWriter &writer) const; + void write(TypeV1Attr attr, DialectBytecodeWriter &writer) const; + void write(TypeExtensionsV1Attr attr, DialectBytecodeWriter &writer) const; ++ void write(ResultAccuracyModeV1Attr attr, ++ DialectBytecodeWriter &writer) const; ++ void write(ResultAccuracyV1Attr attr, DialectBytecodeWriter &writer) const; + + //===--------------------------------------------------------------------===// + // Types +@@ -541,6 +560,10 @@ + return readTypeV1Attr(reader); + case vhlo_encoding::kTypeExtensionsV1Attr: + return readTypeExtensionsV1Attr(reader); ++ case vhlo_encoding::kResultAccuracyModeV1Attr: ++ return readResultAccuracyModeV1Attr(reader); ++ case vhlo_encoding::kResultAccuracyV1Attr: ++ return readResultAccuracyV1Attr(reader); + default: + reader.emitError() << "unknown vhlo attribute code: " << code; + return Attribute(); +@@ -558,7 +581,8 @@ + FftTypeV1Attr, FloatV1Attr, IntegerV1Attr, OutputOperandAliasV1Attr, + PrecisionV1Attr, RngAlgorithmV1Attr, RngDistributionV1Attr, + StringV1Attr, TensorV1Attr, TransposeV1Attr, TypeV1Attr, +- TypeExtensionsV1Attr>([&](auto attr) { ++ TypeExtensionsV1Attr, ResultAccuracyV1Attr, ++ ResultAccuracyModeV1Attr>([&](auto attr) { + LOG_WRITE_CALL; + write(attr, writer); + return success(); +@@ -1450,6 +1474,55 @@ + writer.writeType(type.getElementType()); + } + ++//===----------------------------------------------------------------------===// ++// ResultAccuracyModeAttr ++ ++ResultAccuracyModeV1Attr VhloBytecodeInterface::readResultAccuracyModeV1Attr( ++ DialectBytecodeReader &reader) const { ++ LOG_READ_CALL; ++ return hlo::bytecode::readEnumAttribute( ++ reader, getContext(), ++ [](uint32_t val) { return symbolizeResultAccuracyModeV1(val); }); ++} ++ ++void VhloBytecodeInterface::write(ResultAccuracyModeV1Attr attr, ++ DialectBytecodeWriter &writer) const { ++ writer.writeVarInt(vhlo_encoding::kResultAccuracyModeV1Attr); ++ hlo::bytecode::writeEnumAttribute(attr, writer); ++} ++ ++//===----------------------------------------------------------------------===// ++// ResultAccuracyAttr ++ ++ResultAccuracyV1Attr VhloBytecodeInterface::readResultAccuracyV1Attr( ++ DialectBytecodeReader &reader) const { ++ LOG_READ_CALL; ++ FailureOr atol; ++ FailureOr rtol; ++ int64_t ulps; ++ ResultAccuracyModeV1Attr mode; ++ if (failed(atol = ++ reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || ++ failed(rtol = ++ reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || ++ failed(reader.readSignedVarInt(ulps)) || ++ failed(reader.readAttribute(mode))) { ++ mlir::emitWarning(mlir::UnknownLoc::get(getContext())) ++ << "failed to read APFloat for atol"; ++ return ResultAccuracyV1Attr(); ++ } ++ return ResultAccuracyV1Attr::get(getContext(), *atol, *rtol, ulps, mode); ++} ++ ++void VhloBytecodeInterface::write(ResultAccuracyV1Attr attr, ++ DialectBytecodeWriter &writer) const { ++ writer.writeVarInt(vhlo_encoding::kResultAccuracyV1Attr); ++ writer.writeAPFloatWithKnownSemantics(attr.getAtol()); ++ writer.writeAPFloatWithKnownSemantics(attr.getRtol()); ++ writer.writeSignedVarInt(attr.getUlps()); ++ writer.writeAttribute(attr.getMode()); ++} ++ + } // namespace + + void addBytecodeInterface(VhloDialect *dialect) { +diff --ruN a/stablehlo/stablehlo/dialect/VhloDialect.td b/stablehlo/stablehlo/dialect/VhloDialect.td +--- stablehlo/stablehlo/dialect/VhloDialect.td ++++ stablehlo/stablehlo/dialect/VhloDialect.td +@@ -47,6 +47,7 @@ + 1.6.0: Add DotAlgorithm specificaiton to `dot_general`. + 1.7.0: Introduce `f8E4M3` and `f8E3M4` types. + 1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types. ++ 1.9.0: Add `ResultAccuracy` attribute to `exp` op. + }]; + + let useDefaultAttributePrinterParser = 0; +diff --ruN a/stablehlo/stablehlo/dialect/VhloEnums.td b/stablehlo/stablehlo/dialect/VhloEnums.td +--- stablehlo/stablehlo/dialect/VhloEnums.td ++++ stablehlo/stablehlo/dialect/VhloEnums.td +@@ -20,7 +20,20 @@ + include "mlir/IR/EnumAttr.td" + include "mlir/IR/PatternBase.td" + include "stablehlo/dialect/VhloBase.td" +-include "stablehlo/dialect/VhloAttrs.td" ++include "stablehlo/dialect/VhloDialect.td" ++include "mlir/IR/AttrTypeBase.td" ++ ++def VHLO_VersionedAttrInterface : AttrInterface<"VersionedAttrInterface"> { ++ let cppNamespace = "::mlir::vhlo"; ++ let methods = [ ++ InterfaceMethod< ++ "Returns the minimum version of the VHLO dialect an attribute is supported in.", ++ "mlir::vhlo::Version", "getMinVersion">, ++ InterfaceMethod< ++ "Returns the maximum version (inclusive) of the VHLO dialect an attribute is supported in.", ++ "mlir::vhlo::Version", "getMaxVersion">, ++ ]; ++} + + class VHLO_I32EnumAttr cases> : + I32EnumAttr { +@@ -198,4 +211,23 @@ + def VHLO_TransposeAttrV1 + : VHLO_EnumAttr; + ++//===----------------------------------------------------------------------===// ++// ResultAccuracyMode ++//===----------------------------------------------------------------------===// ++ ++def VHLO_RESULT_V1_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; ++def VHLO_RESULT_V1_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>; ++def VHLO_RESULT_V1_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>; ++ ++def VHLO_ResultAccuracyModeV1 : VHLO_I32EnumAttr<"ResultAccuracyModeV1", ++ [ ++ VHLO_RESULT_V1_ACCURACY_DEFAULT, ++ VHLO_RESULT_V1_ACCURACY_HIGHEST, ++ VHLO_RESULT_V1_ACCURACY_TOLERANCE ++ ]> {} ++ ++def VHLO_ResultAccuracyModeV1Attr ++ : VHLO_EnumAttr; ++ ++ + #endif // STABLEHLO_DIALECT_VHLO_ENUMS diff --ruN a/stablehlo/stablehlo/dialect/VhloOps.cpp b/stablehlo/stablehlo/dialect/VhloOps.cpp --- stablehlo/stablehlo/dialect/VhloOps.cpp +++ stablehlo/stablehlo/dialect/VhloOps.cpp @@ -109,6 +821,243 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloOps.cpp b/stablehlo/stablehlo/diale } // Parse tensor elements using DenseIntOrFPElementsAttr printing. +diff --ruN a/stablehlo/stablehlo/dialect/VhloOps.td b/stablehlo/stablehlo/dialect/VhloOps.td +--- stablehlo/stablehlo/dialect/VhloOps.td ++++ stablehlo/stablehlo/dialect/VhloOps.td +@@ -618,8 +618,15 @@ + let results = (outs VHLO_AnyType:$result); + } + +-def VHLO_ExpOpV1 : VHLO_Op<"exponential_v1", "0.9.0", "current"> { +- let arguments = (ins VHLO_AnyType:$operand); ++def VHLO_ExpOpV1 : VHLO_Op<"exponential_v1", "0.9.0", "1.8.0"> { ++ let arguments = (ins VHLO_AnyType:$operand); ++ let results = (outs VHLO_AnyType:$result); ++} ++ ++def VHLO_ExpOpV2 : VHLO_Op<"exponential_v2", "1.9.0", "current"> { ++ let arguments = (ins ++ VHLO_AnyType:$operand, ++ VHLO_AnyAttr:$result_accuracy); + let results = (outs VHLO_AnyType:$result); + } + +diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp b/stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp +--- stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp ++++ stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp +@@ -16,6 +16,7 @@ + #include + #include + ++#include "llvm/ADT/APFloat.h" + #include "llvm/ADT/ArrayRef.h" + #include "llvm/Support/Casting.h" + #include "llvm/Support/ErrorHandling.h" +@@ -687,3 +688,69 @@ + return llvm::cast(unwrap(attr)) + .getBounds()[pos]; + } ++ ++//===----------------------------------------------------------------------===// ++// ResultAccuracyModeAttr ++//===----------------------------------------------------------------------===// ++ ++MlirAttribute stablehloResultAccuracyModeAttrGet(MlirContext ctx, ++ MlirStringRef value) { ++ std::optional accuracyMode = ++ mlir::stablehlo::symbolizeResultAccuracyMode(unwrap(value)); ++ if (!accuracyMode) llvm::report_fatal_error("Invalid value."); ++ return wrap(mlir::stablehlo::ResultAccuracyModeAttr::get( ++ unwrap(ctx), accuracyMode.value())); ++} ++ ++bool stablehloAttributeIsAResultAccuracyModeAttr(MlirAttribute attr) { ++ return llvm::isa(unwrap(attr)); ++} ++ ++MlirStringRef stablehloResultAccuracyModeAttrGetValue(MlirAttribute attr) { ++ return wrap(mlir::stablehlo::stringifyResultAccuracyMode( ++ llvm::cast(unwrap(attr)) ++ .getValue())); ++} ++//===----------------------------------------------------------------------===// ++// ResultAccuracyAttr ++//===----------------------------------------------------------------------===// ++ ++MlirAttribute stablehloResultAccuracyAttrGet(MlirContext ctx, double atol, ++ double rtol, int64_t ulps, ++ MlirStringRef mode) { ++ std::optional accuracyMode = ++ mlir::stablehlo::symbolizeResultAccuracyMode(unwrap(mode)); ++ if (!accuracyMode) llvm::report_fatal_error("Invalid value."); ++ mlir::stablehlo::ResultAccuracyModeAttr modeAttr = ++ mlir::stablehlo::ResultAccuracyModeAttr::get(unwrap(ctx), ++ accuracyMode.value()); ++ return wrap(mlir::stablehlo::ResultAccuracyAttr::get( ++ unwrap(ctx), llvm::APFloat(atol), llvm::APFloat(rtol), ulps, modeAttr)); ++} ++ ++bool stablehloAttributeIsAResultAccuracyAttr(MlirAttribute attr) { ++ return llvm::isa(unwrap(attr)); ++} ++ ++double stablehloResultAccuracyAttrGetAtol(MlirAttribute attr) { ++ llvm::APFloat result = ++ llvm::cast(unwrap(attr)).getAtol(); ++ return result.convertToDouble(); ++} ++ ++double stablehloResultAccuracyAttrGetRtol(MlirAttribute attr) { ++ llvm::APFloat result = ++ llvm::cast(unwrap(attr)).getRtol(); ++ return result.convertToDouble(); ++} ++ ++int64_t stablehloResultAccuracyAttrGetUlps(MlirAttribute attr) { ++ return llvm::cast(unwrap(attr)) ++ .getUlps(); ++} ++ ++MlirAttribute stablehloResultAccuracyAttrGetMode(MlirAttribute attr) { ++ mlir::stablehlo::ResultAccuracyModeAttr modeAttr = ++ llvm::cast(unwrap(attr)).getMode(); ++ return wrap(modeAttr); ++} +diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloAttributes.h b/stablehlo/stablehlo/integrations/c/StablehloAttributes.h +--- stablehlo/stablehlo/integrations/c/StablehloAttributes.h ++++ stablehlo/stablehlo/integrations/c/StablehloAttributes.h +@@ -13,6 +13,7 @@ + #ifndef STABLEHLO_INTEGRATIONS_C_STABLEHLO_ATTRIBUTES_H + #define STABLEHLO_INTEGRATIONS_C_STABLEHLO_ATTRIBUTES_H + ++#include + #include + #include + +@@ -376,6 +377,42 @@ + MLIR_CAPI_EXPORTED int64_t + stablehloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos); + ++// ===---------------------------------------------------------------------===// ++// ResultAccuracyModeAttr ++//===----------------------------------------------------------------------===// ++ ++MLIR_CAPI_EXPORTED MlirAttribute ++stablehloResultAccuracyModeAttrGet(MlirContext ctx, MlirStringRef value); ++ ++MLIR_CAPI_EXPORTED bool stablehloAttributeIsAResultAccuracyModeAttr( ++ MlirAttribute attr); ++ ++MLIR_CAPI_EXPORTED MlirStringRef ++stablehloResultAccuracyModeAttrGetValue(MlirAttribute attr); ++ ++// ===---------------------------------------------------------------------===// ++// ResultAccuracyAttr ++//===----------------------------------------------------------------------===// ++ ++MLIR_CAPI_EXPORTED MlirAttribute ++stablehloResultAccuracyAttrGet(MlirContext ctx, double atol, double rtol, ++ int64_t ulps, MlirStringRef value); ++ ++MLIR_CAPI_EXPORTED bool stablehloAttributeIsAResultAccuracyAttr( ++ MlirAttribute attr); ++ ++MLIR_CAPI_EXPORTED double stablehloResultAccuracyAttrGetAtol( ++ MlirAttribute attr); ++ ++MLIR_CAPI_EXPORTED double stablehloResultAccuracyAttrGetRtol( ++ MlirAttribute attr); ++ ++MLIR_CAPI_EXPORTED int64_t ++stablehloResultAccuracyAttrGetUlps(MlirAttribute attr); ++ ++MLIR_CAPI_EXPORTED MlirAttribute ++stablehloResultAccuracyAttrGetMode(MlirAttribute attr); ++ + #ifdef __cplusplus + } + #endif +diff --ruN a/stablehlo/stablehlo/integrations/python/StablehloModule.cpp b/stablehlo/stablehlo/integrations/python/StablehloModule.cpp +--- stablehlo/stablehlo/integrations/python/StablehloModule.cpp ++++ stablehlo/stablehlo/integrations/python/StablehloModule.cpp +@@ -599,6 +599,50 @@ + stablehloTypeExtensionsGetBoundsElem); + }); + ++ mlir::python::nanobind_adaptors::mlir_attribute_subclass( ++ m, "ResultAccuracyAttr", stablehloAttributeIsAResultAccuracyAttr) ++ .def_classmethod( ++ "get", ++ [](nb::object cls, double atol, double rtol, int64_t ulps, ++ const std::string &mode, MlirContext ctx) { ++ return cls(stablehloResultAccuracyAttrGet( ++ ctx, atol, rtol, ulps, ++ mlirStringRefCreate(mode.c_str(), mode.size()))); ++ }, ++ nb::arg("cls"), nb::arg("atol"), nb::arg("rtol"), nb::arg("ulps"), ++ nb::arg("mode"), nb::arg("context") = nb::none(), ++ "Creates a ResultAccuracyAttr with the given values.") ++ .def_property_readonly("atol", ++ [](MlirAttribute self) { ++ return stablehloResultAccuracyAttrGetAtol(self); ++ }) ++ .def_property_readonly("rtol", ++ [](MlirAttribute self) { ++ return stablehloResultAccuracyAttrGetRtol(self); ++ }) ++ .def_property_readonly("ulps", ++ [](MlirAttribute self) { ++ return stablehloResultAccuracyAttrGetUlps(self); ++ }) ++ .def_property_readonly("mode", [](MlirAttribute self) { ++ return toPyString(stablehloResultAccuracyModeAttrGetValue( ++ stablehloResultAccuracyAttrGetMode(self))); ++ }); ++ ++ mlir::python::nanobind_adaptors::mlir_attribute_subclass( ++ m, "ResultAccuracyModeAttr", stablehloAttributeIsAResultAccuracyModeAttr) ++ .def_classmethod( ++ "get", ++ [](nb::object cls, const std::string &value, MlirContext ctx) { ++ return cls(stablehloResultAccuracyModeAttrGet( ++ ctx, mlirStringRefCreate(value.c_str(), value.size()))); ++ }, ++ nb::arg("cls"), nb::arg("value"), nb::arg("context") = nb::none(), ++ "Creates a ResultAccuracyModeAttr with the given values.") ++ .def_property_readonly("value", [](MlirAttribute self) { ++ return toPyString(stablehloResultAccuracyModeAttrGetValue(self)); ++ }); ++ + // + // StableHLO APIs + // +diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py +--- stablehlo/stablehlo/integrations/python/tests/stablehlo.py ++++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py +@@ -386,3 +386,23 @@ + cloned_module = module.operation.clone() + pipeline.run(cloned_module.operation) + assert str(module) == str(cloned_module) ++ ++ ++@run ++def test_result_accuracy_attr_default(): ++ attr = stablehlo.ResultAccuracyAttr.get(atol=0, rtol=0, ulps=0, mode="DEFAULT") ++ assert attr is not None ++ assert attr.mode == "DEFAULT" ++ assert attr.atol == 0 ++ assert attr.rtol == 0 ++ assert attr.ulps == 0 ++ ++@run ++def test_result_accuracy_attr_tolerance(): ++ attr = stablehlo.ResultAccuracyAttr.get(atol=1e-5, rtol=1.0, ++ ulps=2, mode="TOLERANCE") ++ assert attr is not None ++ assert attr.mode == "TOLERANCE" ++ assert attr.atol == 1e-5 ++ assert attr.rtol == 1.0 ++ assert attr.ulps == 2 diff --ruN a/stablehlo/stablehlo/reference/Types.cpp b/stablehlo/stablehlo/reference/Types.cpp --- stablehlo/stablehlo/reference/Types.cpp +++ stablehlo/stablehlo/reference/Types.cpp @@ -276,160 +1225,3332 @@ diff --ruN a/stablehlo/stablehlo/tests/interpret/chlo/ragged_dot.mlir b/stablehl + func.return +} + -+// ----- ++// ----- ++ ++func.func @ragged_dot_mode_1_batching() { ++ %lhs = stablehlo.constant dense<[ ++ [ ++ [ -0.0999976546, -0.0605386607, 0.126681596, 0.0375950411, 0.0598301813 ], ++ [ -0.0343122408, -0.0858866125, 0.103659429, 0.103788935, 0.180407882 ], ++ [ 0.0150506198, 0.055824928, 0.149289608, -0.0896283686, -0.0839615092 ], ++ [ 0.0589100644, 0.101344816, -0.097690545, 0.0150246918, -0.0799473301 ], ++ [ 0.0252457932, 0.106031813, 0.076692991, 0.179130971, 0.153850079 ], ++ [ 0.0580786392, -0.0724105313, 0.0961757079, 0.0247998089, 0.110357188 ], ++ [ 0.173096269, 0.128659427, -0.0212640986, -0.0857606456, 0.120824583 ], ++ [ -0.00152973086, 0.0897915736, 0.126923144, 0.197311223, 0.00960160792 ], ++ [ -0.0258883312, 0.194765091, 0.11679814, 0.126006752, 0.0954555795 ], ++ [ -0.0781942382, 0.0894904211, 0.165412158, -0.0181870088, 0.0309234336 ], ++ [ 0.129948437, 0.0433195308, -0.028667666, -0.0175279453, 0.00777949393 ] ++ ], ++ [ ++ [ -0.0500478409, 0.0459552184, 0.16929689, 0.172762454, -0.0818307 ], ++ [ 0.171395928, 0.0513568744, 0.0548876, -0.00429011881, 0.195992649 ], ++ [ 0.0481930152, -0.0201566443, -0.0727801323, 0.184329301, -0.0778752789 ], ++ [ 0.0502121374, 0.0152426511, -0.0168754607, 0.174145252, 0.0589242205 ], ++ [ 0.0393337533, 0.182294011, -0.0849748, 0.128454268, 0.131061375 ], ++ [ 0.148345202, -0.0623903871, -0.0952396914, 0.10653659, 0.160474151 ], ++ [ 0.0888630375, 0.120867364, 0.117623605, 0.199837387, 0.166571677 ], ++ [ -0.0300415382, -0.00810345262, 0.00530457497, 0.0539821163, 0.0773340687 ], ++ [ 0.153794467, 0.0236242339, 0.152453214, -0.0192048177, 0.0246183872 ], ++ [ 0.0611911938, 0.0403752252, -0.013836287, -0.0465016849, -0.053884007 ], ++ [ 0.0714964494, 0.140721709, -0.0900838748, 0.0603349432, 0.0495440438 ] ++ ]]> : tensor<2x11x5xf32> ++ %rhs = stablehlo.constant dense<[ ++ [ ++ [ ++ [ 0.186608255, 0.124487795, 0.0663751587, 0.167221248, 0.0874548, 0.152611881, -0.0520697422 ], ++ [ -0.0361745432, 0.114412986, -0.0608718246, -0.0727029, -0.0176235586, -0.0991001204, 0.0242879838 ], ++ [ -0.0919371173, 0.112945892, 0.181369215, -0.0280267522, -0.0457312278, -0.00473813713, 0.166097224 ], ++ [ 0.0956176, -0.0548994839, 0.104403876, 0.0157444105, 0.0163175985, 0.0499223098, -0.0557401 ], ++ [ 0.076156, 0.153672695, 0.0770325884, 0.186622649, 0.066843845, -0.0555545315, 0.194991559 ] ++ ], ++ [ ++ [ 0.0226300061, -0.0574540682, 0.0694696084, -0.0243620798, 0.0465543643, 0.0392091647, 0.188328564 ], ++ [ -0.0621907599, -0.0400728397, -0.0042250976, 0.0887807682, -0.0619863532, 0.0953761414, 0.0864902064 ], ++ [ 0.140921891, -0.0256474689, 0.0429295525, 0.0167942569, -0.0390249, -0.0914874449, 0.170502067 ], ++ [ 0.0279492214, -0.0573936924, 0.184246033, 0.0230939165, -0.060643442, 0.165694535, -0.0723479092 ], ++ [ -0.051340431, -0.0786809325, 0.00960171223, -0.0240827873, -0.059467189, 0.134945959, 0.0365921929 ] ++ ] ++ ], ++ [ ++ [ ++ [ 0.00485724211, 0.0356900468, 0.142683387, 0.179502338, 0.0954938307, -0.0354254842, 0.103877716 ], ++ [ 0.172676593, -0.0249623209, 0.158257961, 0.0413787, 0.0517867729, 0.0801181123, 0.14526847 ], ++ [ 0.126753062, 0.0386734977, 0.185410261, 0.0898216143, 0.0317991, 0.14740923, 0.106694289 ], ++ [ 0.110662006, 0.196143657, 0.186324477, 0.155380905, -0.0132051334, 0.0612277314, 0.054330416 ], ++ [ -0.0689698234, 0.0242085531, 0.073015, 0.162969738, 0.0320116058, 0.118924297, 0.160779119 ] ++ ], ++ [ ++ [ 0.11469271, 0.140216112, 0.111960642, 0.122514777, -0.0942722782, 0.165809333, 0.0574962273 ], ++ [ 0.0389968231, -0.08044184, 0.114026703, 0.0466829464, 0.100303732, 0.104614742, -0.0401335768 ], ++ [ 0.174990177, 0.159764826, 0.167005628, 0.0631844923, -0.0582415, 0.0351042375, 0.196808755 ], ++ [ -0.035340406, 0.0338070318, -0.00528027117, 0.0543978438, 0.164451241, 0.0319176689, 0.0402595326 ], ++ [ 0.141994983, 0.00954742, -0.0365443081, 0.199735016, -0.053918656, 0.0891464874, 0.0849051103 ] ++ ] ++ ], ++ [ ++ [ ++ [ -0.0998214856, -0.0997363, 0.132005602, 0.118200503, -0.00424671918, 0.025317125, 0.104748271 ], ++ [ 0.104168601, -0.0384214334, 0.150926, 0.112676181, 0.14861238, -0.071635358, -0.0754787177 ], ++ [ 0.129201442, 0.088871561, -0.0358443409, -0.0359359607, -0.0756817609, 0.0166469738, 0.185647905 ], ++ [ 0.184263527, 0.0169560835, -0.0192355737, 0.10765069, -0.0147894919, 0.13305977, 0.135159582 ], ++ [ 0.0267379507, -0.0153532401, -0.0418097563, -0.096605137, -0.0424528457, 0.194970757, -0.0267837271 ] ++ ], ++ [ ++ [ 0.145917833, -0.0590635166, 0.0194431096, 0.0803030357, -0.0469358861, 0.148506433, -0.0526806451 ], ++ [ 0.196381122, -0.0228494033, -0.0299202427, -0.069508791, -0.0341768041, 0.0904152468, 0.108802207 ], ++ [ 0.138430953, 0.108872853, 0.125882119, 0.100856192, 0.0900289789, -0.0830678046, 0.0794649944 ], ++ [ -0.0318976864, -0.00436662883, 0.109950341, -0.0647689179, 0.128771216, 0.0578369871, 0.0661734 ], ++ [ 0.0763966814, -0.00110008568, 0.110896833, -0.057086423, -0.0514936894, 0.0455975607, 0.158067733 ] ++ ] ++ ]]> : tensor<3x2x5x7xf32> ++ %group_sizes = stablehlo.constant dense<[4, 4, 3]> : tensor<3xi64> ++ %result = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { ++ ragged_dot_dimension_numbers = #chlo.ragged_dot< ++ lhs_batching_dimensions = [0], ++ rhs_batching_dimensions = [1], ++ lhs_contracting_dimensions = [2], ++ rhs_contracting_dimensions = [2], ++ lhs_ragged_dimensions = [1], ++ rhs_group_dimensions = [0] ++ >, ++ precision_config = [#chlo, #chlo] ++ } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> ++ check.expect_almost_eq_const %result, dense<[ ++ [ ++ [-0.0199659951, 0.00206358638, 0.0285578221, -0.00411329232, -0.00885893404, -0.0113086831, 0.0343487822], ++ [0.0108370036, 0.0196357146, 0.0464844741, 0.032903526, 0.00752512738, -0.00205732603, 0.0463109687], ++ [-0.0279003512, 0.0171403233, 0.00885203853, -0.022806216, -0.0135696121, -0.00375272054, 0.0139928926], ++ [0.0116565451, -0.00521556707, -0.0245668497, -0.00946252606, 2.734600e-03, 0.00460146647, -0.0332586318], ++ [0.0373648889, 0.040080104, 0.0792120546, 0.0687142611, 0.0129001699, 0.048170276, 6.067640e-02], ++ [-0.00489785476, 0.0151357278, 0.0273378156, 0.0379059538, 0.0080597708, 0.0209609158, 0.0248660222], ++ [0.00253825542, -1.175260e-02, 0.0339594558, 0.0408501513, 0.0275165718, 0.0101594552, 0.0491689071], ++ [5.275800e-02, 0.0415463448, 0.0749897882, 0.0470644757, 0.00624182029, 0.0391805507, 0.03869069], ++ [0.0637338459, 0.00614991458, 0.0153763723, 0.0190313365, 0.0142990183, 0.0227143262, 0.0187453162], ++ [0.0359746702, 0.0182777364, -0.00368779944, -0.0100486111, 6.89582666E-5, -0.00202751439, 0.0124766938], ++ [-0.0151847685, -0.0175893605, 0.0247314386, 0.018632818, 0.00798455066, -0.00110600982, 0.00244264561] ++ ], ++ [ ++ [0.0288968664, -0.00678509939, 0.0346419513, 0.0141028976, -0.017396003, 0.00451522879, 0.00792134088], ++ [-0.0017626211, -0.0284877941, 0.0151375476, -0.00351338694, -0.00874114502, 0.0323345512, 0.0535612516], ++ [0.00123786228, -0.00454656407, 0.0335229039, 0.0019464466, -2.14070082E-4, 0.0266590156, -0.0212618597], ++ [-3.47743975E-4, -0.017693948, 0.0353507064, 0.00244920771, -0.0120135043, 0.0417729542, -0.0025454592], ++ [0.0108208582, -0.0171308704, 0.00553112756, 0.0411250815, 0.0335835591, 0.038393192, -0.00547906291], ++ [0.0169365555, 0.0157370344, -0.0128378682, 0.0470919088, -0.00582840201, 0.0324328542, 0.010203423], ++ [0.0520783663, 0.0298755895, 0.0362326317, 0.0681023895, 0.0207777359, 0.052735541, 0.0455959477], ++ [0.00623999349, -1.49650674E-4, -0.00651274621, 0.0146591738, 0.00641800836, 0.00297434814, 0.00838128477], ++ [0.0506783053, 0.00703135319, 0.0220930576, 0.0259224195, 0.001958607, 0.0123232938, 0.00920359604], ++ [0.0123091843, -5.780780e-03, -0.0128484722, 0.00679983944, -0.00871101767, 0.0087406747, -0.0115246754], ++ [0.0274577513, -0.0175638888, -0.00203213934, -0.0198616516, -0.0110571291, 0.0365728177, 0.0162097216] ++ ] ++ ]> : tensor<2x11x7xf32> ++ func.return ++} +diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir +--- stablehlo/stablehlo/tests/ops_stablehlo.mlir ++++ stablehlo/stablehlo/tests/ops_stablehlo.mlir +@@ -1775,6 +1775,30 @@ + // expected-error@+1 {{'precision_config' failed to satisfy constraint}} + %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = ["FOO", #stablehlo]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + func.return %0: tensor<2x2xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @exponential_result_accuracy ++func.func @exponential_result_accuracy(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor) -> tensor ++ func.return %0: tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @exponential_result_accuracy_tol ++func.func @exponential_result_accuracy_tol(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor) -> tensor ++ func.return %0: tensor ++} ++ ++// ----- ++ ++func.func @exponential_result_accuracy_tol(%arg0: tensor) -> tensor { ++ // expected-error@+1 {{Invalid tolerances for ResultAccuracyAttr with mode HIGHEST, must be all zero.}} ++ %0 = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor) -> tensor ++ func.return %0: tensor + } + + // ----- +diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir +--- stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir ++++ stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir +@@ -766,6 +766,11 @@ + func.return %0 : tensor<3x4xf32> + } + ++func.func @test_unary_result_accuracy(%arg0: tensor<2xf32>) -> tensor<2xf32> { ++ %exp = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor<2xf32>) -> tensor<2xf32> ++ func.return %exp : tensor<2xf32> ++} ++ + func.func @test_unary_round_nearest_even(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "stablehlo.round_nearest_even"(%arg0) {} : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +diff --ruN a/stablehlo/stablehlo/tests/print_stablehlo.mlir b/stablehlo/stablehlo/tests/print_stablehlo.mlir +--- stablehlo/stablehlo/tests/print_stablehlo.mlir ++++ stablehlo/stablehlo/tests/print_stablehlo.mlir +@@ -406,3 +406,16 @@ + %slice6 = stablehlo.slice %arg0 [1:3:1, 4:8:2] : (tensor<3x8xf32>) -> tensor<2x2xf32> + return %slice1, %slice2, %slice3, %slice4, %slice5, %slice6 : tensor<1xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>, tensor<2x2xf32>, tensor<2x2xf32> + } ++ ++func.func @result_accuracy_default() -> () attributes { ++ // CHECK: mode.default = #stablehlo.result_accuracy> ++ // CHECK: mode.highest = #stablehlo.result_accuracy> ++ // CHECK: mode.tolerance_full = #stablehlo.result_accuracy> ++ // CHECK: mode.tolerance_partial = #stablehlo.result_accuracy> ++ mode.default = #stablehlo.result_accuracy>, ++ mode.highest = #stablehlo.result_accuracy>, ++ mode.tolerance_full = #stablehlo.result_accuracy>, ++ mode.tolerance_partial = #stablehlo.result_accuracy> ++} { ++ func.return ++} +diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +--- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir ++++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +@@ -1940,6 +1940,17 @@ + return %1 : tensor<12xi64> + } + ++// ----- ++ ++// CHECK-LABEL: @reorder_invalid_with_dynamic_shape ++func.func @reorder_invalid_with_dynamic_shape(%arg0: tensor<1x3x4xf32>) -> (tensor) { ++ // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> ++ // CHECK-NEXT: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<3x4xf32>) -> tensor ++ // CHECK: return %[[CONVERT]] ++ %0 = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> ++ %1 = stablehlo.convert %0 : (tensor<3x4xf32>) -> tensor ++ return %1 : tensor ++} + + // ----- + +diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir +--- stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir ++++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir +@@ -36,7 +36,7 @@ + + // ----- + +-// expected-error @+1 {{number of refinements must match number of function operands 6 vs 1}} ++// expected-error @+1 {{number of refinements must match number of op operands 6 vs 1}} + func.func @refine_arguments_invalid_arg_num_mismatch(%arg0: tensor) { + return + } +diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir +--- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir ++++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir +@@ -0,0 +1,2966 @@ ++// RUN: stablehlo-opt --mlir-print-op-generic %s.bc | FileCheck %s ++// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-translate --serialize --target=1.9.0 | stablehlo-opt --mlir-print-op-generic | FileCheck %s ++// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-opt > %t.0 ++// RUN: stablehlo-opt --strip-debuginfo %s > %t.1 ++// RUN: diff %t.0 %t.1 ++// RUN: stablehlo-translate --serialize --target=1.9.0 --strip-debuginfo %s > %t.2 ++// RUN: diff %s.bc %t.2 ++// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode -debug-only=vhlo-bytecode %s 2>&1 | FileCheck --check-prefix=CHECK-WARN %s ++// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode %s | stablehlo-opt -debug-only=vhlo-bytecode 2>&1 | FileCheck --check-prefix=CHECK-WARN %s ++ ++// CHECK-WARN-NOT: Not Implemented ++ ++// ============ ATTRIBUTES ============ ++ ++// CHECK-LABEL: "attr_comparison_direction_eq" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_direction_eq(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ // CHECK: comparison_direction = #vhlo ++ comparison_direction = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_direction_ne" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_direction_ne(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ // CHECK: comparison_direction = #vhlo ++ comparison_direction = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_direction_ge" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_direction_ge(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ // CHECK: comparison_direction = #vhlo ++ comparison_direction = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_direction_gt" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_direction_gt(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ // CHECK: comparison_direction = #vhlo ++ comparison_direction = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_direction_le" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_direction_le(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ // CHECK: comparison_direction = #vhlo ++ comparison_direction = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_direction_lt" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_direction_lt(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ // CHECK: comparison_direction = #vhlo ++ comparison_direction = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_type_notype" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_type_notype(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ comparison_direction = #stablehlo ++ // CHECK: compare_type = #vhlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_type_float" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_type_float(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ comparison_direction = #stablehlo, ++ // CHECK: compare_type = #vhlo, ++ compare_type = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_type_totalorder" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_type_totalorder(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ comparison_direction = #stablehlo, ++ // CHECK: compare_type = #vhlo, ++ compare_type = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_type_signed" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_type_signed(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ comparison_direction = #stablehlo, ++ // CHECK: compare_type = #vhlo, ++ compare_type = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_comparison_type_unsigned" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_comparison_type_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ comparison_direction = #stablehlo, ++ // CHECK: compare_type = #vhlo, ++ compare_type = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// ConvDimensionNumbers aka #stablehlo.conv is covered below. ++ ++// CHECK-LABEL: "attr_custom_call_api_version_unspecified" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_custom_call_api_version_unspecified(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "foo", ++ // CHECK: api_version = #vhlo ++ api_version = 0 : i32 ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_custom_call_api_version_original" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_custom_call_api_version_original(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "foo", ++ // CHECK: api_version = #vhlo ++ api_version = 1 : i32 ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_custom_call_api_version_status_returning" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_custom_call_api_version_status_returning(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "foo", ++ // CHECK: api_version = #vhlo ++ api_version = 2 : i32 ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_custom_call_api_version_status_returning_unified" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "foo", ++ // CHECK: api_version = #vhlo ++ api_version = 3 : i32 ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_dict" ++// CHECK: #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = #vhlo.integer_v1<1 : i32>, #vhlo.string_v1<"attr2"> = #vhlo.integer_v1<2 : i32>} ++func.func @attr_dict() attributes {stablehlo.attr = {attr1 = 1 : i32, attr2 = 2 : i32}} { ++ return ++} ++ ++// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++// CHECK: api_version = #vhlo ++// CHECK-SAME: backend_config = #vhlo.dict_v1<{#vhlo.string_v1<"bar"> = #vhlo.integer_v1<42 : i32>}> ++func.func @attr_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "foo", ++ backend_config= {bar = 42 : i32}, ++ api_version = 4 : i32 ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++ ++// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi_no_backend_config" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++// CHECK: api_version = #vhlo ++// CHECK-SAME: backend_config = #vhlo.dict_v1<{}> ++func.func @attr_custom_call_api_version_typed_ffi_no_backend_config(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "foo", ++ api_version = 4 : i32 ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// DotDimensionNumbers aka #stablehlo.dot is covered below. ++ ++// CHECK-LABEL: "attr_fft_type_fft" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { ++ %0 = "stablehlo.fft"(%arg0) { ++ // CHECK: fft_type = #vhlo ++ fft_type = #stablehlo, ++ fft_length = array ++ } : (tensor<16xcomplex>) -> tensor<16xcomplex> ++ func.return %0 : tensor<16xcomplex> ++} ++ ++// CHECK-LABEL: "attr_fft_type_ifft" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { ++ %0 = "stablehlo.fft"(%arg0) { ++ // CHECK: fft_type = #vhlo ++ fft_type = #stablehlo, ++ fft_length = array ++ } : (tensor<16xcomplex>) -> tensor<16xcomplex> ++ func.return %0 : tensor<16xcomplex> ++} ++ ++// CHECK-LABEL: "attr_fft_type_rfft" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { ++ %0 = "stablehlo.fft"(%arg0) { ++ // CHECK: fft_type = #vhlo ++ fft_type = #stablehlo, ++ fft_length = array ++ } : (tensor<16xf32>) -> tensor<9xcomplex> ++ func.return %0 : tensor<9xcomplex> ++} ++ ++// CHECK-LABEL: "attr_fft_type_irfft" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { ++ %0 = "stablehlo.fft"(%arg0) { ++ // CHECK: fft_type = #vhlo ++ fft_type = #stablehlo, ++ fft_length = array ++ } : (tensor<9xcomplex>) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "attr_result_accuracy_HIGHEST" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} ++func.func @attr_result_accuracy_HIGHEST(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { ++ %0 = "stablehlo.exponential"(%arg0) { ++ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> ++ result_accuracy = #stablehlo.result_accuracy> ++ } : (tensor<8x16xf32>) -> tensor<8x16xf32> ++ func.return %0 : tensor<8x16xf32> ++} ++ ++// CHECK-LABEL: "attr_result_accuracy_TOLERANCE" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} ++func.func @attr_result_accuracy_TOLERANCE(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { ++ %0 = "stablehlo.exponential"(%arg0) { ++ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> ++ result_accuracy = #stablehlo.result_accuracy> ++ } : (tensor<8x16xf32>) -> tensor<8x16xf32> ++ func.return %0 : tensor<8x16xf32> ++} ++ ++// CHECK-LABEL: "attr_result_accuracy_DEFAULT" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} ++func.func @attr_result_accuracy_DEFAULT(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { ++ %0 = "stablehlo.exponential"(%arg0) { ++ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> ++ result_accuracy = #stablehlo.result_accuracy> ++ } : (tensor<8x16xf32>) -> tensor<8x16xf32> ++ func.return %0 : tensor<8x16xf32> ++} ++ ++// GatherDimensionNumbers aka #stablehlo.gather is covered below. ++ ++// CHECK-LABEL: "attr_precision_config_default" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { ++ %0 = "stablehlo.dot"(%arg0, %arg1) { ++ // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> ++ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> ++ func.return %0 : tensor<8x8xf32> ++} ++ ++// CHECK-LABEL: "attr_precision_config_high" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { ++ %0 = "stablehlo.dot"(%arg0, %arg1) { ++ // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> ++ precision_config = [#stablehlo, #stablehlo] ++ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> ++ func.return %0 : tensor<8x8xf32> ++} ++ ++// CHECK-LABEL: "attr_precision_config_highest" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { ++ %0 = "stablehlo.dot"(%arg0, %arg1) { ++ // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> ++ precision_config = [#stablehlo, #stablehlo] ++ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> ++ func.return %0 : tensor<8x8xf32> ++} ++ ++// CHECK-LABEL: "attr_rng_algorithm_default" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { ++ %0:2 = "stablehlo.rng_bit_generator"(%arg0) { ++ // CHECK: rng_algorithm = #vhlo ++ rng_algorithm = #stablehlo ++ } : (tensor) -> (tensor, tensor) ++ func.return %0#0, %0#1 : tensor, tensor ++} ++ ++// CHECK-LABEL: "attr_rng_algorithm_three_fry" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_rng_algorithm_three_fry(%arg0: tensor) -> (tensor, tensor) { ++ %0:2 = "stablehlo.rng_bit_generator"(%arg0) { ++ // CHECK: rng_algorithm = #vhlo ++ rng_algorithm = #stablehlo ++ } : (tensor) -> (tensor, tensor) ++ func.return %0#0, %0#1 : tensor, tensor ++} ++ ++// CHECK-LABEL: "attr_rng_algorithm_philox" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor) { ++ %0:2 = "stablehlo.rng_bit_generator"(%arg0) { ++ // CHECK: rng_algorithm = #vhlo ++ rng_algorithm = #stablehlo ++ } : (tensor) -> (tensor, tensor) ++ func.return %0#0, %0#1 : tensor, tensor ++} ++ ++// CHECK-LABEL: "attr_rng_distribution_uniform" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { ++ %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { ++ // CHECK: rng_distribution = #vhlo ++ rng_distribution = #stablehlo ++ } : (tensor, tensor, tensor<0xindex>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "attr_rng_distribution_normal" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { ++ %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { ++ // CHECK: rng_distribution = #vhlo ++ rng_distribution = #stablehlo ++ } : (tensor, tensor, tensor<0xindex>) -> tensor ++ func.return %0 : tensor ++} ++ ++// ScatterDimensionNumbers aka #stablehlo.scatter is covered below. ++ ++// CHECK-LABEL: "attr_transpose_no_transpose" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_transpose_no_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { ++ %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { ++ left_side = true, ++ lower = true, ++ unit_diagonal = true, ++ // transpose_a = #vhlo, ++ transpose_a = #stablehlo ++ } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> ++ func.return %0 : tensor<16x16xf32> ++} ++ ++// CHECK-LABEL: "attr_transpose_transpose" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_transpose_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { ++ %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { ++ left_side = true, ++ lower = true, ++ unit_diagonal = true, ++ // transpose_a = #vhlo, ++ transpose_a = #stablehlo ++ } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> ++ func.return %0 : tensor<16x16xf32> ++} ++ ++// CHECK-LABEL: "attr_transpose_adjoint" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { ++ %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { ++ left_side = true, ++ lower = true, ++ unit_diagonal = true, ++ // transpose_a = #vhlo, ++ transpose_a = #stablehlo ++ } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> ++ func.return %0 : tensor<16x16xf32> ++} ++ ++// TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. ++ ++// CHECK-LABEL: "attr_type_extensions_bounds" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @attr_type_extensions_bounds(%arg0: tensor>) -> tensor> { ++ // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> () ++ func.return %arg0 : tensor> ++} ++ ++ ++// ============ DEFAULTS ============ ++ ++// CHECK-LABEL: "default_all_gather" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { ++ // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ ++ // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, ++ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> ++ %0 = "stablehlo.all_gather"(%arg0) { ++ all_gather_dim = 1 : i64, ++ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> ++ } : (tensor<16x8xf32>) -> tensor<16x16xf32> ++ func.return %0 : tensor<16x16xf32> ++} ++ ++// CHECK-LABEL: "default_all_gather_variadic" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_all_gather_variadic(%arg0: tensor<16x8xf32>, %arg1: tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) { ++ %0:2 = "stablehlo.all_gather"(%arg0, %arg1) { ++ all_gather_dim = 1 : i64, ++ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> ++ } : (tensor<16x8xf32>, tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) ++ func.return %0#0, %0#1 : tensor<16x16xf32>, tensor<16x16xf32> ++} ++ ++// CHECK-LABEL: "default_all_reduce" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_all_reduce(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) ++ // CHECK-SAME: <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, ++ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ ++ %0 = "stablehlo.all_reduce"(%arg0) ({ ++ ^bb0(%arg1: tensor, %arg2: tensor): ++ %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "default_all_to_all" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { ++ // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, ++ // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> ++ // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> ++ %0 = "stablehlo.all_to_all"(%arg0) { ++ split_dimension = 1 : i64, ++ concat_dimension = 0 : i64, ++ split_count = 4 : i64, ++ replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> ++ } : (tensor<4x16xf32>) -> tensor<16x4xf32> ++ func.return %0 : tensor<16x4xf32> ++} ++ ++// CHECK-LABEL: "default_all_to_all_variadic" ++func.func @default_all_to_all_variadic(%arg0: tensor<4x16xf32>, %arg1: tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) { ++ %0:2 = "stablehlo.all_to_all"(%arg0, %arg1) { ++ split_dimension = 1 : i64, ++ concat_dimension = 0 : i64, ++ split_count = 4 : i64, ++ replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<4x16xf32>, tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) ++ func.return %0#0, %0#1 : tensor<16x4xf32>, tensor<20x4xf32> ++} ++ ++// CHECK-LABEL: "default_cholesky" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { ++ // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: lower = #vhlo.bool_v1 ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> ++ %0 = "stablehlo.cholesky"(%arg0) : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> ++ func.return %0 : tensor<1x16x16xf32> ++} ++ ++// CHECK-LABEL: "default_collective_permute" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { ++ // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> ++ %0 = "stablehlo.collective_permute"(%arg0) { ++ source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> ++ } : (tensor<16x8xf32>) -> tensor<16x8xf32> ++ func.return %0 : tensor<16x8xf32> ++} ++ ++// CHECK-LABEL: "default_collective_broadcast" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_collective_broadcast(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { ++ // CHECK: "vhlo.collective_broadcast_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x2xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> ++ %0 = "stablehlo.collective_broadcast"(%arg0) { ++ replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> ++ } : (tensor<16x8xf32>) -> tensor<16x8xf32> ++ func.return %0 : tensor<16x8xf32> ++} ++ ++// CHECK-LABEL: "default_compare" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @default_compare(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: compare_type = #vhlo, ++ // CHECK-SAME: comparison_direction = #vhlo ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ comparison_direction = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "default_composite" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_composite(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{}> ++ // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> ++ // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> ++ // CHECK-SAME: version = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.composite"(%arg0) { ++ name = "stablehlo.composite_target", ++ decomposition = @composite_target ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "default_convolution" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @default_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { ++ // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, ++ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, ++ // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, ++ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x6x6x16x!vhlo.f32_v1> ++ %0 = "stablehlo.convolution"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, ++ feature_group_count = 1 : i64, ++ batch_group_count = 1 : i64 ++ } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> ++ func.return %0 : tensor<1x6x6x16xf32> ++} ++ ++// CHECK-LABEL: "default_custom_call" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_custom_call(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: api_version = #vhlo, ++ // CHECK-SAME: backend_config = #vhlo.string_v1<"">, ++ // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, ++ // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, ++ // CHECK-SAME: has_side_effect = #vhlo.bool_v1, ++ // CHECK-SAME: operand_layouts = #vhlo.array_v1<[]>, ++ // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]> ++ // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "foo" ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "default_dot_general" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @default_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { ++ // CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: accumulation_type = #vhlo.type_v1, ++ // CHECK-SAME: allow_imprecise_accumulation = #vhlo.type_v1, ++ // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: lhs_component_count = #vhlo.type_v1, ++ // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: lhs_precision_type = #vhlo.type_v1, ++ // CHECK-SAME: num_primitive_operations = #vhlo.type_v1, ++ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, ++ // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: rhs_component_count = #vhlo.type_v1, ++ // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: rhs_precision_type = #vhlo.type_v1 ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> ++ %0 = "stablehlo.dot_general"(%arg0, %arg1) { ++ dot_dimension_numbers = #stablehlo.dot< ++ lhs_batching_dimensions = [0], ++ lhs_contracting_dimensions = [2], ++ rhs_batching_dimensions = [0], ++ rhs_contracting_dimensions = [1] ++ > ++ } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> ++ func.return %0 : tensor<8x8x8xf32> ++} ++ ++// CHECK-LABEL: "dot_general_algorithm" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @dot_general_algorithm(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { ++// CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ ++// CHECK-SAME: accumulation_type = #vhlo.type_v1, ++// CHECK-SAME: allow_imprecise_accumulation = #vhlo.bool_v1, ++// CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++// CHECK-SAME: lhs_component_count = #vhlo.integer_v1<1 : i64>, ++// CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++// CHECK-SAME: lhs_precision_type = #vhlo.type_v1, ++// CHECK-SAME: num_primitive_operations = #vhlo.integer_v1<1 : i64>, ++// CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, ++// CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++// CHECK-SAME: rhs_component_count = #vhlo.integer_v1<1 : i64>, ++// CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++// CHECK-SAME: rhs_precision_type = #vhlo.type_v1 ++// CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> ++ %0 = "stablehlo.dot_general"(%arg0, %arg1) { ++ dot_dimension_numbers = #stablehlo.dot< ++ lhs_batching_dimensions = [0], ++ lhs_contracting_dimensions = [2], ++ rhs_batching_dimensions = [0], ++ rhs_contracting_dimensions = [1] ++ >, ++ algorithm = #stablehlo.dot_algorithm< ++ lhs_precision_type = tf32, ++ rhs_precision_type = tf32, ++ accumulation_type = f32, ++ lhs_component_count = 1, ++ rhs_component_count = 1, ++ num_primitive_operations = 1, ++ allow_imprecise_accumulation = false ++ > ++ } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> ++ func.return %0 : tensor<8x8x8xf32> ++} ++ ++// CHECK-LABEL: "default_dynamic_broadcast_in_dim" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { ++ // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>>, ++ // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { ++ broadcast_dimensions = array ++ } : (tensor, tensor<2xindex>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "default_dynamic_conv" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @default_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { ++ // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, ++ // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, ++ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> ++ %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { ++ dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, ++ feature_group_count = 1 : i64, ++ batch_group_count = 1 : i64 ++ } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> ++ func.return %0 : tensor<1x?x?x16xf32> ++} ++ ++// CHECK-LABEL: "default_dynamic_gather" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @default_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { ++ // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, ++ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, ++ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> ++ %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [2], ++ collapsed_slice_dims = [0, 1], ++ start_index_map = [0, 1], ++ index_vector_dim = 2 ++ > ++ } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> ++ func.return %0 : tensor<1x5x8xf32> ++} ++ ++func.func @default_func(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.func_v1"() <{ ++ // CHECK-SAME: arg_attrs = #vhlo.array_v1<[]>, ++ // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, ++ // CHECK-SAME: res_attrs = #vhlo.array_v1<[]>, ++ // CHECK-SAME: sym_name = #vhlo.string_v1<"default_func">, ++ // CHECK-SAME: sym_visibility = #vhlo.string_v1<""> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : () -> () ++ func.return %arg0 : tensor ++} ++ ++// CHECK-LABEL: "default_gather" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @default_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { ++ // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, ++ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, ++ // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, ++ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [2], ++ collapsed_slice_dims = [0, 1], ++ start_index_map = [0, 1], ++ index_vector_dim = 2 ++ >, ++ slice_sizes = array ++ } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> ++ func.return %0 : tensor<1x5x1xf32> ++} ++ ++// CHECK-LABEL: "default_infeed" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { ++ // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: infeed_config = #vhlo.string_v1<"">, ++ // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[]> ++ // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) ++ %0:2 = "stablehlo.infeed"(%arg0) : (!stablehlo.token) -> (tensor, !stablehlo.token) ++ func.return %0#0, %0#1 : tensor, !stablehlo.token ++} ++ ++// CHECK-LABEL: "default_outfeed" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @default_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { ++ // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: outfeed_config = #vhlo.string_v1<""> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 ++ %0 = "stablehlo.outfeed"(%arg0, %arg1) : (tensor, !stablehlo.token) -> !stablehlo.token ++ func.return %0 : !stablehlo.token ++} ++ ++// CHECK-LABEL: "default_recv" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { ++ // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 ++ // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) ++ %0:2 = "stablehlo.recv"(%arg0) { ++ channel_handle = #stablehlo.channel_handle ++ } : (!stablehlo.token) -> (tensor, !stablehlo.token) ++ func.return %0#0, %0#1 : tensor, !stablehlo.token ++} ++ ++// CHECK-LABEL: "default_send" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @default_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { ++ // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 ++ %0 = "stablehlo.send"(%arg0, %arg1) { ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor, !stablehlo.token) -> !stablehlo.token ++ func.return %0 : !stablehlo.token ++} ++ ++// CHECK-LABEL: "default_reduce_scatter" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { ++ // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, ++ // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.reduce_scatter"(%arg0) ({ ++ ^bb0(%arg1: tensor, %arg2: tensor): ++ %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ scatter_dimension = 0 : i64, ++ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> ++ } : (tensor<16xf32>) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "default_reduce_window" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @default_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { ++ // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, ++ // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, ++ // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, ++ // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, ++ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x16x30x7x!vhlo.f32_v1> ++ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ ++ ^bb0(%arg2: tensor, %arg3: tensor): ++ %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ window_dimensions = array ++ } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> ++ func.return %0 : tensor<2x16x30x7xf32> ++} ++ ++// CHECK-LABEL: "default_scatter" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @default_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { ++ // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, ++ // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, ++ // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, ++ // CHECK-SAME: unique_indices = #vhlo.bool_v1, ++ // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> ++ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [1], ++ inserted_window_dims = [0, 1], ++ scatter_dims_to_operand_dims = [0, 1], ++ index_vector_dim = 1 ++ > ++ } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> ++ func.return %0 : tensor<200x100x300xf32> ++} ++ ++// CHECK-LABEL: "default_select_and_scatter" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @default_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { ++ // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, ++ // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, ++ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }, { ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<10x23x23x64x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> ++ %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }, { ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ window_dimensions = array ++ } : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> ++ func.return %0 : tensor<10x24x24x64xf32> ++} ++ ++// CHECK-LABEL: "default_sort" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @default_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { ++ // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: dimension = #vhlo.integer_v1<-1 : i64> ++ // CHECK-SAME: is_stable = #vhlo.bool_v1 ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.sort"(%arg0) ({ ++ ^bb0(%arg1: tensor, %arg2: tensor): ++ %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) : (tensor<16xf32>) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// ============ OPS ============ ++ ++// CHECK-LABEL: "op_abs" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_abs(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_add" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_after_all" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { ++ // CHECK: "vhlo.after_all_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> !vhlo.token_v1 ++ %0 = "stablehlo.after_all"(%arg0) : (!stablehlo.token) -> !stablehlo.token ++ func.return %0 : !stablehlo.token ++} ++ ++// CHECK-LABEL: "op_all_gather" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { ++ // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ ++ // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, ++ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> ++ %0 = "stablehlo.all_gather"(%arg0) { ++ all_gather_dim = 1 : i64, ++ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, ++ channel_handle = #stablehlo.channel_handle, ++ use_global_device_ids ++ } : (tensor<16x8xf32>) -> tensor<16x16xf32> ++ func.return %0 : tensor<16x16xf32> ++} ++ ++// CHECK-LABEL: "op_all_reduce" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_all_reduce(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, ++ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.all_reduce"(%arg0) ({ ++ ^bb0(%arg1: tensor, %arg2: tensor): ++ %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, ++ channel_handle = #stablehlo.channel_handle, ++ use_global_device_ids ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_all_reduce_with_promotable_types" ++func.func @op_all_reduce_with_promotable_types(%operand: tensor) -> tensor { ++ // CHECK: "vhlo.all_reduce_v2"(%[[ARG0:.*]]) ++ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () ++ // CHECK: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %result = "stablehlo.all_reduce"(%operand) ({ ++ ^bb0(%arg0: tensor, %arg1: tensor): ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%0) : (tensor) -> () ++ }) { ++ replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, ++ channel_handle = #stablehlo.channel_handle, ++ use_global_device_ids ++ } : (tensor) -> tensor ++ ++ func.return %result : tensor ++} ++ ++// CHECK-LABEL: "default_all_reduce_variadic" ++func.func @default_all_reduce_variadic(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { ++ %0:2 = "stablehlo.all_reduce"(%arg0, %arg1) ({ ++ ^bb0(%arg2: tensor, %arg3: tensor): ++ %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> (tensor) ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> ++ } : (tensor, tensor) -> (tensor, tensor) ++ func.return %0#0, %0#1 : tensor, tensor ++} ++ ++// CHECK-LABEL: "op_all_to_all" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { ++ // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, ++ // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> ++ // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> ++ %0 = "stablehlo.all_to_all"(%arg0) { ++ split_dimension = 1 : i64, ++ concat_dimension = 0 : i64, ++ split_count = 4 : i64, ++ replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<4x16xf32>) -> tensor<16x4xf32> ++ func.return %0 : tensor<16x4xf32> ++} ++ ++// CHECK-LABEL: "op_and" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_atan2" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.atan2_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_batch_norm_grad" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) ++func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { ++ // CHECK: "vhlo.batch_norm_grad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ ++ // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, ++ // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) ++ %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { ++ epsilon = 0.001 : f32, ++ feature_index = 0 : i64 ++ } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) ++ func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_batch_norm_inference" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) ++func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { ++ // CHECK: "vhlo.batch_norm_inference_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ ++ // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, ++ // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1> ++ %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { ++ epsilon = 0.001 : f32, ++ feature_index = 0 : i64 ++ } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> ++ func.return %0 : tensor<16x16x16x16xf32> ++} ++ ++// CHECK-LABEL: "op_batch_norm_training" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { ++ // CHECK: "vhlo.batch_norm_training_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, ++ // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) ++ %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { ++ epsilon = 0.001 : f32, ++ feature_index = 0 : i64 ++ } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) ++ func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_bitcast_convert" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_bitcast_convert(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.bitcast_convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.bitcast_convert"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_broadcast_in_dim" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { ++ // CHECK: "vhlo.broadcast_in_dim_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> ++ %0 = "stablehlo.broadcast_in_dim"(%arg0) { ++ broadcast_dimensions = array ++ } : (tensor<16xf32>) -> tensor<16x16xf32> ++ func.return %0 : tensor<16x16xf32> ++} ++ ++// CHECK-LABEL: "op_broadcast" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { ++ // CHECK: "vhlo.broadcast_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: broadcast_sizes = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> ++ %0 = "stablehlo.broadcast"(%arg0) { ++ broadcast_sizes = array ++ } : (tensor<16xf32>) -> tensor<16x16xf32> ++ func.return %0 : tensor<16x16xf32> ++} ++ ++// CHECK-LABEL: "op_case" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_case(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ ++ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.case"(%arg0) ({ ++ "stablehlo.return"(%arg1) : (tensor) -> () ++ }) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_cbrt" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_cbrt(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.cbrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.cbrt"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_ceil" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_ceil(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.ceil_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.ceil"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_cholesky" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { ++ // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: lower = #vhlo.bool_v1 ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> ++ %0 = "stablehlo.cholesky"(%arg0) { ++ lower = true ++ } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> ++ func.return %0 : tensor<1x16x16xf32> ++} ++ ++// CHECK-LABEL: "op_clamp" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { ++ // CHECK: "vhlo.clamp_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_count_leading_zeros" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.count_leading_zeros_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.count_leading_zeros"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_collective_permute" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { ++ // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> ++ %0 = "stablehlo.collective_permute"(%arg0) { ++ source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<16x8xf32>) -> tensor<16x8xf32> ++ func.return %0 : tensor<16x8xf32> ++} ++ ++// CHECK-LABEL: "op_compare" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: compare_type = #vhlo, ++ // CHECK-SAME: comparison_direction = #vhlo ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.compare"(%arg0, %arg1) { ++ comparison_direction = #stablehlo, ++ compare_type = #stablehlo ++ } : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_complex" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> { ++ // CHECK: "vhlo.complex_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1> ++ %0 = "stablehlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> ++ func.return %0 : tensor> ++} ++ ++// CHECK-LABEL: "op_composite" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_composite(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{#vhlo.string_v1<"my_int"> = #vhlo.integer_v1<1 : i64>, #vhlo.string_v1<"my_string"> = #vhlo.string_v1<"foo">}> ++ // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> ++ // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> ++ // CHECK-SAME: version = #vhlo.integer_v1<1 : i32> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.composite"(%arg0) { ++ name = "stablehlo.composite_target", ++ decomposition = @composite_target, ++ version = 1 : i32, ++ composite_attributes = { ++ my_string = "foo", ++ my_int = 1 : i64 ++ } ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_concatenate" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { ++ // CHECK: "vhlo.concatenate_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1<8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.concatenate"(%arg0, %arg1) { ++ dimension = 0 : i64 ++ } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_constant" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_constant(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.constant_v1"() <{ ++ // CHECK-SAME: value = #vhlo.tensor_v1 : tensor> ++ // CHECK-SAME: }> : () -> !vhlo.tensor_v1 ++ %0 = "stablehlo.constant"() { ++ value = dense<0.0> : tensor ++ } : () -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_convert" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_convert(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_convolution" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> { ++ // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, ++ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, ++ // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, ++ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x7x7x16x!vhlo.f32_v1> ++ %0 = "stablehlo.convolution"(%arg0, %arg1) { ++ window_strides = array, ++ padding = dense<1> : tensor<2x2xi64>, ++ lhs_dilation = array, ++ rhs_dilation = array, ++ window_reversal = array, ++ dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, ++ feature_group_count = 1 : i64, ++ batch_group_count = 1 : i64, ++ precision_config = [#stablehlo, #stablehlo] ++ } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> ++ func.return %0 : tensor<1x7x7x16xf32> ++} ++ ++// CHECK-LABEL: "op_cosine" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_cosine(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.cosine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.cosine"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_create_token" ++func.func @op_create_token() -> !stablehlo.token { ++ // CHECK: "vhlo.create_token_v1"() : () -> !vhlo.token_v1 ++ %0 = "stablehlo.create_token"() : () -> !stablehlo.token ++ func.return %0 : !stablehlo.token ++} ++ ++// CHECK-LABEL: "op_cross_replica_sum" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.cross-replica-sum_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.cross-replica-sum"(%arg0) { ++ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_custom_call" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_custom_call(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: api_version = #vhlo, ++ // CHECK-SAME: backend_config = #vhlo.string_v1<"\08\03\1A\02">, ++ // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, ++ // CHECK-SAME: called_computations = #vhlo.array_v1<[#vhlo.string_v1<"foo">]>, ++ // CHECK-SAME: has_side_effect = #vhlo.bool_v1, ++ // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, ++ // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[ ++ // CHECK-SAME: #vhlo.output_operand_alias_v1< ++ // CHECK-SAME: outputTupleIndices = [], ++ // CHECK-SAME: operandIndex = 0, ++ // CHECK-SAME: operandTupleIndices = []>]> ++ // CHECK-SAME: result_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "foo", ++ has_side_effect = true, ++ backend_config = "\08\03\1A\02", ++ api_version = 2 : i32, ++ called_computations = [@foo], ++ operand_layouts = [dense<> : tensor<0xindex>], ++ output_operand_aliases = [ ++ #stablehlo.output_operand_alias], ++ result_layouts = [dense<> : tensor<0xindex>] ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_custom_call_empty_result_layout" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func public @op_custom_call_empty_result_layout(%arg0: tensor) -> tensor { ++ // %0 = "vhlo.custom_call_v1"(%arg0) <{>}> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> ++ // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: api_version = #vhlo, ++ // CHECK-SAME: backend_config = #vhlo.string_v1<"">, ++ // CHECK-SAME: call_target_name = #vhlo.string_v1<"empty_output">, ++ // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, ++ // CHECK-SAME: has_side_effect = #vhlo.bool_v1, ++ // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, ++ // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]>, ++ // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> ++ %0 = "stablehlo.custom_call"(%arg0) <{ ++ api_version = 2 : i32, ++ call_target_name = "empty_output", ++ has_side_effect = true, ++ operand_layouts = [dense<> : tensor<0xindex>], ++ result_layouts = [] ++ }> : (tensor) -> tuple<> ++ return %arg0 : tensor ++} ++ ++// CHECK-LABEL: "op_divide" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.divide_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_dot_general" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { ++ // CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: accumulation_type = #vhlo.type_v1, ++ // CHECK-SAME: allow_imprecise_accumulation = #vhlo.type_v1, ++ // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: lhs_component_count = #vhlo.type_v1, ++ // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: lhs_precision_type = #vhlo.type_v1, ++ // CHECK-SAME: num_primitive_operations = #vhlo.type_v1, ++ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, ++ // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: rhs_component_count = #vhlo.type_v1, ++ // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: rhs_precision_type = #vhlo.type_v1 ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> ++ %0 = "stablehlo.dot_general"(%arg0, %arg1) { ++ dot_dimension_numbers = #stablehlo.dot< ++ lhs_batching_dimensions = [0], ++ lhs_contracting_dimensions = [2], ++ rhs_batching_dimensions = [0], ++ rhs_contracting_dimensions = [1] ++ >, ++ precision_config = [#stablehlo, #stablehlo] ++ } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> ++ func.return %0 : tensor<8x8x8xf32> ++} ++ ++// CHECK-LABEL: "op_dot" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { ++ // CHECK: "vhlo.dot_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> ++ %0 = "stablehlo.dot"(%arg0, %arg1) { ++ precision_config = [#stablehlo, #stablehlo] ++ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> ++ func.return %0 : tensor<8x8xf32> ++} ++ ++// CHECK-LABEL: "op_dynamic_broadcast_in_dim" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { ++ // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { ++ broadcast_dimensions = array, ++ known_expanding_dimensions = array, ++ known_nonexpanding_dimensions = array ++ } : (tensor, tensor<2xindex>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_dynamic_conv" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { ++ // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, ++ // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, ++ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> ++ %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { ++ window_strides = array, ++ lhs_dilation = array, ++ rhs_dilation = array, ++ window_reversal = array, ++ dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, ++ feature_group_count = 1 : i64, ++ batch_group_count = 1 : i64, ++ precision_config = [#stablehlo, #stablehlo] ++ } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> ++ func.return %0 : tensor<1x?x?x16xf32> ++} ++ ++// CHECK-LABEL: "op_dynamic_gather" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { ++ // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, ++ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, ++ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> ++ %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [2], ++ collapsed_slice_dims = [0, 1], ++ start_index_map = [0, 1], ++ index_vector_dim = 2 ++ >, ++ indices_are_sorted = true ++ } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> ++ func.return %0 : tensor<1x5x8xf32> ++} ++ ++// CHECK-LABEL: "op_dynamic_gather_with_batching_dims" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_dynamic_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<4xi32>) -> tensor<1x5x8xf32> { ++ // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, ++ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> ++ %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [2], ++ collapsed_slice_dims = [1, 2], ++ operand_batching_dims = [0], ++ start_indices_batching_dims = [1], ++ start_index_map = [1, 2], ++ index_vector_dim = 2 ++ >, ++ indices_are_sorted = true ++ } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>, tensor<4xi32>) -> tensor<1x5x8xf32> ++ func.return %0 : tensor<1x5x8xf32> ++} ++ ++// CHECK-LABEL: "op_dynamic_iota" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { ++ // CHECK: "vhlo.dynamic_iota_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.dynamic_iota"(%arg0) { ++ iota_dimension = 0 : i64 ++ } : (tensor<1xindex>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_dynamic_pad" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) ++func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor { ++ // CHECK: "vhlo.dynamic_pad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_dynamic_reshape" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { ++ // CHECK: "vhlo.dynamic_reshape_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_dynamic_slice" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { ++ // CHECK: "vhlo.dynamic_slice_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> ++ %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { ++ slice_sizes = array ++ } : (tensor<16xf32>, tensor) -> tensor<4xf32> ++ func.return %0 : tensor<4xf32> ++} ++ ++// CHECK-LABEL: "op_dynamic_update_slice" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<16xf32> { ++ // CHECK: "vhlo.dynamic_update_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_einsum" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { ++ // CHECK: "vhlo.einsum_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab,bc->ac"> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> ++ %0 = "stablehlo.einsum"(%arg0, %arg1) { ++ einsum_config = "ab,bc->ac" ++ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> ++ func.return %0 : tensor<8x8xf32> ++} ++ ++// CHECK-LABEL: "op_exponential_minus_one" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.exponential_minus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_exponential" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_exponential(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = #vhlo.result_accuracy_v1>}> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_fft" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { ++ // CHECK: "vhlo.fft_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: fft_length = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: fft_type = #vhlo ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.complex_v1>) -> !vhlo.tensor_v1<16x!vhlo.complex_v1> ++ %0 = "stablehlo.fft"(%arg0) { ++ fft_type = #stablehlo, ++ fft_length = array ++ } : (tensor<16xcomplex>) -> tensor<16xcomplex> ++ func.return %0 : tensor<16xcomplex> ++} ++ ++// CHECK-LABEL: "op_floor" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_floor(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.floor_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.floor"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++func.func private @op_func(%arg0: tensor {stablehlo.arg = "0"}) -> (tensor {stablehlo.result = "0"}) { ++ // CHECK: "vhlo.func_v1"() <{ ++ // CHECK-SAME: arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.arg"> = #vhlo.string_v1<"0">}>]>, ++ // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, ++ // CHECK-SAME: res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.result"> = #vhlo.string_v1<"0">}>]>, ++ // CHECK-SAME: sym_name = #vhlo.string_v1<"op_func">, ++ // CHECK-SAME: sym_visibility = #vhlo.string_v1<"private"> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : () -> () ++ ++ func.return %arg0 : tensor ++} ++ ++// CHECK-LABEL: "op_gather" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { ++ // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, ++ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, ++ // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, ++ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [2], ++ collapsed_slice_dims = [0, 1], ++ start_index_map = [0, 1], ++ index_vector_dim = 2 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> ++ func.return %0 : tensor<1x5x1xf32> ++} ++ ++// CHECK-LABEL: "op_gather_with_batching_dims" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { ++ // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, ++ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<4xi64>>, ++ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [2], ++ collapsed_slice_dims = [1, 2], ++ operand_batching_dims = [0], ++ start_indices_batching_dims = [1], ++ start_index_map = [1, 2], ++ index_vector_dim = 2 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> ++ func.return %0 : tensor<1x5x1xf32> ++} ++ ++// CHECK-LABEL: "op_get_dimension_size" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_get_dimension_size(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.get_dimension_size_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.get_dimension_size"(%arg0) { ++ dimension = 0 : i64 ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_get_tuple_element" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_get_tuple_element(%arg0: tuple, tensor>) -> tensor { ++ // CHECK: "vhlo.get_tuple_element_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: index = #vhlo.integer_v1<0 : i32> ++ // CHECK-SAME: }> : (!vhlo.tuple_v1, !vhlo.tensor_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.get_tuple_element"(%arg0) { ++ index = 0 : i32 ++ } : (tuple, tensor>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_if" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { ++ // CHECK: "vhlo.if_v1"(%[[ARG0]]) ({ ++ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }, { ++ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG2]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.if"(%arg0) ({ ++ "stablehlo.return"(%arg1) : (tensor) -> () ++ }, { ++ "stablehlo.return"(%arg2) : (tensor) -> () ++ }) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_imag" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_imag(%arg0: tensor>) -> tensor { ++ // CHECK: "vhlo.imag_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.imag"(%arg0) : (tensor>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_infeed" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { ++ // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: infeed_config = #vhlo.string_v1<"foo">, ++ // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[#vhlo.array_v1<[]>]> ++ // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) ++ %0:2 = "stablehlo.infeed"(%arg0) { ++ infeed_config = "foo", ++ layout = [[]] ++ } : (!stablehlo.token) -> (tensor, !stablehlo.token) ++ func.return %0#0, %0#1 : tensor, !stablehlo.token ++} ++ ++// CHECK-LABEL: "op_iota" ++func.func @op_iota() -> tensor<16xf32> { ++ // CHECK: "vhlo.iota_v1"() <{ ++ // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : () -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.iota"() { ++ iota_dimension = 0 : i64 ++ } : () -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_is_finite" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_is_finite(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.is_finite_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.is_finite"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_log" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_log(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.log_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.log"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_log_plus_one" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_log_plus_one(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.log_plus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.log_plus_one"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_logistic" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_logistic(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.logistic_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.logistic"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_map" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { ++ // CHECK: "vhlo.map_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.abs_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.map"(%arg0) ({ ++ ^bb0(%arg1: tensor): ++ %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ dimensions = array ++ } : (tensor<16xf32>) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_maximum" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_maximum(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.maximum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_minimum" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_minimum(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.minimum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_multiply" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_multiply(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.multiply_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_negate" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_negate(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.negate_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.negate"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_not" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_not(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.not_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.not"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_optimization_barrier" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_optimization_barrier(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.optimization_barrier_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.optimization_barrier"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_or" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.or_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_outfeed" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { ++ // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: outfeed_config = #vhlo.string_v1<"foo"> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 ++ %0 = "stablehlo.outfeed"(%arg0, %arg1) { ++ outfeed_config = "foo" ++ } : (tensor, !stablehlo.token) -> !stablehlo.token ++ func.return %0 : !stablehlo.token ++} ++ ++// CHECK-LABEL: "op_pad" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { ++ // CHECK: "vhlo.pad_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: edge_padding_high = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: edge_padding_low = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: interior_padding = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.pad"(%arg0, %arg1) { ++ edge_padding_high = array, ++ edge_padding_low = array, ++ interior_padding = array ++ } : (tensor<8xf32>, tensor) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_popcnt" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_popcnt(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.popcnt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_power" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_power(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.power_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_real_dynamic_slice" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}) ++func.func @op_real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor { ++ // CHECK: "vhlo.real_dynamic_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_real" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_real(%arg0: tensor>) -> tensor { ++ // CHECK: "vhlo.real_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.real"(%arg0) : (tensor>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_recv" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { ++ // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: channel_type = #vhlo.integer_v1<3 : i64>, ++ // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 ++ // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) ++ %0:2 = "stablehlo.recv"(%arg0) { ++ channel_handle = #stablehlo.channel_handle, ++ is_host_transfer = true ++ } : (!stablehlo.token) -> (tensor, !stablehlo.token) ++ func.return %0#0, %0#1 : tensor, !stablehlo.token ++} ++ ++// CHECK-LABEL: "op_reduce" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.reduce_v1"(%[[ARG0]], %[[ARG1]]) ++ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () ++ // CHECK: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.reduce"(%arg0, %arg1) ({ ++ ^bb0(%arg2: tensor, %arg3: tensor): ++ %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ dimensions = array ++ } : (tensor<16xf32>, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_reduce_precision" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_reduce_precision(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.reduce_precision_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: exponent_bits = #vhlo.integer_v1<8 : i32> ++ // CHECK-SAME: mantissa_bits = #vhlo.integer_v1<10 : i32> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.reduce_precision"(%arg0) { ++ exponent_bits = 8 : i32, ++ mantissa_bits = 10 : i32 ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK_lABEL: "op_reduce_with_promotable_types" ++func.func @op_reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) ++ -> (tensor<4xf64>) { ++ // CHECK: "vhlo.reduce_v1"(%[[ARG0:.*]], %[[ARG1:.*]]) ++ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () ++ // CHECK: }) : (!vhlo.tensor_v1<4x4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f64_v1> ++ %0 = "stablehlo.reduce"(%arg0, %arg1) ({ ++ ^bb0(%arg2: tensor, %arg3: tensor ): ++ %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ ++ }) {dimensions = array} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> ++ ++ func.return %0: tensor<4xf64> ++} ++ ++// CHECK-LABEL: "op_reduce_scatter" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { ++ // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, ++ // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.reduce_scatter"(%arg0) ({ ++ ^bb0(%arg1: tensor, %arg2: tensor): ++ %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ scatter_dimension = 0 : i64, ++ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, ++ channel_handle = #stablehlo.channel_handle, ++ use_global_device_ids ++ } : (tensor<16xf32>) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK_lABEL: "op_reduce_scatter_with_promotable_types" ++func.func @op_reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { ++ // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0:.*]]) ++ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () ++ // CHECK: }) : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f64_v1> ++ %0 = "stablehlo.reduce_scatter"(%data) ({ ++ ^bb0(%arg2: tensor, %arg3: tensor): ++ %1 = stablehlo.add %arg2, %arg3 : tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, ++ scatter_dimension = 1 : i64, ++ channel_handle = #stablehlo.channel_handle, ++ use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> ++ func.return %0 : tensor<4x4xf64> ++} ++ ++ ++// CHECK-LABEL: "op_reduce_window" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x9x16x7xf32> { ++ // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, ++ // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, ++ // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, ++ // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, ++ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x9x16x7x!vhlo.f32_v1> ++ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ ++ ^bb0(%arg2: tensor, %arg3: tensor): ++ %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ window_dimensions = array, ++ window_strides = array, ++ base_dilations = array, ++ window_dilations = array, ++ padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> ++ } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x9x16x7xf32> ++ func.return %0 : tensor<2x9x16x7xf32> ++} ++ ++// CHECK-LABEL: "op_reduce_window_with_promotable_types" ++func.func @op_reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, ++ %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> ++ (tensor<2x2xf64>, tensor<2x2xf32>) { ++ // CHECK: "vhlo.reduce_window_v1"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) ++ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): ++ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]], %[[VAL2:.*]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> () ++ // CHECK: }) : (!vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1, !vhlo.tensor_v1) -> (!vhlo.tensor_v1<2x2x!vhlo.f64_v1>, !vhlo.tensor_v1<2x2x!vhlo.f32_v1>) ++ %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ++ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, ++ %b1: tensor): ++ %2 = stablehlo.add %a0, %b0 : tensor ++ %3 = stablehlo.add %a1, %b1 : tensor ++ "stablehlo.return"(%2,%3) : (tensor, tensor) -> () ++ }) ++ { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, ++ window_dimensions = array, ++ window_strides = array } ++ : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> ++ (tensor<2x2xf64>, tensor<2x2xf32>) ++ func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> ++} ++ ++// CHECK-LABEL: "op_remainder" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.remainder_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_replica_id" ++func.func @op_replica_id() -> tensor { ++ // CHECK: "vhlo.replica_id_v1"() : () -> !vhlo.tensor_v1 ++ %0 = "stablehlo.replica_id"() : () -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_partition_id" ++func.func @op_partition_id() -> tensor { ++ // CHECK: "vhlo.partition_id_v1"() : () -> !vhlo.tensor_v1 ++ %0 = "stablehlo.partition_id"() : () -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_reshape" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { ++ // CHECK: "vhlo.reshape_v1"(%[[ARG0]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f32_v1> ++ %0 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> ++ func.return %0 : tensor<4x4xf32> ++} ++ ++// CHECK-LABEL: "op_return" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ ++ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.case"(%arg0) ({ ++ "stablehlo.return"(%arg1) : (tensor) -> () ++ }) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_reverse" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { ++ // CHECK: "vhlo.reverse_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.reverse"(%arg0) { ++ dimensions = array ++ } : (tensor<16xf32>) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_rng_bit_generator" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { ++ // CHECK: "vhlo.rng_bit_generator_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: rng_algorithm = #vhlo ++ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> (!vhlo.tensor_v1, !vhlo.tensor_v1) ++ %0:2 = "stablehlo.rng_bit_generator"(%arg0) { ++ rng_algorithm = #stablehlo ++ } : (tensor) -> (tensor, tensor) ++ func.return %0#0, %0#1 : tensor, tensor ++} ++ ++// CHECK-LABEL: "op_rng" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { ++ // CHECK: "vhlo.rng_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: rng_distribution = #vhlo ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { ++ rng_distribution = #stablehlo ++ } : (tensor, tensor, tensor<0xindex>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_round_nearest_afz" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_round_nearest_afz(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.round_nearest_afz_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.round_nearest_afz"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_round_nearest_even" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_round_nearest_even(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.round_nearest_even_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.round_nearest_even"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_rsqrt" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_rsqrt(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.rsqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.rsqrt"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_scatter" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { ++ // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, ++ // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, ++ // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, ++ // CHECK-SAME: unique_indices = #vhlo.bool_v1, ++ // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> ++ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [1], ++ inserted_window_dims = [0, 1], ++ scatter_dims_to_operand_dims = [0, 1], ++ index_vector_dim = 1 ++ >, ++ indices_are_sorted = true, ++ unique_indices = true ++ } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> ++ func.return %0 : tensor<200x100x300xf32> ++} ++ ++// CHECK-LABEL: "op_scatter_with_batching_dims" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_scatter_with_batching_dims(%arg0: tensor<10x200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<10x200x100x300xf32> { ++ // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, ++ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, ++ // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, ++ // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: unique_indices = #vhlo.bool_v1, ++ // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1> ++ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [1], ++ inserted_window_dims = [1, 2], ++ input_batching_dims = [0], ++ scatter_dims_to_operand_dims = [1, 2], ++ scatter_indices_batching_dims = [0], ++ index_vector_dim = 1 ++ >, ++ indices_are_sorted = true, ++ unique_indices = true ++ } : (tensor<10x200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<10x200x100x300xf32> ++ func.return %0 : tensor<10x200x100x300xf32> ++} ++ ++// CHECK_lABEL: "op_scatter_with_promotable_types" ++func.func @op_scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, ++ %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> ++ tensor<200x100x300xf64> { ++ // CHECK: "vhlo.scatter_v2"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]]) ++ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () ++ // CHECK: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f64_v1> ++ %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ++ ^bb0(%lhs: tensor, %rhs: tensor): ++ %add = stablehlo.add %lhs, %rhs : tensor ++ "stablehlo.return"(%add) : (tensor) -> () ++ }) { ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [1], ++ inserted_window_dims = [0, 1], ++ scatter_dims_to_operand_dims = [0, 1], ++ index_vector_dim = 1 ++ >, ++ indices_are_sorted = true, ++ unique_indices = true ++ } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> ++ tensor<200x100x300xf64> ++ func.return %0 : tensor<200x100x300xf64> ++} ++ ++// CHECK-LABEL: "op_select_and_scatter" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { ++ // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ ++ // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, ++ // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, ++ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }, { ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> ++ %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }, { ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ window_dimensions = array, ++ window_strides = array, ++ padding = dense<1> : tensor<4x2xi64> ++ } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf32> ++ func.return %0 : tensor<10x24x24x64xf32> ++} ++ ++// CHECK-LABEL: "op_select_and_scatter_with_promotable_types" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_select_and_scatter_with_promotable_types(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf64> { ++ // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) ++ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK: %[[VAL:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK: "vhlo.return_v1"(%[[VAL]]) : (!vhlo.tensor_v1) -> () ++ // CHECK: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f64_v1> ++ %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }, { ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ window_dimensions = array, ++ window_strides = array, ++ padding = dense<1> : tensor<4x2xi64> ++ } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf64> ++ func.return %0 : tensor<10x24x24x64xf64> ++} ++ ++// CHECK-LABEL: "op_select" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) ++func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { ++ // CHECK: "vhlo.select_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_send" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { ++ // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, ++ // CHECK-SAME: channel_type = #vhlo.integer_v1<2 : i64>, ++ // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 ++ %0 = "stablehlo.send"(%arg0, %arg1) { ++ channel_handle = #stablehlo.channel_handle, ++ is_host_transfer = true ++ } : (tensor, !stablehlo.token) -> !stablehlo.token ++ func.return %0 : !stablehlo.token ++} ++ ++// CHECK-LABEL: "op_set_dimension_size" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { ++ // CHECK: "vhlo.set_dimension_size_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { ++ dimension = 0 : i64 ++ } : (tensor, tensor) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_shift_left" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_shift_left(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.shift_left_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_shift_right_arithmetic" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_shift_right_arithmetic(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.shift_right_arithmetic_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_shift_right_logical" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_shift_right_logical(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.shift_right_logical_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_sign" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_sign(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.sign_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.sign"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_sine" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_sine(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.sine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.sine"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_slice" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { ++ // CHECK: "vhlo.slice_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: limit_indices = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: start_indices = #vhlo.tensor_v1 : tensor<1xi64>>, ++ // CHECK-SAME: strides = #vhlo.tensor_v1 : tensor<1xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> ++ %0 = "stablehlo.slice"(%arg0) { ++ start_indices = array, ++ limit_indices = array, ++ strides = array ++ } : (tensor<16xf32>) -> tensor<4xf32> ++ func.return %0 : tensor<4xf32> ++} ++ ++// CHECK-LABEL: "op_sort" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { ++ // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: is_stable = #vhlo.bool_v1 ++ // CHECK-SAME: }> ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> ++ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> ++ %0 = "stablehlo.sort"(%arg0) ({ ++ ^bb0(%arg1: tensor, %arg2: tensor): ++ %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor ++ "stablehlo.return"(%1) : (tensor) -> () ++ }) { ++ dimension = 0 : i64, ++ is_stable = true ++ } : (tensor<16xf32>) -> tensor<16xf32> ++ func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "op_sqrt" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_sqrt(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.sqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.sqrt"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_subtract" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_subtract(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.subtract_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_tan" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_tan(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.tan_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.tan"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_tanh" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_tanh(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.tanh_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.tanh"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_torch_index_select" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { ++ // CHECK: "vhlo.torch_index_select_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: batch_dims = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: dim = #vhlo.integer_v1<0 : i64> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<5x1x5x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<2x1x5x!vhlo.f32_v1> ++ %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { ++ dim = 0 : i64, ++ batch_dims = 0 : i64 ++ } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> ++ func.return %0 : tensor<2x1x5xf32> ++} ++ ++// CHECK-LABEL: "op_transpose" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { ++ // CHECK: "vhlo.transpose_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: permutation = #vhlo.tensor_v1 : tensor<2xi64>> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x16x!vhlo.f32_v1> ++ %0 = "stablehlo.transpose"(%arg0) { ++ permutation = array ++ } : (tensor<16x8xf32>) -> tensor<8x16xf32> ++ func.return %0 : tensor<8x16xf32> ++} ++ ++// CHECK-LABEL: "op_triangular_solve" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { ++ // CHECK: "vhlo.triangular_solve_v1"(%[[ARG0]], %[[ARG1]]) <{ ++ // CHECK-SAME: left_side = #vhlo.bool_v1, ++ // CHECK-SAME: lower = #vhlo.bool_v1, ++ // CHECK-SAME: transpose_a = #vhlo, ++ // CHECK-SAME: unit_diagonal = #vhlo.bool_v1 ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> ++ %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { ++ left_side = true, ++ lower = true, ++ unit_diagonal = true, ++ transpose_a = #stablehlo ++ } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> ++ func.return %0 : tensor<16x16xf32> ++} ++ ++// CHECK-LABEL: "op_tuple" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_tuple(%arg0: tensor) -> tuple> { ++ // CHECK: "vhlo.tuple_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tuple_v1> ++ %0 = "stablehlo.tuple"(%arg0) : (tensor) -> tuple> ++ func.return %0 : tuple> ++} ++ ++// CHECK-LABEL: "op_unary_einsum" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { ++ // CHECK: "vhlo.unary_einsum_v1"(%[[ARG0]]) <{ ++ // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab->a"> ++ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x!vhlo.f32_v1> ++ %0 = "stablehlo.unary_einsum"(%arg0) { ++ einsum_config = "ab->a" ++ } : (tensor<8x16xf32>) -> tensor<8xf32> ++ func.return %0 : tensor<8xf32> ++} ++ ++// CHECK-LABEL: "op_uniform_dequantize" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { ++ // CHECK: "vhlo.uniform_dequantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "op_uniform_quantize" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_uniform_quantize(%arg0: tensor) -> tensor> { ++ // CHECK: "vhlo.uniform_quantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1> ++ %0 = "stablehlo.uniform_quantize"(%arg0) : (tensor) -> tensor> ++ func.return %0 : tensor> ++} ++ ++// CHECK-LABEL: "op_while" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @op_while(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.while_v1"(%[[ARG0]]) ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): ++ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }, { ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1) ++ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () ++ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.while"(%arg0) ({ ++ ^bb0(%arg1: tensor): ++ "stablehlo.return"(%arg1) : (tensor) -> () ++ }, { ++ ^bb0(%arg1: tensor): ++ "stablehlo.return"(%arg1) : (tensor) -> () ++ }) : (tensor) -> tensor ++ func.return %0: tensor ++} ++ ++// CHECK-LABEL: "op_xor" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.xor_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// ============ TYPES ============ ++ ++// CHECK-LABEL: "type_i1" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_i1(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_i2" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_i2(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_i4" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_i4(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_i8" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_i8(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_i16" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_i16(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_i32" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_i32(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_i64" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_i64(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_ui2" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_ui2(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_ui4" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_ui4(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_ui8" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_ui8(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_ui16" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_ui16(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_ui32" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_ui32(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_ui64" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f4E2M1FN" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f4E2M1FN(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f6E2M3FN" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f6E2M3FN(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f6E3M2FN" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f6E3M2FN(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f8E3M4" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f8E4M3" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f8E4M3FN" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f8E5M2" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f8E4M3FNUZ" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f8E4M3FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f8E4M3B11FNUZ" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f8E4M3B11FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f8E5M2FNUZ" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f8E8M0FNU" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f8E8M0FNU(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_bf16" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f16" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f16(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f32" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f32(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_f64" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_f64(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_complex_f32" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_complex_f32(%arg0: tensor>, %arg1: tensor>) -> tensor> { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> ++ func.return %0 : tensor> ++} ++ ++// CHECK-LABEL: "type_complex_f64" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_complex_f64(%arg0: tensor>, %arg1: tensor>) -> tensor> { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> ++ func.return %0 : tensor> ++} + -+func.func @ragged_dot_mode_1_batching() { -+ %lhs = stablehlo.constant dense<[ -+ [ -+ [ -0.0999976546, -0.0605386607, 0.126681596, 0.0375950411, 0.0598301813 ], -+ [ -0.0343122408, -0.0858866125, 0.103659429, 0.103788935, 0.180407882 ], -+ [ 0.0150506198, 0.055824928, 0.149289608, -0.0896283686, -0.0839615092 ], -+ [ 0.0589100644, 0.101344816, -0.097690545, 0.0150246918, -0.0799473301 ], -+ [ 0.0252457932, 0.106031813, 0.076692991, 0.179130971, 0.153850079 ], -+ [ 0.0580786392, -0.0724105313, 0.0961757079, 0.0247998089, 0.110357188 ], -+ [ 0.173096269, 0.128659427, -0.0212640986, -0.0857606456, 0.120824583 ], -+ [ -0.00152973086, 0.0897915736, 0.126923144, 0.197311223, 0.00960160792 ], -+ [ -0.0258883312, 0.194765091, 0.11679814, 0.126006752, 0.0954555795 ], -+ [ -0.0781942382, 0.0894904211, 0.165412158, -0.0181870088, 0.0309234336 ], -+ [ 0.129948437, 0.0433195308, -0.028667666, -0.0175279453, 0.00777949393 ] -+ ], -+ [ -+ [ -0.0500478409, 0.0459552184, 0.16929689, 0.172762454, -0.0818307 ], -+ [ 0.171395928, 0.0513568744, 0.0548876, -0.00429011881, 0.195992649 ], -+ [ 0.0481930152, -0.0201566443, -0.0727801323, 0.184329301, -0.0778752789 ], -+ [ 0.0502121374, 0.0152426511, -0.0168754607, 0.174145252, 0.0589242205 ], -+ [ 0.0393337533, 0.182294011, -0.0849748, 0.128454268, 0.131061375 ], -+ [ 0.148345202, -0.0623903871, -0.0952396914, 0.10653659, 0.160474151 ], -+ [ 0.0888630375, 0.120867364, 0.117623605, 0.199837387, 0.166571677 ], -+ [ -0.0300415382, -0.00810345262, 0.00530457497, 0.0539821163, 0.0773340687 ], -+ [ 0.153794467, 0.0236242339, 0.152453214, -0.0192048177, 0.0246183872 ], -+ [ 0.0611911938, 0.0403752252, -0.013836287, -0.0465016849, -0.053884007 ], -+ [ 0.0714964494, 0.140721709, -0.0900838748, 0.0603349432, 0.0495440438 ] -+ ]]> : tensor<2x11x5xf32> -+ %rhs = stablehlo.constant dense<[ -+ [ -+ [ -+ [ 0.186608255, 0.124487795, 0.0663751587, 0.167221248, 0.0874548, 0.152611881, -0.0520697422 ], -+ [ -0.0361745432, 0.114412986, -0.0608718246, -0.0727029, -0.0176235586, -0.0991001204, 0.0242879838 ], -+ [ -0.0919371173, 0.112945892, 0.181369215, -0.0280267522, -0.0457312278, -0.00473813713, 0.166097224 ], -+ [ 0.0956176, -0.0548994839, 0.104403876, 0.0157444105, 0.0163175985, 0.0499223098, -0.0557401 ], -+ [ 0.076156, 0.153672695, 0.0770325884, 0.186622649, 0.066843845, -0.0555545315, 0.194991559 ] -+ ], -+ [ -+ [ 0.0226300061, -0.0574540682, 0.0694696084, -0.0243620798, 0.0465543643, 0.0392091647, 0.188328564 ], -+ [ -0.0621907599, -0.0400728397, -0.0042250976, 0.0887807682, -0.0619863532, 0.0953761414, 0.0864902064 ], -+ [ 0.140921891, -0.0256474689, 0.0429295525, 0.0167942569, -0.0390249, -0.0914874449, 0.170502067 ], -+ [ 0.0279492214, -0.0573936924, 0.184246033, 0.0230939165, -0.060643442, 0.165694535, -0.0723479092 ], -+ [ -0.051340431, -0.0786809325, 0.00960171223, -0.0240827873, -0.059467189, 0.134945959, 0.0365921929 ] -+ ] -+ ], -+ [ -+ [ -+ [ 0.00485724211, 0.0356900468, 0.142683387, 0.179502338, 0.0954938307, -0.0354254842, 0.103877716 ], -+ [ 0.172676593, -0.0249623209, 0.158257961, 0.0413787, 0.0517867729, 0.0801181123, 0.14526847 ], -+ [ 0.126753062, 0.0386734977, 0.185410261, 0.0898216143, 0.0317991, 0.14740923, 0.106694289 ], -+ [ 0.110662006, 0.196143657, 0.186324477, 0.155380905, -0.0132051334, 0.0612277314, 0.054330416 ], -+ [ -0.0689698234, 0.0242085531, 0.073015, 0.162969738, 0.0320116058, 0.118924297, 0.160779119 ] -+ ], -+ [ -+ [ 0.11469271, 0.140216112, 0.111960642, 0.122514777, -0.0942722782, 0.165809333, 0.0574962273 ], -+ [ 0.0389968231, -0.08044184, 0.114026703, 0.0466829464, 0.100303732, 0.104614742, -0.0401335768 ], -+ [ 0.174990177, 0.159764826, 0.167005628, 0.0631844923, -0.0582415, 0.0351042375, 0.196808755 ], -+ [ -0.035340406, 0.0338070318, -0.00528027117, 0.0543978438, 0.164451241, 0.0319176689, 0.0402595326 ], -+ [ 0.141994983, 0.00954742, -0.0365443081, 0.199735016, -0.053918656, 0.0891464874, 0.0849051103 ] -+ ] -+ ], -+ [ -+ [ -+ [ -0.0998214856, -0.0997363, 0.132005602, 0.118200503, -0.00424671918, 0.025317125, 0.104748271 ], -+ [ 0.104168601, -0.0384214334, 0.150926, 0.112676181, 0.14861238, -0.071635358, -0.0754787177 ], -+ [ 0.129201442, 0.088871561, -0.0358443409, -0.0359359607, -0.0756817609, 0.0166469738, 0.185647905 ], -+ [ 0.184263527, 0.0169560835, -0.0192355737, 0.10765069, -0.0147894919, 0.13305977, 0.135159582 ], -+ [ 0.0267379507, -0.0153532401, -0.0418097563, -0.096605137, -0.0424528457, 0.194970757, -0.0267837271 ] -+ ], -+ [ -+ [ 0.145917833, -0.0590635166, 0.0194431096, 0.0803030357, -0.0469358861, 0.148506433, -0.0526806451 ], -+ [ 0.196381122, -0.0228494033, -0.0299202427, -0.069508791, -0.0341768041, 0.0904152468, 0.108802207 ], -+ [ 0.138430953, 0.108872853, 0.125882119, 0.100856192, 0.0900289789, -0.0830678046, 0.0794649944 ], -+ [ -0.0318976864, -0.00436662883, 0.109950341, -0.0647689179, 0.128771216, 0.0578369871, 0.0661734 ], -+ [ 0.0763966814, -0.00110008568, 0.110896833, -0.057086423, -0.0514936894, 0.0455975607, 0.158067733 ] -+ ] -+ ]]> : tensor<3x2x5x7xf32> -+ %group_sizes = stablehlo.constant dense<[4, 4, 3]> : tensor<3xi64> -+ %result = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { -+ ragged_dot_dimension_numbers = #chlo.ragged_dot< -+ lhs_batching_dimensions = [0], -+ rhs_batching_dimensions = [1], -+ lhs_contracting_dimensions = [2], -+ rhs_contracting_dimensions = [2], -+ lhs_ragged_dimensions = [1], -+ rhs_group_dimensions = [0] -+ >, -+ precision_config = [#chlo, #chlo] -+ } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> -+ check.expect_almost_eq_const %result, dense<[ -+ [ -+ [-0.0199659951, 0.00206358638, 0.0285578221, -0.00411329232, -0.00885893404, -0.0113086831, 0.0343487822], -+ [0.0108370036, 0.0196357146, 0.0464844741, 0.032903526, 0.00752512738, -0.00205732603, 0.0463109687], -+ [-0.0279003512, 0.0171403233, 0.00885203853, -0.022806216, -0.0135696121, -0.00375272054, 0.0139928926], -+ [0.0116565451, -0.00521556707, -0.0245668497, -0.00946252606, 2.734600e-03, 0.00460146647, -0.0332586318], -+ [0.0373648889, 0.040080104, 0.0792120546, 0.0687142611, 0.0129001699, 0.048170276, 6.067640e-02], -+ [-0.00489785476, 0.0151357278, 0.0273378156, 0.0379059538, 0.0080597708, 0.0209609158, 0.0248660222], -+ [0.00253825542, -1.175260e-02, 0.0339594558, 0.0408501513, 0.0275165718, 0.0101594552, 0.0491689071], -+ [5.275800e-02, 0.0415463448, 0.0749897882, 0.0470644757, 0.00624182029, 0.0391805507, 0.03869069], -+ [0.0637338459, 0.00614991458, 0.0153763723, 0.0190313365, 0.0142990183, 0.0227143262, 0.0187453162], -+ [0.0359746702, 0.0182777364, -0.00368779944, -0.0100486111, 6.89582666E-5, -0.00202751439, 0.0124766938], -+ [-0.0151847685, -0.0175893605, 0.0247314386, 0.018632818, 0.00798455066, -0.00110600982, 0.00244264561] -+ ], -+ [ -+ [0.0288968664, -0.00678509939, 0.0346419513, 0.0141028976, -0.017396003, 0.00451522879, 0.00792134088], -+ [-0.0017626211, -0.0284877941, 0.0151375476, -0.00351338694, -0.00874114502, 0.0323345512, 0.0535612516], -+ [0.00123786228, -0.00454656407, 0.0335229039, 0.0019464466, -2.14070082E-4, 0.0266590156, -0.0212618597], -+ [-3.47743975E-4, -0.017693948, 0.0353507064, 0.00244920771, -0.0120135043, 0.0417729542, -0.0025454592], -+ [0.0108208582, -0.0171308704, 0.00553112756, 0.0411250815, 0.0335835591, 0.038393192, -0.00547906291], -+ [0.0169365555, 0.0157370344, -0.0128378682, 0.0470919088, -0.00582840201, 0.0324328542, 0.010203423], -+ [0.0520783663, 0.0298755895, 0.0362326317, 0.0681023895, 0.0207777359, 0.052735541, 0.0455959477], -+ [0.00623999349, -1.49650674E-4, -0.00651274621, 0.0146591738, 0.00641800836, 0.00297434814, 0.00838128477], -+ [0.0506783053, 0.00703135319, 0.0220930576, 0.0259224195, 0.001958607, 0.0123232938, 0.00920359604], -+ [0.0123091843, -5.780780e-03, -0.0128484722, 0.00679983944, -0.00871101767, 0.0087406747, -0.0115246754], -+ [0.0274577513, -0.0175638888, -0.00203213934, -0.0198616516, -0.0110571291, 0.0365728177, 0.0162097216] -+ ] -+ ]> : tensor<2x11x7xf32> -+ func.return ++// CHECK-LABEL: "type_tf32" ++// CHECK: #vhlo.type_v1 ++func.func @type_tf32() attributes {stablehlo.attr = tf32 } { ++ return +} -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir -@@ -1940,6 +1940,17 @@ - return %1 : tensor<12xi64> ++ ++// CHECK-LABEL: "type_none" ++// CHECK: #vhlo.type_v1 ++func.func @type_none() attributes {stablehlo.attr = none } { ++ return ++} ++ ++// CHECK-LABEL: "type_dynamism_ranked" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { ++ // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// CHECK-LABEL: "type_per_tensor_quantization" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) ++func.func @type_per_tensor_quantization(%arg0: tensor>, %arg1: tensor>) -> tensor> { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> ++ func.return %0 : tensor> ++} ++ ++// CHECK-LABEL: "type_per_axis_quantization" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @type_per_axis_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { ++ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG0]]) : (!vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>, !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>) -> !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1> ++ %0 = stablehlo.add %arg0, %arg0 : tensor<2x!quant.uniform> ++ func.return %0 : tensor<2x!quant.uniform> ++} ++ ++// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> ++// CHECK-LABEL: "type_token_callee" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { ++ // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> () ++ return %arg0 : !stablehlo.token ++} ++ ++// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> ++// CHECK-LABEL: "type_token_caller" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { ++ // CHECK: "vhlo.call_v1"(%[[ARG0]]) <{callee = #vhlo.string_v1<"type_token_callee">} ++ // CHECK-SAME: (!vhlo.token_v1) -> !vhlo.token_v1 ++ %0 = func.call @type_token_callee(%arg0) : (!stablehlo.token) -> !stablehlo.token ++ return %0 : !stablehlo.token ++} ++ ++// CHECK-LABEL: "type_tuple" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) ++func.func @type_tuple(%arg0: tuple>) -> tuple { ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "foo" ++ // CHECK: (!vhlo.tuple_v1>) -> !vhlo.tuple_v1 ++ } : (tuple>) -> tuple ++ return %0 : tuple ++} ++ ++// ============ DEPENDENCIES ============ ++ ++func.func @composite_target(%arg0: tensor) -> tensor { ++ return %arg0: tensor ++} +diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir +--- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir ++++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir +@@ -248,6 +248,36 @@ + fft_length = array + } : (tensor<9xcomplex>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> ++} ++ ++// CHECK-LABEL: "attr_result_accuracy_HIGHEST" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} ++func.func @attr_result_accuracy_HIGHEST(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { ++ %0 = "stablehlo.exponential"(%arg0) { ++ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> ++ result_accuracy = #stablehlo.result_accuracy> ++ } : (tensor<8x16xf32>) -> tensor<8x16xf32> ++ func.return %0 : tensor<8x16xf32> ++} ++ ++// CHECK-LABEL: "attr_result_accuracy_TOLERANCE" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} ++func.func @attr_result_accuracy_TOLERANCE(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { ++ %0 = "stablehlo.exponential"(%arg0) { ++ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> ++ result_accuracy = #stablehlo.result_accuracy> ++ } : (tensor<8x16xf32>) -> tensor<8x16xf32> ++ func.return %0 : tensor<8x16xf32> ++} ++ ++// CHECK-LABEL: "attr_result_accuracy_DEFAULT" ++// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} ++func.func @attr_result_accuracy_DEFAULT(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { ++ %0 = "stablehlo.exponential"(%arg0) { ++ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> ++ result_accuracy = #stablehlo.result_accuracy> ++ } : (tensor<8x16xf32>) -> tensor<8x16xf32> ++ func.return %0 : tensor<8x16xf32> } + // GatherDimensionNumbers aka #stablehlo.gather is covered below. +@@ -1621,7 +1651,7 @@ + // CHECK-LABEL: "op_exponential" + // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) + func.func @op_exponential(%arg0: tensor) -> tensor { +- // CHECK: "vhlo.exponential_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 ++ // CHECK: "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = #vhlo.result_accuracy_v1>}> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor + } +diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir +--- stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir ++++ stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir +@@ -0,0 +1,26 @@ ++// RUN: stablehlo-opt --vhlo-to-version=target=1.9.0 -verify-diagnostics --split-input-file %s ++ ++func.func @invalid_array_element() -> () attributes { ++ // expected-error @+1 {{expected array of VHLO attriutes}} ++ vhlo.attr = #vhlo.array_v1<[#stablehlo]> ++} { ++ return ++} ++ +// ----- + -+// CHECK-LABEL: @reorder_invalid_with_dynamic_shape -+func.func @reorder_invalid_with_dynamic_shape(%arg0: tensor<1x3x4xf32>) -> (tensor) { -+ // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> -+ // CHECK-NEXT: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<3x4xf32>) -> tensor -+ // CHECK: return %[[CONVERT]] -+ %0 = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> -+ %1 = stablehlo.convert %0 : (tensor<3x4xf32>) -> tensor -+ return %1 : tensor ++func.func @invalid_dict_element_value() -> () attributes { ++ // expected-error @+1 {{expected VHLO attribute}} ++ vhlo.attr = #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = 3 : i32}> ++} { ++ return +} - - // ----- - -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir -@@ -36,7 +36,7 @@ - - // ----- - --// expected-error @+1 {{number of refinements must match number of function operands 6 vs 1}} -+// expected-error @+1 {{number of refinements must match number of op operands 6 vs 1}} - func.func @refine_arguments_invalid_arg_num_mismatch(%arg0: tensor) { - return - } ++ ++// ----- ++ ++func.func @invalid_result_accuracy() -> () attributes { ++ // expected-error @+1 {{expected VHLO result accuracy mode}} ++ vhlo.attr = #vhlo.result_accuracy_v1> ++} { ++ return ++} +diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir +--- stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir ++++ stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir +@@ -0,0 +1,24 @@ ++// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.8.0' %s | FileCheck %s ++ ++// ExpOp was changed in v1.9.0 to have ++// result_accuracy attribute. Ensure that serializing for 1.8.0 is valid and targets the ++// v1.8.0 opset. ++// ++// This will catch issues in op `isLegal` checks: ++// op.minVersion() <= target <= op.maxVersion() ++ ++// CHECK-LABEL: vhlo.func_v1 @exp_op ++func.func public @exp_op(%arg0: tensor) -> tensor { ++ // CHECK: vhlo.exponential_v1 ++ %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor ++ return %0 : tensor ++} ++ ++// CHECK-LABEL: vhlo.func_v1 @exp_op_default ++func.func @exp_op_default(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.exponential"(%arg0) { ++ // CHECK: vhlo.exponential_v1 ++ result_accuracy = #stablehlo.result_accuracy> ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} +diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir +--- stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir ++++ stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir +@@ -0,0 +1,22 @@ ++// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.8.0' --verify-diagnostics --split-input-file %s ++ ++ ++func.func @attr_result_accuracy_default(%arg0: tensor) -> tensor { ++ %0 = "stablehlo.exponential"(%arg0) { ++ // CHECK: vhlo.exponential_v1 ++ result_accuracy = #stablehlo.result_accuracy> ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ ++// expected-error @-3 {{failed to convert VHLO to v1.8.0}} ++func.func @attr_result_accuracy_highest(%arg0: tensor) -> tensor { ++ // expected-error @+1 {{failed to legalize operation 'vhlo.exponential_v2' that was explicitly marked illegal}} ++ %0 = "stablehlo.exponential"(%arg0) { ++ result_accuracy = #stablehlo.result_accuracy> ++ } : (tensor) -> tensor ++ func.return %0 : tensor ++} ++ diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp --- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp @@ -701,6 +4822,18 @@ diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stable } } // namespace +diff --ruN a/stablehlo/stablehlo/transforms/MapStablehloToVhlo.h b/stablehlo/stablehlo/transforms/MapStablehloToVhlo.h +--- stablehlo/stablehlo/transforms/MapStablehloToVhlo.h ++++ stablehlo/stablehlo/transforms/MapStablehloToVhlo.h +@@ -94,7 +94,7 @@ + MAP_STABLEHLO_TO_VHLO(DynamicUpdateSliceOp, V1) + MAP_STABLEHLO_TO_VHLO(EinsumOp, V1) + MAP_STABLEHLO_TO_VHLO(Expm1Op, V1) +-MAP_STABLEHLO_TO_VHLO(ExpOp, V1) ++MAP_STABLEHLO_TO_VHLO(ExpOp, V2) + MAP_STABLEHLO_TO_VHLO(FftOp, V1) + MAP_STABLEHLO_TO_VHLO(FloorOp, V1) + MAP_STABLEHLO_TO_VHLO(GatherOp, V2) diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h --- stablehlo/stablehlo/transforms/Passes.h +++ stablehlo/stablehlo/transforms/Passes.h @@ -804,6 +4937,84 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cp + } // namespace stablehlo } // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td +--- stablehlo/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td ++++ stablehlo/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td +@@ -683,12 +683,15 @@ + // Notice that for `y != 0`, neither `cos(y)` nor `sin(y)` is never + // zero on the set of floating point numbers. + // +-def ExpOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_ExpOp ComplexElementType:$z), ++def ConstDefaultResultAccuracyAttr : ++ ConstantAttr; ++ ++def ExpOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_ExpOp ComplexElementType:$z, ConstDefaultResultAccuracyAttr), + (StableHLO_ComplexOp + (StableHLO_SelectOp + (StableHLO_CompareOp:$eq_e_constant_posinf + (StableHLO_ExpOp:$e +- (StableHLO_RealOp:$x $z)), ++ (StableHLO_RealOp:$x $z), ConstDefaultResultAccuracyAttr), + (StableHLO_ConstantLikePosInfValue $x), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), +@@ -697,7 +700,7 @@ + (StableHLO_ExpOp:$e2 + (StableHLO_MulOp + $x, +- (StableHLO_ConstantLike<"0.5"> $x))), ++ (StableHLO_ConstantLike<"0.5"> $x)), ConstDefaultResultAccuracyAttr), + (StableHLO_CosineOp:$cs + (StableHLO_ImagOp:$y $z))), + $e2), +diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +--- stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp ++++ stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +@@ -19,6 +19,7 @@ + + #include "llvm/Support/Casting.h" + #include "llvm/Support/Debug.h" ++#include "llvm/Support/ErrorHandling.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/IR/Attributes.h" + #include "mlir/IR/Builders.h" +@@ -129,6 +130,16 @@ + } + if (auto attr = dyn_cast(stablehloAttr)) { + RETURN_CONVERTED_ENUM_ATTR(Transpose, V1); ++ } ++ if (auto attr = dyn_cast(stablehloAttr)) { ++ RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode, V1); ++ } ++ if (auto attr = dyn_cast(stablehloAttr)) { ++ auto modeAttr = convertGeneric(attr.getMode(), typeConverter); ++ if (!modeAttr) return {}; ++ return vhlo::ResultAccuracyV1Attr::get(attr.getContext(), attr.getAtol(), ++ attr.getRtol(), attr.getUlps(), ++ modeAttr); + } + if (stablehloAttr.getDialect().getNamespace() == + stablehlo::StablehloDialect::getDialectNamespace()) { +@@ -815,6 +826,19 @@ + } + } + } ++ if constexpr (std::is_same::value) { ++ if (!stablehloOp.getResultAccuracyAttr()) ++ addDefaultAttr("result_accuracy", ++ stablehlo::ResultAccuracyAttr::get( ++ pattern.getContext(), ++ /*atol=*/APFloat(0.0), ++ /*rtol=*/APFloat(0.0), ++ /*ulps=*/0, ++ /*mode=*/ ++ stablehlo::ResultAccuracyModeAttr::get( ++ pattern.getContext(), ++ stablehlo::ResultAccuracyMode::DEFAULT))); ++ } + if constexpr (std::is_same::value) { + if (!stablehloOp.getKnownExpandingDimensionsAttr()) diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp b/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp --- stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp +++ stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp @@ -1021,4 +5232,180 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/ // Gets a FuncOp that --stablehlo-refine-shapes will run on. // Returns a nullptr and emits appropriate errors if such a function cannot +diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +--- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ++++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +@@ -23,6 +23,7 @@ + #include "llvm/Support/AllocatorBase.h" + #include "llvm/Support/Casting.h" + #include "llvm/Support/Debug.h" ++#include "llvm/Support/ErrorHandling.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/IR/Attributes.h" + #include "mlir/IR/BuiltinAttributes.h" +@@ -168,6 +169,17 @@ + auto builtinType = typeConverter->convertType(attr.getValue()); + if (!builtinType) return {}; + return TypeAttr::get(builtinType); ++ } ++ if (auto attr = dyn_cast(vhloAttr)) { ++ RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode, V1); ++ } ++ if (auto attr = dyn_cast(vhloAttr)) { ++ auto modeAttr = dyn_cast_or_null( ++ convertGeneric(attr.getMode(), typeConverter)); ++ if (!modeAttr) return {}; ++ return stablehlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(), ++ attr.getRtol(), attr.getUlps(), ++ modeAttr); + } + + // All VHLO Attributes must be converted by now. +@@ -737,6 +749,13 @@ + }); + } + ++bool isDefaultResultAccuracyAttribute(Attribute vhloAttr) { ++ auto attr = dyn_cast_or_null(vhloAttr); ++ return attr.getAtol().isZero() && attr.getRtol().isZero() && ++ attr.getUlps() == 0 && ++ dyn_cast(attr.getMode()).getValue() == ++ vhlo::ResultAccuracyModeV1::DEFAULT; ++} + template + bool isSplatTensor(const ConversionPattern& pattern, Attribute vhloAttr, + T splatValue) { +@@ -897,6 +916,11 @@ + eraseAttrs(vhloAttrs, "dimension"); + if (isBoolean(vhloOp.getIsStableAttr(), false)) + eraseAttrs(vhloAttrs, "is_stable"); ++ } ++ if constexpr (std::is_same::value) { ++ if (isDefaultResultAccuracyAttribute(vhloOp.getResultAccuracyAttr())) { ++ eraseAttrs(vhloAttrs, "result_accuracy"); ++ } + } + return success(); + } +diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp +--- stablehlo/stablehlo/transforms/VhloToVersion.cpp ++++ stablehlo/stablehlo/transforms/VhloToVersion.cpp +@@ -139,6 +139,8 @@ + return isLegalType(tensorAttr.getType(), targetVersion); + if (auto typeAttr = dyn_cast(attr)) + return isLegalType(typeAttr.getValue(), targetVersion); ++ if (auto resultAccuracyAttr = dyn_cast(attr)) ++ return isLegalAttribute(resultAccuracyAttr.getMode(), targetVersion); + + // Is VHLO and valid version, success. + return success(); +@@ -324,6 +326,22 @@ + denseElements.getRawData()); + } + ++bool isDefaultResultAccuracy(Attribute attr) { ++ auto resultAccuracy = dyn_cast(attr); ++ auto default_mode = ResultAccuracyModeV1Attr::get( ++ attr.getContext(), ResultAccuracyModeV1::DEFAULT); ++ return resultAccuracy.getAtol().isZero() && ++ resultAccuracy.getRtol().isZero() && resultAccuracy.getUlps() == 0 && ++ resultAccuracy.getMode() == default_mode; ++} ++ ++ResultAccuracyV1Attr getDefaultResultAccuracy(OpBuilder& builder) { ++ return ResultAccuracyV1Attr::get( ++ builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ++ ResultAccuracyModeV1Attr::get(builder.getContext(), ++ ResultAccuracyModeV1::DEFAULT)); ++} ++ + // DRR has limited support for ops with regions + struct ScatterOpV2ToV1 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; +@@ -393,6 +411,40 @@ + } + }; + ++struct ExpOpV1ToV2 : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ ++ LogicalResult matchAndRewrite(ExpOpV1 op, ++ PatternRewriter& rewriter) const override { ++ ResultAccuracyV1Attr defaultResultAccuracy = ResultAccuracyV1Attr::get( ++ rewriter.getContext(), APFloat(0.0), APFloat(0.0), 0, ++ ResultAccuracyModeV1Attr::get(rewriter.getContext(), ++ ResultAccuracyModeV1::DEFAULT)); ++ rewriter.replaceOpWithNewOp( ++ op, op->getResultTypes(), op.getOperand(), defaultResultAccuracy); ++ return success(); ++ } ++}; ++ ++struct ExpOpV2ToV1 : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ ++ LogicalResult matchAndRewrite(ExpOpV2 op, ++ PatternRewriter& rewriter) const override { ++ auto defaultResultAccuracy = ResultAccuracyV1Attr::get( ++ rewriter.getContext(), APFloat(0.0), APFloat(0.0), 0, ++ ResultAccuracyModeV1Attr::get(rewriter.getContext(), ++ ResultAccuracyModeV1::DEFAULT)); ++ if (op.getResultAccuracy() != defaultResultAccuracy) { ++ return rewriter.notifyMatchFailure(op, ++ "non-default result accuracy attr"); ++ } ++ rewriter.replaceOpWithNewOp(op, op->getResultTypes(), ++ op.getOperand()); ++ return success(); ++ } ++}; ++ + #include "stablehlo/transforms/VhloToVersionPatterns.h.inc" + + } // namespace +@@ -405,6 +457,7 @@ + vhlo::populateWithGenerated(*patterns); + patterns->add(context); + patterns->add(context); ++ patterns->add(context); + } + + } // namespace stablehlo +diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersionPatterns.td b/stablehlo/stablehlo/transforms/VhloToVersionPatterns.td +--- stablehlo/stablehlo/transforms/VhloToVersionPatterns.td ++++ stablehlo/stablehlo/transforms/VhloToVersionPatterns.td +@@ -15,6 +15,9 @@ + + include "mlir/IR/OpBase.td" + include "stablehlo/dialect/VhloOps.td" ++include "mlir/IR/CommonAttrConstraints.td" ++include "stablehlo/dialect/VhloEnums.td" ++include "stablehlo/dialect/VhloAttrs.td" + + def VHLO_GetEmptyDims : NativeCodeCall<"getEmptyI64Tensor($_builder)">; + +@@ -31,6 +34,11 @@ + def VHLO_GetFirstOperand : NativeCodeCall<"$0.front()">; + + def VHLO_WrapInVector : NativeCodeCall<"{$0}">; ++ ++def VHLO_GetDefaultResultAccuracyAttr : NativeCodeCall<"getDefaultResultAccuracy($_builder)">; ++ ++ ++def VHLO_DefaultResultAccuracy : AttrConstraint, "Default result accuracy">; + + def DynamicConvUpgradeV1ToV2: + Pat<(VHLO_DynamicConvOpV1 $lhs, $rhs, $d_padding, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config), +@@ -83,3 +91,11 @@ + Pat<(VHLO_DotGeneralOpV1 $lhs, $rhs, $lhs_batching_dimensions, $rhs_batching_dimensions, $lhs_contracting_dimensions, $rhs_contracting_dimensions, $precision_config), + (VHLO_DotGeneralOpV2 $lhs, $rhs, $lhs_batching_dimensions, $rhs_batching_dimensions, $lhs_contracting_dimensions, $rhs_contracting_dimensions, $precision_config, + (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType))>; ++ ++def ExpOpDowngradeV2ToV1 : ++ Pat<(VHLO_ExpOpV2 $operand, VHLO_DefaultResultAccuracy:$result_accuracy), ++ (VHLO_ExpOpV1 $operand)>; ++ ++def ExpOpUpgradeV1ToV2 : ++ Pat<(VHLO_ExpOpV1 $operand), ++ (VHLO_ExpOpV2 $operand, (VHLO_GetDefaultResultAccuracyAttr))>; diff --git a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 0cd634313728c..8d143132452d4 100644 --- a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -638,6 +638,39 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( return output; } +// Converts ResultAccuracyAttr to XLA ResultAccuracy proto. +static xla::ResultAccuracy Convert_result_accuracy( + std::optional + optional_result_accuracy_attr) { + if (!optional_result_accuracy_attr.has_value()) return xla::ResultAccuracy(); + + auto result_accuracy = xla::ResultAccuracy(); + if (optional_result_accuracy_attr.value().getMode().getValue() == + mlir::mhlo::ResultAccuracyMode::TOLERANCE) { + result_accuracy.mutable_tolerance()->set_atol( + optional_result_accuracy_attr.value().getAtol().convertToFloat()); + result_accuracy.mutable_tolerance()->set_rtol( + optional_result_accuracy_attr.value().getRtol().convertToFloat()); + result_accuracy.mutable_tolerance()->set_ulps( + optional_result_accuracy_attr.value().getUlps()); + } else { + xla::ResultAccuracy::Mode mode; + auto result_accuracy_mode = + ::mlir::mhlo::stringifyResultAccuracyMode( + optional_result_accuracy_attr.value().getMode().getValue()) + .str(); + if (xla::ResultAccuracy::Mode_Parse(result_accuracy_mode, &mode)) { + result_accuracy.set_mode(mode); + } else { + auto* context = optional_result_accuracy_attr.value().getContext(); + mlir::emitError(mlir::UnknownLoc::get(context)) + << "unexpected result accuracy mode " << result_accuracy_mode; + return xla::ResultAccuracy(); + } + } + return result_accuracy; +} + // Returns an OpSharding proto from the "sharding" attribute of the op. If the // op doesn't have a sharding attribute or the sharding attribute is invalid, // returns std::nullopt. diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 925ed003c7e6f..aa9c8fda99b6d 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -1215,6 +1215,29 @@ LogicalResult SparseDotOp::verify() { return success(); } +// ===----------------------------------------------------------------------===// +// ExpOp +//===----------------------------------------------------------------------===// + +LogicalResult ResultAccuracyAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, + APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) { + return hlo::verifyResultAccuracyAttr( + emitError, std::move(atol), std::move(rtol), ulps, + stringifyResultAccuracyMode(mode.getValue())); +} + +LogicalResult ExpOp::verify() { + if (auto attr = getResultAccuracyAttr()) { + if (failed(ResultAccuracyAttr::verify([&] { return emitError(); }, + attr.getAtol(), attr.getRtol(), + attr.getUlps(), attr.getMode()))) { + return failure(); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // FftOp //===----------------------------------------------------------------------===// @@ -7261,6 +7284,20 @@ static LogicalResult verifyArgResultAliasAttr(StringAttr attrName, return success(); } +//===----------------------------------------------------------------------===// +// Custom unary op +//===----------------------------------------------------------------------===// + +void ResultAccuracyAttr::print(AsmPrinter& odsPrinter) const { + hlo::printResultAccuracyAttr(odsPrinter, getAtol(), getRtol(), getUlps(), + getMode()); +} + +Attribute ResultAccuracyAttr::parse(AsmParser& parser, Type type) { + return hlo::parseResultAccuracyAttr(parser, type); +} + // Each CrossProgramPrefetchAttr specifies a parameter and a ShapeIndex // (1) the parameter must be valid // (2) there must be a subshape at the given indices diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 99fbf3e26cace..9fa2054f3bc06 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -343,6 +343,10 @@ def MHLO_ExpOp: MHLO_UnaryElementwiseOp<"exponential", %result = mhlo.exponential %operand : tensor<2x2xf64> ``` }]; + let arguments = (ins MHLO_FpComplexOrQuantizedIntTensor:$operand, + DefaultValuedOptionalAttr:$result_accuracy); + let results = (outs MHLO_FpComplexOrQuantizedIntTensor:$result); + let hasVerifier = 1; let hasFolder = 1; } def MHLO_Expm1Op: MHLO_UnaryElementwiseOp<"exponential_minus_one", diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td b/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td index c2c3b0aca31ff..b44881d6dc36b 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td @@ -332,4 +332,36 @@ def MHLO_BoolElementsAttr : let convertFromStorage = "$_self"; } +//==----------------------------------------------------------------------===// +// Result Accuracy attributes +//===----------------------------------------------------------------------===// + +def MHLO_APFloatV1 : APFloatParameter<""> { + let parser = [{ + [&]() -> FailureOr { + double value; + if (failed($_parser.parseFloat(value))) { + return failure(); + } + return APFloat(value); + }() + }]; + let printer = "$_printer.printFloat($_self);"; +} + +def MHLO_ResultAccuracyAttr : AttrDef { + let mnemonic = "result_accuracy"; + let summary = "The requested accuracy for unary ops."; + let parameters = (ins + "APFloat":$atol, + "APFloat":$rtol, + "int64_t":$ulps, + MHLO_ResultAccuracyModeAttr:$mode + + ); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + let constBuilderCall = "mlir::mhlo::ResultAccuracyAttr::get($_builder.getContext(), llvm::APFloat(0.0), llvm::APFloat(0.0), 0, mlir::mhlo::ResultAccuracyModeAttr::get($_builder.getContext(), $0))"; +} + #endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ATTRS diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td b/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td index 3e4039ef9598a..53903a874fde8 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td @@ -260,4 +260,28 @@ def MHLO_RngAlgorithmAttr : EnumAttr; +def MHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>; +def MHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>; + +def MHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode", + "XLA result accuracy mode.", + [ + MHLO_RESULT_ACCURACY_DEFAULT, + MHLO_RESULT_ACCURACY_HIGHEST, + MHLO_RESULT_ACCURACY_TOLERANCE + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_ResultAccuracyModeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + + #endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS diff --git a/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc b/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc index 0afccba2e587d..b150650ea8930 100644 --- a/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc +++ b/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc @@ -15,12 +15,15 @@ limitations under the License. #include "mhlo/IR/mhlo_bytecode.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "stablehlo/dialect/Base.h" //===----------------------------------------------------------------------===// @@ -173,6 +176,18 @@ enum AttributeCode { /// operandTupleIndices : svarint[] /// } kOutputOperandAlias = 16, + + // ResultAccuracyModeAttr { + // mode: varint (encoded enum) + // } + kResultAccuracyModeAttr = 17, + + // ResultAccuracyAttr { + // atol: APFloat + // rtol: APFloat + // ulps: svarint + // } + kResultAccuracyAttr = 18, }; /// This enum contains marker codes used to indicate which type is @@ -243,6 +258,10 @@ class MhloBytecodeInterface : public BytecodeDialectInterface { OutputOperandAliasAttr readOutputOperandAliasAttr( DialectBytecodeReader &reader) const; PrecisionAttr readPrecisionAttr(DialectBytecodeReader &reader) const; + ResultAccuracyAttr readResultAccuracyAttr( + DialectBytecodeReader &reader) const; + ResultAccuracyModeAttr readResultAccuracyModeAttr( + DialectBytecodeReader &reader) const; RngAlgorithmAttr readRngAlgorithmAttr(DialectBytecodeReader &reader) const; RngDistributionAttr readRngDistributionAttr( DialectBytecodeReader &reader) const; @@ -268,6 +287,8 @@ class MhloBytecodeInterface : public BytecodeDialectInterface { DialectBytecodeWriter &writer) const; void write(OutputOperandAliasAttr attr, DialectBytecodeWriter &writer) const; void write(PrecisionAttr attr, DialectBytecodeWriter &writer) const; + void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const; + void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const; void write(RngAlgorithmAttr attr, DialectBytecodeWriter &writer) const; void write(RngDistributionAttr attr, DialectBytecodeWriter &writer) const; void write(ScatterDimensionNumbersAttr attr, @@ -331,6 +352,10 @@ Attribute MhloBytecodeInterface::readAttribute( return readOutputOperandAliasAttr(reader); case mhlo_encoding::kPrecisionAttr: return readPrecisionAttr(reader); + case mhlo_encoding::kResultAccuracyAttr: + return readResultAccuracyAttr(reader); + case mhlo_encoding::kResultAccuracyModeAttr: + return readResultAccuracyModeAttr(reader); case mhlo_encoding::kRngAlgorithmAttr: return readRngAlgorithmAttr(reader); case mhlo_encoding::kRngDistributionAttr: @@ -582,8 +607,9 @@ LogicalResult MhloBytecodeInterface::writeAttribute( ConvDimensionNumbersAttr, ChannelHandleAttr, DomainKindAttr, DotDimensionNumbersAttr, FftTypeAttr, FusionKindAttr, GatherDimensionNumbersAttr, OutputOperandAliasAttr, PrecisionAttr, - RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr, - TransposeAttr, TypeExtensionsAttr>([&](auto attr) { + ResultAccuracyAttr, ResultAccuracyModeAttr, RngAlgorithmAttr, + RngDistributionAttr, ScatterDimensionNumbersAttr, TransposeAttr, + TypeExtensionsAttr>([&](auto attr) { LOG_WRITE_CALL; write(attr, writer); return success(); @@ -594,6 +620,21 @@ LogicalResult MhloBytecodeInterface::writeAttribute( }); } +void MhloBytecodeInterface::write(ResultAccuracyModeAttr attr, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(mhlo_encoding::kResultAccuracyModeAttr); + hlo::bytecode::writeEnumAttribute(attr, writer); +} + +void MhloBytecodeInterface::write(ResultAccuracyAttr attr, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(mhlo_encoding::kResultAccuracyAttr); + writer.writeAPFloatWithKnownSemantics(attr.getAtol()); + writer.writeAPFloatWithKnownSemantics(attr.getRtol()); + writer.writeSignedVarInt(attr.getUlps()); + writer.writeAttribute(attr.getMode()); +} + void MhloBytecodeInterface::write(ArgResultAliasAttr attr, DialectBytecodeWriter &writer) const { writer.writeVarInt(mhlo_encoding::kArgResultAliasAttr); @@ -790,6 +831,39 @@ void MhloBytecodeInterface::write(TokenType type, writer.writeVarInt(mhlo_encoding::kTokenType); } +//===----------------------------------------------------------------------===// +// ResultAccuracyModeAttr + +ResultAccuracyModeAttr MhloBytecodeInterface::readResultAccuracyModeAttr( + DialectBytecodeReader &reader) const { + LOG_READ_CALL; + return hlo::bytecode::readEnumAttribute( + reader, getContext(), + [](uint32_t val) { return symbolizeResultAccuracyMode(val); }); +} + +//===----------------------------------------------------------------------===// +// ResultAccuracyAttr + +ResultAccuracyAttr MhloBytecodeInterface::readResultAccuracyAttr( + DialectBytecodeReader &reader) const { + LOG_READ_CALL; + FailureOr atol; + FailureOr rtol; + int64_t ulps = 0; + ResultAccuracyModeAttr mode; + if (failed(atol = + reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || + failed(rtol = + reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || + failed(reader.readSignedVarInt(ulps)) || + failed(reader.readAttribute(mode))) { + reader.emitError() << "failed to read APFloat for atol"; + return ResultAccuracyAttr(); + } + return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode); +} + } // namespace void addBytecodeInterface(MhloDialect *dialect) { diff --git a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 3931db159ec81..fcb185a4d3110 100644 --- a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -227,6 +227,20 @@ Attribute convertDenseArray(mlir::StringAttr hloName, Attribute hloAttr) { if (!stablehloValue.has_value()) return {}; \ return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value()) +stablehlo::ResultAccuracyMode convertResultAccuracyMode( + mhlo::ResultAccuracyMode mode) { + switch (mode) { + case mhlo::ResultAccuracyMode::DEFAULT: + return stablehlo::ResultAccuracyMode::DEFAULT; + case mhlo::ResultAccuracyMode::HIGHEST: + return stablehlo::ResultAccuracyMode::HIGHEST; + case mhlo::ResultAccuracyMode::TOLERANCE: + return stablehlo::ResultAccuracyMode::TOLERANCE; + default: + return {}; + } +} + Attribute convertAttr(Attribute hloAttr) { // Handle MHLO attributes. // The logic that handles attributes from other dialects (e.g. builtin @@ -301,6 +315,19 @@ Attribute convertAttr(Attribute hloAttr) { if (auto attr = mlir::dyn_cast(hloAttr)) { RETURN_CONVERTED_ENUM_ATTR(Transpose); } + if (auto attr = mlir::dyn_cast(hloAttr)) { + RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode); + } + if (auto attr = mlir::dyn_cast(hloAttr)) { + stablehlo::ResultAccuracyModeAttr modeAttr; + modeAttr = stablehlo::ResultAccuracyModeAttr::get( + attr.getContext(), + convertResultAccuracyMode(attr.getMode().getValue())); + + return stablehlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(), + attr.getRtol(), attr.getUlps(), + modeAttr); + } if (hloAttr.getDialect().getNamespace() == mhlo::MhloDialect::getDialectNamespace()) { // Our guiding principle is to support all StableHLO functionality in MHLO. diff --git a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index 9e454cdfa7f67..fa3186ce17a94 100644 --- a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -45,6 +45,19 @@ namespace { if (!hloValue.has_value()) return {}; \ return mhlo::Name##Attr::get(attr.getContext(), hloValue.value()) +mhlo::ResultAccuracyMode convertResultAccuracyMode( + stablehlo::ResultAccuracyMode mode) { + switch (mode) { + case stablehlo::ResultAccuracyMode::DEFAULT: + return mhlo::ResultAccuracyMode::DEFAULT; + case stablehlo::ResultAccuracyMode::HIGHEST: + return mhlo::ResultAccuracyMode::HIGHEST; + case stablehlo::ResultAccuracyMode::TOLERANCE: + return mhlo::ResultAccuracyMode::TOLERANCE; + default: + return {}; + } +} Attribute convertAttr(Attribute stablehloAttr) { // StableHLO uses DenseArray for some attributes, MHLO is in the process // of integrating this change. In the meantime, convert DenseArray to @@ -140,6 +153,20 @@ Attribute convertAttr(Attribute stablehloAttr) { if (auto attr = mlir::dyn_cast(stablehloAttr)) { RETURN_CONVERTED_ENUM_ATTR(Transpose); } + if (auto attr = + mlir::dyn_cast(stablehloAttr)) { + RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode); + } + if (auto attr = + mlir::dyn_cast(stablehloAttr)) { + mhlo::ResultAccuracyModeAttr modeAttr = mhlo::ResultAccuracyModeAttr::get( + attr.getContext(), + convertResultAccuracyMode(attr.getMode().getValue())); + + return mhlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(), + attr.getRtol(), attr.getUlps(), + modeAttr); + } if (stablehloAttr.getDialect().getNamespace() == stablehlo::StablehloDialect::getDialectNamespace()) { // Our guiding principle is to support all StableHLO functionality in MHLO. diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index e3a50d613cb94..9a8f58d8e763e 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -335,6 +335,19 @@ func.func @attr_type_extensions_bounds( func.return %arg0 : tensor> } +// ----- + +// CHECK-LABEL: "attr_result_accuracy_mode" +func.func @attr_result_accuracy_mode(%arg0: tensor) -> tensor { + %0 = "mhlo.exponential"(%arg0) { + // CHECK: result_accuracy = #stablehlo.result_accuracy> + result_accuracy = #mhlo.result_accuracy> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + // ============ OPS ============ // CHECK-LABEL: "op_abs" diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 16a64cdc22b76..64080d89e2e5f 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -1640,6 +1640,14 @@ func.func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi3 // ----- +// CHECK-LABEL: func @exponential_result_accuracy +func.func @exponential_result_accuracy(%arg0: tensor) -> tensor { + %0 = "mhlo.exponential"(%arg0) {result_accuracy = #mhlo.result_accuracy>} : (tensor) -> tensor + func.return %0: tensor +} + +// ----- + func.func @dot_more_dynamic_output_type(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<3xf32>, tensor) -> tensor func.return %0 : tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index ffc8be409243d..9549615c457de 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -260,6 +260,18 @@ func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor< // ----- + +// CHECK-LABEL: "attr_result_accuracy" +func.func @attr_result_accuracy(%arg0: tensor) -> tensor { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: result_accuracy = #mhlo.result_accuracy> + result_accuracy = #stablehlo.result_accuracy> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: "attr_rng_algorithm_default" func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { %0:2 = "stablehlo.rng_bit_generator"(%arg0) {