diff --git a/csharp/lib/ConnectionConfiguration.cs b/csharp/lib/ConnectionConfiguration.cs index 12ffc0c1dc..e5f0b06213 100644 --- a/csharp/lib/ConnectionConfiguration.cs +++ b/csharp/lib/ConnectionConfiguration.cs @@ -11,7 +11,7 @@ public abstract class ConnectionConfiguration [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] internal struct ConnectionRequest { - public uint address_count; + public nuint address_count; public IntPtr addresses; // ** NodeAddress - array pointer public TlsMode tls_mode; public bool cluster_mode; @@ -23,7 +23,6 @@ internal struct ConnectionRequest public ProtocolVersion protocol; [MarshalAs(UnmanagedType.LPStr)] public string? client_name; - } [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] @@ -31,7 +30,7 @@ internal struct NodeAddress { [MarshalAs(UnmanagedType.LPStr)] public string host; - public uint port; + public ushort port; } [StructLayout(LayoutKind.Sequential)] diff --git a/csharp/lib/src/connection.rs b/csharp/lib/src/connection.rs index 3dc5a7f889..e7d40aa0f3 100644 --- a/csharp/lib/src/connection.rs +++ b/csharp/lib/src/connection.rs @@ -1,28 +1,28 @@ /** * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ - use std::ffi::c_char; #[repr(C)] -pub struct ConnectionRequest { - pub address_count : u32, - pub addresses : *const *const NodeAddress, - pub tls_mode : TlsMode, - pub cluster_mode : bool, - pub request_timeout : u32, - pub read_from : ReadFrom, - pub connection_retry_strategy : ConnectionRetryStrategy, - pub authentication_info : AuthenticationInfo, - pub database_id : u32, - pub protocol : ProtocolVersion, - pub client_name : *const c_char, +pub struct ConnectionConfig { + pub address_count: usize, + /// Pointer to an array. + pub addresses: *const *const NodeAddress, + pub tls_mode: TlsMode, + pub cluster_mode: bool, + pub request_timeout: u32, + pub read_from: ReadFrom, + pub connection_retry_strategy: ConnectionRetryStrategy, + pub authentication_info: AuthenticationInfo, + pub database_id: u32, + pub protocol: ProtocolVersion, + pub client_name: *const c_char, } #[repr(C)] pub struct NodeAddress { - pub host : *const c_char, - pub port : u32 + pub host: *const c_char, + pub port: u16, } #[repr(C)] @@ -49,8 +49,8 @@ pub struct ConnectionRetryStrategy { #[repr(C)] pub struct AuthenticationInfo { - pub username : *const c_char, - pub password : *const c_char + pub username: *const c_char, + pub password: *const c_char, } #[repr(C)] diff --git a/csharp/lib/src/lib.rs b/csharp/lib/src/lib.rs index 6ad834c5fa..dbee695d23 100644 --- a/csharp/lib/src/lib.rs +++ b/csharp/lib/src/lib.rs @@ -1,12 +1,14 @@ /** * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ +mod connection; +pub use connection::{ + AuthenticationInfo, ConnectionConfig, ConnectionRetryStrategy, NodeAddress, ProtocolVersion, + ReadFrom, TlsMode, +}; -pub(crate) mod connection; -pub use connection::*; - +use glide_core::client::Client as GlideClient; use glide_core::connection_request; -use glide_core::{client::Client as GlideClient}; use redis::{Cmd, FromRedisValue, RedisResult}; use std::{ ffi::{c_void, CStr, CString}, @@ -29,78 +31,92 @@ pub struct Client { runtime: Runtime, } +/// Convert raw C string to a rust string. +/// # Safety +/// Unsafe function because creating a string from a pointer. unsafe fn ptr_to_str(ptr: *const c_char) -> &'static str { if ptr as i64 != 0 { - CStr::from_ptr(ptr).to_str().unwrap() + unsafe { CStr::from_ptr(ptr) }.to_str().unwrap() } else { "" } } +/// Convert raw array pointer to a vector of `NodeAddress`es. +/// # Safety +/// Unsafe function because dereferencing a pointer. pub unsafe fn node_addresses_to_proto( data: *const *const NodeAddress, len: usize, ) -> Vec { - std::slice::from_raw_parts(data as *mut NodeAddress, len).iter().map(|addr| { - //dbg!(); - let mut address_info = connection_request::NodeAddress::new(); - address_info.host = ptr_to_str(addr.host).into(); - address_info.port = addr.port; - //dbg!(address_info) - address_info - }).collect() + unsafe { std::slice::from_raw_parts(data as *mut NodeAddress, len) } + .iter() + .map(|addr| { + let mut address_info = connection_request::NodeAddress::new(); + address_info.host = unsafe { ptr_to_str(addr.host) }.into(); + address_info.port = addr.port as u32; + address_info + }) + .collect() } +/// Convert connection configuration to a corresponding protobuf object. +/// # Safety +/// Unsafe function because dereferencing a pointer. unsafe fn create_connection_request( - config: *const ConnectionRequest, + config: *const ConnectionConfig, ) -> connection_request::ConnectionRequest { - //dbg!(); let mut connection_request = connection_request::ConnectionRequest::new(); - //dbg!(); - connection_request.addresses = node_addresses_to_proto((*config).addresses, (*config).address_count as usize); - //dbg!(); - connection_request.tls_mode = match (*config).tls_mode { + + let config_ref = unsafe { &*config }; + + connection_request.addresses = + unsafe { node_addresses_to_proto(config_ref.addresses, config_ref.address_count) }; + + connection_request.tls_mode = match config_ref.tls_mode { TlsMode::SecureTls => connection_request::TlsMode::SecureTls, TlsMode::InsecureTls => connection_request::TlsMode::InsecureTls, TlsMode::NoTls => connection_request::TlsMode::NoTls, } .into(); - //dbg!("tls {}", connection_request.tls_mode); - connection_request.cluster_mode_enabled = (*config).cluster_mode; - connection_request.request_timeout = (*config).request_timeout; - //dbg!("cluster {}, timeout {}", connection_request.cluster_mode_enabled, connection_request.request_timeout); - connection_request.read_from = match (*config).read_from { + connection_request.cluster_mode_enabled = config_ref.cluster_mode; + connection_request.request_timeout = config_ref.request_timeout; + + connection_request.read_from = match config_ref.read_from { ReadFrom::AZAffinity => connection_request::ReadFrom::AZAffinity, ReadFrom::PreferReplica => connection_request::ReadFrom::PreferReplica, ReadFrom::Primary => connection_request::ReadFrom::Primary, ReadFrom::LowestLatency => connection_request::ReadFrom::LowestLatency, - }.into(); - //dbg!("read {}", connection_request.read_from); + } + .into(); + let mut retry_strategy = connection_request::ConnectionRetryStrategy::new(); - retry_strategy.number_of_retries = (*config).connection_retry_strategy.number_of_retries; - retry_strategy.factor = (*config).connection_retry_strategy.factor; - retry_strategy.exponent_base = (*config).connection_retry_strategy.exponent_base; + retry_strategy.number_of_retries = config_ref.connection_retry_strategy.number_of_retries; + retry_strategy.factor = config_ref.connection_retry_strategy.factor; + retry_strategy.exponent_base = config_ref.connection_retry_strategy.exponent_base; connection_request.connection_retry_strategy = Some(retry_strategy).into(); - //dbg!("strategy {}", connection_request.connection_retry_strategy.clone()); + let mut auth_info = connection_request::AuthenticationInfo::new(); - auth_info.username = ptr_to_str((*config).authentication_info.username).into(); - auth_info.password = ptr_to_str((*config).authentication_info.password).into(); + auth_info.username = unsafe { ptr_to_str(config_ref.authentication_info.username) }.into(); + auth_info.password = unsafe { ptr_to_str(config_ref.authentication_info.password) }.into(); connection_request.authentication_info = Some(auth_info).into(); - //dbg!("auth {}", connection_request.authentication_info.clone()); - connection_request.database_id = (*config).database_id; - connection_request.protocol = match (*config).protocol { + + connection_request.database_id = config_ref.database_id; + connection_request.protocol = match config_ref.protocol { ProtocolVersion::RESP2 => connection_request::ProtocolVersion::RESP2, ProtocolVersion::RESP3 => connection_request::ProtocolVersion::RESP3, - }.into(); + } + .into(); - connection_request.client_name = ptr_to_str((*config).client_name).into(); + connection_request.client_name = unsafe { ptr_to_str(config_ref.client_name) }.into(); - dbg!(connection_request) - //connection_request + connection_request } -fn create_client_internal( - config: *const ConnectionRequest, +/// # Safety +/// Unsafe function because calling other unsafe function. +unsafe fn create_client_internal( + config: *const ConnectionConfig, success_callback: unsafe extern "C" fn(usize, *const c_char) -> (), failure_callback: unsafe extern "C" fn(usize) -> (), ) -> RedisResult { @@ -120,20 +136,24 @@ fn create_client_internal( } /// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool. +/// # Safety +/// Unsafe function because calling other unsafe function. #[no_mangle] -pub extern "C" fn create_client( - config: *const ConnectionRequest, +pub unsafe extern "C" fn create_client( + config: *const ConnectionConfig, success_callback: unsafe extern "C" fn(usize, *const c_char) -> (), failure_callback: unsafe extern "C" fn(usize) -> (), ) -> *const c_void { - match create_client_internal(config, success_callback, failure_callback) { + match unsafe { create_client_internal(config, success_callback, failure_callback) } { Err(_) => std::ptr::null(), // TODO - log errors Ok(client) => Box::into_raw(Box::new(client)) as *const c_void, } } +/// # Safety +/// Unsafe function because dereferencing a pointer. #[no_mangle] -pub extern "C" fn close_client(client_ptr: *const c_void) { +pub unsafe extern "C" fn close_client(client_ptr: *const c_void) { let client_ptr = unsafe { Box::from_raw(client_ptr as *mut Client) }; let _runtime_handle = client_ptr.runtime.enter(); drop(client_ptr);