Skip to content

Commit

Permalink
[xla:python] Add method to get python callback capsule without requir…
Browse files Browse the repository at this point in the history
…ing operand or result shapes / returning capsule descriptor.

PiperOrigin-RevId: 723571011
  • Loading branch information
danielsuo authored and Google-ML-Automation committed Feb 5, 2025
1 parent a85a5fe commit d758b65
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
19 changes: 19 additions & 0 deletions xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,8 @@ absl::StatusOr<nb::object> PyClient::MakePythonCallbackUsingHostSendAndRecv(
return callback_capsule;
}

// TODO(b/394595987): Remove this API method once we remove the call from
// mlir.py's get_emit_python_callback.
absl::StatusOr<std::pair<uint64_t, nb::object>>
PyClient::GetEmitPythonCallbackDescriptor(
nb::callable callable, absl::Span<Shape const> operand_shapes,
Expand All @@ -647,6 +649,20 @@ PyClient::GetEmitPythonCallbackDescriptor(
return std::make_pair(descriptor, nb::object(std::move(callback_capsule)));
}

// TODO(b/394595987): Deprecate / clean up this API method to remove the need
// for `operand_shapes` and `result_shapes` once we can remove
// xla::PyClient::GetEmitPythonCallbackDescriptor (called by mlir.py's
// get_emit_python_callback for CPU/GPU devices).
absl::StatusOr<nb::object> PyClient::GetEmitPythonCallback(
nb::callable callable) {
absl::Span<const Shape> operand_shapes;
absl::Span<const Shape> result_shapes;
TF_ASSIGN_OR_RETURN(auto descriptor_and_callback,
GetEmitPythonCallbackDescriptor(
std::move(callable), operand_shapes, result_shapes));
return nb::object(std::move(descriptor_and_callback.second));
}

XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
&XlaPythonCpuCallback);

Expand Down Expand Up @@ -764,6 +780,9 @@ PyType_Slot PyClient::slots_[] = {
xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallbackDescriptor),
nb::arg("callable"), nb::arg("operand_shapes"),
nb::arg("result_shapes").none() = nb::none())
.def("get_emit_python_callback",
xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallback),
nb::arg("callable"))
.def("make_python_callback_from_host_send_and_recv",
xla::ValueOrThrowWrapper(
&PyClient::MakePythonCallbackUsingHostSendAndRecv),
Expand Down
9 changes: 9 additions & 0 deletions xla/python/py_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ class PyClient {
absl::Span<Shape const> operand_shapes,
absl::Span<Shape const> result_shapes);

// `GetEmitPythonCallback` takes in an input Python callable. It returns a
// Python object whose reference will keep the Python callback alive.
//
// The callable receives as arguments NumPy arrays for arguments with array
// types, and None for Token argument. The callable must return a tuple of
// either arrays or None values.
absl::StatusOr<nanobind::object> GetEmitPythonCallback(
nanobind::callable callable);

// `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable
// that takes in arguments of shapes `operand_shapes` and returns results of
// shapes `result_shapes`. The arguments correspond to Send ops in the HLO
Expand Down

0 comments on commit d758b65

Please sign in to comment.