Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PgBindIter for encoding and use it as the implementation encoding &[T] #3651

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
114 changes: 114 additions & 0 deletions sqlx-postgres/src/bind_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use sqlx_core::{
database::Database,
encode::{Encode, IsNull},
types::Type,
};

use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};

pub struct PgBindIter<I>(I);

impl<I> PgBindIter<I> {
pub fn new(inner: I) -> Self {
Self(inner)
}
}

impl<I> From<I> for PgBindIter<I> {
fn from(inner: I) -> Self {
Self::new(inner)
}
}

impl<T, I> Type<Postgres> for PgBindIter<I>
where
T: Type<Postgres> + PgHasArrayType,
I: Iterator<Item = T>,
{
fn type_info() -> <Postgres as Database>::TypeInfo {
T::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
T::array_compatible(ty)
}
}

impl<'q, T, I> PgBindIter<I>
where
I: Iterator<Item = T>,
T: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_inner(
// need ownership to iterate
mut iter: I,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> {
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved
let first = iter.next();
let type_info = first
.as_ref()
.and_then(Encode::produces)
.unwrap_or_else(T::type_info);

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags

match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),

ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}

let len_start = buf.len();
buf.extend(0_i32.to_be_bytes()); // len (unknown so far)
buf.extend(1_i32.to_be_bytes()); // lower bound

match first {
Some(first) => buf.encode(first)?,
None => return Ok(IsNull::No),
}

let mut count = 1_i32;
const MAX: usize = i32::MAX as usize;
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved

for value in (&mut iter).take(MAX) {
buf.encode(value)?;
count += 1;
}

const OVERFLOW: usize = MAX + 1;
if iter.next().is_some() {
return Err(format!("encoded iterator is too large for Postgres: {OVERFLOW}").into());
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved
}

// set the length now that we know what it is.
buf[len_start..(len_start + 4)].copy_from_slice(count.to_be_bytes().as_slice());
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved

Ok(IsNull::No)
}
}

impl<'q, T, I> Encode<'q, Postgres> for PgBindIter<I>
where
T: Type<Postgres> + Encode<'q, Postgres>,
// Clone is required for the encode_by_ref call since we can't iterate with a shared reference
I: Iterator<Item = T> + Clone,
{
fn encode_by_ref(
&self,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> {
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved
Self::encode_inner(self.0.clone(), buf)
}
fn encode(
self,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>>
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved
where
Self: Sized,
{
Self::encode_inner(self.0, buf)
}
}
2 changes: 2 additions & 0 deletions sqlx-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::executor::Executor;

mod advisory_lock;
mod arguments;
mod bind_iter;
mod column;
mod connection;
mod copy;
Expand Down Expand Up @@ -44,6 +45,7 @@ pub(crate) use sqlx_core::driver_prelude::*;

pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
pub use arguments::{PgArgumentBuffer, PgArguments};
pub use bind_iter::PgBindIter;
pub use column::PgColumn;
pub use connection::PgConnection;
pub use copy::{PgCopyIn, PgPoolCopyExt};
Expand Down
32 changes: 3 additions & 29 deletions sqlx-postgres/src/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::borrow::Cow;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::type_info::PgType;
use crate::types::Oid;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
Expand Down Expand Up @@ -156,39 +155,14 @@ where
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
let type_info = self
.first()
.and_then(Encode::produces)
.unwrap_or_else(T::type_info);

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags

// element type
match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),

ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}

let array_len = i32::try_from(self.len()).map_err(|_| {
// do the length check early to avoid doing unnecessary work
i32::try_from(self.len()).map_err(|_| {
format!(
"encoded array length is too large for Postgres: {}",
self.len()
)
})?;

buf.extend(array_len.to_be_bytes()); // len
buf.extend(&1_i32.to_be_bytes()); // lower bound

for element in self.iter() {
buf.encode(element)?;
}

Ok(IsNull::No)
crate::bind_iter::PgBindIter::new(self.iter()).encode(buf)
}
}

Expand Down
Loading