Skip to content

Commit

Permalink
Use extension trait for iterators to allow code to flow better. Make …
Browse files Browse the repository at this point in the history
…struct private. Don't reference unneeded generic T. Make doc tests compile.
  • Loading branch information
tylerhawkes committed Dec 24, 2024
1 parent afe6b19 commit 4028058
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 40 deletions.
70 changes: 36 additions & 34 deletions sqlx-postgres/src/bind_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use sqlx_core::{

use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};

/// A wrapper enabling iterators to encode arrays in Postgres.
// not exported but pub because it is used in the extension trait
pub struct PgBindIter<I>(I);

/// Iterator extension trait enabling iterators to encode arrays in Postgres.
///
/// Because of the blanket impl of `PgHasArrayType` for all references
/// we can borrow instead of needing to clone or copy in the iterators
Expand All @@ -18,62 +21,61 @@ use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Pos
/// iterating over them again to encode.
///
/// This now requires only iterating over the array once for each field
/// while using less memory giving both speed and memory usage improvements.
/// while using less memory giving both speed and memory usage improvements
/// along with allowing much more flexibility in the underlying collection.
///
/// ```rust,ignore
/// ```rust,no_run
/// # async fn test_bind_iter() {
/// # use sqlx::types::chrono::{DateTime, Utc}
/// # fn people() -> &'static [Person] {
/// # &[]
/// # }
/// # let mut conn = sqlx::Postgres::Connection::connect("dummyurl").await;
/// use sqlx::postgres::PgBindIterExt;
///
/// #[derive(sqlx::FromRow)]
/// struct Person {
/// id: i64,
/// name: String,
/// birthdate: DateTime<Utc>,
/// }
///
/// let people: &[Person] = people();
///
/// sqlx::query(
/// "insert into person(id, name, birthdate) select * from unnest($1, $2, $3)"
/// )
/// .bind(PgBindIter::from(people.iter().map(|p|p.id)))
/// .bind(PgBindIter::from(people.iter().map(|p|&p.name)))
/// .bind(PgBindIter::from(people.iter().map(|p|&p.birthdate)))
/// .execute(pool)
/// .await?;
/// # let people: &[Person] = people();
/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
/// .bind(people.iter().map(|p| p.id).bind_iter())
/// .bind(people.iter().map(|p| &p.name).bind_iter())
/// .bind(people.iter().map(|p| &p.birthdate).bind_iter())
/// .execute(&mut conn)
/// .await?;
/// # }
/// ```
pub struct PgBindIter<I>(I);

impl<I> PgBindIter<I> {
pub fn new(inner: I) -> Self {
Self(inner)
}
pub trait PgBindIterExt: Iterator + Sized {
fn bind_iter(self) -> PgBindIter<Self>;
}

impl<I> From<I> for PgBindIter<I> {
fn from(inner: I) -> Self {
Self::new(inner)
impl<I: Iterator + Sized> PgBindIterExt for I {
fn bind_iter(self) -> PgBindIter<I> {
PgBindIter(self)
}
}

impl<T, I> Type<Postgres> for PgBindIter<I>
impl<I> Type<Postgres> for PgBindIter<I>
where
T: Type<Postgres> + PgHasArrayType,
I: Iterator<Item = T>,
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + PgHasArrayType,
{
fn type_info() -> <Postgres as Database>::TypeInfo {
T::array_type_info()
<I as Iterator>::Item::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
T::array_compatible(ty)
<I as Iterator>::Item::array_compatible(ty)
}
}

impl<'q, T, I> PgBindIter<I>
impl<'q, I> PgBindIter<I>
where
I: Iterator<Item = T>,
T: Type<Postgres> + Encode<'q, Postgres>,
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_inner(
// need ownership to iterate
Expand All @@ -85,7 +87,7 @@ where
let type_info = first
.as_ref()
.and_then(Encode::produces)
.unwrap_or_else(T::type_info);
.unwrap_or_else(<I as Iterator>::Item::type_info);

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags
Expand Down Expand Up @@ -129,11 +131,11 @@ where
}
}

impl<'q, T, I> Encode<'q, Postgres> for PgBindIter<I>
impl<'q, I> Encode<'q, Postgres> for PgBindIter<I>
where
T: Type<Postgres> + Encode<'q, Postgres>,
// Clone is required for the encode_by_ref call since we can't iterate with a shared reference
I: Iterator<Item = T> + Clone,
I: Iterator + Clone,
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
Self::encode_inner(self.0.clone(), buf)
Expand Down
2 changes: 1 addition & 1 deletion sqlx-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub(crate) use sqlx_core::driver_prelude::*;

pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
pub use arguments::{PgArgumentBuffer, PgArguments};
pub use bind_iter::PgBindIter;
pub use bind_iter::PgBindIterExt;
pub use column::PgColumn;
pub use connection::PgConnection;
pub use copy::{PgCopyIn, PgPoolCopyExt};
Expand Down
2 changes: 1 addition & 1 deletion sqlx-postgres/src/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ where
self.len()
)
})?;
crate::bind_iter::PgBindIter::new(self.iter()).encode(buf)
crate::PgBindIterExt::bind_iter(self.iter()).encode(buf)
}
}

Expand Down
8 changes: 4 additions & 4 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2045,7 +2045,7 @@ async fn test_issue_3052() {

#[sqlx_macros::test]
async fn test_bind_iter() -> anyhow::Result<()> {
use sqlx::postgres::PgBindIter;
use sqlx::postgres::PgBindIterExt;
use sqlx::types::chrono::{DateTime, Utc};

let mut conn = new::<Postgres>().await?;
Expand Down Expand Up @@ -2084,10 +2084,10 @@ create temporary table person(
let rows_affected =
sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
// owned value
.bind(PgBindIter::from(people.iter().map(|p| p.id)))
.bind(people.iter().map(|p| p.id).bind_iter())
// borrowed value
.bind(PgBindIter::from(people.iter().map(|p| &p.name)))
.bind(PgBindIter::from(people.iter().map(|p| &p.birthdate)))
.bind(people.iter().map(|p| &p.name).bind_iter())
.bind(people.iter().map(|p| &p.birthdate).bind_iter())
.execute(&mut conn)
.await?
.rows_affected();
Expand Down

0 comments on commit 4028058

Please sign in to comment.