diff --git a/bindings/rust/extended/s2n-tls/src/callbacks.rs b/bindings/rust/extended/s2n-tls/src/callbacks.rs index c1d476ec1b8..1a710673604 100644 --- a/bindings/rust/extended/s2n-tls/src/callbacks.rs +++ b/bindings/rust/extended/s2n-tls/src/callbacks.rs @@ -48,7 +48,7 @@ pub use pkey::*; /// callbacks were configured through the Rust bindings. pub(crate) unsafe fn with_context(conn_ptr: *mut s2n_connection, action: F) -> T where - F: FnOnce(&mut Connection, &mut Context) -> T, + F: FnOnce(&mut Connection, &Context) -> T, { let raw = NonNull::new(conn_ptr).expect("connection should not be null"); // Since this is a callback, it receives a pointer to the connection @@ -57,8 +57,8 @@ where // We must make the connection `ManuallyDrop` before `action`, otherwise a panic // in `action` would cause the unwind mechanism to drop the connection. let mut conn = ManuallyDrop::new(Connection::from_raw(raw)); - let mut config = conn.config().expect("config should not be null"); - let context = config.context_mut(); + let config = conn.config().expect("config should not be null"); + let context = config.context(); action(&mut conn, context) } diff --git a/bindings/rust/extended/s2n-tls/src/config.rs b/bindings/rust/extended/s2n-tls/src/config.rs index c289b1bf0a8..75d63bbab47 100644 --- a/bindings/rust/extended/s2n-tls/src/config.rs +++ b/bindings/rust/extended/s2n-tls/src/config.rs @@ -93,14 +93,16 @@ impl Config { /// Retrieve a mutable reference to the [`Context`] stored on the config. /// /// Corresponds to [s2n_config_get_ctx]. - pub(crate) fn context_mut(&mut self) -> &mut Context { + /// + /// SAFETY: There must only ever by mutable reference to `Context` alive at + /// any time. Configs can be shared across threads, so this method is + /// likely not correct for your usecase. + pub(crate) unsafe fn context_mut(&mut self) -> &mut Context { let mut ctx = core::ptr::null_mut(); - unsafe { - s2n_config_get_ctx(self.as_mut_ptr(), &mut ctx) - .into_result() - .unwrap(); - &mut *(ctx as *mut Context) - } + s2n_config_get_ctx(self.as_mut_ptr(), &mut ctx) + .into_result() + .unwrap(); + &mut *(ctx as *mut Context) } #[cfg(test)] @@ -135,7 +137,7 @@ impl Clone for Config { impl Drop for Config { /// Corresponds to [s2n_config_free]. fn drop(&mut self) { - let context = self.context_mut(); + let context = self.context(); let count = context.refcount.fetch_sub(1, Ordering::Release); debug_assert!(count > 0, "refcount should not drop below 1 instance"); @@ -158,18 +160,20 @@ impl Drop for Config { // https://github.com/rust-lang/rust/blob/e012a191d768adeda1ee36a99ef8b92d51920154/library/alloc/src/sync.rs#L1637 std::sync::atomic::fence(Ordering::Acquire); - unsafe { - // This is the last instance so free the context. - let context = Box::from_raw(context); - drop(context); + // This is the last instance so free the context. + let context = unsafe { + // SAFETY: The reference count is verified to be 1, so this is the + // last instance of the config, and the only reference to the context. + Box::from_raw(self.context_mut()) + }; + drop(context); - let _ = s2n_config_free(self.0.as_ptr()).into_result(); - } + let _ = unsafe { s2n_config_free(self.0.as_ptr()).into_result() }; } } pub struct Builder { - config: Config, + pub(crate) config: Config, load_system_certs: bool, enable_ocsp: bool, } @@ -334,7 +338,12 @@ impl Builder { ) .into_result() }; - self.context_mut().application_owned_certs.push(chain); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because the + // Builder owns the only reference to the config. + self.config.context_mut() + }; + context.application_owned_certs.push(chain); result?; Ok(self) @@ -373,9 +382,12 @@ impl Builder { let collected_chains = chain_arrays.into_iter().take(cert_chain_count).flatten(); - self.context_mut() - .application_owned_certs - .extend(collected_chains); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because the + // Builder owns the only reference to the config. + self.config.context_mut() + }; + context.application_owned_certs.extend(collected_chains); unsafe { s2n_config_set_cert_chain_and_key_defaults( @@ -547,12 +559,17 @@ impl Builder { verify_host(host_name, host_name_len, handler) } - self.context_mut().verify_host_callback = Some(Box::new(handler)); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because the + // Builder owns the only reference to the config. + self.config.context_mut() + }; + context.verify_host_callback = Some(Box::new(handler)); unsafe { s2n_config_set_verify_host_callback( self.as_mut_ptr(), Some(verify_host_cb_fn), - self.context_mut() as *mut Context as *mut c_void, + self.config.context() as *const Context as *mut c_void, ) .into_result()?; } @@ -607,7 +624,11 @@ impl Builder { } let handler = Box::new(handler); - let context = self.context_mut(); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because while + // it is being built, the Builder is the only reference to the config. + self.config.context_mut() + }; context.client_hello_callback = Some(handler); unsafe { @@ -628,7 +649,11 @@ impl Builder { ) -> Result<&mut Self, Error> { // Store callback in config context let handler = Box::new(handler); - let context = self.context_mut(); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because while + // it is being built, the Builder is the only reference to the config. + self.config.context_mut() + }; context.connection_initializer = Some(handler); Ok(self) } @@ -659,14 +684,18 @@ impl Builder { // Store callback in context let handler = Box::new(handler); - let context = self.context_mut(); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because while + // it is being built, the Builder is the only reference to the config. + self.config.context_mut() + }; context.session_ticket_callback = Some(handler); unsafe { s2n_config_set_session_ticket_cb( self.as_mut_ptr(), Some(session_ticket_cb), - self.context_mut() as *mut Context as *mut c_void, + self.config.context() as *const Context as *mut c_void, ) .into_result() }?; @@ -698,7 +727,11 @@ impl Builder { } let handler = Box::new(handler); - let context = self.context_mut(); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because while + // it is being built, the Builder is the only reference to the config. + self.config.context_mut() + }; context.private_key_callback = Some(handler); unsafe { @@ -734,13 +767,17 @@ impl Builder { } let handler = Box::new(handler); - let context = self.context_mut(); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because while + // it is being built, the Builder is the only reference to the config. + self.config.context_mut() + }; context.wall_clock = Some(handler); unsafe { s2n_config_set_wall_clock( self.as_mut_ptr(), Some(clock_cb), - self.context_mut() as *mut _ as *mut c_void, + self.config.context() as *const _ as *mut c_void, ) .into_result()?; } @@ -773,13 +810,17 @@ impl Builder { } let handler = Box::new(handler); - let context = self.context_mut(); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because while + // it is being built, the Builder is the only reference to the config. + self.config.context_mut() + }; context.monotonic_clock = Some(handler); unsafe { s2n_config_set_monotonic_clock( self.as_mut_ptr(), Some(clock_cb), - self.context_mut() as *mut _ as *mut c_void, + self.config.context() as *const _ as *mut c_void, ) .into_result()?; } @@ -913,17 +954,6 @@ impl Builder { pub(crate) fn as_mut_ptr(&mut self) -> *mut s2n_config { self.config.as_mut_ptr() } - - /// Retrieve a mutable reference to the [`Context`] stored on the config. - pub(crate) fn context_mut(&mut self) -> &mut Context { - let mut ctx = core::ptr::null_mut(); - unsafe { - s2n_config_get_ctx(self.as_mut_ptr(), &mut ctx) - .into_result() - .unwrap(); - &mut *(ctx as *mut Context) - } - } } #[cfg(feature = "quic")] diff --git a/bindings/rust/extended/s2n-tls/src/renegotiate.rs b/bindings/rust/extended/s2n-tls/src/renegotiate.rs index 215ca630c42..08f5dc8fe2a 100644 --- a/bindings/rust/extended/s2n-tls/src/renegotiate.rs +++ b/bindings/rust/extended/s2n-tls/src/renegotiate.rs @@ -41,7 +41,7 @@ //! //! impl RenegotiateCallback for Callback { //! fn on_renegotiate_request( -//! &mut self, +//! &self, //! conn: &mut Connection, //! ) -> Option { //! let response = match conn.server_name() { @@ -51,7 +51,7 @@ //! Some(response) //! } //! -//! fn on_renegotiate_wipe(&mut self, conn: &mut Connection) -> Result<(), Error> { +//! fn on_renegotiate_wipe(&self, conn: &mut Connection) -> Result<(), Error> { //! conn.set_application_protocol_preference(Some("http"))?; //! Ok(()) //! } @@ -155,23 +155,20 @@ pub trait RenegotiateCallback: 'static + Send + Sync { // // This method returns Option instead of Result because the callback has no mechanism // for surfacing errors to the application, so Result would be somewhat deceptive. - fn on_renegotiate_request( - &mut self, - connection: &mut Connection, - ) -> Option; + fn on_renegotiate_request(&self, connection: &mut Connection) -> Option; /// A callback that triggers after the connection is wiped for renegotiation. /// /// Because renegotiation requires wiping the connection, connection-level /// configuration will need to be set again via this callback. /// See [`Connection::wipe_for_renegotiate()`] for more information. - fn on_renegotiate_wipe(&mut self, _connection: &mut Connection) -> Result<(), Error> { + fn on_renegotiate_wipe(&self, _connection: &mut Connection) -> Result<(), Error> { Ok(()) } } impl RenegotiateCallback for RenegotiateResponse { - fn on_renegotiate_request(&mut self, _conn: &mut Connection) -> Option { + fn on_renegotiate_request(&self, _conn: &mut Connection) -> Option { Some(*self) } } @@ -248,8 +245,8 @@ impl Connection { // We trigger the callback last so that the application can modify any // preserved configuration (like the server name or waker) if necessary. - if let Some(mut config) = self.config() { - if let Some(callback) = config.context_mut().renegotiate.as_mut() { + if let Some(config) = self.config() { + if let Some(callback) = config.context().renegotiate.as_ref() { callback.on_renegotiate_wipe(self)?; } } @@ -390,7 +387,7 @@ impl config::Builder { response: *mut s2n_renegotiate_response::Type, ) -> libc::c_int { with_context(conn_ptr, |conn, context| { - let callback = context.renegotiate.as_mut(); + let callback = context.renegotiate.as_ref(); if let Some(callback) = callback { if let Some(result) = callback.on_renegotiate_request(conn) { // If the callback indicates renegotiation, schedule it. @@ -408,7 +405,11 @@ impl config::Builder { } let handler = Box::new(handler); - let context = self.context_mut(); + let context = unsafe { + // SAFETY: usage of context_mut is safe in the builder, because while + // it is being built, the Builder is the only reference to the config. + self.config.context_mut() + }; context.renegotiate = Some(handler); unsafe { s2n_config_set_renegotiate_request_cb( @@ -717,10 +718,7 @@ mod tests { fn error_callback() -> Result<(), Box> { struct ErrorRenegotiateCallback {} impl RenegotiateCallback for ErrorRenegotiateCallback { - fn on_renegotiate_request( - &mut self, - _: &mut Connection, - ) -> Option { + fn on_renegotiate_request(&self, _: &mut Connection) -> Option { None } }