Skip to content

Commit

Permalink
fix cancellation issues with PgListener, PgStream::recv() (#3467)
Browse files Browse the repository at this point in the history
* fix(postgres): make `PgStream::recv_unchecked()` cancel-safe

* fix(postgres): make `PgListener` close the connection on-error

* fix: incorrect math in `BufferedSocket::read_buffered()`
  • Loading branch information
abonander authored Aug 27, 2024
1 parent 20ba796 commit e10789d
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 21 deletions.
47 changes: 37 additions & 10 deletions sqlx-core/src/net/socket/buffered.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::error::Error;
use crate::net::Socket;
use bytes::BytesMut;
use std::ops::ControlFlow;
use std::{cmp, io};

use crate::error::Error;

use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};

// Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
Expand Down Expand Up @@ -45,8 +45,39 @@ impl<S: Socket> BufferedSocket<S> {
}
}

pub async fn read_buffered(&mut self, len: usize) -> io::Result<BytesMut> {
self.read_buf.read(len, &mut self.socket).await
pub async fn read_buffered(&mut self, len: usize) -> Result<BytesMut, Error> {
self.try_read(|buf| {
Ok(if buf.len() < len {
ControlFlow::Continue(len)
} else {
ControlFlow::Break(buf.split_to(len))
})
})
.await
}

/// Retryable read operation.
///
/// The callback should check the contents of the buffer passed to it and either:
///
/// * Remove a full message from the buffer and return [`ControlFlow::Break`], or:
/// * Return [`ControlFlow::Continue`] with the expected _total_ length of the buffer,
/// _without_ modifying it.
///
/// Cancel-safe as long as the callback does not modify the passed `BytesMut`
/// before returning [`ControlFlow::Continue`].
pub async fn try_read<F, R>(&mut self, mut try_read: F) -> Result<R, Error>
where
F: FnMut(&mut BytesMut) -> Result<ControlFlow<R, usize>, Error>,
{
loop {
let read_len = match try_read(&mut self.read_buf.read)? {
ControlFlow::Continue(read_len) => read_len,
ControlFlow::Break(ret) => return Ok(ret),
};

self.read_buf.read(read_len, &mut self.socket).await?;
}
}

pub fn write_buffer(&self) -> &WriteBuffer {
Expand Down Expand Up @@ -244,7 +275,7 @@ impl WriteBuffer {
}

impl ReadBuffer {
async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<BytesMut> {
async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<()> {
// Because of how `BytesMut` works, we should only be shifting capacity back and forth
// between `read` and `available` unless we have to read an oversize message.
while self.read.len() < len {
Expand All @@ -266,7 +297,7 @@ impl ReadBuffer {
self.advance(read);
}

Ok(self.drain(len))
Ok(())
}

fn reserve(&mut self, amt: usize) {
Expand All @@ -279,10 +310,6 @@ impl ReadBuffer {
self.read.unsplit(self.available.split_to(amt));
}

fn drain(&mut self, amt: usize) -> BytesMut {
self.read.split_to(amt)
}

fn shrink(&mut self) {
if self.available.capacity() > DEFAULT_BUF_SIZE {
// `BytesMut` doesn't have a way to shrink its capacity,
Expand Down
40 changes: 40 additions & 0 deletions sqlx-core/src/pool/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ use super::inner::{is_beyond_max_lifetime, DecrementSizeGuard, PoolInner};
use crate::pool::options::PoolConnectionMetadata;
use std::future::Future;

const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5);

/// A connection managed by a [`Pool`][crate::pool::Pool].
///
/// Will be returned to the pool on-drop.
pub struct PoolConnection<DB: Database> {
live: Option<Live<DB>>,
close_on_drop: bool,
pub(crate) pool: Arc<PoolInner<DB>>,
}

Expand Down Expand Up @@ -85,6 +88,16 @@ impl<DB: Database> PoolConnection<DB> {
floating.inner.raw.close().await
}

/// Close this connection on-drop, instead of returning it to the pool.
///
/// May be used in cases where waiting for the [`.close()`][Self::close] call
/// to complete is unacceptable, but you still want the connection to be closed gracefully
/// so that the server can clean up resources.
#[inline(always)]
pub fn close_on_drop(&mut self) {
self.close_on_drop = true;
}

/// Detach this connection from the pool, allowing it to open a replacement.
///
/// Note that if your application uses a single shared pool, this
Expand Down Expand Up @@ -140,6 +153,27 @@ impl<DB: Database> PoolConnection<DB> {
}
}
}

fn take_and_close(&mut self) -> impl Future<Output = ()> + Send + 'static {
// float the connection in the pool before we move into the task
// in case the returned `Future` isn't executed, like if it's spawned into a dying runtime
// https://github.com/launchbadge/sqlx/issues/1396
// Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22).
let floating = self.live.take().map(|live| live.float(self.pool.clone()));

let pool = self.pool.clone();

async move {
if let Some(floating) = floating {
// Don't hold the connection forever if it hangs while trying to close
crate::rt::timeout(CLOSE_ON_DROP_TIMEOUT, floating.close())
.await
.ok();
}

pool.min_connections_maintenance(None).await;
}
}
}

impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB> {
Expand All @@ -164,6 +198,11 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB
/// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from.
impl<DB: Database> Drop for PoolConnection<DB> {
fn drop(&mut self) {
if self.close_on_drop {
crate::rt::spawn(self.take_and_close());
return;
}

// We still need to spawn a task to maintain `min_connections`.
if self.live.is_some() || self.pool.options.min_connections > 0 {
crate::rt::spawn(self.return_to_pool());
Expand Down Expand Up @@ -221,6 +260,7 @@ impl<DB: Database> Floating<DB, Live<DB>> {
guard.cancel();
PoolConnection {
live: Some(inner),
close_on_drop: false,
pool,
}
}
Expand Down
47 changes: 38 additions & 9 deletions sqlx-postgres/src/connection/stream.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::collections::BTreeMap;
use std::ops::{Deref, DerefMut};
use std::ops::{ControlFlow, Deref, DerefMut};
use std::str::FromStr;

use futures_channel::mpsc::UnboundedSender;
use futures_util::SinkExt;
use log::Level;
use sqlx_core::bytes::{Buf, Bytes};
use sqlx_core::bytes::Buf;

use crate::connection::tls::MaybeUpgradeTls;
use crate::error::Error;
Expand Down Expand Up @@ -77,16 +77,45 @@ impl PgStream {
}

pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
// all packets in postgres start with a 5-byte header
// this header contains the message type and the total length of the message
let mut header: Bytes = self.inner.read(5).await?;
// NOTE: to not break everything, this should be cancel-safe;
// DO NOT modify `buf` unless a full message has been read
self.inner
.try_read(|buf| {
// all packets in postgres start with a 5-byte header
// this header contains the message type and the total length of the message
let Some(mut header) = buf.get(..5) else {
return Ok(ControlFlow::Continue(5));
};

let format = BackendMessageFormat::try_from_u8(header.get_u8())?;

let message_len = header.get_u32() as usize;

let expected_len = message_len
.checked_add(1)
// this shouldn't really happen but is mostly a sanity check
.ok_or_else(|| {
err_protocol!("message_len + 1 overflows usize: {message_len}")
})?;

if buf.len() < expected_len {
return Ok(ControlFlow::Continue(expected_len));
}

// `buf` SHOULD NOT be modified ABOVE this line

// pop off the format code since it's not counted in `message_len`
buf.advance(1);

let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
let size = (header.get_u32() - 4) as usize;
// consume the message, including the length prefix
let mut contents = buf.split_to(message_len).freeze();

let contents = self.inner.read(size).await?;
// cut off the length prefix
contents.advance(4);

Ok(ReceivedMessage { format, contents })
Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
})
.await
}

// Get the next message from the server
Expand Down
7 changes: 5 additions & 2 deletions sqlx-postgres/src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,11 @@ impl PgListener {
if (err.kind() == io::ErrorKind::ConnectionAborted
|| err.kind() == io::ErrorKind::UnexpectedEof) =>
{
self.buffer_tx = self.connection().await?.stream.notifications.take();
self.connection = None;
if let Some(mut conn) = self.connection.take() {
self.buffer_tx = conn.stream.notifications.take();
// Close the connection in a background task, so we can continue.
conn.close_on_drop();
}

// lost connection
return Ok(None);
Expand Down

0 comments on commit e10789d

Please sign in to comment.