Skip to content

Commit

Permalink
Implement Acquire for PgListener (#3550)
Browse files Browse the repository at this point in the history
* Implement Acquire for PgListener

* Add a test which checks that PgListener implements Acquire

* Drop unnecessary call to `.acquire()`

* Rename test channel to avoid conflict with other tests
  • Loading branch information
sandhose authored Oct 28, 2024
1 parent eac4b7a commit 709226c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
17 changes: 16 additions & 1 deletion sqlx-postgres/src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use std::str::from_utf8;
use futures_channel::mpsc;
use futures_core::future::BoxFuture;
use futures_core::stream::{BoxStream, Stream};
use futures_util::{FutureExt, StreamExt, TryStreamExt};
use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use sqlx_core::acquire::Acquire;
use sqlx_core::transaction::Transaction;
use sqlx_core::Either;

use crate::describe::Describe;
Expand Down Expand Up @@ -328,6 +330,19 @@ impl Drop for PgListener {
}
}

impl<'c> Acquire<'c> for &'c mut PgListener {
type Database = Postgres;
type Connection = &'c mut PgConnection;

fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, Error>> {
self.connection().boxed()
}

fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>> {
self.connection().and_then(|c| c.begin()).boxed()
}
}

impl<'c> Executor<'c> for &'c mut PgListener {
type Database = Postgres;

Expand Down
39 changes: 39 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,45 @@ async fn test_pg_listener_allows_pool_to_close() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_pg_listener_implements_acquire() -> anyhow::Result<()> {
use sqlx::Acquire;

let pool = pool::<Postgres>().await?;

let mut listener = PgListener::connect_with(&pool).await?;
listener
.listen("test_pg_listener_implements_acquire")
.await?;

// Start a transaction on the underlying connection
let mut txn = listener.begin().await?;

// This will reuse the same connection, so this connection should be listening to the channel
let channels: Vec<String> = sqlx::query_scalar("SELECT pg_listening_channels()")
.fetch_all(&mut *txn)
.await?;

assert_eq!(channels, vec!["test_pg_listener_implements_acquire"]);

// Send a notification
sqlx::query("NOTIFY test_pg_listener_implements_acquire, 'hello'")
.execute(&mut *txn)
.await?;

txn.commit().await?;

// And now we can receive the notification we sent in the transaction
let notification = listener.recv().await?;
assert_eq!(
notification.channel(),
"test_pg_listener_implements_acquire"
);
assert_eq!(notification.payload(), "hello");

Ok(())
}

#[sqlx_macros::test]
async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result<()> {
// Only supported in Postgres 11+
Expand Down

0 comments on commit 709226c

Please sign in to comment.