Skip to content

Commit

Permalink
[xla:emitters] drop first operand of functions marked with xla.entry …
Browse files Browse the repository at this point in the history
…and xla.backend_kind=cpu

This paves the way for CPU emitters.

For consistency, this also adds the xla.backend_kind=gpu attribute to the GPU
entry function.

PiperOrigin-RevId: 717583000
  • Loading branch information
cota authored and Google-ML-Automation committed Jan 23, 2025
1 parent e795171 commit aa1410d
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 30 deletions.
1 change: 1 addition & 0 deletions xla/backends/gpu/codegen/emitters/emitter_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> EmitterBase::CreateMLIRModule(
mlir::ArrayAttr::get(&context, arg_attrs),
/*res_attrs=*/mlir::ArrayAttr{});
entry_func->setAttr("xla.entry", mlir::UnitAttr::get(&context));
SetBackendKind(&context, entry_func, BackendKind::kGpu);

TF_RETURN_IF_ERROR(EmitMlir(module.get(), entry_func, fusion));
return module;
Expand Down
1 change: 0 additions & 1 deletion xla/backends/gpu/codegen/emitters/ir/xla_gpu_attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_dialect.td"
include "xla/codegen/emitters/ir/xla_attrs.td"

class XLAGPU_Attr<string name, list<Trait> traits = []> :
AttrDef<XlaGpuDialect, name, traits> {
Expand Down
1 change: 1 addition & 0 deletions xla/codegen/emitters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ xla_cc_test(
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
12 changes: 10 additions & 2 deletions xla/codegen/emitters/elemental_hlo_to_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1186,8 +1186,16 @@ ValueRange ProvideParameter(const PartitionedComputation& computation,
}

auto callee = call_target_provider(operand);
SmallVector<Value> operands(
this_fn.getArguments().take_front(instr->parent()->num_parameters()));
SmallVector<Value> operands;
if (auto backend_kind = GetBackendKind(this_fn);
backend_kind == xla::BackendKind::kCpu && this_fn->getAttr("xla.entry")) {
operands =
SmallVector<Value>{this_fn.getArguments().drop_front().take_front(
instr->parent()->num_parameters())};
} else {
operands = SmallVector<Value>{
this_fn.getArguments().take_front(instr->parent()->num_parameters())};
}
absl::c_copy(indices, std::back_inserter(operands));
auto results = builder.create<PureCallOp>(callee, operands).getResults();
auto callee_subgraph = computation.FindSubgraph(operand);
Expand Down
100 changes: 74 additions & 26 deletions xla/codegen/emitters/elemental_hlo_to_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand All @@ -32,6 +33,7 @@ limitations under the License.
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -71,9 +73,12 @@ class ElementalHloToMlirTest : public HloTestBase {

// Converts the root subgraph of the entry function of the given hlo module to
// MLIR.
absl::Status Run(const std::string& hlo, const std::string& filecheck_str,
absl::Status Run(const absl::string_view hlo,
const absl::string_view filecheck_str,
std::function<EpilogueSpecification(HloComputation* entry)>
epilogue_spec_fn = nullptr) {
epilogue_spec_fn = nullptr,
bool set_xla_entry = false,
std::optional<xla::BackendKind> xla_backend = std::nullopt) {
auto hlo_module = ParseAndReturnVerifiedModule(hlo).value();

mlir::ImplicitLocOpBuilder builder(mlir::UnknownLoc::get(&context_),
Expand All @@ -95,6 +100,12 @@ class ElementalHloToMlirTest : public HloTestBase {
auto entry_func = fns[&partitioned_computations
.FindPartitionedComputation(entry_computation)
.GetRootSubgraph()];
if (set_xla_entry) {
entry_func->setAttr("xla.entry", mlir::UnitAttr::get(&context_));
}
if (xla_backend) {
SetBackendKind(&context_, entry_func, *xla_backend);
}
auto& entry_pc =
partitioned_computations.FindPartitionedComputation(entry_computation);
auto call_targets = partitioned_computations.CreateCallTargetProvider(fns);
Expand Down Expand Up @@ -1390,41 +1401,78 @@ TEST_F(ElementalHloToMlirTest, PopulationCountUnsigned) {
)"));
}

TEST_F(ElementalHloToMlirTest, Epilogue) {
TF_EXPECT_OK(Run(
class ElementalHloToMlirEpilogueTest : public ElementalHloToMlirTest {
protected:
std::function<EpilogueSpecification(HloComputation* entry)> EpilogueSpec() {
return [this](HloComputation* entry) {
EpilogueSpecification epilogue;
epilogue.heroes.push_back(entry->GetInstructionWithName("transpose"));
epilogue.roots.push_back(entry->GetInstructionWithName("add"));
epilogue.index_ranges = {2, 16, 17};
epilogue.root_indexing.push_back(
IndexingMap{mlir::AffineMap::getMultiDimIdentityMap(3, &context_)
.getSubMap({0, 2, 1}),
DimVarsFromTensorSizes({2, 17, 17}),
{},
{}});
return epilogue;
};
}
static constexpr absl::string_view kHlo =
R"(
ENTRY main {
%p0 = f32[2,16,17] parameter(0)
%log = f32[2,16,17] log(%p0)
// Note: %p0 is only used in some of the tests.
%p0 = f32[7] parameter(0)
%p1 = f32[2,16,17] parameter(1)
%log = f32[2,16,17] log(%p1)
%transpose = f32[2,17,16] transpose(%log), dimensions={0,2,1}
%p1 = f32[] parameter(1)
%bc = f32[2,17,16] broadcast(%p1), dimensions={}
%p2 = f32[] parameter(2)
%bc = f32[2,17,16] broadcast(%p2), dimensions={}
ROOT %add = f32[2,17,16] add(%transpose, %bc)
})",
})";
static constexpr absl::string_view kCheck =
R"(
// CHECK: @main_add(
// CHECK-SAME: %[[A0:.*]]: tensor<7xf32>
// CHECK: %[[PURE:.*]] = xla.pure_call @main_transpose(%[[A0]],
// CHECK: @main_transpose(tensor<7xf32>,
// CHECK: @main__epilogue__add(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x17xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<f32>
// CHECK-SAME: %[[ARG0:.*]]: tensor<7xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x16x17xf32>
// CHECK-SAME: %[[ARG2:.*]]: tensor<f32>
// CHECK-SAME: %[[X:.*]]: index {xla.range = [0 : index, 1 :
// CHECK-SAME: %[[Y:.*]]: index {xla.range = [0 : index, 15 :
// CHECK-SAME: %[[Z:.*]]: index {xla.range = [0 : index, 16 :
// CHECK-SAME: %[[TRANSPOSE:.*]]: f32) -> f32
// CHECK: %[[B:.*]] = tensor.extract %[[ARG1]][]
// CHECK: %[[B:.*]] = tensor.extract %[[ARG2]][]
// CHECK: %[[RET:.*]] = arith.addf %[[TRANSPOSE]], %[[B]]
// CHECK: return %[[RET]])",
[this](HloComputation* entry) {
EpilogueSpecification epilogue;
epilogue.heroes.push_back(entry->GetInstructionWithName("transpose"));
epilogue.roots.push_back(entry->GetInstructionWithName("add"));
epilogue.index_ranges = {2, 16, 17};
epilogue.root_indexing.push_back(
IndexingMap{mlir::AffineMap::getMultiDimIdentityMap(3, &context_)
.getSubMap({0, 2, 1}),
DimVarsFromTensorSizes({2, 17, 17}),
{},
{}});
return epilogue;
}));
// CHECK: return %[[RET]]
)";
};

TEST_F(ElementalHloToMlirEpilogueTest, Epilogue) {
TF_EXPECT_OK(Run(kHlo, kCheck, EpilogueSpec()));
}

TEST_F(ElementalHloToMlirEpilogueTest, XlaEntry) {
TF_EXPECT_OK(Run(kHlo, kCheck, EpilogueSpec(), /*set_xla_entry=*/true));
}

TEST_F(ElementalHloToMlirEpilogueTest, XlaGpuEntry) {
TF_EXPECT_OK(Run(kHlo, kCheck, EpilogueSpec(), /*set_xla_entry=*/true,
/*xla_backend=*/xla::BackendKind::kGpu));
}

TEST_F(ElementalHloToMlirEpilogueTest, XlaCpuEntry) {
TF_EXPECT_OK(Run(kHlo,
R"(
// CHECK: @main_add(
// CHECK-SAME: %[[ARG0:.*]]: tensor<7xf32>
// main_transpose must still have arg0, but the pure_call must not.
// CHECK: %[[PURE:.*]] = xla.pure_call @main_transpose(%arg1,
// CHECK: @main_transpose(tensor<7xf32)",
EpilogueSpec(), /*set_xla_entry=*/true,
/*xla_backend=*/xla::BackendKind::kCpu));
}

TEST_F(ElementalHloToMlirTest, ScalarConstant) {
Expand Down
4 changes: 4 additions & 0 deletions xla/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,16 @@ xla_test(
":xla",
"//xla/hlo/analysis:indexing_analysis",
"//xla/hlo/analysis:indexing_test_utils",
"//xla/hlo/testlib:filecheck",
"//xla/mlir/utils:error_util",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_googletest//:gtest",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@tsl//tsl/platform:test",
],
)
10 changes: 9 additions & 1 deletion xla/codegen/emitters/ir/tests/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,12 @@ func.func @atomic_rmw_mismatch_block_arg_vector_type(%in: tensor<16xf32>, %i: in
xla.yield %add : vector<2xi32>
}
return %ret : tensor<16xf32>
}
}

// -----

// expected-error @+2 {{expected ::xla::BackendKind to be one of: cpu, gpu, tpu}}
// expected-error @+1 {{failed to parse XLA_BackendKindAttr parameter 'value' which is to be a `::xla::BackendKind`}}
func.func @test_backend_kind(%arg0: f32) attributes { xla.backend_kind = #xla.backend_kind<foo> } {
func.return
}
24 changes: 24 additions & 0 deletions xla/codegen/emitters/ir/tests/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,27 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32,
// CHECK: %[[ADD:.*]] = arith.addf %{{.*}}, %[[EXTRACTED]] : f32
// CHECK: xla.yield %[[ADD]] : f32
// CHECK: } {xla.range = [0 : index, 42 : index]}

// -----

func.func @test_backend_kind(%arg0: f32) attributes { xla.backend_kind = #xla.backend_kind<cpu> } {
func.return
}
// CHECK: @test_backend_kind
// CHECK-SAME: #xla.backend_kind<cpu>

// -----

func.func @test_backend_kind(%arg0: f32) attributes { xla.backend_kind = #xla.backend_kind<gpu> } {
func.return
}
// CHECK: @test_backend_kind
// CHECK-SAME: #xla.backend_kind<gpu>

// -----

func.func @test_backend_kind(%arg0: f32) attributes { xla.backend_kind = #xla.backend_kind<tpu> } {
func.return
}
// CHECK: @test_backend_kind
// CHECK-SAME: #xla.backend_kind<tpu>
15 changes: 15 additions & 0 deletions xla/codegen/emitters/ir/xla_attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,19 @@ def XLA_IndexingMapAttr : XLA_Attr<"IndexingMap"> {
}];
}

// Note: settle for BackendKind to avoid clashing with the existing
// xla::Backend and xla::BackendType types.
def XLA_BackendKind : I32EnumAttr<"BackendKind", "XLA Backend kind (or type)", [
I32EnumAttrCase<"kCpu", 0, "cpu">,
I32EnumAttrCase<"kGpu", 1, "gpu">,
I32EnumAttrCase<"kTpu", 2, "tpu">
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::xla";
}
def XLA_BackendKindAttr :
EnumAttr<XlaDialect, XLA_BackendKind, "backend_kind"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // XLA_CODEGEN_EMITTERS_IR_XLA_ATTRS
1 change: 1 addition & 0 deletions xla/codegen/emitters/ir/xla_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
// The order of these includes is important.
#define GET_ATTRDEF_CLASSES
#include "xla/codegen/emitters/ir/xla_attrs.cc.inc"
#include "xla/codegen/emitters/ir/xla_enums.cc.inc"

namespace xla {
namespace {
Expand Down
15 changes: 15 additions & 0 deletions xla/codegen/emitters/ir/xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,21 @@ void LoopOp::getCanonicalizationPatterns(mlir::RewritePatternSet& results,
results.add<FoldConstantDimensions, SimplifyLoopOfApplyIndexing>(context);
}

void SetBackendKind(mlir::MLIRContext* context, mlir::func::FuncOp fn,
xla::BackendKind backend_kind) {
fn->setAttr(xla::BackendKindAttr::name,
xla::BackendKindAttr::get(context, backend_kind));
}

std::optional<xla::BackendKind> GetBackendKind(mlir::func::FuncOp fn) {
auto backend_attr =
fn->getAttrOfType<xla::BackendKindAttr>(xla::BackendKindAttr::name);
if (!backend_attr) {
return std::nullopt;
}
return backend_attr.getValue();
}

} // namespace xla

#define GET_OP_CLASSES
Expand Down
8 changes: 8 additions & 0 deletions xla/codegen/emitters/ir/xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ limitations under the License.
#include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep
#include "xla/codegen/emitters/ir/xla_dialect.h.inc"
#include "xla/hlo/analysis/indexing_map.h" // IWYU pragma: keep

// The order of these includes is important.
#include "xla/codegen/emitters/ir/xla_enums.h.inc" // IWYU pragma: keep
#define GET_ATTRDEF_CLASSES
#include "xla/codegen/emitters/ir/xla_attrs.h.inc"
#define GET_OP_CLASSES
Expand Down Expand Up @@ -69,6 +72,11 @@ std::optional<Interval> GetRange(mlir::Value value);
// determined.
std::optional<Interval> GetIVRange(mlir::Value iv);

// Helpers for getting/setting xla::BackendKind attribute given a func::FuncOp.
std::optional<xla::BackendKind> GetBackendKind(mlir::func::FuncOp fn);
void SetBackendKind(mlir::MLIRContext* context, mlir::func::FuncOp fn,
xla::BackendKind backend_kind);

} // namespace xla

#endif // XLA_CODEGEN_EMITTERS_IR_XLA_OPS_H_
Loading

0 comments on commit aa1410d

Please sign in to comment.