Skip to content

Commit

Permalink
Added int/uint2 dtypes to the various bits of the XLA runtime
Browse files Browse the repository at this point in the history
See jax-ml/jax#21369.

PiperOrigin-RevId: 636667084
  • Loading branch information
superbobry authored and copybara-github committed May 23, 2024
1 parent 11da7f4 commit ff042fb
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 5 deletions.
4 changes: 4 additions & 0 deletions xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,10 @@ typedef enum {
PJRT_Buffer_Type_U4,

PJRT_Buffer_Type_TOKEN,

// 2-bit integer types
PJRT_Buffer_Type_S2,
PJRT_Buffer_Type_U2,
} PJRT_Buffer_Type;

typedef enum {
Expand Down
8 changes: 8 additions & 0 deletions xla/pjrt/c/pjrt_c_api_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) {
return PJRT_Buffer_Type::PJRT_Buffer_Type_PRED;
case xla::PrimitiveType::TOKEN:
return PJRT_Buffer_Type::PJRT_Buffer_Type_TOKEN;
case xla::PrimitiveType::S2:
return PJRT_Buffer_Type::PJRT_Buffer_Type_S2;
case xla::PrimitiveType::S4:
return PJRT_Buffer_Type::PJRT_Buffer_Type_S4;
case xla::PrimitiveType::S8:
Expand All @@ -272,6 +274,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) {
return PJRT_Buffer_Type::PJRT_Buffer_Type_S32;
case xla::PrimitiveType::S64:
return PJRT_Buffer_Type::PJRT_Buffer_Type_S64;
case xla::PrimitiveType::U2:
return PJRT_Buffer_Type::PJRT_Buffer_Type_U2;
case xla::PrimitiveType::U4:
return PJRT_Buffer_Type::PJRT_Buffer_Type_U4;
case xla::PrimitiveType::U8:
Expand Down Expand Up @@ -317,6 +321,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) {
return xla::PrimitiveType::PRED;
case PJRT_Buffer_Type::PJRT_Buffer_Type_TOKEN:
return xla::PrimitiveType::TOKEN;
case PJRT_Buffer_Type::PJRT_Buffer_Type_S2:
return xla::PrimitiveType::S2;
case PJRT_Buffer_Type::PJRT_Buffer_Type_S4:
return xla::PrimitiveType::S4;
case PJRT_Buffer_Type::PJRT_Buffer_Type_S8:
Expand All @@ -327,6 +333,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) {
return xla::PrimitiveType::S32;
case PJRT_Buffer_Type::PJRT_Buffer_Type_S64:
return xla::PrimitiveType::S64;
case PJRT_Buffer_Type::PJRT_Buffer_Type_U2:
return xla::PrimitiveType::U2;
case PJRT_Buffer_Type::PJRT_Buffer_Type_U4:
return xla::PrimitiveType::U4;
case PJRT_Buffer_Type::PJRT_Buffer_Type_U8:
Expand Down
2 changes: 2 additions & 0 deletions xla/python/ifrt/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ class DType {
kPred = 1,

// Signed integral values of fixed width.
kS2 = 26,
kS4 = 21,
kS8 = 2,
kS16 = 3,
kS32 = 4,
kS64 = 5,

// Unsigned integral values of fixed width.
kU2 = 27,
kU4 = 22,
kU8 = 6,
kU16 = 7,
Expand Down
2 changes: 2 additions & 0 deletions xla/python/ifrt/dtype.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ message DTypeProto {
KIND_PRED = 1;

// Signed integral values of fixed width.
KIND_S2 = 26;
KIND_S4 = 21;
KIND_S8 = 2;
KIND_S16 = 3;
KIND_S32 = 4;
KIND_S64 = 5;

// Unsigned integral values of fixed width.
KIND_U2 = 27;
KIND_U4 = 22;
KIND_U8 = 6;
KIND_U16 = 7;
Expand Down
4 changes: 4 additions & 0 deletions xla/python/pjrt_ifrt/pjrt_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,13 @@ absl::StatusOr<xla::PrimitiveType> ToPrimitiveType(DType dtype) {
return PT
CASE(DType::kInvalid, xla::PrimitiveType::PRIMITIVE_TYPE_INVALID);
CASE(DType::kPred, xla::PrimitiveType::PRED);
CASE(DType::kS2, xla::PrimitiveType::S2);
CASE(DType::kS4, xla::PrimitiveType::S4);
CASE(DType::kS8, xla::PrimitiveType::S8);
CASE(DType::kS16, xla::PrimitiveType::S16);
CASE(DType::kS32, xla::PrimitiveType::S32);
CASE(DType::kS64, xla::PrimitiveType::S64);
CASE(DType::kU2, xla::PrimitiveType::U2);
CASE(DType::kU4, xla::PrimitiveType::U4);
CASE(DType::kU8, xla::PrimitiveType::U8);
CASE(DType::kU16, xla::PrimitiveType::U16);
Expand Down Expand Up @@ -171,11 +173,13 @@ absl::StatusOr<DType> ToDType(xla::PrimitiveType primitive_type) {
switch (primitive_type) {
case xla::PrimitiveType::PRIMITIVE_TYPE_INVALID:
case xla::PrimitiveType::PRED:
case xla::PrimitiveType::S2:
case xla::PrimitiveType::S4:
case xla::PrimitiveType::S8:
case xla::PrimitiveType::S16:
case xla::PrimitiveType::S32:
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U2:
case xla::PrimitiveType::U4:
case xla::PrimitiveType::U8:
case xla::PrimitiveType::U16:
Expand Down
14 changes: 13 additions & 1 deletion xla/python/py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,15 @@ absl::StatusOr<DevicePutResultFn> HandleNumpyScalar(
std::variant<T, SquashedT, void*> data;
PrimitiveType type;
// For extension types, ScalarAsCtype returns a pointer to the data.
if (std::is_same<T, xla::s4>()) {
if (std::is_same<T, xla::s2>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = S2;
} else if (std::is_same<T, xla::s4>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = S4;
} else if (std::is_same<T, xla::u2>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = U2;
} else if (std::is_same<T, xla::u4>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = U4;
Expand Down Expand Up @@ -370,10 +376,16 @@ absl::StatusOr<DevicePutResultFn> DevicePut(nb::handle arg,
// Python types (np_int64, np_float64, np_complex128).
(*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar<bool>;
(*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar<xla::s4>;
if (dtypes.np_int2.has_value()) {
(*p)[dtypes.np_int2->ptr()] = HandleNumpyScalar<xla::s2>;
}
(*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar<int8_t>;
(*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar<int16_t>;
(*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar<int32_t>;
(*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
if (dtypes.np_uint2.has_value()) {
(*p)[dtypes.np_uint2->ptr()] = HandleNumpyScalar<xla::u2>;
}
(*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar<xla::u4>;
(*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar<uint8_t>;
(*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar<uint16_t>;
Expand Down
49 changes: 45 additions & 4 deletions xla/python/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ struct CustomDtypes {
nb_dtype float8_e4m3fnuz;
nb_dtype float8_e5m2;
nb_dtype float8_e5m2fnuz;
std::optional<nb_dtype> int2;
nb_dtype int4;
std::optional<nb_dtype> uint2;
nb_dtype uint4;
};

Expand All @@ -84,6 +86,12 @@ const CustomDtypes& GetCustomDtypes() {
nb_dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz"));
dtypes->int4 = nb_dtype::from_args(ml_dtypes.attr("int4"));
dtypes->uint4 = nb_dtype::from_args(ml_dtypes.attr("uint4"));
if (nb::hasattr(ml_dtypes, "int2")) {
dtypes->int2 = nb_dtype::from_args(ml_dtypes.attr("int2"));
}
if (nb::hasattr(ml_dtypes, "uint2")) {
dtypes->uint2 = nb_dtype::from_args(ml_dtypes.attr("uint2"));
}
return dtypes;
}();
return custom_dtypes;
Expand Down Expand Up @@ -137,7 +145,13 @@ absl::StatusOr<PrimitiveType> DtypeToPrimitiveType(const nb_dtype& np_type) {
map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ);
map->emplace(custom_dtypes.float8_e5m2, F8E5M2);
map->emplace(custom_dtypes.float8_e5m2fnuz, F8E5M2FNUZ);
if (custom_dtypes.int2.has_value()) {
map->emplace(*custom_dtypes.int2, S2);
}
map->emplace(custom_dtypes.int4, S4);
if (custom_dtypes.uint2.has_value()) {
map->emplace(*custom_dtypes.uint2, U2);
}
map->emplace(custom_dtypes.uint4, U4);
return map;
}();
Expand All @@ -160,6 +174,11 @@ absl::StatusOr<nb_dtype> PrimitiveTypeToNbDtype(PrimitiveType type) {
switch (type) {
case PRED:
return to_nb_dtype(NPY_BOOL);
case S2:
if (custom_dtypes.int2.has_value()) {
return *custom_dtypes.int2;
}
break;
case S4:
return custom_dtypes.int4;
case S8:
Expand All @@ -170,6 +189,11 @@ absl::StatusOr<nb_dtype> PrimitiveTypeToNbDtype(PrimitiveType type) {
return to_nb_dtype(NPY_INT32);
case S64:
return to_nb_dtype(NPY_INT64);
case U2:
if (custom_dtypes.uint2.has_value()) {
return *custom_dtypes.uint2;
}
break;
case U4:
return custom_dtypes.uint4;
case U8:
Expand Down Expand Up @@ -203,9 +227,10 @@ absl::StatusOr<nb_dtype> PrimitiveTypeToNbDtype(PrimitiveType type) {
case C128:
return to_nb_dtype(NPY_COMPLEX128);
default:
return Unimplemented("Unimplemented primitive type %s",
PrimitiveType_Name(type));
break;
}
return Unimplemented("Unimplemented primitive type %s",
PrimitiveType_Name(type));
}

absl::StatusOr<nb_dtype> IfrtDtypeToNbDtype(ifrt::DType dtype) {
Expand All @@ -217,6 +242,11 @@ absl::StatusOr<nb_dtype> IfrtDtypeToNbDtype(ifrt::DType dtype) {
switch (dtype.kind()) {
case ifrt::DType::kPred:
return to_nb_dtype(NPY_BOOL);
case ifrt::DType::kS2:
if (custom_dtypes.int2.has_value()) {
return *custom_dtypes.int2;
}
break;
case ifrt::DType::kS4:
return custom_dtypes.int4;
case ifrt::DType::kS8:
Expand All @@ -227,6 +257,11 @@ absl::StatusOr<nb_dtype> IfrtDtypeToNbDtype(ifrt::DType dtype) {
return to_nb_dtype(NPY_INT32);
case ifrt::DType::kS64:
return to_nb_dtype(NPY_INT64);
case ifrt::DType::kU2:
if (custom_dtypes.uint2.has_value()) {
return *custom_dtypes.uint2;
}
break;
case ifrt::DType::kU4:
return custom_dtypes.uint4;
case ifrt::DType::kU8:
Expand Down Expand Up @@ -268,9 +303,9 @@ absl::StatusOr<nb_dtype> IfrtDtypeToNbDtype(ifrt::DType dtype) {
// logic (see `TF_DataType_to_PyArray_TYPE`).
return to_nb_dtype(NPY_OBJECT);
default:
return Unimplemented("Unimplemented primitive type %s",
dtype.DebugString());
break;
}
return Unimplemented("Unimplemented primitive type %s", dtype.DebugString());
}

absl::StatusOr<ifrt::DType> DtypeToIfRtDType(nb_dtype dtype) {
Expand All @@ -295,11 +330,17 @@ const NumpyScalarTypes& GetNumpyScalarTypes() {
nb::module_ numpy = nb::module_::import_("numpy");
nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes");
dtypes->np_bool = nb::object(numpy.attr("bool_"));
if (nb::hasattr(ml_dtypes, "int2")) {
dtypes->np_int2 = nb::object(ml_dtypes.attr("int2"));
}
dtypes->np_int4 = nb::object(ml_dtypes.attr("int4"));
dtypes->np_int8 = nb::object(numpy.attr("int8"));
dtypes->np_int16 = nb::object(numpy.attr("int16"));
dtypes->np_int32 = nb::object(numpy.attr("int32"));
dtypes->np_int64 = nb::object(numpy.attr("int64"));
if (nb::hasattr(ml_dtypes, "uint2")) {
dtypes->np_uint2 = nb::object(ml_dtypes.attr("uint2"));
}
dtypes->np_uint4 = nb::object(ml_dtypes.attr("uint4"));
dtypes->np_uint8 = nb::object(numpy.attr("uint8"));
dtypes->np_uint16 = nb::object(numpy.attr("uint16"));
Expand Down
4 changes: 4 additions & 0 deletions xla/python/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,15 @@ absl::StatusOr<nanobind::str> TypeDescriptorForPrimitiveType(

struct NumpyScalarTypes {
nanobind::object np_bool;
// Remove std::optional once the minimum ml_dtypes in JAX is >= 0.4.1.
std::optional<nanobind::object> np_int2;
nanobind::object np_int4;
nanobind::object np_int8;
nanobind::object np_int16;
nanobind::object np_int32;
nanobind::object np_int64;
// Remove std::optional once the minimum ml_dtypes in JAX is >= 0.4.1.
std::optional<nanobind::object> np_uint2;
nanobind::object np_uint4;
nanobind::object np_uint8;
nanobind::object np_uint16;
Expand Down
2 changes: 2 additions & 0 deletions xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ class XlaRuntimeError(RuntimeError):
class PrimitiveType(enum.IntEnum):
PRIMITIVE_TYPE_INVALID: PrimitiveType
PRED: PrimitiveType
S2: PrimitiveType
S4: PrimitiveType
S8: PrimitiveType
S16: PrimitiveType
S32: PrimitiveType
S64: PrimitiveType
U2: PrimitiveType
U4: PrimitiveType
U8: PrimitiveType
U16: PrimitiveType
Expand Down
3 changes: 3 additions & 0 deletions xla/service/llvm_ir/llvm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type,
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
llvm::Module* module) {
switch (element_type) {
case S2:
case U2:
return llvm::Type::getIntNTy(module->getContext(), 2);
case S4:
case U4:
return llvm::Type::getIntNTy(module->getContext(), 4);
Expand Down

0 comments on commit ff042fb

Please sign in to comment.