Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize nth and nth_back for BoundListIterator #4810

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cc9cabd
Optimize nth and nth_back for BoundListIterator. Add unit test and be…
Owen-CH-Leung Dec 21, 2024
3a0c196
Fix fmt and newsfragment CI
Owen-CH-Leung Dec 21, 2024
40d38f3
Fix clippy and changelog CI
Owen-CH-Leung Dec 21, 2024
1b19616
Merge branch 'main' into optimize_nth_and_nthback_for_boundlistiter
Owen-CH-Leung Jan 9, 2025
f6e95a8
Revise Impl of nth and nth_back. Impl advance_by
Owen-CH-Leung Jan 9, 2025
b0c749b
Fix failing fmt
Owen-CH-Leung Jan 9, 2025
b2bf973
Fix failing ruff test
Owen-CH-Leung Jan 9, 2025
4e86709
Merge branch 'main' into optimize_nth_and_nthback_for_boundlistiter
Owen-CH-Leung Jan 10, 2025
e4269c2
branch out nth, nth_unchecked, nth_back, nth_back_unchecked.
Owen-CH-Leung Jan 10, 2025
6e18229
Fix fmt
Owen-CH-Leung Jan 10, 2025
5bab05b
Revise advance_by impl. add advance_by unittest.
Owen-CH-Leung Jan 10, 2025
0b23173
Fix fmt
Owen-CH-Leung Jan 10, 2025
e88f8be
Fix clippy unused function warning
Owen-CH-Leung Jan 10, 2025
3a7a171
Set appropriate Py_LIMITED_API flag
Owen-CH-Leung Jan 10, 2025
ed8dba6
Rewrite nth & nth_back using conditional compilation. Rearrange flags…
Owen-CH-Leung Jan 14, 2025
51104a1
fix fmt
Owen-CH-Leung Jan 14, 2025
cae0981
fix failing CI
Owen-CH-Leung Jan 14, 2025
00e4802
Impl advance_back_by. Remove cfg flag for with_critical_section
Owen-CH-Leung Jan 15, 2025
b7373aa
refactor advance_by and advance_back_by. Add back cfg for with_critic…
Owen-CH-Leung Jan 15, 2025
7751a1c
Put allow deadcode for with_critical_section
Owen-CH-Leung Jan 15, 2025
a735850
Remove use of get_item. Revise changelog
Owen-CH-Leung Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/4810.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimizes `nth` and `nth_back` for `BoundListIterator`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also mention that advance_by and advance_by_back are implemented on nightly

30 changes: 29 additions & 1 deletion pyo3-benches/benches/bench_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,33 @@ fn list_get_item(b: &mut Bencher<'_>) {
});
}

#[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))]
fn list_nth(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
const LEN: usize = 50;
let list = PyList::new_bound(py, 0..LEN);
let mut sum = 0;
b.iter(|| {
for i in 0..LEN {
sum += list.iter().nth(i).unwrap().extract::<usize>().unwrap();
}
});
});
}

fn list_nth_back(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
const LEN: usize = 50;
let list = PyList::new_bound(py, 0..LEN);
let mut sum = 0;
b.iter(|| {
for i in 0..LEN {
sum += list.iter().nth_back(i).unwrap().extract::<usize>().unwrap();
}
});
});
}

#[cfg(not(Py_LIMITED_API))]
fn list_get_item_unchecked(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
const LEN: usize = 50_000;
Expand All @@ -66,6 +92,8 @@ fn sequence_from_list(b: &mut Bencher<'_>) {
fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("iter_list", iter_list);
c.bench_function("list_new", list_new);
c.bench_function("list_nth", list_nth);
c.bench_function("list_nth_back", list_nth_back);
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
c.bench_function("list_get_item", list_get_item);
#[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))]
c.bench_function("list_get_item_unchecked", list_get_item_unchecked);
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![warn(missing_docs)]
#![cfg_attr(
feature = "nightly",
feature(auto_traits, negative_impls, try_trait_v2)
feature(auto_traits, negative_impls, try_trait_v2, iter_advance_by)
)]
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
// Deny some lints in doctests.
Expand Down
283 changes: 278 additions & 5 deletions src/types/list.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::iter::FusedIterator;

use crate::err::{self, PyResult};
use crate::ffi::{self, Py_ssize_t};
use crate::ffi_ptr_ext::FfiPtrExt;
use crate::internal_tricks::get_ssize_index;
use crate::types::any::PyAnyMethods;
use crate::types::sequence::PySequenceMethods;
use crate::types::{PySequence, PyTuple};
use crate::{
Borrowed, Bound, BoundObject, IntoPyObject, IntoPyObjectExt, PyAny, PyErr, PyObject, Python,
};

use crate::types::any::PyAnyMethods;
use crate::types::sequence::PySequenceMethods;
use std::iter::FusedIterator;
#[cfg(all(not(Py_LIMITED_API), feature = "nightly"))]
use std::num::NonZero;

/// Represents a Python `list`.
///
Expand Down Expand Up @@ -547,6 +547,46 @@ impl<'py> BoundListIterator<'py> {
}
}

#[inline]
#[cfg(all(not(Py_LIMITED_API), feature = "nightly"))]
#[deny(unsafe_op_in_unsafe_fn)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused - shouldn't this only be defined on not(feature = "nightly")? Because on nightly this is defined in terms of advance_by in the standard library implementation, so we don't need to override nth. It'll just do the right thing on nightly by calling advance_by and then next(). On not nightly you need this override because there's no way to override advance_by.

Ditto for all the other BoundListIterator methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed - I spent hours reading the stdlib again to figure out that on stable toolchain, we should override nth, and override advance_by on nightly. Thanks for spotting this. I've made the changes to compile it on stable

unsafe fn nth_unchecked(
index: &mut Index,
length: &mut Length,
list: &Bound<'py, PyList>,
n: usize,
) -> Option<Bound<'py, PyAny>> {
let length = length.0.min(list.len());
let target_index = index.0 + n;
if index.0 + n < length {
let item = unsafe { list.get_item_unchecked(target_index) };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if index.0 + n < length {
if target_index < length {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks & revised

index.0 = target_index + 1;
Some(item)
} else {
None
}
}

#[inline]
#[cfg(all(Py_LIMITED_API, feature = "nightly"))]
#[deny(unsafe_op_in_unsafe_fn)]
fn nth(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[deny(unsafe_op_in_unsafe_fn)]

there's no unsafe code in this function so this is unnecessary

index: &mut Index,
length: &mut Length,
list: &Bound<'py, PyList>,
n: usize,
) -> Option<Bound<'py, PyAny>> {
let length = length.0.min(list.len());
let target_index = index.0 + n;
if index.0 + n < length {
let item = list.get_item(target_index).expect("get-item failed");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if index.0 + n < length {
if target_index < length {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks & revised

index.0 = target_index + 1;
Some(item)
} else {
None
}
}

/// # Safety
///
/// On the free-threaded build, caller must verify they have exclusive
Expand Down Expand Up @@ -589,6 +629,45 @@ impl<'py> BoundListIterator<'py> {
}
}

#[inline]
#[cfg(all(not(Py_LIMITED_API), feature = "nightly"))]
#[deny(unsafe_op_in_unsafe_fn)]
unsafe fn nth_back_unchecked(
index: &mut Index,
length: &mut Length,
list: &Bound<'py, PyList>,
n: usize,
) -> Option<Bound<'py, PyAny>> {
let length_size = length.0.min(list.len());
if index.0 + n < length_size {
let target_index = length_size - n - 1;
let item = unsafe { list.get_item_unchecked(target_index) };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally this logic feels a little backwards to me, and I would probably prefer to do this using signed integers and testing if the index is less than zero. That said, this is totally equivalent and if it makes sense to you as-is then no need to change it.

*length = Length(target_index);
Some(item)
} else {
None
}
}

#[inline]
#[cfg(all(Py_LIMITED_API, feature = "nightly"))]
fn nth_back(
index: &mut Index,
length: &mut Length,
list: &Bound<'py, PyList>,
n: usize,
) -> Option<Bound<'py, PyAny>> {
let length_size = length.0.min(list.len());
if index.0 + n < length_size {
let target_index = length_size - n - 1;
let item = list.get_item(target_index).expect("get-item failed");
*length = Length(target_index);
Some(item)
} else {
None
}
}

ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
#[cfg(not(Py_LIMITED_API))]
fn with_critical_section<R>(
&mut self,
Expand Down Expand Up @@ -625,6 +704,26 @@ impl<'py> Iterator for BoundListIterator<'py> {
}
}

#[inline]
#[cfg(feature = "nightly")]
fn nth(&mut self, n: usize) -> Option<Self::Item> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[cfg(feature = "nightly")]
#[cfg(not(feature = "nightly"))]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks & revised

#[cfg(not(Py_LIMITED_API))]
{
self.with_critical_section(|index, length, list| unsafe {
Self::nth_unchecked(index, length, list, n)
})
}
#[cfg(Py_LIMITED_API)]
{
let Self {
index,
length,
list,
} = self;
Self::nth(index, length, list, n)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[cfg(not(Py_LIMITED_API))]
{
self.with_critical_section(|index, length, list| unsafe {
Self::nth_unchecked(index, length, list, n)
})
}
#[cfg(Py_LIMITED_API)]
{
let Self {
index,
length,
list,
} = self;
Self::nth(index, length, list, n)
}
self.with_critical_section(|index, length, list| unsafe {
Self::nth(index, length, list, n)
})

If you implement my suggestion to only have BoundListIterator::nth then you can simplify this a lot.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and you can do a similar refactor for the nth_back implementation for the impl DoubleEndedIterator for BoundListIterator block below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I've adopted this idea for nth and nth_back


#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
Expand Down Expand Up @@ -750,6 +849,32 @@ impl<'py> Iterator for BoundListIterator<'py> {
None
})
}

#[inline]
#[cfg(all(not(Py_LIMITED_API), feature = "nightly"))]
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[cfg(all(not(Py_LIMITED_API), feature = "nightly"))]
#[cfg(feature = "nightly")]

Unless I'm missing something...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we need the flag not(Py_LIMITED_API. Otherwise we will compile the code without access to with_critical_section

https://github.com/PyO3/pyo3/actions/runs/12769998628/job/35593851381?pr=4810

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can delete the #[cfg(not(Py_LIMITED_API))] on BoundListIterator::with_critical_section, it was only there to avoid a clippy lint about unused code. pyo3::sync::with_critical_section is unconditionally exposed in the PyO3 API, it's just a no-op on GIL-enabled builds. We wanted to allow people to write code that avoids conditional compilation, where possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I deleted #[cfg(not(Py_LIMITED_API))] but was hitting the unused code lint warning, so I put in a allow lint dead_code. Lemme know your thoughts.

I also pushed a few commits and implemented advance_back_by.

self.with_critical_section(|index, length, list| {
let max_len = length.0.min(list.len());
let currently_at = index.0;
if currently_at >= max_len {
if n == 0 {
return Ok(());
} else {
return Err(unsafe { NonZero::new_unchecked(n) });
}
}

let items_left = max_len - currently_at;
if n <= items_left {
index.0 += n;
Ok(())
} else {
index.0 = max_len;
let remainder = n - items_left;
Err(unsafe { NonZero::new_unchecked(remainder) })
}
})
}
}

impl DoubleEndedIterator for BoundListIterator<'_> {
Expand All @@ -772,6 +897,26 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
}
}

#[inline]
#[cfg(feature = "nightly")]
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
#[cfg(not(Py_LIMITED_API))]
{
self.with_critical_section(|index, length, list| unsafe {
Self::nth_back_unchecked(index, length, list, n)
})
}
#[cfg(Py_LIMITED_API)]
{
let Self {
index,
length,
list,
} = self;
Self::nth_back(index, length, list, n)
}
}

#[inline]
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
fn rfold<B, F>(mut self, init: B, mut f: F) -> B
Expand Down Expand Up @@ -839,6 +984,8 @@ mod tests {
use crate::types::sequence::PySequenceMethods;
use crate::types::{PyList, PyTuple};
use crate::{ffi, IntoPyObject, PyResult, Python};
#[cfg(feature = "nightly")]
use std::num::NonZero;

#[test]
fn test_new() {
Expand Down Expand Up @@ -1502,4 +1649,130 @@ mod tests {
assert!(tuple.eq(tuple_expected).unwrap());
})
}

#[test]
fn test_iter_nth() {
Python::with_gil(|py| {
let v = vec![6, 7, 8, 9, 10];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
iter.next();
assert_eq!(iter.nth(1).unwrap().extract::<i32>().unwrap(), 8);
assert_eq!(iter.nth(1).unwrap().extract::<i32>().unwrap(), 10);
assert!(iter.nth(1).is_none());

let v: Vec<i32> = vec![];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
iter.next();
assert!(iter.nth(1).is_none());

let v = vec![1, 2, 3];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert!(iter.nth(10).is_none());

let v = vec![6, 7, 8, 9, 10];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();
let mut iter = list.iter();
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 9);
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 10);

let mut iter = list.iter();
iter.nth_back(1);
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 8);
assert!(iter.next().is_none());
});
}

#[test]
fn test_iter_nth_back() {
Python::with_gil(|py| {
let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert_eq!(iter.nth_back(0).unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 3);
assert!(iter.nth_back(2).is_none());

let v: Vec<i32> = vec![];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert!(iter.nth_back(0).is_none());
assert!(iter.nth_back(1).is_none());

let v = vec![1, 2, 3];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert!(iter.nth_back(5).is_none());

let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
iter.next_back(); // Consume the last element
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 3);
assert_eq!(iter.next_back().unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(iter.nth_back(0).unwrap().extract::<i32>().unwrap(), 1);

let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 4);
assert_eq!(iter.nth_back(2).unwrap().extract::<i32>().unwrap(), 1);

let mut iter2 = list.iter();
iter2.next_back();
assert_eq!(iter2.nth_back(1).unwrap().extract::<i32>().unwrap(), 3);
assert_eq!(iter2.next_back().unwrap().extract::<i32>().unwrap(), 2);

let mut iter3 = list.iter();
iter3.nth(1);
assert_eq!(iter3.nth_back(2).unwrap().extract::<i32>().unwrap(), 3);
assert!(iter3.nth_back(0).is_none());
});
}

#[cfg(feature = "nightly")]
#[test]
fn test_iter_advance_by() {
Python::with_gil(|py| {
let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert_eq!(iter.advance_by(2), Ok(()));
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 3);
assert_eq!(iter.advance_by(0), Ok(()));
assert_eq!(iter.advance_by(100), Err(NonZero::new(98).unwrap()));

let mut iter2 = list.iter();
assert_eq!(iter2.advance_by(6), Err(NonZero::new(1).unwrap()));

let mut iter3 = list.iter();
assert_eq!(iter3.advance_by(5), Ok(()));

let mut iter4 = list.iter();
assert_eq!(iter4.advance_by(0), Ok(()));
assert_eq!(iter4.next().unwrap().extract::<i32>().unwrap(), 1);
})
}
}
Loading