Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ODML] Pass expand-tuple : Migrate from MHLO to StableHLO #21778

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/mlir_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@ cc_library(
"stablehlo_ext/transforms/chlo_recompose_ops.cpp",
"stablehlo_ext/transforms/sdy_refine_shapes.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_entry_function_tuples.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp",
"stablehlo_ext/transforms/stablehlo_prepare_for_hlo_export.cpp",
"stablehlo_ext/transforms/stablehlo_refine_shapes.cpp",
Expand Down
9 changes: 9 additions & 0 deletions xla/mlir_hlo/stablehlo_ext/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,12 @@ def StablehloFlattenTuplePass : Pass<"stablehlo-ext-flatten-tuple", "func::FuncO
"support both tuple and variadic type.";
let constructor = "createStablehloFlattenTuplePass()";
}

def StablehloFlattenEntryFunctionTuplesPass : Pass<"stablehlo-ext-expand-flatten-entry-function-tuples", "ModuleOp"> {
let summary = "Flatten HLO tuple for the entry function of the module.";
let options = [
Option<"entryFunctionNameOption", "entry-function", "std::string",
/*default=*/"", "the name of entry function of the module">,
];
let dependentDialects = ["mlir::stablehlo::StablehloDialect"];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/* Copyright 2021 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <iterator>

#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc

namespace mlir {
namespace stablehlo_ext {

#define GEN_PASS_DEF_STABLEHLOFLATTENENTRYFUNCTIONTUPLESPASS
#include "stablehlo_ext/transforms/passes.h.inc"

namespace {

// This pass assumes the function to be expanded has no callees, to be specific,
// the function is more like the main function.
class StablehloFlattenEntryFunctionTuplesPass
: public impl::StablehloFlattenEntryFunctionTuplesPassBase<
StablehloFlattenEntryFunctionTuplesPass> {
public:
StablehloFlattenEntryFunctionTuplesPass()
: StablehloFlattenEntryFunctionTuplesPassBase<
StablehloFlattenEntryFunctionTuplesPass>() {}
explicit StablehloFlattenEntryFunctionTuplesPass(
const StablehloFlattenEntryFunctionTuplesPassOptions &opts)
: StablehloFlattenEntryFunctionTuplesPassBase<
StablehloFlattenEntryFunctionTuplesPass>(opts) {}

// Expands the mhlo.tuple used in return op. Also updates function
// signature accordingly.
void expandTupledTensorInReturnOp(func::FuncOp func) {
FunctionType oldFuncType = func.getFunctionType();
// Update input signatures.
// We will flatten the tuples for the function inputs as well.
// So if an input is tuple, will be flattened and packed as following:
// func_1(%arg0: tuple<input1, input2>) =>
//
// func_1(%arg0: <input1>, %arg1: <input2>) {
// %0 = mhlo.tuple(%arg0, %arg1)
// }
SmallVector<Type, 4> expandedInputTypes;
SmallVector<BlockArgument, 20> funcArguments(func.getArguments().begin(),
func.getArguments().end());
for (auto argument : funcArguments) {
auto type = argument.getType();
auto tupleType = mlir::dyn_cast_or_null<TupleType>(type);
if (!tupleType) {
expandedInputTypes.push_back(type);
} else {
// We need to
// 1) expand the tuple
// 2) insert a new tuple
// 3) rewire the new tuple
int originalArgumentIndex = argument.getArgNumber();
int argumentIndex = originalArgumentIndex;
SmallVector<Value, 4> flattenedOperands;
// insert the flattened tuples after the original tuple.
Location loc = func.getBody().getLoc();
for (auto flattenedType : tupleType.getTypes()) {
expandedInputTypes.push_back(flattenedType);
func.insertArgument(++argumentIndex, flattenedType, {}, loc);
flattenedOperands.push_back(func.getArgument(argumentIndex));
}

// Construct a new tuple and rewire it.
OpBuilder builder(func.getBody());
builder.setInsertionPointToStart(&func.getBody().front());
auto newTuple = builder.create<stablehlo::TupleOp>(loc, tupleType,
flattenedOperands);
func.getArgument(originalArgumentIndex).replaceAllUsesWith(newTuple);

// Now the original argument has been rewired, we should be able to
// safely erase it.
func.eraseArgument(originalArgumentIndex);
}
}

// Update output signatures.
auto returnOp = cast<mlir::func::ReturnOp>(func.getBody().back().back());
OpBuilder builder(returnOp);

// Expand all tuples in old return operands.
SmallVector<Value, 4> expandedReturnOperands;
SmallVector<Type, 4> expandedResultTypes;
for (auto value : returnOp.getOperands()) {
if (auto tupleTy = mlir::dyn_cast<TupleType>(value.getType())) {
llvm::copy(tupleTy.getTypes(), std::back_inserter(expandedResultTypes));
for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) {
expandedReturnOperands.push_back(
builder.createOrFold<stablehlo::GetTupleElementOp>(
value.getLoc(), ty, value, index));
}
} else {
expandedReturnOperands.push_back(value);
expandedResultTypes.push_back(value.getType());
}
}

if (returnOp.getOperands() == expandedReturnOperands) return;

builder.create<mlir::func::ReturnOp>(returnOp.getLoc(),
expandedReturnOperands);
returnOp.erase();
auto newFuncType = FunctionType::get(
oldFuncType.getContext(), expandedInputTypes, expandedResultTypes);
func.setType(newFuncType);
}

void runOnOperation() override {
auto module = getOperation();
// Find `main` function.
auto entryFunction =
module.lookupSymbol<func::FuncOp>(entryFunctionNameOption);
if (!entryFunction) {
return;
}

// Recursively expand tuples until all of them are gone.
while (
llvm::any_of(llvm::concat<const Type>(entryFunction.getArgumentTypes(),
entryFunction.getResultTypes()),
[](Type type) { return mlir::isa<TupleType>(type); })) {
expandTupledTensorInReturnOp(entryFunction);
}
}
};

} // namespace

} // namespace stablehlo_ext
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// RUN: mlir-hlo-opt %s -split-input-file -stablehlo-ext-expand-flatten-entry-function-tuples='entry-function=main' -allow-unregistered-dialect | FileCheck %s

// CHECK-LABEL: func @main
func.func @main(%arg0: tensor<1x224x224x3xf16>, %arg1: tensor<f32>) -> tensor<1x224x224x3xf16> {
// CHECK: return %arg0 : tensor<1x224x224x3xf16>
func.return %arg0 : tensor<1x224x224x3xf16>
}

// -----

// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1xf32>, %[[ARG1:.*]]: tensor<1x8x8x16xf32>) -> (tensor<1024xf32>, tensor<1xf32>)
func.func @main(%arg0: tensor<1x1xf32>, %arg1: tensor<1x8x8x16xf32>) -> tuple<tensor<1024xf32>, tensor<1xf32>> {
// CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[ARG0]] : (tensor<1x1xf32>) -> tensor<1xf32>
%0 = stablehlo.reshape %arg0 : (tensor<1x1xf32>) -> tensor<1xf32>
// CHECK-NEXT: %[[RESHAPE1:.*]] = stablehlo.reshape %[[ARG1]] : (tensor<1x8x8x16xf32>) -> tensor<1024xf32>
%1 = stablehlo.reshape %arg1 : (tensor<1x8x8x16xf32>) -> tensor<1024xf32>
// CHECK-NEXT: %[[TUPLE:.*]] = stablehlo.tuple %[[RESHAPE1]], %[[RESHAPE0]] {name = "tuple.374"} : tuple<tensor<1024xf32>, tensor<1xf32>>
%2 = stablehlo.tuple %1, %0 {name = "tuple.374"} : tuple<tensor<1024xf32>, tensor<1xf32>>
// CHECK-NEXT: %[[RES0:.*]] = stablehlo.get_tuple_element %[[TUPLE]][0] : (tuple<tensor<1024xf32>, tensor<1xf32>>) -> tensor<1024xf32>
// CHECK-NEXT: %[[RES1:.*]] = stablehlo.get_tuple_element %[[TUPLE]][1] : (tuple<tensor<1024xf32>, tensor<1xf32>>) -> tensor<1xf32>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : tensor<1024xf32>, tensor<1xf32>
return %2 : tuple<tensor<1024xf32>, tensor<1xf32>>
}

// -----

// CHECK-LABEL: func @main
// CEHCK-SAME: () -> (tensor<1xf32>, tensor<1xi32>)
func.func @main() -> tuple<tensor<1xf32>, tensor<1xi32>> {
// CHECK-NEXT: %[[TUPLE:.*]] = "test.dummy"() : () -> tuple<tensor<1xf32>, tensor<1xi32>>
%0 = "test.dummy"() : () -> tuple<tensor<1xf32>, tensor<1xi32>>
// CHECK-NEXT: %[[RES0:.*]] = stablehlo.get_tuple_element %[[TUPLE]][0] : (tuple<tensor<1xf32>, tensor<1xi32>>) -> tensor<1xf32>
// CHECK-NEXT: %[[RES1:.*]] = stablehlo.get_tuple_element %[[TUPLE]][1] : (tuple<tensor<1xf32>, tensor<1xi32>>) -> tensor<1xi32>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : tensor<1xf32>, tensor<1xi32>
func.return %0 : tuple<tensor<1xf32>, tensor<1xi32>>
}

// -----

// CHECK-LABEL: func @main
func.func @main() -> tuple<> {
// CHECK-NEXT: %[[TUPLE:.*]] = stablehlo.tuple {xla_shape = "()"} : tuple<>
%0 = "stablehlo.tuple"() {xla_shape = "()"} : () -> tuple<>
// CHECK-NEXT: return{{$}}
func.return %0 : tuple<>
}

// -----

// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: tensor<1024xf32>, %[[ARG1:.*]]: tensor<1xf32>) -> (tensor<1024xf32>, tensor<1xf32>)
func.func @main(%arg0: tuple<tensor<1024xf32>, tensor<1xf32>>) -> tuple<tensor<1024xf32>, tensor<1xf32>> {
// CHECK-NEXT: %[[TUPLE:.*]] = stablehlo.tuple %[[ARG0]], %[[ARG1]] : tuple<tensor<1024xf32>, tensor<1xf32>>
// CHECK-NEXT: %[[RES0:.*]] = stablehlo.get_tuple_element %[[TUPLE]][0] : (tuple<tensor<1024xf32>, tensor<1xf32>>) -> tensor<1024xf32>
// CHECK-NEXT: %[[RES1:.*]] = stablehlo.get_tuple_element %[[TUPLE]][1] : (tuple<tensor<1024xf32>, tensor<1xf32>>) -> tensor<1xf32>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : tensor<1024xf32>, tensor<1xf32>
func.return %arg0 : tuple<tensor<1024xf32>, tensor<1xf32>>
}

// -----

// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>) -> (tensor<1xf32>, tensor<1xi32>)
func.func @main(%arg0: tuple<tuple<tensor<1xi8>>>) -> tuple<tuple<tensor<1xf32>>, tensor<1xi32>> {
// CHECK: %[[T0:.*]] = stablehlo.tuple %[[ARG]] : tuple<tensor<1xi8>>
// CHECK: %[[T1:.*]] = stablehlo.tuple %[[T0]] : tuple<tuple<tensor<1xi8>>>
// CHECK: %[[T:.*]] = "test.dummy"(%[[T1]]) : (tuple<tuple<tensor<1xi8>>>) -> tuple<tuple<tensor<1xf32>>, tensor<1xi32>>
%0 = "test.dummy"(%arg0) : (tuple<tuple<tensor<1xi8>>>) -> tuple<tuple<tensor<1xf32>>, tensor<1xi32>>
// CHECK: %[[GTE0:.*]] = stablehlo.get_tuple_element %[[T]][0] : (tuple<tuple<tensor<1xf32>>, tensor<1xi32>>) -> tuple<tensor<1xf32>>
// CHECK: %[[GTE1:.*]] = stablehlo.get_tuple_element %[[T]][1] : (tuple<tuple<tensor<1xf32>>, tensor<1xi32>>) -> tensor<1xi32>
// CHECK: %[[GTE2:.*]] = stablehlo.get_tuple_element %[[GTE0]][0] : (tuple<tensor<1xf32>>) -> tensor<1xf32>
// CHECK: return %[[GTE2]], %[[GTE1]] : tensor<1xf32>, tensor<1xi32>
func.return %0 : tuple<tuple<tensor<1xf32>>, tensor<1xi32>>
}
Loading