From 5c3f21721d921751f3c64451f9be3b59f9cac8c7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 Apr 2024 03:51:55 -0700 Subject: [PATCH] Share StableHLO/MHLO pretty printers for ConstantOp PiperOrigin-RevId: 625273473 --- third_party/stablehlo/temporary.patch | 123 ++++++++++++++++++++++++++ xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 41 +-------- 2 files changed, 125 insertions(+), 39 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 81c73942d657f..f9229c0d13765 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -174,6 +174,80 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) +diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp +--- stablehlo/stablehlo/dialect/AssemblyFormat.cpp ++++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp +@@ -15,6 +15,7 @@ + + #include "stablehlo/dialect/AssemblyFormat.h" + ++#include + #include + #include + #include +@@ -130,6 +131,42 @@ + for (Type& t : opTypes) typePtrs.push_back(&t); + + return detail::parseSameOperandsAndResultTypeImpl(parser, typePtrs, result); ++} ++ ++void printConstantOp(OpAsmPrinter& p, Operation* op, ElementsAttr value) { ++ assert(op->getNumResults() == 1); ++ // If not all types are the same, use generic form. ++ if (value.getType() != op->getResultTypes().front()) { ++ p.printGenericOp(op, /*printOpName=*/false); ++ return; ++ } ++ ++ p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); ++ p << ' '; ++ p.printStrippedAttrOrType(value); ++} ++ ++ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result) { ++ // Parse the generic form. ++ if (succeeded(parser.parseOptionalLParen())) { ++ if (parser.parseRParen()) return failure(); ++ if (parser.parseOptionalAttrDict(result.attributes)) return failure(); ++ if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || ++ parser.parseArrow()) ++ return failure(); ++ Type resultTy; ++ if (parser.parseType(resultTy)) return failure(); ++ result.addTypes(resultTy); ++ return success(); ++ } ++ ++ ElementsAttr valueAttr; ++ if (parser.parseOptionalAttrDict(result.attributes)) return failure(); ++ if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value", ++ result.attributes)) ++ return failure(); ++ result.addTypes(valueAttr.getType()); ++ return success(); + } + + void printTupleOpType(OpAsmPrinter& p, Operation*, TypeRange operands, +diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/dialect/AssemblyFormat.h +--- stablehlo/stablehlo/dialect/AssemblyFormat.h ++++ stablehlo/stablehlo/dialect/AssemblyFormat.h +@@ -101,6 +101,16 @@ + SmallVectorImpl& operands, + SmallVectorImpl& opTypes, Type& result); + ++// Print a `constant` op. ++// ++// op ::= attr-dict $value ++// ++// When the `value` and `output` have different type, it just uses the default ++// operator assembly format as a fallback. ++void printConstantOp(OpAsmPrinter& p, Operation* op, ElementsAttr value); ++ ++ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result); ++ + // TuplesOp - only print result type. Operand type is trivially inferrable. + // + // Inferring operand types from tuple type: diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp --- stablehlo/stablehlo/dialect/StablehloOps.cpp +++ stablehlo/stablehlo/dialect/StablehloOps.cpp @@ -194,6 +268,55 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/ OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { assert(adaptor.getOperands().empty() && "constant has no operands"); +@@ -311,44 +321,11 @@ + } + + ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { +- // Parse the generic form. +- if (succeeded(parser.parseOptionalLParen())) { +- if (parser.parseRParen()) return failure(); +- if (parser.parseOptionalAttrDict(result.attributes)) return failure(); +- if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || +- parser.parseArrow()) +- return failure(); +- Type resultTy; +- if (parser.parseType(resultTy)) return failure(); +- result.addTypes(resultTy); +- return success(); +- } +- +- ElementsAttr valueAttr; +- if (parser.parseOptionalAttrDict(result.attributes)) return failure(); +- if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value", +- result.attributes)) +- return failure(); +- result.addTypes(valueAttr.getType()); +- return success(); +-} +- +-/// Print a `constant` op. +-/// +-/// op ::= attr-dict $value +-/// +-/// When the `value` and `output` have different type, it just uses the default +-/// operator assembly format as a fallback. ++ return hlo::parseConstantOp(parser, result); ++} ++ + void ConstantOp::print(::mlir::OpAsmPrinter& p) { +- // If not all types are the same, use generic form. +- if (getValue().getType() != getType()) { +- p.printGenericOp(getOperation(), /*printOpName=*/false); +- return; +- } +- +- p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); +- p << ' '; +- p.printStrippedAttrOrType(getValueAttr()); ++ hlo::printConstantOp(p, getOperation(), getValue()); + } + + //===----------------------------------------------------------------------===// diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td --- stablehlo/stablehlo/dialect/StablehloOps.td +++ stablehlo/stablehlo/dialect/StablehloOps.td diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index c5ef02a62e14d..d12f6599463d2 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -650,48 +650,11 @@ bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { } ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { - // Parse the generic form. - if (succeeded(parser.parseOptionalLParen())) { - if (parser.parseRParen()) return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || - parser.parseArrow()) - return failure(); - Type resultTy; - if (parser.parseType(resultTy)) { - return failure(); - } - result.addTypes(resultTy); - return success(); - } - - ElementsAttr valueAttr; - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - - if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value", - result.attributes)) { - return failure(); - } - result.addTypes(valueAttr.getType()); - return success(); + return hlo::parseConstantOp(parser, result); } -/// Print a `constant` op. -/// -/// op ::= attr-dict $value -/// -/// When the `value` and `output` have different type, it just uses the default -/// operator assembly format as a fallback. void ConstantOp::print(::mlir::OpAsmPrinter& p) { - // If not all types are the same, use generic form. - if (getValue().getType() != getType()) { - p.printGenericOp(getOperation(), /*printOpName=*/false); - return; - } - - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); - p << ' '; - p.printStrippedAttrOrType(getValueAttr()); + hlo::printConstantOp(p, getOperation(), getValue()); } //===----------------------------------------------------------------------===//