From 9f9d7f6ba6b410ca340e3d81e3df289c39f7e9d7 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 13 May 2024 07:30:48 -0700 Subject: [PATCH] Reshape multi-dimensional constants to 1d. The LLVM lowering doesn't support arbitrary shapes. PiperOrigin-RevId: 633203497 --- xla/service/gpu/fusions/mlir/lower_tensors.cc | 7 +++++++ xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/fusions/mlir/lower_tensors.cc b/xla/service/gpu/fusions/mlir/lower_tensors.cc index a529e6d86d00e..5677b45e5342b 100644 --- a/xla/service/gpu/fusions/mlir/lower_tensors.cc +++ b/xla/service/gpu/fusions/mlir/lower_tensors.cc @@ -334,6 +334,13 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value, mlir::ModuleOp module, bool is_constant, int addr_space, mlir::ImplicitLocOpBuilder& b) { + if (auto elements = mlir::dyn_cast_or_null(value)) { + // The lowering to LLVM only works for 1d tensors or those with trailing + // unit dimensions. + value = elements.reshape(mlir::RankedTensorType::get( + {elements.getNumElements()}, elements.getElementType())); + } + Type element_type = shaped_ty.getElementType(); // Needed to support complex element type. mlir::LLVMTypeConverter converter(b.getContext()); diff --git a/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir index 707cfc6f34196..d93f3ecd514c2 100644 --- a/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir +++ b/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir @@ -163,8 +163,8 @@ module { return %0 : f32 } } -// CHECK: llvm.mlir.global private constant @global_cst_0(dense<[ -// CHECK-SAME: [1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32> +// CHECK: llvm.mlir.global private constant @global_cst_0(dense< +// CHECK-SAME: [1.000000e+00, 2.000000e+00]> : tensor<2xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32> // CHECK: @extract_from_constant // CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @global_cst_0 : !llvm.ptr // CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ADDR_OF]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32