Skip to content

Commit

Permalink
[XLA:CPU] Add more comprehensive tests for dot emitter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720131590
  • Loading branch information
WillFroom authored and Google-ML-Automation committed Jan 27, 2025
1 parent 2421f74 commit 8d83352
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 74 deletions.
77 changes: 39 additions & 38 deletions xla/backends/cpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -343,30 +343,6 @@ cc_library(
],
)

cc_library(
name = "object_loader",
srcs = ["object_loader.cc"],
hdrs = ["object_loader.h"],
deps = [
":compiled_function_library",
":contiguous_section_memory_manager",
"//xla/backends/cpu/runtime:function_library",
"//xla/service/cpu:orc_jit_memory_mapper",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:JITLink",
"@llvm-project//llvm:OrcShared",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:ir_headers",
],
)

cc_library(
name = "dot_kernel_emitter",
srcs = ["dot_kernel_emitter.cc"],
Expand Down Expand Up @@ -394,6 +370,45 @@ cc_library(
],
)

py_strict_test(
name = "dot_kernel_emitter_test",
srcs = ["dot_kernel_emitter_test.py"],
tags = [
"no_oss",
],
deps = [
"//third_party/py/numpy",
"//xla/backends/cpu/testlib",
"//xla/codegen/testlib",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
],
)

cc_library(
name = "object_loader",
srcs = ["object_loader.cc"],
hdrs = ["object_loader.h"],
deps = [
":compiled_function_library",
":contiguous_section_memory_manager",
"//xla/backends/cpu/runtime:function_library",
"//xla/service/cpu:orc_jit_memory_mapper",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:JITLink",
"@llvm-project//llvm:OrcShared",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:ir_headers",
],
)

xla_cc_test(
name = "object_loader_test",
srcs = ["object_loader_test.cc"],
Expand Down Expand Up @@ -425,17 +440,3 @@ xla_cc_test(
"@tsl//tsl/platform:statusor",
],
)

py_strict_test(
name = "dot_kernel_emitter_test",
srcs = ["dot_kernel_emitter_test.py"],
tags = [
"no_oss",
],
deps = [
"//third_party/py/numpy",
"//xla/backends/cpu/testlib",
"//xla/codegen/testlib",
"@absl_py//absl/testing:absltest",
],
)
129 changes: 120 additions & 9 deletions xla/backends/cpu/codegen/dot_kernel_emitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,82 @@
# limitations under the License.
# ==============================================================================

from collections.abc import Sequence
import os

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from xla.backends.cpu import testlib as testlib_cpu
from xla.backends.cpu.testlib import utilities
from xla.codegen import testlib as testlib_base
from xla.codegen.testlib import utilities as base_utilities

# We have some checks in the dot emitter which will fail to emit for certain
# shapes if multi-threading is enabled.
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false"

create_literal = base_utilities.create_literal_from_np
HloInstruction = testlib_base.HloInstruction
HloOpcode = testlib_base.HloOpcode


class DotKernelRunnerTest(absltest.TestCase):

def test_dot_kernel_emitter(self):
lhs_np = np.array([[1, 2], [3, 4]], dtype=np.float32)
rhs_np = np.array([[5, 6], [7, 8]], dtype=np.float32)
def create_input(
value_range: tuple[float, float],
shape: Sequence[int],
dtype: np.dtype,
) -> np.ndarray:
size = np.prod(shape) if shape else 1
result = np.linspace(
value_range[0], value_range[1], size, dtype=dtype
).reshape(shape)

return result


emitter_types = [
testlib_cpu.ElementalKernelEmitter,
testlib_cpu.DotKernelEmitter,
]


dtypes_to_test = [
np.dtype(np.uint8),
np.dtype(np.uint16),
np.dtype(np.uint32),
np.dtype(np.uint64),
np.dtype(np.int8),
np.dtype(np.int16),
np.dtype(np.int32),
np.dtype(np.int64),
np.dtype(np.float16),
np.dtype(np.float32),
np.dtype(np.float64),
]


class DotKernelTest(parameterized.TestCase):

@parameterized.product(
emitter_type=emitter_types,
rhs_shape=[(4,), (4, 3), (4, 3, 10), (500, 10, 123)],
dtype=dtypes_to_test,
)
def test_vector_matrix_dot(self, emitter_type, rhs_shape, dtype):
value_range = (0.0, 20.0)
lhs_np = create_input(value_range, rhs_shape[0], dtype)
rhs_np = create_input(value_range, rhs_shape, dtype)

lhs_literal = create_literal(lhs_np)
rhs_literal = create_literal(rhs_np)

output_literal = create_literal(np.ndarray((2, 2), dtype=np.float32))
output_literal = create_literal(np.ndarray(rhs_shape[1:], dtype=dtype))

lhs_param = HloInstruction.create_parameter(0, lhs_literal.shape(), "lhs")
rhs_param = HloInstruction.create_parameter(1, rhs_literal.shape(), "rhs")

dot_dimension_numbers = testlib_base.DotDimensionNumbers([1], [0], [], [])
dot_dimension_numbers = testlib_base.DotDimensionNumbers([0], [0])
hlo_op = HloInstruction.create_dot(
output_literal.shape(), lhs_param, rhs_param, dot_dimension_numbers
)
Expand All @@ -50,7 +98,7 @@ def test_dot_kernel_emitter(self):
)
jit_compiler = testlib_cpu.JitCompiler()

emitter = testlib_cpu.DotKernelEmitter(
emitter = emitter_type(
hlo_module.get_root_instruction(),
buffer_assignment,
jit_compiler.get_target_machine(),
Expand All @@ -61,7 +109,70 @@ def test_dot_kernel_emitter(self):
)

runner.call([lhs_literal, rhs_literal, output_literal])
np.testing.assert_equal(np.asarray(output_literal), lhs_np @ rhs_np)

np_result = np.tensordot(lhs_np, rhs_np, axes=(0, 0))
np.testing.assert_array_max_ulp(
np.asarray(output_literal),
np_result,
maxulp=10,
)

@parameterized.product(
emitter_type=emitter_types,
shapes=[
((1, 1), (1, 1)),
((1, 1), (1, 10)),
((2, 2), (2, 2)),
((2, 2), (2, 3)),
((10, 10), (10, 10)),
((15, 13), (13, 17)),
],
dtype=dtypes_to_test,
)
def test_matrix_multiplication(self, emitter_type, shapes, dtype):
if dtype == np.float16 and emitter_type is testlib_cpu.DotKernelEmitter:
self.skipTest("float16 is not supported by the dot emitter")

value_range = (0.0, 20.0)
lhs_np = create_input(value_range, shapes[0], dtype)
rhs_np = create_input(value_range, shapes[1], dtype)

lhs_literal = create_literal(lhs_np)
rhs_literal = create_literal(rhs_np)

output_shape = shapes[0][:-1] + shapes[1][1:]
output_literal = create_literal(np.ndarray(output_shape, dtype=dtype))

lhs_param = HloInstruction.create_parameter(0, lhs_literal.shape(), "lhs")
rhs_param = HloInstruction.create_parameter(1, rhs_literal.shape(), "rhs")

dot_dimension_numbers = testlib_base.DotDimensionNumbers([1], [0])
hlo_op = HloInstruction.create_dot(
output_literal.shape(), lhs_param, rhs_param, dot_dimension_numbers
)

hlo_module, buffer_assignment = utilities.build_hlo_module(
hlo_op, lhs_param, rhs_param
)
jit_compiler = testlib_cpu.JitCompiler()

emitter = emitter_type(
hlo_module.get_root_instruction(),
buffer_assignment,
jit_compiler.get_target_machine(),
)

kernel_definition = emitter.emit_kernel_definition()
runner = testlib_cpu.KernelRunner.create(kernel_definition, jit_compiler)

runner.call([lhs_literal, rhs_literal, output_literal])

np_result = lhs_np @ rhs_np
np.testing.assert_array_max_ulp(
np.asarray(output_literal),
np_result,
maxulp=10,
)


if __name__ == "__main__":
Expand Down
12 changes: 3 additions & 9 deletions xla/backends/cpu/testlib/elemental_kernel_emitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,7 @@ def __repr__(self):
shape=[(4,), (4, 3), (4, 3, 10)],
dtype=[np.dtype(np.float32), np.dtype(np.float64)],
)
class ElementalKernelRunnerTest(absltest.TestCase):

def id(self):
return self._test_params_reprs.get(self._testMethodName, "")
class ElementalKernelRunnerTest(parameterized.TestCase):

def test_elemental_kernel_emitter(
self,
Expand Down Expand Up @@ -204,7 +201,7 @@ def test_elemental_kernel_emitter(
np.dtype(np.float64),
],
)
class ElementalComparisonKernelRunnerTest(absltest.TestCase):
class ElementalComparisonKernelRunnerTest(parameterized.TestCase):

def test_elemental_comparision_kernel_emitter(self, op_def, shape, dtype):
[direction, np_op] = op_def
Expand Down Expand Up @@ -264,10 +261,7 @@ def test_elemental_comparision_kernel_emitter(self, op_def, shape, dtype):
np.dtype(np.float64),
],
)
class HloModuleKernelRunnerTest(absltest.TestCase):

def id(self):
return self._test_params_reprs.get(self._testMethodName, "")
class HloModuleKernelRunnerTest(parameterized.TestCase):

def test_map(self, input_dimensions, dtype):
scalar_shape = xla_extension.Shape.scalar_shape(dtype)
Expand Down
1 change: 1 addition & 0 deletions xla/codegen/testlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ tsl_pybind_extension(
"@nanobind",
"@local_config_python//:python_headers", # buildcleaner: keep
"//xla:comparison_util",
"//xla:debug_options_flags",
"//xla:literal",
"//xla:shape_util",
"//xla:util",
Expand Down
50 changes: 32 additions & 18 deletions xla/codegen/testlib/kernel_runner_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "xla/codegen/kernel_spec.h"
#include "xla/codegen/testlib/kernel_runner.h"
#include "xla/comparison_util.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
Expand Down Expand Up @@ -88,6 +89,12 @@ std::unique_ptr<HloInstruction> CreateComparisonHloInstruction(
return HloInstruction::CreateCompare(shape, lhs, rhs, direction);
}

HloModuleConfig DefaultHloModuleConfigWithDebugOptions() {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsFromFlags());
return config;
}

// A dummy kernel runner that implements a simple elementwise add.
class DummyAddKernelRunner final : public KernelRunner {
public:
Expand Down Expand Up @@ -182,22 +189,27 @@ NB_MODULE(_extension, kernel_runner_module) {
.value("kLt", Comparison::Direction::kLt);

nb::class_<DotDimensionNumbers>(kernel_runner_module, "DotDimensionNumbers")
.def("__init__", [](DotDimensionNumbers* self,
std::vector<int64_t> lhs_contracting_dims,
std::vector<int64_t> rhs_contracting_dims,
std::vector<int64_t> lhs_batch_dims,
std::vector<int64_t> rhs_batch_dims) {
new (self) DotDimensionNumbers();
self->mutable_lhs_contracting_dimensions()->Assign(
lhs_contracting_dims.begin(), lhs_contracting_dims.end());
self->mutable_rhs_contracting_dimensions()->Assign(
rhs_contracting_dims.begin(), rhs_contracting_dims.end());

self->mutable_lhs_batch_dimensions()->Assign(lhs_batch_dims.begin(),
lhs_batch_dims.end());
self->mutable_rhs_batch_dimensions()->Assign(rhs_batch_dims.begin(),
rhs_batch_dims.end());
});
.def(
"__init__",
[](DotDimensionNumbers* self,
std::vector<int64_t> lhs_contracting_dims,
std::vector<int64_t> rhs_contracting_dims,
std::vector<int64_t> lhs_batch_dims,
std::vector<int64_t> rhs_batch_dims) {
new (self) DotDimensionNumbers();
self->mutable_lhs_contracting_dimensions()->Assign(
lhs_contracting_dims.begin(), lhs_contracting_dims.end());
self->mutable_rhs_contracting_dimensions()->Assign(
rhs_contracting_dims.begin(), rhs_contracting_dims.end());

self->mutable_lhs_batch_dimensions()->Assign(lhs_batch_dims.begin(),
lhs_batch_dims.end());
self->mutable_rhs_batch_dimensions()->Assign(rhs_batch_dims.begin(),
rhs_batch_dims.end());
},
nb::arg("lhs_contracting_dims"), nb::arg("rhs_contracting_dims"),
nb::arg("lhs_batch_dims") = std::vector<int64_t>{},
nb::arg("rhs_batch_dims") = std::vector<int64_t>{});

nb::class_<HloInstruction> hlo_instruction(kernel_runner_module,
"HloInstruction");
Expand Down Expand Up @@ -237,7 +249,8 @@ NB_MODULE(_extension, kernel_runner_module) {
.def_static("parse_from_string",
[](absl::string_view str) {
absl::StatusOr<std::unique_ptr<HloModule>> hlo_module =
ParseAndReturnUnverifiedModule(str);
ParseAndReturnUnverifiedModule(
str, DefaultHloModuleConfigWithDebugOptions());

if (!hlo_module.ok()) {
throw std::runtime_error(
Expand All @@ -250,7 +263,8 @@ NB_MODULE(_extension, kernel_runner_module) {
"build",
[](std::unique_ptr<HloInstruction> root, nb::args instructions) {
auto hlo_module = std::make_unique<HloModule>(
absl::StrCat(root->name(), "_module"), HloModuleConfig());
absl::StrCat(root->name(), "_module"),
DefaultHloModuleConfigWithDebugOptions());

HloComputation::Builder builder(
absl::StrCat(root->name(), "_computation"));
Expand Down

0 comments on commit 8d83352

Please sign in to comment.