From efe7187060c9d6d1105f92c985ed220408c12fd3 Mon Sep 17 00:00:00 2001 From: Icxolu <10486322+Icxolu@users.noreply.github.com> Date: Tue, 5 Mar 2024 19:50:44 +0100 Subject: [PATCH] convert `PyArrayDescr` to `Bound` API --- src/array.rs | 20 +- src/array_like.rs | 2 +- src/datetime.rs | 12 +- src/dtype.rs | 456 ++++++++++++++++++++++++++++++++++--------- src/lib.rs | 6 +- src/strings.rs | 14 +- src/untyped_array.rs | 8 +- tests/array.rs | 18 +- 8 files changed, 411 insertions(+), 125 deletions(-) diff --git a/src/array.rs b/src/array.rs index e000c9201..c4a1c9539 100644 --- a/src/array.rs +++ b/src/array.rs @@ -26,7 +26,7 @@ use pyo3::{ use crate::borrow::{PyReadonlyArray, PyReadwriteArray}; use crate::cold; use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; -use crate::dtype::Element; +use crate::dtype::{Element, PyArrayDescrMethods}; use crate::error::{ BorrowError, DimensionalityError, FromVecError, IgnoreError, NotContiguousError, TypeError, DIMENSIONALITY_MISMATCH_ERR, MAX_DIMENSIONALITY_ERR, @@ -278,10 +278,10 @@ impl PyArray { } // Check if the element type matches `T`. - let src_dtype = arr_gil_ref.dtype(); - let dst_dtype = T::get_dtype(ob.py()); - if !src_dtype.is_equiv_to(dst_dtype) { - return Err(TypeError::new(src_dtype, dst_dtype).into()); + let src_dtype = array.dtype(); + let dst_dtype = T::get_dtype_bound(ob.py()); + if !src_dtype.is_equiv_to(&dst_dtype) { + return Err(TypeError::new(src_dtype.into_gil_ref(), dst_dtype.into_gil_ref()).into()); } Ok(array) @@ -354,7 +354,7 @@ impl PyArray { let ptr = PY_ARRAY_API.PyArray_NewFromDescr( py, PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type), - T::get_dtype(py).into_dtype_ptr(), + T::get_dtype_bound(py).into_dtype_ptr(), dims.ndim_cint(), dims.as_dims_ptr(), strides as *mut npy_intp, // strides @@ -380,7 +380,7 @@ impl PyArray { let ptr = PY_ARRAY_API.PyArray_NewFromDescr( py, PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type), - T::get_dtype(py).into_dtype_ptr(), + T::get_dtype_bound(py).into_dtype_ptr(), dims.ndim_cint(), dims.as_dims_ptr(), strides as *mut npy_intp, // strides @@ -500,7 +500,7 @@ impl PyArray { py, dims.ndim_cint(), dims.as_dims_ptr(), - T::get_dtype(py).into_dtype_ptr(), + T::get_dtype_bound(py).into_dtype_ptr(), if is_fortran { -1 } else { 0 }, ); Self::from_owned_ptr(py, ptr) @@ -1315,7 +1315,7 @@ impl PyArray { PY_ARRAY_API.PyArray_CastToType( self.py(), self.as_array_ptr(), - U::get_dtype(self.py()).into_dtype_ptr(), + U::get_dtype_bound(self.py()).into_dtype_ptr(), if is_fortran { -1 } else { 0 }, ) }; @@ -1461,7 +1461,7 @@ impl> PyArray { start.as_(), stop.as_(), step.as_(), - T::get_dtype(py).num(), + T::get_dtype_bound(py).num(), ); Self::from_owned_ptr(py, ptr) } diff --git a/src/array_like.rs b/src/array_like.rs index a06122af3..ef070ceb1 100644 --- a/src/array_like.rs +++ b/src/array_like.rs @@ -164,7 +164,7 @@ where let kwargs = if C::VAL { let kwargs = PyDict::new(py); - kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?; + kwargs.set_item(intern!(py, "dtype"), T::get_dtype_bound(py))?; Some(kwargs) } else { None diff --git a/src/datetime.rs b/src/datetime.rs index 3800e2a32..adcd77cf5 100644 --- a/src/datetime.rs +++ b/src/datetime.rs @@ -63,10 +63,10 @@ use std::fmt; use std::hash::Hash; use std::marker::PhantomData; -use pyo3::{sync::GILProtected, Py, Python}; +use pyo3::{sync::GILProtected, Bound, Py, Python}; use rustc_hash::FxHashMap; -use crate::dtype::{Element, PyArrayDescr}; +use crate::dtype::{Element, PyArrayDescr, PyArrayDescrMethods}; use crate::npyffi::{PyArray_DatetimeDTypeMetaData, NPY_DATETIMEUNIT, NPY_TYPES}; /// Represents the [datetime units][datetime-units] supported by NumPy @@ -156,7 +156,7 @@ impl From> for i64 { unsafe impl Element for Datetime { const IS_COPY: bool = true; - fn get_dtype<'py>(py: Python<'py>) -> &'py PyArrayDescr { + fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> { static DTYPES: TypeDescriptors = unsafe { TypeDescriptors::new(NPY_TYPES::NPY_DATETIME) }; DTYPES.from_unit(py, U::UNIT) @@ -191,7 +191,7 @@ impl From> for i64 { unsafe impl Element for Timedelta { const IS_COPY: bool = true; - fn get_dtype<'py>(py: Python<'py>) -> &'py PyArrayDescr { + fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> { static DTYPES: TypeDescriptors = unsafe { TypeDescriptors::new(NPY_TYPES::NPY_TIMEDELTA) }; DTYPES.from_unit(py, U::UNIT) @@ -220,7 +220,7 @@ impl TypeDescriptors { } #[allow(clippy::wrong_self_convention)] - fn from_unit<'py>(&'py self, py: Python<'py>, unit: NPY_DATETIMEUNIT) -> &'py PyArrayDescr { + fn from_unit<'py>(&self, py: Python<'py>, unit: NPY_DATETIMEUNIT) -> Bound<'py, PyArrayDescr> { let mut dtypes = self.dtypes.get(py).borrow_mut(); let dtype = match dtypes.get_or_insert_with(Default::default).entry(unit) { @@ -241,7 +241,7 @@ impl TypeDescriptors { } }; - dtype.clone().into_ref(py) + dtype.clone().into_bound(py) } } diff --git a/src/dtype.rs b/src/dtype.rs index 8ba095af8..a1576ec80 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -12,10 +12,11 @@ use pyo3::{ ffi::{self, PyTuple_Size}, pyobject_native_type_extract, pyobject_native_type_named, types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType}, - AsPyPointer, Borrowed, PyAny, PyNativeType, PyObject, PyResult, PyTypeInfo, Python, ToPyObject, + AsPyPointer, Borrowed, Bound, PyAny, PyNativeType, PyObject, PyResult, PyTypeInfo, Python, + ToPyObject, }; #[cfg(feature = "half")] -use pyo3::{sync::GILOnceCell, IntoPy, Py}; +use pyo3::{sync::GILOnceCell, Py}; use crate::npyffi::{ NpyTypes, PyArray_Descr, NPY_ALIGNED_STRUCT, NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, @@ -68,8 +69,17 @@ unsafe impl PyTypeInfo for PyArrayDescr { pyobject_native_type_extract!(PyArrayDescr); /// Returns the type descriptor ("dtype") for a registered type. +#[deprecated( + since = "0.21.0", + note = "This will be replaced by `dtype_bound` in the future." +)] pub fn dtype<'py, T: Element>(py: Python<'py>) -> &'py PyArrayDescr { - T::get_dtype(py) + T::get_dtype_bound(py).into_gil_ref() +} + +/// Returns the type descriptor ("dtype") for a registered type. +pub fn dtype_bound<'py, T: Element>(py: Python<'py>) -> Bound<'py, PyArrayDescr> { + T::get_dtype_bound(py) } impl PyArrayDescr { @@ -78,14 +88,30 @@ impl PyArrayDescr { /// Equivalent to invoking the constructor of [`numpy.dtype`][dtype]. /// /// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html - #[inline] + #[deprecated( + since = "0.21.0", + note = "This will be replace by `new_bound` in the future." + )] pub fn new<'py, T: ToPyObject + ?Sized>(py: Python<'py>, ob: &T) -> PyResult<&'py Self> { - fn inner<'py>(py: Python<'py>, obj: PyObject) -> PyResult<&'py PyArrayDescr> { + Self::new_bound(py, ob).map(Bound::into_gil_ref) + } + /// Creates a new type descriptor ("dtype") object from an arbitrary object. + /// + /// Equivalent to invoking the constructor of [`numpy.dtype`][dtype]. + /// + /// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html + #[inline] + pub fn new_bound<'py, T: ToPyObject + ?Sized>( + py: Python<'py>, + ob: &T, + ) -> PyResult> { + fn inner(py: Python<'_>, obj: PyObject) -> PyResult> { let mut descr: *mut PyArray_Descr = ptr::null_mut(); unsafe { // None is an invalid input here and is not converted to NPY_DEFAULT_TYPE - PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr as *mut _); - py.from_owned_ptr_or_err(descr as _) + PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr); + Bound::from_owned_ptr_or_err(py, descr.cast()) + .map(|any| any.downcast_into_unchecked()) } } @@ -94,48 +120,60 @@ impl PyArrayDescr { /// Returns `self` as `*mut PyArray_Descr`. pub fn as_dtype_ptr(&self) -> *mut PyArray_Descr { - self.as_ptr() as _ + self.as_borrowed().as_dtype_ptr() } /// Returns `self` as `*mut PyArray_Descr` while increasing the reference count. /// /// Useful in cases where the descriptor is stolen by the API. pub fn into_dtype_ptr(&self) -> *mut PyArray_Descr { - self.into_ptr() as _ + (*self.as_borrowed()).clone().into_dtype_ptr() } /// Shortcut for creating a type descriptor of `object` type. + #[deprecated( + since = "0.21.0", + note = "This will be replaced by `object_bound` in the future." + )] pub fn object<'py>(py: Python<'py>) -> &'py Self { + Self::object_bound(py).into_gil_ref() + } + + /// Shortcut for creating a type descriptor of `object` type. + pub fn object_bound(py: Python<'_>) -> Bound<'_, Self> { Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT) } /// Returns the type descriptor for a registered type. + #[deprecated( + since = "0.21.0", + note = "This will be replaced by `of_bound` in the future." + )] pub fn of<'py, T: Element>(py: Python<'py>) -> &'py Self { - T::get_dtype(py) + Self::of_bound::(py).into_gil_ref() + } + + /// Returns the type descriptor for a registered type. + pub fn of_bound<'py, T: Element>(py: Python<'py>) -> Bound<'py, Self> { + T::get_dtype_bound(py) } /// Returns true if two type descriptors are equivalent. pub fn is_equiv_to(&self, other: &Self) -> bool { - let self_ptr = self.as_dtype_ptr(); - let other_ptr = other.as_dtype_ptr(); - - unsafe { - self_ptr == other_ptr - || PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0 - } + self.as_borrowed().is_equiv_to(&other.as_borrowed()) } - fn from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> &'py Self { + fn from_npy_type(py: Python<'_>, npy_type: NPY_TYPES) -> Bound<'_, Self> { unsafe { let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _); - py.from_owned_ptr(descr as _) + Bound::from_owned_ptr(py, descr.cast()).downcast_into_unchecked() } } - pub(crate) fn new_from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> &'py Self { + pub(crate) fn new_from_npy_type(py: Python<'_>, npy_type: NPY_TYPES) -> Bound<'_, Self> { unsafe { let descr = PY_ARRAY_API.PyArray_DescrNewFromType(py, npy_type as _); - py.from_owned_ptr(descr as _) + Bound::from_owned_ptr(py, descr.cast()).downcast_into_unchecked() } } @@ -146,8 +184,7 @@ impl PyArrayDescr { /// [arrays-scalars]: https://numpy.org/doc/stable/reference/arrays.scalars.html /// [dtype-type]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.type.html pub fn typeobj(&self) -> &PyType { - let dtype_type_ptr = unsafe { *self.as_dtype_ptr() }.typeobj; - unsafe { PyType::from_type_ptr(self.py(), dtype_type_ptr) } + self.as_borrowed().typeobj().into_gil_ref() } /// Returns a unique number for each of the 21 different built-in @@ -160,7 +197,7 @@ impl PyArrayDescr { /// [enumerated-types]: https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types /// [dtype-num]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.num.html pub fn num(&self) -> c_int { - unsafe { *self.as_dtype_ptr() }.type_num + self.as_borrowed().num() } /// Returns the element size of this type descriptor. @@ -169,7 +206,7 @@ impl PyArrayDescr { /// /// [dtype-itemsiize]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.itemsize.html pub fn itemsize(&self) -> usize { - unsafe { *self.as_dtype_ptr() }.elsize.max(0) as _ + self.as_borrowed().itemsize() } /// Returns the required alignment (bytes) of this type descriptor according to the compiler. @@ -178,7 +215,7 @@ impl PyArrayDescr { /// /// [dtype-alignment]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.alignment.html pub fn alignment(&self) -> usize { - unsafe { *self.as_dtype_ptr() }.alignment.max(0) as _ + self.as_borrowed().alignment() } /// Returns an ASCII character indicating the byte-order of this type descriptor object. @@ -189,7 +226,7 @@ impl PyArrayDescr { /// /// [dtype-byteorder]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html pub fn byteorder(&self) -> u8 { - unsafe { *self.as_dtype_ptr() }.byteorder.max(0) as _ + self.as_borrowed().byteorder() } /// Returns a unique ASCII character for each of the 21 different built-in types. @@ -200,7 +237,7 @@ impl PyArrayDescr { /// /// [dtype-char]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.char.html pub fn char(&self) -> u8 { - unsafe { *self.as_dtype_ptr() }.type_.max(0) as _ + self.as_borrowed().char() } /// Returns an ASCII character (one of `biufcmMOSUV`) identifying the general kind of data. @@ -211,7 +248,7 @@ impl PyArrayDescr { /// /// [dtype-kind]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html pub fn kind(&self) -> u8 { - unsafe { *self.as_dtype_ptr() }.kind.max(0) as _ + self.as_borrowed().kind() } /// Returns bit-flags describing how this type descriptor is to be interpreted. @@ -220,7 +257,7 @@ impl PyArrayDescr { /// /// [dtype-flags]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.flags.html pub fn flags(&self) -> c_char { - unsafe { *self.as_dtype_ptr() }.flags + self.as_borrowed().flags() } /// Returns the number of dimensions if this type descriptor represents a sub-array, and zero otherwise. @@ -229,10 +266,7 @@ impl PyArrayDescr { /// /// [dtype-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.ndim.html pub fn ndim(&self) -> usize { - if !self.has_subarray() { - return 0; - } - unsafe { PyTuple_Size((*((*self.as_dtype_ptr()).subarray)).shape).max(0) as _ } + self.as_borrowed().ndim() } /// Returns the type descriptor for the base element of subarrays, regardless of their dimension or shape. @@ -243,15 +277,7 @@ impl PyArrayDescr { /// /// [dtype-base]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.base.html pub fn base(&self) -> &PyArrayDescr { - if !self.has_subarray() { - self - } else { - #[allow(deprecated)] - unsafe { - use pyo3::FromPyPointer; - Self::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).base as _) - } - } + self.as_borrowed().base().into_gil_ref() } /// Returns the shape of the sub-array. @@ -262,22 +288,210 @@ impl PyArrayDescr { /// /// [dtype-shape]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.shape.html pub fn shape(&self) -> Vec { + self.as_borrowed().shape() + } + + /// Returns true if the type descriptor contains any reference-counted objects in any fields or sub-dtypes. + /// + /// Equivalent to [`numpy.dtype.hasobject`][dtype-hasobject]. + /// + /// [dtype-hasobject]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.hasobject.html + pub fn has_object(&self) -> bool { + self.as_borrowed().has_object() + } + + /// Returns true if the type descriptor is a struct which maintains field alignment. + /// + /// This flag is sticky, so when combining multiple structs together, it is preserved + /// and produces new dtypes which are also aligned. + /// + /// Equivalent to [`numpy.dtype.isalignedstruct`][dtype-isalignedstruct]. + /// + /// [dtype-isalignedstruct]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.isalignedstruct.html + pub fn is_aligned_struct(&self) -> bool { + self.as_borrowed().is_aligned_struct() + } + + /// Returns true if the type descriptor is a sub-array. + pub fn has_subarray(&self) -> bool { + self.as_borrowed().has_subarray() + } + + /// Returns true if the type descriptor is a structured type. + pub fn has_fields(&self) -> bool { + self.as_borrowed().has_fields() + } + + /// Returns true if type descriptor byteorder is native, or `None` if not applicable. + pub fn is_native_byteorder(&self) -> Option { + self.as_borrowed().is_native_byteorder() + } + + /// Returns an ordered list of field names, or `None` if there are no fields. + /// + /// The names are ordered according to increasing byte offset. + /// + /// Equivalent to [`numpy.dtype.names`][dtype-names]. + /// + /// [dtype-names]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.names.html + pub fn names(&self) -> Option> { + if !self.has_fields() { + return None; + } + let names = unsafe { Borrowed::from_ptr(self.py(), (*self.as_dtype_ptr()).names) }; + names.extract().ok() + } + + /// Returns the type descriptor and offset of the field with the given name. + /// + /// This method will return an error if this type descriptor is not structured, + /// or if it does not contain a field with a given name. + /// + /// The list of all names can be found via [`PyArrayDescr::names`]. + /// + /// Equivalent to retrieving a single item from [`numpy.dtype.fields`][dtype-fields]. + /// + /// [dtype-fields]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html + pub fn get_field(&self, name: &str) -> PyResult<(&PyArrayDescr, usize)> { + self.as_borrowed() + .get_field(name) + .map(|(descr, n)| (descr.into_gil_ref(), n)) + } +} + +/// Implementation of functionality for [`PyArrayDescr`]. +#[doc(alias = "PyArrayDescr")] +// TODO: seal this trait +pub trait PyArrayDescrMethods<'py> { + /// Returns `self` as `*mut PyArray_Descr`. + fn as_dtype_ptr(&self) -> *mut PyArray_Descr; + + /// Returns `self` as `*mut PyArray_Descr` while increasing the reference count. + /// + /// Useful in cases where the descriptor is stolen by the API. + fn into_dtype_ptr(self) -> *mut PyArray_Descr; + + /// Returns true if two type descriptors are equivalent. + fn is_equiv_to(&self, other: &Self) -> bool; + + /// Returns the [array scalar][arrays-scalars] corresponding to this type descriptor. + /// + /// Equivalent to [`numpy.dtype.type`][dtype-type]. + /// + /// [arrays-scalars]: https://numpy.org/doc/stable/reference/arrays.scalars.html + /// [dtype-type]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.type.html + fn typeobj(&self) -> Bound<'py, PyType>; + + /// Returns a unique number for each of the 21 different built-in + /// [enumerated types][enumerated-types]. + /// + /// These are roughly ordered from least-to-most precision. + /// + /// Equivalent to [`numpy.dtype.num`][dtype-num]. + /// + /// [enumerated-types]: https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types + /// [dtype-num]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.num.html + fn num(&self) -> c_int { + unsafe { *self.as_dtype_ptr() }.type_num + } + + /// Returns the element size of this type descriptor. + /// + /// Equivalent to [`numpy.dtype.itemsize`][dtype-itemsize]. + /// + /// [dtype-itemsiize]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.itemsize.html + + fn itemsize(&self) -> usize { + unsafe { *self.as_dtype_ptr() }.elsize.max(0) as _ + } + + /// Returns the required alignment (bytes) of this type descriptor according to the compiler. + /// + /// Equivalent to [`numpy.dtype.alignment`][dtype-alignment]. + /// + /// [dtype-alignment]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.alignment.html + fn alignment(&self) -> usize { + unsafe { *self.as_dtype_ptr() }.alignment.max(0) as _ + } + + /// Returns an ASCII character indicating the byte-order of this type descriptor object. + /// + /// All built-in data-type objects have byteorder either `=` or `|`. + /// + /// Equivalent to [`numpy.dtype.byteorder`][dtype-byteorder]. + /// + /// [dtype-byteorder]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html + fn byteorder(&self) -> u8 { + unsafe { *self.as_dtype_ptr() }.byteorder.max(0) as _ + } + + /// Returns a unique ASCII character for each of the 21 different built-in types. + /// + /// Note that structured data types are categorized as `V` (void). + /// + /// Equivalent to [`numpy.dtype.char`][dtype-char]. + /// + /// [dtype-char]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.char.html + fn char(&self) -> u8 { + unsafe { *self.as_dtype_ptr() }.type_.max(0) as _ + } + + /// Returns an ASCII character (one of `biufcmMOSUV`) identifying the general kind of data. + /// + /// Note that structured data types are categorized as `V` (void). + /// + /// Equivalent to [`numpy.dtype.kind`][dtype-kind]. + /// + /// [dtype-kind]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html + fn kind(&self) -> u8 { + unsafe { *self.as_dtype_ptr() }.kind.max(0) as _ + } + + /// Returns bit-flags describing how this type descriptor is to be interpreted. + /// + /// Equivalent to [`numpy.dtype.flags`][dtype-flags]. + /// + /// [dtype-flags]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.flags.html + fn flags(&self) -> c_char { + unsafe { *self.as_dtype_ptr() }.flags + } + + /// Returns the number of dimensions if this type descriptor represents a sub-array, and zero otherwise. + /// + /// Equivalent to [`numpy.dtype.ndim`][dtype-ndim]. + /// + /// [dtype-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.ndim.html + fn ndim(&self) -> usize { if !self.has_subarray() { - Vec::new() - } else { - // NumPy guarantees that shape is a tuple of non-negative integers so this should never panic. - unsafe { Borrowed::from_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape) } - .extract() - .unwrap() + return 0; } + unsafe { PyTuple_Size((*((*self.as_dtype_ptr()).subarray)).shape).max(0) as _ } } + /// Returns the type descriptor for the base element of subarrays, regardless of their dimension or shape. + /// + /// If the dtype is not a subarray, returns self. + /// + /// Equivalent to [`numpy.dtype.base`][dtype-base]. + /// + /// [dtype-base]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.base.html + fn base(&self) -> Bound<'py, PyArrayDescr>; + + /// Returns the shape of the sub-array. + /// + /// If the dtype is not a sub-array, an empty vector is returned. + /// + /// Equivalent to [`numpy.dtype.shape`][dtype-shape]. + /// + /// [dtype-shape]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.shape.html + fn shape(&self) -> Vec; + /// Returns true if the type descriptor contains any reference-counted objects in any fields or sub-dtypes. /// /// Equivalent to [`numpy.dtype.hasobject`][dtype-hasobject]. /// /// [dtype-hasobject]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.hasobject.html - pub fn has_object(&self) -> bool { + fn has_object(&self) -> bool { self.flags() & NPY_ITEM_HASOBJECT != 0 } @@ -289,24 +503,24 @@ impl PyArrayDescr { /// Equivalent to [`numpy.dtype.isalignedstruct`][dtype-isalignedstruct]. /// /// [dtype-isalignedstruct]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.isalignedstruct.html - pub fn is_aligned_struct(&self) -> bool { + fn is_aligned_struct(&self) -> bool { self.flags() & NPY_ALIGNED_STRUCT != 0 } /// Returns true if the type descriptor is a sub-array. - pub fn has_subarray(&self) -> bool { + fn has_subarray(&self) -> bool { // equivalent to PyDataType_HASSUBARRAY(self) unsafe { !(*self.as_dtype_ptr()).subarray.is_null() } } /// Returns true if the type descriptor is a structured type. - pub fn has_fields(&self) -> bool { + fn has_fields(&self) -> bool { // equivalent to PyDataType_HASFIELDS(self) unsafe { !(*self.as_dtype_ptr()).names.is_null() } } /// Returns true if type descriptor byteorder is native, or `None` if not applicable. - pub fn is_native_byteorder(&self) -> Option { + fn is_native_byteorder(&self) -> Option { // based on PyArray_ISNBO(self->byteorder) match self.byteorder() { b'=' => Some(true), @@ -322,13 +536,7 @@ impl PyArrayDescr { /// Equivalent to [`numpy.dtype.names`][dtype-names]. /// /// [dtype-names]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.names.html - pub fn names(&self) -> Option> { - if !self.has_fields() { - return None; - } - let names = unsafe { Borrowed::from_ptr(self.py(), (*self.as_dtype_ptr()).names) }; - names.extract().ok() - } + fn names(&self) -> Option>; /// Returns the type descriptor and offset of the field with the given name. /// @@ -340,7 +548,64 @@ impl PyArrayDescr { /// Equivalent to retrieving a single item from [`numpy.dtype.fields`][dtype-fields]. /// /// [dtype-fields]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html - pub fn get_field(&self, name: &str) -> PyResult<(&PyArrayDescr, usize)> { + fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)>; +} + +impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> { + fn as_dtype_ptr(&self) -> *mut PyArray_Descr { + self.as_ptr() as _ + } + + fn into_dtype_ptr(self) -> *mut PyArray_Descr { + self.into_ptr() as _ + } + + fn is_equiv_to(&self, other: &Self) -> bool { + let self_ptr = self.as_dtype_ptr(); + let other_ptr = other.as_dtype_ptr(); + + unsafe { + self_ptr == other_ptr + || PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0 + } + } + + fn typeobj(&self) -> Bound<'py, PyType> { + let dtype_type_ptr = unsafe { *self.as_dtype_ptr() }.typeobj; + unsafe { PyType::from_borrowed_type_ptr(self.py(), dtype_type_ptr) } + } + + fn base(&self) -> Bound<'py, PyArrayDescr> { + if !self.has_subarray() { + self.clone() + } else { + unsafe { + Bound::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).base.cast()) + .downcast_into_unchecked() + } + } + } + + fn shape(&self) -> Vec { + if !self.has_subarray() { + Vec::new() + } else { + // NumPy guarantees that shape is a tuple of non-negative integers so this should never panic. + unsafe { Borrowed::from_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape) } + .extract() + .unwrap() + } + } + + fn names(&self) -> Option> { + if !self.has_fields() { + return None; + } + let names = unsafe { Borrowed::from_ptr(self.py(), (*self.as_dtype_ptr()).names) }; + names.extract().ok() + } + + fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)> { if !self.has_fields() { return Err(PyValueError::new_err( "cannot get field information: type descriptor has no fields", @@ -361,7 +626,7 @@ impl PyArrayDescr { .downcast_into::() .unwrap(); let offset = tuple.get_item(1).unwrap().extract().unwrap(); - Ok((dtype.into_gil_ref(), offset)) + Ok((dtype, offset)) } } @@ -413,7 +678,16 @@ pub unsafe trait Element: Clone + Send { const IS_COPY: bool; /// Returns the associated type descriptor ("dtype") for the given element type. - fn get_dtype<'py>(py: Python<'py>) -> &'py PyArrayDescr; + #[deprecated( + since = "0.21.0", + note = "This will be replaced by `get_dtype_bound` in the future." + )] + fn get_dtype<'py>(py: Python<'py>) -> &'py PyArrayDescr { + Self::get_dtype_bound(py).into_gil_ref() + } + + /// Returns the associated type descriptor ("dtype") for the given element type. + fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr>; } fn npy_int_type_lookup(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES { @@ -467,7 +741,7 @@ macro_rules! impl_element_scalar { unsafe impl Element for $ty { const IS_COPY: bool = true; - fn get_dtype<'py>(py: Python<'py>) -> &'py PyArrayDescr { + fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> { PyArrayDescr::from_npy_type(py, $npy_type) } } @@ -495,15 +769,15 @@ impl_element_scalar!(f16 => NPY_HALF); unsafe impl Element for bf16 { const IS_COPY: bool = true; - fn get_dtype<'py>(py: Python<'py>) -> &PyArrayDescr { + fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> { static DTYPE: GILOnceCell> = GILOnceCell::new(); DTYPE .get_or_init(py, || { - PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").into_py(py) + PyArrayDescr::new_bound(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").unbind() }) .clone() - .into_ref(py) + .into_bound(py) } } @@ -518,8 +792,8 @@ impl_element_scalar!(usize, isize); unsafe impl Element for PyObject { const IS_COPY: bool = false; - fn get_dtype<'py>(py: Python<'py>) -> &PyArrayDescr { - PyArrayDescr::object(py) + fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> { + PyArrayDescr::object_bound(py) } } @@ -527,31 +801,35 @@ unsafe impl Element for PyObject { mod tests { use super::*; - use pyo3::py_run; + use pyo3::{py_run, types::PyTypeMethods}; use crate::npyffi::NPY_NEEDS_PYAPI; #[test] fn test_dtype_new() { Python::with_gil(|py| { - assert!(PyArrayDescr::new(py, "float64") + assert!(PyArrayDescr::new_bound(py, "float64") .unwrap() - .is(dtype::(py))); + .is(&dtype_bound::(py))); - let dt = PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).unwrap(); + let dt = PyArrayDescr::new_bound(py, [("a", "O"), ("b", "?")].as_ref()).unwrap(); assert_eq!(dt.names(), Some(vec!["a", "b"])); assert!(dt.has_object()); - assert!(dt.get_field("a").unwrap().0.is(dtype::(py))); - assert!(dt.get_field("b").unwrap().0.is(dtype::(py))); + assert!(dt + .get_field("a") + .unwrap() + .0 + .is(&dtype_bound::(py))); + assert!(dt.get_field("b").unwrap().0.is(&dtype_bound::(py))); - assert!(PyArrayDescr::new(py, &123_usize).is_err()); + assert!(PyArrayDescr::new_bound(py, &123_usize).is_err()); }); } #[test] fn test_dtype_names() { fn type_name<'py, T: Element>(py: Python<'py>) -> String { - dtype::(py).typeobj().qualname().unwrap() + dtype_bound::(py).typeobj().qualname().unwrap() } Python::with_gil(|py| { assert_eq!(type_name::(py), "bool_"); @@ -587,7 +865,7 @@ mod tests { #[test] fn test_dtype_methods_scalar() { Python::with_gil(|py| { - let dt = dtype::(py); + let dt = dtype_bound::(py); assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int); assert_eq!(dt.flags(), 0); @@ -603,7 +881,7 @@ mod tests { assert!(!dt.has_fields()); assert!(!dt.is_aligned_struct()); assert!(!dt.has_subarray()); - assert!(dt.base().is_equiv_to(dt)); + assert!(dt.base().is_equiv_to(&dt)); assert_eq!(dt.ndim(), 0); assert_eq!(dt.shape(), vec![]); }); @@ -612,7 +890,7 @@ mod tests { #[test] fn test_dtype_methods_subarray() { Python::with_gil(|py| { - let locals = PyDict::new(py); + let locals = PyDict::new_bound(py); py_run!( py, *locals, @@ -622,7 +900,7 @@ mod tests { .get_item("dtype") .unwrap() .unwrap() - .downcast::() + .downcast_into::() .unwrap(); assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int); @@ -641,14 +919,14 @@ mod tests { assert!(dt.has_subarray()); assert_eq!(dt.ndim(), 2); assert_eq!(dt.shape(), vec![2, 3]); - assert!(dt.base().is_equiv_to(dtype::(py))); + assert!(dt.base().is_equiv_to(&dtype_bound::(py))); }); } #[test] fn test_dtype_methods_record() { Python::with_gil(|py| { - let locals = PyDict::new(py); + let locals = PyDict::new_bound(py); py_run!( py, *locals, @@ -658,7 +936,7 @@ mod tests { .get_item("dtype") .unwrap() .unwrap() - .downcast::() + .downcast_into::() .unwrap(); assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int); @@ -679,15 +957,15 @@ mod tests { assert!(!dt.has_subarray()); assert_eq!(dt.ndim(), 0); assert_eq!(dt.shape(), vec![]); - assert!(dt.base().is_equiv_to(dt)); + assert!(dt.base().is_equiv_to(&dt)); let x = dt.get_field("x").unwrap(); - assert!(x.0.is_equiv_to(dtype::(py))); + assert!(x.0.is_equiv_to(&dtype_bound::(py))); assert_eq!(x.1, 0); let y = dt.get_field("y").unwrap(); - assert!(y.0.is_equiv_to(dtype::(py))); + assert!(y.0.is_equiv_to(&dtype_bound::(py))); assert_eq!(y.1, 8); let z = dt.get_field("z").unwrap(); - assert!(z.0.is_equiv_to(dtype::(py))); + assert!(z.0.is_equiv_to(&dtype_bound::(py))); assert_eq!(z.1, 16); }); } diff --git a/src/lib.rs b/src/lib.rs index 8e538366e..753d65caf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,7 +106,11 @@ pub use crate::borrow::{ PyReadwriteArray5, PyReadwriteArray6, PyReadwriteArrayDyn, }; pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; -pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr}; +#[allow(deprecated)] +pub use crate::dtype::dtype; +pub use crate::dtype::{ + dtype_bound, Complex32, Complex64, Element, PyArrayDescr, PyArrayDescrMethods, +}; pub use crate::error::{BorrowError, FromVecError, NotContiguousError}; pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API}; pub use crate::strings::{PyFixedString, PyFixedUnicode}; diff --git a/src/strings.rs b/src/strings.rs index cd1316b80..ba9f14236 100644 --- a/src/strings.rs +++ b/src/strings.rs @@ -13,11 +13,11 @@ use std::str; use pyo3::{ ffi::{Py_UCS1, Py_UCS4}, sync::GILProtected, - Py, Python, + Bound, Py, Python, }; use rustc_hash::FxHashMap; -use crate::dtype::{Element, PyArrayDescr}; +use crate::dtype::{Element, PyArrayDescr, PyArrayDescrMethods}; use crate::npyffi::NPY_TYPES; /// A newtype wrapper around [`[u8; N]`][Py_UCS1] to handle [`byte` scalars][numpy-bytes] while satisfying coherence. @@ -76,7 +76,7 @@ impl From<[Py_UCS1; N]> for PyFixedString { unsafe impl Element for PyFixedString { const IS_COPY: bool = true; - fn get_dtype<'py>(py: Python<'py>) -> &PyArrayDescr { + fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> { static DTYPES: TypeDescriptors = TypeDescriptors::new(); unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_STRING, b'|' as _, size_of::()) } @@ -147,7 +147,7 @@ impl From<[Py_UCS4; N]> for PyFixedUnicode { unsafe impl Element for PyFixedUnicode { const IS_COPY: bool = true; - fn get_dtype<'py>(py: Python<'py>) -> &PyArrayDescr { + fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> { static DTYPES: TypeDescriptors = TypeDescriptors::new(); unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_UNICODE, b'=' as _, size_of::()) } @@ -169,12 +169,12 @@ impl TypeDescriptors { /// `npy_type` must be either `NPY_STRING` or `NPY_UNICODE` with matching `byteorder` and `size` #[allow(clippy::wrong_self_convention)] unsafe fn from_size<'py>( - &'py self, + &self, py: Python<'py>, npy_type: NPY_TYPES, byteorder: c_char, size: usize, - ) -> &'py PyArrayDescr { + ) -> Bound<'py, PyArrayDescr> { let mut dtypes = self.dtypes.get(py).borrow_mut(); let dtype = match dtypes.get_or_insert_with(Default::default).entry(size) { @@ -190,7 +190,7 @@ impl TypeDescriptors { } }; - dtype.clone().into_ref(py) + dtype.clone().into_bound(py) } } diff --git a/src/untyped_array.rs b/src/untyped_array.rs index fa5ae08ee..cf5bd385c 100644 --- a/src/untyped_array.rs +++ b/src/untyped_array.rs @@ -98,13 +98,13 @@ impl PyUntypedArray { /// # Example /// /// ``` - /// use numpy::{dtype, PyArray}; + /// use numpy::{dtype_bound, PyArray}; /// use pyo3::Python; /// /// Python::with_gil(|py| { /// let array = PyArray::from_vec(py, vec![1_i32, 2, 3]); /// - /// assert!(array.dtype().is_equiv_to(dtype::(py))); + /// assert!(array.dtype().is_equiv_to(dtype_bound::(py).as_gil_ref())); /// }); /// ``` /// @@ -268,13 +268,13 @@ pub trait PyUntypedArrayMethods<'py>: sealed::Sealed { /// # Example /// /// ``` - /// use numpy::{dtype, PyArray}; + /// use numpy::{dtype_bound, PyArray}; /// use pyo3::Python; /// /// Python::with_gil(|py| { /// let array = PyArray::from_vec(py, vec![1_i32, 2, 3]); /// - /// assert!(array.dtype().is_equiv_to(dtype::(py))); + /// assert!(array.dtype().is_equiv_to(dtype_bound::(py).as_gil_ref())); /// }); /// ``` /// diff --git a/tests/array.rs b/tests/array.rs index 48654479c..41d7f05db 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -4,12 +4,12 @@ use std::mem::size_of; use half::{bf16, f16}; use ndarray::{array, s, Array1, Dim}; use numpy::{ - dtype, get_array_module, npyffi::NPY_ORDER, pyarray, PyArray, PyArray1, PyArray2, PyArrayDescr, - PyArrayDyn, PyFixedString, PyFixedUnicode, ToPyArray, + dtype_bound, get_array_module, npyffi::NPY_ORDER, pyarray, PyArray, PyArray1, PyArray2, + PyArrayDescr, PyArrayDescrMethods, PyArrayDyn, PyFixedString, PyFixedUnicode, ToPyArray, }; use pyo3::{ py_run, pyclass, pymethods, - types::{IntoPyDict, PyDict, PyList}, + types::{IntoPyDict, PyAnyMethods, PyDict, PyList}, IntoPy, Py, PyAny, PyCell, PyResult, Python, }; @@ -376,13 +376,17 @@ fn dtype_via_python_attribute() { let arr = array![[2, 3], [4, 5u32]]; let pyarr = arr.to_pyarray(py); - let dt: &PyArrayDescr = py - .eval("a.dtype", Some([("a", pyarr)].into_py_dict(py)), None) + let dt = py + .eval_bound( + "a.dtype", + Some(&[("a", pyarr)].into_py_dict_bound(py)), + None, + ) .unwrap() - .downcast() + .downcast_into::() .unwrap(); - assert!(dt.is_equiv_to(dtype::(py))); + assert!(dt.is_equiv_to(&dtype_bound::(py))); }); }