diff --git a/csharp/lib/src/lib.rs b/csharp/lib/src/lib.rs index 88da043f03..0347704abe 100644 --- a/csharp/lib/src/lib.rs +++ b/csharp/lib/src/lib.rs @@ -7,6 +7,7 @@ use redis::{FromRedisValue, RedisResult}; use std::{ ffi::{c_void, CStr, CString}, os::raw::c_char, + sync::Arc, }; use tokio::runtime::Builder; use tokio::runtime::Runtime; @@ -68,7 +69,8 @@ fn create_client_internal( }) } -/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool. +/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. +/// All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool. #[no_mangle] pub extern "C" fn create_client( host: *const c_char, @@ -79,38 +81,48 @@ pub extern "C" fn create_client( ) -> *const c_void { match create_client_internal(host, port, use_tls, success_callback, failure_callback) { Err(_) => std::ptr::null(), // TODO - log errors - Ok(client) => Box::into_raw(Box::new(client)) as *const c_void, + Ok(client) => Arc::into_raw(Arc::new(client)) as *const c_void, } } +/// # Safety +/// +/// This function should only be called once per pointer created by [create_client]. After calling this function +/// the `client_ptr` is not in a valid state. #[no_mangle] pub extern "C" fn close_client(client_ptr: *const c_void) { - let client_ptr = unsafe { Box::from_raw(client_ptr as *mut Client) }; - let _runtime_handle = client_ptr.runtime.enter(); - drop(client_ptr); + let count = Arc::strong_count(&unsafe { Arc::from_raw(client_ptr as *mut Client) }); + assert!(count == 1, "Client is still in use."); } /// Expects that key and value will be kept valid until the callback is called. +/// +/// # Safety +/// +/// This function should only be called should with a pointer created by [create_client], before [close_client] was called with the pointer. #[no_mangle] -pub extern "C" fn command( +pub unsafe extern "C" fn command( client_ptr: *const c_void, callback_index: usize, request_type: RequestType, args: *const *mut c_char, arg_count: u32, ) { - let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) }; + let client = unsafe { + // we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc. + Arc::increment_strong_count(client_ptr); + Arc::from_raw(client_ptr as *mut Client) + }; + let core_client_clone = client.clone(); // The safety of these needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed. - let ptr_address = client_ptr as usize; let args_address = args as usize; let mut client_clone = client.client.clone(); client.runtime.spawn(async move { let Some(mut cmd) = request_type.get_command() else { unsafe { - let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); - (client.failure_callback)(callback_index); // TODO - report errors + (core_client_clone.failure_callback)(callback_index); // TODO - report errors return; } }; @@ -128,11 +140,12 @@ pub extern "C" fn command( .await .and_then(Option::::from_owned_redis_value); unsafe { - let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); match result { - Ok(None) => (client.success_callback)(callback_index, std::ptr::null()), - Ok(Some(c_str)) => (client.success_callback)(callback_index, c_str.as_ptr()), - Err(_) => (client.failure_callback)(callback_index), // TODO - report errors + Ok(None) => (core_client_clone.success_callback)(callback_index, std::ptr::null()), + Ok(Some(c_str)) => { + (core_client_clone.success_callback)(callback_index, c_str.as_ptr()) + } + Err(_) => (core_client_clone.failure_callback)(callback_index), // TODO - report errors }; } }); diff --git a/go/src/lib.rs b/go/src/lib.rs index f1eb794d31..8dde0edaf6 100644 --- a/go/src/lib.rs +++ b/go/src/lib.rs @@ -16,6 +16,7 @@ use redis::cluster_routing::{ use redis::cluster_routing::{ResponsePolicy, Routable}; use redis::{Cmd, RedisResult, Value}; use std::slice::from_raw_parts; +use std::sync::Arc; use std::{ ffi::{c_void, CString}, mem, @@ -180,8 +181,8 @@ fn create_client_internal( /// /// * `connection_request_bytes` must point to `connection_request_len` consecutive properly initialized bytes. It must be a well-formed Protobuf `ConnectionRequest` object. The array must be allocated by the caller and subsequently freed by the caller after this function returns. /// * `connection_request_len` must not be greater than the length of the connection request bytes array. It must also not be greater than the max value of a signed pointer-sized integer. -/// * The `conn_ptr` pointer in the returned `ConnectionResponse` must live while the client is open/active and must be explicitly freed by calling [`close_client`]. -/// * The `connection_error_message` pointer in the returned `ConnectionResponse` must live until the returned `ConnectionResponse` pointer is passed to [`free_connection_response`]. +/// * The `conn_ptr` pointer in the returned `ConnectionResponse` must live while the client is open/active and must be explicitly freed by calling [close_client]. +/// * The `connection_error_message` pointer in the returned `ConnectionResponse` must live until the returned `ConnectionResponse` pointer is passed to [free_connection_response]. /// * Both the `success_callback` and `failure_callback` function pointers need to live while the client is open/active. The caller is responsible for freeing both callbacks. // TODO: Consider making this async #[no_mangle] @@ -201,7 +202,7 @@ pub unsafe extern "C" fn create_client( ), }, Ok(client) => ConnectionResponse { - conn_ptr: Box::into_raw(Box::new(client)) as *const c_void, + conn_ptr: Arc::into_raw(Arc::new(client)) as *const c_void, connection_error_message: std::ptr::null(), }, }; @@ -220,13 +221,15 @@ pub unsafe extern "C" fn create_client( /// /// * `close_client` can only be called once per client. Calling it twice is undefined behavior, since the address will be freed twice. /// * `close_client` must be called after `free_connection_response` has been called to avoid creating a dangling pointer in the `ConnectionResponse`. -/// * `client_adapter_ptr` must be obtained from the `ConnectionResponse` returned from [`create_client`]. +/// * `client_adapter_ptr` must be obtained from the `ConnectionResponse` returned from [create_client]. /// * `client_adapter_ptr` must be valid until `close_client` is called. // TODO: Ensure safety when command has not completed yet #[no_mangle] pub unsafe extern "C" fn close_client(client_adapter_ptr: *const c_void) { assert!(!client_adapter_ptr.is_null()); - drop(unsafe { Box::from_raw(client_adapter_ptr as *mut ClientAdapter) }); + let client_adapter = unsafe { Arc::from_raw(client_adapter_ptr as *mut ClientAdapter) }; + let count = Arc::strong_count(&client_adapter); + assert!(count == 1, "Client is still in use. {count} references remain."); } /// Deallocates a `ConnectionResponse`. @@ -502,7 +505,7 @@ fn valkey_value_to_command_response(value: Value) -> RedisResult value, Err(err) => { @@ -552,7 +554,7 @@ pub unsafe extern "C" fn command( let c_err_str = CString::into_raw( CString::new(message).expect("Couldn't convert error message to CString"), ); - unsafe { (client_adapter.failure_callback)(channel, c_err_str, error_type) }; + unsafe { (client_adapter_clone.failure_callback)(channel, c_err_str, error_type) }; return; } }; @@ -561,9 +563,10 @@ pub unsafe extern "C" fn command( unsafe { match result { - Ok(message) => { - (client_adapter.success_callback)(channel, Box::into_raw(Box::new(message))) - } + Ok(message) => (client_adapter_clone.success_callback)( + channel, + Box::into_raw(Box::new(message)), + ), Err(err) => { let message = errors::error_message(&err); let error_type = errors::error_type(&err); @@ -571,7 +574,7 @@ pub unsafe extern "C" fn command( let c_err_str = CString::into_raw( CString::new(message).expect("Couldn't convert error message to CString"), ); - (client_adapter.failure_callback)(channel, c_err_str, error_type); + (client_adapter_clone.failure_callback)(channel, c_err_str, error_type); } }; }