diff --git a/xla/python/BUILD b/xla/python/BUILD index 8bdad8107b7e7..a858ebafe1730 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -202,7 +202,6 @@ cc_library( features = ["-use_header_modules"], visibility = [":friends"], deps = [ - ":nb_helpers", ":nb_numpy", "//xla:literal", "//xla:shape_util", @@ -351,6 +350,7 @@ cc_library( # placeholder for index annotation deps "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", diff --git a/xla/python/py_array.cc b/xla/python/py_array.cc index fb20f253ba222..f3959bed20a4a 100644 --- a/xla/python/py_array.cc +++ b/xla/python/py_array.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -97,6 +98,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -1714,15 +1716,96 @@ absl::StatusOr PyHostValue::AsNumPyArray( } else { TF_RETURN_IF_ERROR(ready_.Await()); } + if (string_array_contents_ != nullptr) { + TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array)); + } return value_; } +absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray( + ifrt::Array* ifrt_array) { +#ifdef NPY_2_0_API_VERSION + if (PyArray_RUNTIME_VERSION < NPY_2_0_API_VERSION) { + return absl::FailedPreconditionError( + absl::StrCat("String arrays are not supported in NumPy version: ", + PyArray_RUNTIME_VERSION)); + } + auto numpy_dtype = nb::steal( + reinterpret_cast(PyArray_DescrFromType(NPY_VSTRING))); + value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(), + /*strides=*/std::nullopt); + + auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr()); + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(dst_py_array_obj))); + for (auto& cord : *string_array_contents_) { + absl::string_view input_str_view = cord.Flatten(); + auto py_unicode = nb::steal(PyUnicode_FromStringAndSize( + input_str_view.data(), input_str_view.size())); + if (py_unicode.ptr() == nullptr) { + return absl::InternalError("PyUnicode_FromStringAndSize failed"); + } + if (PyArray_SETITEM(dst_py_array_obj, + static_cast(PyArray_ITER_DATA(iter.ptr())), + py_unicode.ptr()) != 0) { + return absl::InternalError("PyArray_SETITEM failed"); + } + PyArray_ITER_NEXT(iter.ptr()); + } + + value_.attr("flags").attr("writeable") = nb::bool_(false); + + string_array_contents_.reset(); + + return absl::OkStatus(); +#else + return absl::FailedPreconditionError( + "String arrays are not supported in this NumPy version."); +#endif +} + +absl::Status PyHostValue::CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype())); + auto shape = ifrt_array->shape(); + + // Allocate a vector of cords to hold the contents of the array until + // they are until they are ultimately converted to a numpy array as part + // of the `AsNumPyArray` call. + string_array_contents_ = + std::make_shared>(shape.num_elements()); + ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(), + /*byte_strides=*/std::nullopt, + ifrt::ArrayCopySemantics::kAlwaysCopy); + + ready_.OnReady( + [string_array_contents = string_array_contents_](absl::Status) { + }); // Keeps the cords alive until the copy is done. + + return absl::OkStatus(); +} + absl::Status PyHostValue::CopyToHostAsync( std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { if (ready_.IsValid()) { // The array value has been populated, so CopyToHostAsync has been called. return absl::OkStatus(); } + + // Copying in Arrays of type kString requires some special handling + if (ifrt_array->dtype().kind() == ifrt::DType::kString) { + return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() && IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) { diff --git a/xla/python/py_array.h b/xla/python/py_array.h index 0333b51f04cb1..9424592e60600 100644 --- a/xla/python/py_array.h +++ b/xla/python/py_array.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" @@ -69,8 +70,19 @@ class PyHostValue { std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); private: + absl::Status CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array); + ifrt::Future<> ready_; nb_numpy_ndarray value_; + + // Optional field, only used for arrays of type kString. This vector of cords + // serves as input buffer for the CopyToHostBuffer call. It holds these + // contents until it is lazily converted it to a numpy array when the user + // calls `AsNumPyArray`. + std::shared_ptr> string_array_contents_; }; // Private to PyArray, but you cannot forward declare member classes. diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 45baa4abf7935..5843cf4282d6b 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -21,14 +21,19 @@ limitations under the License. #include #include #include +#include #include #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -66,6 +71,32 @@ namespace xla { namespace { +absl::StatusOr> StringDTypeArrayToCords( + PyArrayObject* py_array_obj) { + if (PyArray_SIZE(py_array_obj) == 0) { + return absl::InvalidArgumentError("empty numpy array"); + } + + std::vector cords; + cords.reserve(PyArray_SIZE(py_array_obj)); + + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(py_array_obj))); + while (PyArray_ITER_NOTDONE(iter.ptr())) { + auto* iter_data = PyArray_ITER_DATA(iter.ptr()); + auto* item = PyArray_GETITEM(py_array_obj, static_cast(iter_data)); + if (!item) { + return absl::InternalError( + "Failed to get elements out of the ndarray iter."); + } + Py_ssize_t len; + auto str = PyUnicode_AsUTF8AndSize(item, &len); + cords.push_back(absl::Cord(absl::string_view(str, len))); + PyArray_ITER_NEXT(iter.ptr()); + } + return cords; +} + using DevicePutFunc = std::function( nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind)>; @@ -252,10 +283,50 @@ absl::StatusOr HandleNumpyScalar( }; } +absl::StatusOr HandleStringNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + xla::nb_numpy_ndarray array = nb::cast(h); + auto py_array_obj = reinterpret_cast(array.ptr()); + TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj)); + + // Assemble all the parameters of MakeArrayFromHostBuffer + void* data = cords.data(); + ifrt::Shape shape( + absl::MakeSpan(static_cast(array.shape()), array.ndim())); + std::shared_ptr sharding = + xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind); + + auto on_done_with_host_buffer = [cords = std::move(cords)] {}; + + return [client, data = data, shape = std::move(shape), + sharding = std::move(sharding), + on_done_with_host_buffer = + std::move(on_done_with_host_buffer)]() mutable + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + data, ifrt::DType(ifrt::DType::kString), std::move(shape), + /*byte_strides=*/std::nullopt, std::move(sharding), + ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes, + std::move(on_done_with_host_buffer))); + + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + absl::StatusOr HandleNumpyArray( nb::handle h, ifrt::Client* client, ifrt::Device* to_device, const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { xla::nb_numpy_ndarray array = nb::cast(h); + + // String numpy arrays require substantially different processing. + if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') { + return HandleStringNumpyArray(h, client, to_device, options, + to_memory_kind); + } + TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype())); PrimitiveType squashed_type; diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index fdc475192c2ef..44706ee28b865 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 310 +_version = 311 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/tsl/python/lib/core/numpy.h b/xla/tsl/python/lib/core/numpy.h index ca57a0370548e..307c253d111fc 100644 --- a/xla/tsl/python/lib/core/numpy.h +++ b/xla/tsl/python/lib/core/numpy.h @@ -43,9 +43,11 @@ limitations under the License. #include // clang-format on -#include "numpy/arrayobject.h" // IWYU pragma: export -#include "numpy/npy_common.h" // IWYU pragma: export -#include "numpy/ufuncobject.h" // IWYU pragma: export +#include "numpy/arrayobject.h" // IWYU pragma: export +#include "numpy/ndarraytypes.h" // IWYU pragma: export +#include "numpy/npy_common.h" // IWYU pragma: export +#include "numpy/numpyconfig.h" // IWYU pragma: export +#include "numpy/ufuncobject.h" // IWYU pragma: export namespace tsl {