Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a minimal but viable implementation of string arrays (with numpy.dtypes.StringDType) in JAX. Currently this only supports making of a string array by means of either jax.numpy.asarray or jax.device_put and reading it back with jax.device_get. #21503

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ cc_library(
features = ["-use_header_modules"],
visibility = [":friends"],
deps = [
":nb_helpers",
":nb_numpy",
"//xla:literal",
"//xla:shape_util",
Expand Down Expand Up @@ -351,6 +350,7 @@ cc_library(
# placeholder for index annotation deps
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
Expand Down
83 changes: 83 additions & 0 deletions xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -97,6 +98,7 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -1714,15 +1716,96 @@ absl::StatusOr<nb::object> PyHostValue::AsNumPyArray(
} else {
TF_RETURN_IF_ERROR(ready_.Await());
}
if (string_array_contents_ != nullptr) {
TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array));
}
return value_;
}

absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray(
ifrt::Array* ifrt_array) {
#ifdef NPY_2_0_API_VERSION
if (PyArray_RUNTIME_VERSION < NPY_2_0_API_VERSION) {
return absl::FailedPreconditionError(
absl::StrCat("String arrays are not supported in NumPy version: ",
PyArray_RUNTIME_VERSION));
}
auto numpy_dtype = nb::steal<nb_dtype>(
reinterpret_cast<PyObject*>(PyArray_DescrFromType(NPY_VSTRING)));
value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(),
/*strides=*/std::nullopt);

auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr());
auto iter =
nb::steal(PyArray_IterNew(reinterpret_cast<PyObject*>(dst_py_array_obj)));
for (auto& cord : *string_array_contents_) {
absl::string_view input_str_view = cord.Flatten();
auto py_unicode = nb::steal(PyUnicode_FromStringAndSize(
input_str_view.data(), input_str_view.size()));
if (py_unicode.ptr() == nullptr) {
return absl::InternalError("PyUnicode_FromStringAndSize failed");
}
if (PyArray_SETITEM(dst_py_array_obj,
static_cast<char*>(PyArray_ITER_DATA(iter.ptr())),
py_unicode.ptr()) != 0) {
return absl::InternalError("PyArray_SETITEM failed");
}
PyArray_ITER_NEXT(iter.ptr());
}

value_.attr("flags").attr("writeable") = nb::bool_(false);

string_array_contents_.reset();

return absl::OkStatus();
#else
return absl::FailedPreconditionError(
"String arrays are not supported in this NumPy version.");
#endif
}

absl::Status PyHostValue::CopyStringArrayToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array) {
auto transfer_guard_formatter = [ifrt_array] {
return absl::StrCat(
"shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","),
"), dtype=", ifrt_array->dtype().DebugString(), ", device=",
ifrt_array->sharding().devices()->devices().front()->DebugString());
};
TF_RETURN_IF_ERROR(
jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter));

TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype()));
auto shape = ifrt_array->shape();

// Allocate a vector of cords to hold the contents of the array until
// they are until they are ultimately converted to a numpy array as part
// of the `AsNumPyArray` call.
string_array_contents_ =
std::make_shared<std::vector<absl::Cord>>(shape.num_elements());
ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(),
/*byte_strides=*/std::nullopt,
ifrt::ArrayCopySemantics::kAlwaysCopy);

ready_.OnReady(
[string_array_contents = string_array_contents_](absl::Status) {
}); // Keeps the cords alive until the copy is done.

return absl::OkStatus();
}

absl::Status PyHostValue::CopyToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array) {
if (ready_.IsValid()) {
// The array value has been populated, so CopyToHostAsync has been called.
return absl::OkStatus();
}

// Copying in Arrays of type kString requires some special handling
if (ifrt_array->dtype().kind() == ifrt::DType::kString) {
return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array);
}

auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array);
if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() &&
IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) {
Expand Down
12 changes: 12 additions & 0 deletions xla/python/py_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
Expand Down Expand Up @@ -69,8 +70,19 @@ class PyHostValue {
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array);

private:
absl::Status CopyStringArrayToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array);

absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array);

ifrt::Future<> ready_;
nb_numpy_ndarray value_;

// Optional field, only used for arrays of type kString. This vector of cords
// serves as input buffer for the CopyToHostBuffer call. It holds these
// contents until it is lazily converted it to a numpy array when the user
// calls `AsNumPyArray`.
std::shared_ptr<std::vector<absl::Cord>> string_array_contents_;
};

// Private to PyArray, but you cannot forward declare member classes.
Expand Down
71 changes: 71 additions & 0 deletions xla/python/py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@ limitations under the License.
#include <exception>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -66,6 +71,32 @@ namespace xla {

namespace {

absl::StatusOr<std::vector<absl::Cord>> StringDTypeArrayToCords(
PyArrayObject* py_array_obj) {
if (PyArray_SIZE(py_array_obj) == 0) {
return absl::InvalidArgumentError("empty numpy array");
}

std::vector<absl::Cord> cords;
cords.reserve(PyArray_SIZE(py_array_obj));

auto iter =
nb::steal(PyArray_IterNew(reinterpret_cast<PyObject*>(py_array_obj)));
while (PyArray_ITER_NOTDONE(iter.ptr())) {
auto* iter_data = PyArray_ITER_DATA(iter.ptr());
auto* item = PyArray_GETITEM(py_array_obj, static_cast<char*>(iter_data));
if (!item) {
return absl::InternalError(
"Failed to get elements out of the ndarray iter.");
}
Py_ssize_t len;
auto str = PyUnicode_AsUTF8AndSize(item, &len);
cords.push_back(absl::Cord(absl::string_view(str, len)));
PyArray_ITER_NEXT(iter.ptr());
}
return cords;
}

using DevicePutFunc = std::function<absl::StatusOr<DevicePutResultFn>(
nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options,
ifrt::MemoryKind to_memory_kind)>;
Expand Down Expand Up @@ -252,10 +283,50 @@ absl::StatusOr<DevicePutResultFn> HandleNumpyScalar(
};
}

absl::StatusOr<DevicePutResultFn> HandleStringNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);
auto py_array_obj = reinterpret_cast<PyArrayObject*>(array.ptr());
TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj));

// Assemble all the parameters of MakeArrayFromHostBuffer
void* data = cords.data();
ifrt::Shape shape(
absl::MakeSpan(static_cast<const int64_t*>(array.shape()), array.ndim()));
std::shared_ptr<xla::ifrt::Sharding> sharding =
xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind);

auto on_done_with_host_buffer = [cords = std::move(cords)] {};

return [client, data = data, shape = std::move(shape),
sharding = std::move(sharding),
on_done_with_host_buffer =
std::move(on_done_with_host_buffer)]() mutable
-> absl::StatusOr<DevicePutResult> {
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
client->MakeArrayFromHostBuffer(
data, ifrt::DType(ifrt::DType::kString), std::move(shape),
/*byte_strides=*/std::nullopt, std::move(sharding),
ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes,
std::move(on_done_with_host_buffer)));

return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false);
};
}

absl::StatusOr<DevicePutResultFn> HandleNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);

// String numpy arrays require substantially different processing.
if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') {
return HandleStringNumpyArray(h, client, to_device, options,
to_memory_kind);
}

TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype()));

PrimitiveType squashed_type;
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 310
_version = 311

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down
8 changes: 5 additions & 3 deletions xla/tsl/python/lib/core/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ limitations under the License.
#include <Python.h>
// clang-format on

#include "numpy/arrayobject.h" // IWYU pragma: export
#include "numpy/npy_common.h" // IWYU pragma: export
#include "numpy/ufuncobject.h" // IWYU pragma: export
#include "numpy/arrayobject.h" // IWYU pragma: export
#include "numpy/ndarraytypes.h" // IWYU pragma: export
#include "numpy/npy_common.h" // IWYU pragma: export
#include "numpy/numpyconfig.h" // IWYU pragma: export
#include "numpy/ufuncobject.h" // IWYU pragma: export

namespace tsl {

Expand Down
Loading