From 0ea90881c1889e264ed919bd273b3b9f8fb38728 Mon Sep 17 00:00:00 2001 From: gridbox Date: Wed, 5 Jun 2024 22:06:15 -0400 Subject: [PATCH] feat: Add set_update_hook on SqliteConnection (#3260) * feat: Add set_update_hook on SqliteConnection * refactor: Address PR comments * fix: Expose UpdateHookResult for public use --------- Co-authored-by: John Smith --- sqlx-sqlite/src/connection/establish.rs | 1 + sqlx-sqlite/src/connection/mod.rs | 97 ++++++++++++++++++++++++- sqlx-sqlite/src/lib.rs | 2 +- tests/sqlite/sqlite.rs | 56 +++++++++++++- 4 files changed, 152 insertions(+), 4 deletions(-) diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 5e3ff6462e..cc99637667 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -294,6 +294,7 @@ impl EstablishParams { transaction_depth: 0, log_settings: self.log_settings.clone(), progress_handler_callback: None, + update_hook_callback: None }) } } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index b6afc0c52b..9f938791d4 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,4 +1,5 @@ use std::cmp::Ordering; +use std::ffi::CStr; use std::fmt::Write; use std::fmt::{self, Debug, Formatter}; use std::os::raw::{c_int, c_void}; @@ -8,7 +9,10 @@ use std::ptr::NonNull; use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; -use libsqlite3_sys::{sqlite3, sqlite3_progress_handler}; +use libsqlite3_sys::{ + sqlite3, sqlite3_progress_handler, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, + SQLITE_UPDATE, +}; pub(crate) use handle::ConnectionHandle; use sqlx_core::common::StatementCache; @@ -58,6 +62,34 @@ pub struct LockedSqliteHandle<'a> { pub(crate) struct Handler(NonNull bool + Send + 'static>); unsafe impl Send for Handler {} +#[derive(Debug, PartialEq, Eq)] +pub enum SqliteOperation { + Insert, + Update, + Delete, + Unknown(i32), +} + +impl From for SqliteOperation { + fn from(value: i32) -> Self { + match value { + SQLITE_INSERT => SqliteOperation::Insert, + SQLITE_UPDATE => SqliteOperation::Update, + SQLITE_DELETE => SqliteOperation::Delete, + code => SqliteOperation::Unknown(code), + } + } +} + +pub struct UpdateHookResult<'a> { + pub operation: SqliteOperation, + pub database: &'a str, + pub table: &'a str, + pub rowid: i64, +} +pub(crate) struct UpdateHookHandler(NonNull); +unsafe impl Send for UpdateHookHandler {} + pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, @@ -71,6 +103,8 @@ pub(crate) struct ConnectionState { /// Stores the progress handler set on the current connection. If the handler returns `false`, /// the query is interrupted. progress_handler_callback: Option, + + update_hook_callback: Option, } impl ConnectionState { @@ -78,7 +112,16 @@ impl ConnectionState { pub(crate) fn remove_progress_handler(&mut self) { if let Some(mut handler) = self.progress_handler_callback.take() { unsafe { - sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _); + sqlite3_progress_handler(self.handle.as_ptr(), 0, None, std::ptr::null_mut()); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } + } + } + + pub(crate) fn remove_update_hook(&mut self) { + if let Some(mut handler) = self.update_hook_callback.take() { + unsafe { + sqlite3_update_hook(self.handle.as_ptr(), None, std::ptr::null_mut()); let _ = { Box::from_raw(handler.0.as_mut()) }; } } @@ -215,6 +258,31 @@ where } } +extern "C" fn update_hook( + callback: *mut c_void, + op_code: c_int, + database: *const i8, + table: *const i8, + rowid: i64, +) where + F: FnMut(UpdateHookResult), +{ + unsafe { + let _ = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + let operation: SqliteOperation = op_code.into(); + let database = CStr::from_ptr(database).to_str().unwrap_or_default(); + let table = CStr::from_ptr(table).to_str().unwrap_or_default(); + (*callback)(UpdateHookResult { + operation, + database, + table, + rowid, + }) + }); + } +} + impl LockedSqliteHandle<'_> { /// Returns the underlying sqlite3* connection handle. /// @@ -279,10 +347,34 @@ impl LockedSqliteHandle<'_> { } } + pub fn set_update_hook(&mut self, callback: F) + where + F: FnMut(UpdateHookResult) + Send + 'static, + { + unsafe { + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); + let handler = callback.as_ptr() as *mut _; + self.guard.remove_update_hook(); + self.guard.update_hook_callback = Some(UpdateHookHandler(callback)); + + sqlite3_update_hook( + self.as_raw_handle().as_mut(), + Some(update_hook::), + handler, + ); + } + } + /// Removes the progress handler on a database connection. The method does nothing if no handler was set. pub fn remove_progress_handler(&mut self) { self.guard.remove_progress_handler(); } + + pub fn remove_update_hook(&mut self) { + self.guard.remove_update_hook(); + } } impl Drop for ConnectionState { @@ -290,6 +382,7 @@ impl Drop for ConnectionState { // explicitly drop statements before the connection handle is dropped self.statements.clear(); self.remove_progress_handler(); + self.remove_update_hook(); } } diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index db09cc2f48..9f0cb376ad 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -33,7 +33,7 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; -pub use connection::{LockedSqliteHandle, SqliteConnection}; +pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; pub use database::Sqlite; pub use error::SqliteError; pub use options::{ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 2d0b3267ba..c47b1a772b 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1,7 +1,7 @@ use futures::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; -use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; +use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, @@ -794,3 +794,57 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow:: assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) } + +#[sqlx_macros::test] +async fn test_query_with_update_hook() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_update_hook(move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Insert); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + assert_eq!(result.rowid, 3); + }); + + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )") + .execute(&mut conn) + .await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Result<()> { + let ref_counted_object = Arc::new(0); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + + { + let mut conn = new::().await?; + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_update_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_update_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_update_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + conn.lock_handle().await?.remove_update_hook(); + } + + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + Ok(()) +}