diff --git a/xla/backends/cpu/codegen/BUILD b/xla/backends/cpu/codegen/BUILD index 26a1f072d1955..2da3a8723a7ed 100644 --- a/xla/backends/cpu/codegen/BUILD +++ b/xla/backends/cpu/codegen/BUILD @@ -1,3 +1,4 @@ +load("//xla:strict.default.bzl", "py_strict_test") load("//xla:xla.bzl", "xla_cc_test", "xla_internal") load( "//xla/tsl/platform:build_config_root.bzl", @@ -366,6 +367,33 @@ cc_library( ], ) +cc_library( + name = "dot_kernel_emitter", + srcs = ["dot_kernel_emitter.cc"], + hdrs = ["dot_kernel_emitter.h"], + deps = [ + ":kernel_api_ir_builder", + ":target_machine_features", + "//xla:util", + "//xla/codegen:kernel_definition", + "//xla/codegen:kernel_emitter", + "//xla/codegen:kernel_spec", + "//xla/codegen:llvm_ir_kernel_source", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:hlo_module_config", + "//xla/service/cpu:dot_op_emitter", + "//xla/service/llvm_ir:ir_array", + "//xla/stream_executor:launch_dim", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:ir_headers", + ], +) + xla_cc_test( name = "object_loader_test", srcs = ["object_loader_test.cc"], @@ -397,3 +425,17 @@ xla_cc_test( "@tsl//tsl/platform:statusor", ], ) + +py_strict_test( + name = "dot_kernel_emitter_test", + srcs = ["dot_kernel_emitter_test.py"], + tags = [ + "no_oss", + ], + deps = [ + "//third_party/py/numpy", + "//xla/backends/cpu/testlib", + "//xla/codegen/testlib", + "@absl_py//absl/testing:absltest", + ], +) diff --git a/xla/backends/cpu/codegen/dot_kernel_emitter.cc b/xla/backends/cpu/codegen/dot_kernel_emitter.cc new file mode 100644 index 0000000000000..08bb828297f16 --- /dev/null +++ b/xla/backends/cpu/codegen/dot_kernel_emitter.cc @@ -0,0 +1,112 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/dot_kernel_emitter.h" + +#include +#include + +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" +#include "xla/backends/cpu/codegen/target_machine_features.h" +#include "xla/codegen/kernel_definition.h" +#include "xla/codegen/kernel_spec.h" +#include "xla/codegen/llvm_ir_kernel_source.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/dot_op_emitter.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla::cpu { + +static bool IsDotCodegenStrategy(DotImplementationStrategy strategy) { + switch (strategy) { + case DotImplementationStrategy::kNaiveLlvmIr: + case DotImplementationStrategy::kTiledLlvmIrGemv: + case DotImplementationStrategy::kTiledLlvmIrGemm: + return true; + default: + return false; + } +} + +DotKernelEmitter::DotKernelEmitter(const HloInstruction* instr, + const BufferAssignment* buffer_assignment, + const TargetMachineFeatures* target_machine) + : instr_(instr), + buffer_assignment_(buffer_assignment), + target_machine_(target_machine) {} + +absl::StatusOr DotKernelEmitter::EmitKernelDefinition() { + const HloModuleConfig& config = instr_->GetModule()->config(); + + DotImplementationStrategy strategy = + GetDotImplementationStrategy(config, *instr_, *target_machine_); + + if (!IsDotCodegenStrategy(strategy)) { + return Internal("Unsupported dot implementation strategy"); + } + + auto ctx = std::make_unique(); + + const HloModule* hlo_module = instr_->GetModule(); + if (hlo_module == nullptr) { + return Internal("HloModule is null"); + } + + KernelApiIrBuilder kernel_api_ir_builder( + *ctx, + KernelApiIrBuilder::Options::FromHloModuleConfig(hlo_module->config())); + + std::unique_ptr llvm_module = KernelApiIrBuilder::CreateModule( + absl::StrCat(instr_->name(), "_elemental_kernel_module"), *ctx); + + TF_ASSIGN_OR_RETURN(KernelApiIrBuilder::KernelPrototype kernel_prototype, + kernel_api_ir_builder.EmitKernelPrototype( + *llvm_module, instr_, buffer_assignment_, "_kernel")); + + llvm::IRBuilder<> builder(*ctx); + builder.SetInsertPoint( + kernel_prototype.function->getEntryBlock().getTerminator()); + + llvm_ir::IrArray lhs_array = kernel_prototype.arguments[0]; + llvm_ir::IrArray rhs_array = kernel_prototype.arguments[1]; + llvm_ir::IrArray target_array = kernel_prototype.results[0]; + + TF_RETURN_IF_ERROR(EmitDotOperation( + *instr_, target_array, lhs_array, rhs_array, + /*addend_array=*/nullptr, /*executable_run_options_value=*/nullptr, + &builder, config, *target_machine_, + /*allow_runtime_calls=*/false)); + + auto source = std::make_unique(std::move(ctx), + std::move(llvm_module)); + + KernelSpec spec(kernel_prototype.function->getName(), se::ThreadDim(), + std::move(kernel_prototype.buffer_uses)); + + return KernelDefinition(std::move(spec), std::move(source)); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/codegen/dot_kernel_emitter.h b/xla/backends/cpu/codegen/dot_kernel_emitter.h new file mode 100644 index 0000000000000..8cb7c3c6577fe --- /dev/null +++ b/xla/backends/cpu/codegen/dot_kernel_emitter.h @@ -0,0 +1,45 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_DOT_KERNEL_EMITTER_H_ +#define XLA_BACKENDS_CPU_CODEGEN_DOT_KERNEL_EMITTER_H_ + +#include "absl/status/statusor.h" +#include "xla/backends/cpu/codegen/target_machine_features.h" +#include "xla/codegen/kernel_definition.h" +#include "xla/codegen/kernel_emitter.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_assignment.h" + +namespace xla::cpu { + +class DotKernelEmitter final : public KernelEmitter { + public: + DotKernelEmitter(const HloInstruction* instr, + const BufferAssignment* buffer_assignment, + const TargetMachineFeatures* target_machine); + + absl::StatusOr EmitKernelDefinition() override; + + private: + const HloInstruction* instr_; + + const BufferAssignment* buffer_assignment_ = nullptr; + const TargetMachineFeatures* target_machine_ = nullptr; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_CODEGEN_DOT_KERNEL_EMITTER_H_ diff --git a/xla/backends/cpu/codegen/dot_kernel_emitter_test.py b/xla/backends/cpu/codegen/dot_kernel_emitter_test.py new file mode 100644 index 0000000000000..ddb4e22493b90 --- /dev/null +++ b/xla/backends/cpu/codegen/dot_kernel_emitter_test.py @@ -0,0 +1,68 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import absltest +import numpy as np + +from xla.backends.cpu import testlib as testlib_cpu +from xla.backends.cpu.testlib import utilities +from xla.codegen import testlib as testlib_base +from xla.codegen.testlib import utilities as base_utilities + +create_literal = base_utilities.create_literal_from_np +HloInstruction = testlib_base.HloInstruction +HloOpcode = testlib_base.HloOpcode + + +class DotKernelRunnerTest(absltest.TestCase): + + def test_dot_kernel_emitter(self): + lhs_np = np.array([[1, 2], [3, 4]], dtype=np.float32) + rhs_np = np.array([[5, 6], [7, 8]], dtype=np.float32) + + lhs_literal = create_literal(lhs_np) + rhs_literal = create_literal(rhs_np) + + output_literal = create_literal(np.ndarray((2, 2), dtype=np.float32)) + + lhs_param = HloInstruction.create_parameter(0, lhs_literal.shape(), "lhs") + rhs_param = HloInstruction.create_parameter(1, rhs_literal.shape(), "rhs") + + dot_dimension_numbers = testlib_base.DotDimensionNumbers([1], [0], [], []) + hlo_op = HloInstruction.create_dot( + output_literal.shape(), lhs_param, rhs_param, dot_dimension_numbers + ) + + hlo_module, buffer_assignment = utilities.build_hlo_module( + hlo_op, lhs_param, rhs_param + ) + jit_compiler = testlib_cpu.JitCompiler() + + emitter = testlib_cpu.DotKernelEmitter( + hlo_module.get_root_instruction(), + buffer_assignment, + jit_compiler.get_target_machine(), + ) + + runner = testlib_cpu.KernelRunner.create( + emitter.emit_kernel_definition(), jit_compiler + ) + + runner.call([lhs_literal, rhs_literal, output_literal]) + np.testing.assert_equal(np.asarray(output_literal), lhs_np @ rhs_np) + + +if __name__ == "__main__": + absltest.main() diff --git a/xla/backends/cpu/testlib/BUILD b/xla/backends/cpu/testlib/BUILD index c170dd9c67a51..ccd4763b83ef2 100644 --- a/xla/backends/cpu/testlib/BUILD +++ b/xla/backends/cpu/testlib/BUILD @@ -103,6 +103,7 @@ tsl_pybind_extension( "@com_google_absl//absl/strings:string_view", "@nanobind", "@local_config_python//:python_headers", # buildcleaner: keep + "//xla/backends/cpu/codegen:dot_kernel_emitter", "//xla/backends/cpu/codegen:elemental_kernel_emitter", "//xla/backends/cpu/codegen:jit_compiler", "//xla/backends/cpu/codegen:target_machine_features", diff --git a/xla/backends/cpu/testlib/__init__.py b/xla/backends/cpu/testlib/__init__.py index 9a101f0e9b634..2bcfe8a2342ea 100644 --- a/xla/backends/cpu/testlib/__init__.py +++ b/xla/backends/cpu/testlib/__init__.py @@ -17,6 +17,7 @@ from xla.backends.cpu.testlib import _extension # go/keep-sorted start +DotKernelEmitter = _extension.DotKernelEmitter ElementalKernelEmitter = _extension.ElementalKernelEmitter HloCompiler = _extension.HloCompiler JitCompiler = _extension.JitCompiler diff --git a/xla/backends/cpu/testlib/kernel_runner_extension.cc b/xla/backends/cpu/testlib/kernel_runner_extension.cc index bfc0c412f1130..2292efa39c6b5 100644 --- a/xla/backends/cpu/testlib/kernel_runner_extension.cc +++ b/xla/backends/cpu/testlib/kernel_runner_extension.cc @@ -28,6 +28,7 @@ limitations under the License. #include "nanobind/stl/tuple.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/backends/cpu/codegen/dot_kernel_emitter.h" #include "xla/backends/cpu/codegen/elemental_kernel_emitter.h" #include "xla/backends/cpu/codegen/jit_compiler.h" #include "xla/backends/cpu/codegen/target_machine_features.h" @@ -121,6 +122,13 @@ NB_MODULE(_extension, kernel_runner_module) { nb::keep_alive<1, 2>(), nb::keep_alive<1, 3>(), nb::keep_alive<1, 4>()); + nb::class_(kernel_runner_module, + "DotKernelEmitter") + .def(nb::init(), + nb::keep_alive<1, 2>(), nb::keep_alive<1, 3>(), + nb::keep_alive<1, 4>()); + nb::class_(kernel_runner_module, "JitCompiler") .def(nb::new_([]() { absl::StatusOr compiler = diff --git a/xla/codegen/testlib/__init__.py b/xla/codegen/testlib/__init__.py index ea75a34464d27..52a91e351203c 100644 --- a/xla/codegen/testlib/__init__.py +++ b/xla/codegen/testlib/__init__.py @@ -20,6 +20,7 @@ # go/keep-sorted start BufferAssignment = _extension.BufferAssignment ComparisonDirection = _extension.ComparisonDirection +DotDimensionNumbers = _extension.DotDimensionNumbers HloInstruction = _extension.HloInstruction HloModule = _extension.HloModule HloOpcode = _extension.HloOpcode diff --git a/xla/codegen/testlib/kernel_runner_extension.cc b/xla/codegen/testlib/kernel_runner_extension.cc index c6a7b576bedf1..ce174a90f0019 100644 --- a/xla/codegen/testlib/kernel_runner_extension.cc +++ b/xla/codegen/testlib/kernel_runner_extension.cc @@ -69,12 +69,19 @@ void KernelRunnerCall(KernelRunner* kernel_runner, } } -// Need this helper as Literal rquires an explicit clone. +// Need this helper as Literal requires an explicit clone. std::unique_ptr CreateConstantHloInstruction( const Literal& literal) { return HloInstruction::CreateConstant(literal.Clone()); } +std::unique_ptr CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers) { + return HloInstruction::CreateDot(shape, lhs, rhs, dimension_numbers, + PrecisionConfig()); +} + std::unique_ptr CreateComparisonHloInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, Comparison::Direction direction) { @@ -174,12 +181,32 @@ NB_MODULE(_extension, kernel_runner_module) { .value("kLe", Comparison::Direction::kLe) .value("kLt", Comparison::Direction::kLt); + nb::class_(kernel_runner_module, "DotDimensionNumbers") + .def("__init__", [](DotDimensionNumbers* self, + std::vector lhs_contracting_dims, + std::vector rhs_contracting_dims, + std::vector lhs_batch_dims, + std::vector rhs_batch_dims) { + new (self) DotDimensionNumbers(); + self->mutable_lhs_contracting_dimensions()->Assign( + lhs_contracting_dims.begin(), lhs_contracting_dims.end()); + self->mutable_rhs_contracting_dimensions()->Assign( + rhs_contracting_dims.begin(), rhs_contracting_dims.end()); + + self->mutable_lhs_batch_dimensions()->Assign(lhs_batch_dims.begin(), + lhs_batch_dims.end()); + self->mutable_rhs_batch_dimensions()->Assign(rhs_batch_dims.begin(), + rhs_batch_dims.end()); + }); + nb::class_ hlo_instruction(kernel_runner_module, "HloInstruction"); // Factory methods hlo_instruction .def_static("create_parameter", &HloInstruction::CreateParameter) .def_static("create_constant", &CreateConstantHloInstruction) + .def_static("create_dot", &CreateDot, nb::keep_alive<0, 2>(), + nb::keep_alive<0, 3>()) .def_static("create_unary", &HloInstruction::CreateUnary, nb::keep_alive<0, 3>()) .def_static("create_binary", &HloInstruction::CreateBinary, diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index f3caa3fd5a023..20148fab3c865 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -847,6 +847,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/backends/cpu:xnn_emitter", "//xla/backends/cpu:xnn_fusion", + "//xla/backends/cpu/codegen:dot_kernel_emitter", "//xla/backends/cpu/codegen:elemental_kernel_emitter", "//xla/backends/cpu/codegen:target_machine_features", "//xla/backends/cpu/runtime:all_gather_thunk", diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index 1890d5377bfb4..89a3325899737 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -220,38 +220,6 @@ static bool IsDotCodegenStrategy(DotImplementationStrategy strategy) { kDotCodegenStrategies.end(); } -absl::StatusOr IrEmitter2::EmitDotHostKernel( - const HloInstruction* instr) { - VLOG(2) << "Emit dot host kernel: " << instr->name(); - - DotImplementationStrategy strategy = GetDotImplementationStrategy( - hlo_module_.config(), *instr, - nested_ir_emitter_->target_machine_features()); - - if (!IsDotCodegenStrategy(strategy)) { - return Internal("Unsupported dot implementation strategy"); - } - - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); - - llvm::IRBuilder<> b(module_->getContext()); - b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - - llvm_ir::IrArray lhs_array = kernel_prototype.arguments[0]; - llvm_ir::IrArray rhs_array = kernel_prototype.arguments[1]; - llvm_ir::IrArray target_array = kernel_prototype.results[0]; - - TF_RETURN_IF_ERROR(EmitDotOperation( - *instr, target_array, lhs_array, rhs_array, - /*addend_array=*/nullptr, /*executable_run_options_value=*/nullptr, &b, - hlo_module_.config(), nested_ir_emitter_->target_machine_features(), - /*allow_runtime_calls=*/false)); - - return kernels_.emplace_back( - KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); -} - absl::StatusOr IrEmitter2::EmitConcatenateHostKernel( const HloInstruction* instr) { VLOG(2) << "Emit concatenate host kernel: " << instr->name(); diff --git a/xla/service/cpu/ir_emitter2.h b/xla/service/cpu/ir_emitter2.h index 77ea6647d4ec9..9dd1188f24105 100644 --- a/xla/service/cpu/ir_emitter2.h +++ b/xla/service/cpu/ir_emitter2.h @@ -105,11 +105,6 @@ class IrEmitter2 { absl::StatusOr EmitFusionHostKernel( const HloFusionInstruction* fusion); - // Emits a host kernel for the given dot instruction. Small dot operations - // are emitted as LLVM IR directly, while larger ones are emitted as a dot - // thunk that calls into libraries. - absl::StatusOr EmitDotHostKernel(const HloInstruction* instr); - // Emits a host kernel for the given concatenate instruction. absl::StatusOr EmitConcatenateHostKernel( const HloInstruction* instr); diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 15254d8f19fe4..5dc4bb1cafff6 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/codegen/dot_kernel_emitter.h" #include "xla/backends/cpu/codegen/elemental_kernel_emitter.h" #include "xla/backends/cpu/codegen/target_machine_features.h" #include "xla/backends/cpu/runtime/all_gather_thunk.h" @@ -820,12 +821,23 @@ absl::StatusOr ThunkEmitter::EmitDotThunk( case DotImplementationStrategy::kNaiveLlvmIr: case DotImplementationStrategy::kTiledLlvmIrGemm: case DotImplementationStrategy::kTiledLlvmIrGemv: { - TF_ASSIGN_OR_RETURN(auto kernel, - ir_emitter_.EmitDotHostKernel(instruction)); - TF_ASSIGN_OR_RETURN(auto buffers, - GetHostKernelAllocationSlices(instruction)); - - return MakeKernelThunkSequence(instruction, buffers, kernel); + DotKernelEmitter emitter(instruction, &buffer_assignment_, + &target_machine_features_); + TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition, + emitter.EmitKernelDefinition()); + + auto [kernel_spec, kernel_source] = + std::move(kernel_definition).release(); + auto llvm_ir_kernel_source = absl::WrapUnique( + tsl::down_cast(kernel_source.release())); + + kernels_.push_back( + {kernel_spec.name(), + std::move(*llvm_ir_kernel_source).thread_safe_module()}); + + return MakeKernelThunkSequence( + instruction, std::move(kernel_spec), + /*min_alignment=*/cpu_function_runtime::MinAlign()); } // Emit DotThunk implementing dot instruction as a library call.