diff --git a/.changelog/5879.feature.md b/.changelog/5879.feature.md new file mode 100644 index 00000000000..f1f23c0ea6f --- /dev/null +++ b/.changelog/5879.feature.md @@ -0,0 +1 @@ +runtime/src/enclave_rpc/client: Support concurrent sessions diff --git a/keymanager/src/client/remote.rs b/keymanager/src/client/remote.rs index a8fbddfc294..96f49ff2f4c 100644 --- a/keymanager/src/client/remote.rs +++ b/keymanager/src/client/remote.rs @@ -64,6 +64,14 @@ use super::KeyManagerClient; /// Key manager RPC endpoint. const KEY_MANAGER_ENDPOINT: &str = "key-manager"; +/// Maximum total number of EnclaveRPC sessions. +const RPC_MAX_SESSIONS: usize = 32; +/// Maximum concurrent EnclaveRPC sessions per peer. In case more sessions are open, old sessions +/// will be closed to make room for new sessions. +const RPC_MAX_SESSIONS_PER_PEER: usize = 2; +/// EnclaveRPC sessions without any processed frame for more than RPC_STALE_SESSION_TIMEOUT_SECS +/// seconds can be closed to make room for new sessions. +const RPC_STALE_SESSION_TIMEOUT_SECS: i64 = 10; /// A key manager client which talks to a remote key manager enclave. pub struct RemoteClient { @@ -125,17 +133,22 @@ impl RemoteClient { identity: Arc, keys_cache_sizes: usize, ) -> Self { + let builder = session::Builder::default() + .remote_enclaves(enclaves) + .quote_policy(policy) + .local_identity(identity) + .consensus_verifier(Some(consensus_verifier.clone())) + .remote_runtime_id(km_runtime_id); + Self::new( runtime_id, RpcClient::new_runtime( - session::Builder::default() - .remote_enclaves(enclaves) - .quote_policy(policy) - .local_identity(identity) - .consensus_verifier(Some(consensus_verifier.clone())) - .remote_runtime_id(km_runtime_id), protocol, KEY_MANAGER_ENDPOINT, + builder, + RPC_MAX_SESSIONS, + RPC_MAX_SESSIONS_PER_PEER, + RPC_STALE_SESSION_TIMEOUT_SECS, ), consensus_verifier, keys_cache_sizes, @@ -189,7 +202,7 @@ impl RemoteClient { } /// Set allowed enclaves and runtime signing key from key manager status. - pub fn set_status(&self, status: KeyManagerStatus) -> Result<(), KeyManagerError> { + pub async fn set_status(&self, status: KeyManagerStatus) -> Result<(), KeyManagerError> { // Set runtime signing key. if let Some(rsk) = status.rsk { self.rsk.write().unwrap().replace(rsk); @@ -197,7 +210,7 @@ impl RemoteClient { // Set key manager runtime ID. *self.key_manager_id.write().unwrap() = Some(status.id); - self.rpc_client.update_runtime_id(Some(status.id)); + self.rpc_client.update_runtime_id(Some(status.id)).await; // Verify and apply the policy, if set. let untrusted_policy = match status.policy { @@ -211,15 +224,15 @@ impl RemoteClient { if !Policy::unsafe_skip() { let enclaves: HashSet = HashSet::from_iter(policy.enclaves.keys().cloned()); - self.rpc_client.update_enclaves(Some(enclaves)); + self.rpc_client.update_enclaves(Some(enclaves)).await; } Ok(()) } /// Set key manager's quote policy. - pub fn set_quote_policy(&self, policy: QuotePolicy) { - self.rpc_client.update_quote_policy(policy); + pub async fn set_quote_policy(&self, policy: QuotePolicy) { + self.rpc_client.update_quote_policy(policy).await; } fn verify_public_key( diff --git a/runtime/src/enclave_rpc/client.rs b/runtime/src/enclave_rpc/client.rs index e7b28c14297..07531a6cdb0 100644 --- a/runtime/src/enclave_rpc/client.rs +++ b/runtime/src/enclave_rpc/client.rs @@ -1,36 +1,37 @@ //! Enclave RPC client. use std::{ collections::HashSet, - mem, sync::{ atomic::{AtomicU32, Ordering}, Arc, }, }; +use futures::stream::{FuturesUnordered, StreamExt}; use lazy_static::lazy_static; #[cfg(not(test))] use rand::{rngs::OsRng, RngCore}; + use thiserror::Error; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::OwnedMutexGuard; use crate::{ common::{ crypto::signature, namespace::Namespace, sgx::{EnclaveIdentity, QuotePolicy}, + time::insecure_posix_time, }, - enclave_rpc::{ - session::{Builder, Session}, - types, - }, + enclave_rpc::{session::Builder, types}, + future::block_on, protocol::Protocol, }; -use super::transport::{RuntimeTransport, Transport}; +use super::{ + sessions::{self, MultiplexedSession, Sessions, SharedSession}, + transport::{RuntimeTransport, Transport}, +}; -/// Internal command queue backlog. -const CMDQ_BACKLOG: usize = 32; /// Maximum number of retries on transport errors. const MAX_TRANSPORT_ERROR_RETRIES: usize = 3; @@ -56,369 +57,20 @@ pub enum RpcClientError { Dropped, #[error("decode error: {0}")] DecodeError(#[from] cbor::DecodeError), + #[error("sessions error: {0}")] + SessionsError(#[from] sessions::Error), #[error("unknown error: {0}")] Unknown(#[from] anyhow::Error), } -/// A command sent to the client controller task. -#[derive(Debug)] -enum Command { - Call( - types::Request, - types::Kind, - Vec, - oneshot::Sender>, - ), - PeerFeedback(u64, types::PeerFeedback, types::Kind), - UpdateEnclaves(Option>), - UpdateQuotePolicy(QuotePolicy), - UpdateRuntimeID(Option), - #[cfg(test)] - Ping(oneshot::Sender<()>), -} - -struct MultiplexedSession { - /// Session builder for resetting sessions. - builder: Builder, - /// Unique session identifier. - id: types::SessionID, - /// Current underlying protocol session. - inner: Session, -} - -impl MultiplexedSession { - fn new(builder: Builder) -> Self { - Self { - builder: builder.clone(), - id: types::SessionID::random(), - inner: builder.build_initiator(), - } - } - - fn reset(&mut self) { - self.id = types::SessionID::random(); - self.inner = self.builder.clone().build_initiator(); - } -} - -struct Controller { - /// Multiplexed session. - session: MultiplexedSession, - /// Used transport. - transport: Box, - /// Internal command queue (receiver part). - cmdq: mpsc::Receiver, - /// The ID of the client. - client_id: u32, - /// The total number of requests sent. - sent_request_count: u32, -} - -impl Controller { - async fn run(mut self) { - while let Some(cmd) = self.cmdq.recv().await { - match cmd { - Command::Call(request, kind, nodes, sender) => { - self.call(request, kind, nodes, sender).await - } - Command::PeerFeedback(request_id, peer_feedback, kind) => { - let _ = self - .transport - .submit_peer_feedback(request_id, peer_feedback) - .await; // Ignore error. - - // In case the peer feedback is bad, reset the session so a new peer can be - // selected for a subsequent session. - if !matches!(peer_feedback, types::PeerFeedback::Success) - && kind == types::Kind::NoiseSession - { - self.reset().await; - } - } - Command::UpdateEnclaves(enclaves) => { - if self.session.builder.get_remote_enclaves() == &enclaves { - continue; - } - - self.session.builder = - mem::take(&mut self.session.builder).remote_enclaves(enclaves); - self.reset().await; - } - Command::UpdateQuotePolicy(policy) => { - let policy = Some(Arc::new(policy)); - if self.session.builder.get_quote_policy() == &policy { - continue; - } - - self.session.builder = - mem::take(&mut self.session.builder).quote_policy(policy); - self.reset().await; - } - Command::UpdateRuntimeID(id) => { - if self.session.builder.get_remote_runtime_id() == &id { - continue; - } - - self.session.builder = - mem::take(&mut self.session.builder).remote_runtime_id(id); - self.reset().await; - } - #[cfg(test)] - Command::Ping(sender) => { - let _ = sender.send(()); - } - } - } - - // Close stream after the client is dropped. - let _ = self.close().await; - } - - async fn call( - &mut self, - request: types::Request, - kind: types::Kind, - nodes: Vec, - sender: oneshot::Sender>, - ) { - let result = async { - match kind { - types::Kind::NoiseSession => { - // Attempt to establish a connection. This will not do anything in case the - // session has already been established. - self.connect(nodes).await?; - - // Perform the call. - self.secure_call_raw(request).await - } - types::Kind::InsecureQuery => { - // Perform the call. - self.insecure_call_raw(request, nodes).await - } - _ => Err(RpcClientError::UnsupportedRpcKind), - } - } - .await; - - let request_id = self.get_request_id(); - - if result.is_err() { - // Set peer feedback immediately so retries can try new peers. - let _ = self - .transport - .submit_peer_feedback(request_id, types::PeerFeedback::Failure) - .await; // Ignore error. - - // In case there was a transport error we need to reset the session immediately as no - // progress is possible. - if kind == types::Kind::NoiseSession { - self.reset().await; - } - } - - let _ = sender.send(result.map(|rsp| (request_id, rsp))); - } - - async fn connect(&mut self, nodes: Vec) -> Result<(), RpcClientError> { - // No need to create a new session if we are connected to one of the nodes. - if self.session.inner.is_connected() - && (nodes.is_empty() || self.session.inner.is_connected_to(&nodes)) - { - return Ok(()); - } - // Make sure the session is reset for a new connection. - self.reset().await; - - // Handshake1 -> Handshake2 - let mut buffer = vec![]; - self.session - .inner - .process_data(vec![], &mut buffer) - .await - .expect("initiation must always succeed"); - let session_id = self.session.id; - - let request_id = self.increment_request_id(); - - let rsp = self - .transport - .write_noise_session(request_id, session_id, buffer, String::new(), nodes) - .await - .map_err(|_| RpcClientError::Transport)?; - - // Update the session with the identity of the remote node. The latter still needs to be - // verified using the RAK from the consensus layer. - self.session.inner.set_remote_node(rsp.node)?; - - // Handshake2 -> Transport - let mut buffer = vec![]; - self.session - .inner - .process_data(rsp.data, &mut buffer) - .await - .map_err(|_| RpcClientError::Transport)?; - - let _ = self - .transport - .submit_peer_feedback(request_id, types::PeerFeedback::Success) - .await; // Ignore error. - - let request_id = self.increment_request_id(); - - self.transport - .write_noise_session( - request_id, - session_id, - buffer, - String::new(), - vec![rsp.node], - ) - .await - .map_err(|_| RpcClientError::Transport)?; - - // Check if the session has failed authentication. In this case, notify the other side - // (returning an error here will do that in `call`). - if self.session.inner.is_unauthenticated() { - return Err(RpcClientError::Transport); - } - - let _ = self - .transport - .submit_peer_feedback(request_id, types::PeerFeedback::Success) - .await; // Ignore error. - - Ok(()) - } - - async fn secure_call_raw( - &mut self, - request: types::Request, - ) -> Result { - let method = request.method.clone(); - let msg = types::Message::Request(request); - - // Prepare the request message. - let mut buffer = vec![]; - self.session - .inner - .write_message(msg, &mut buffer) - .map_err(|_| RpcClientError::Transport)?; - let node = self.session.inner.get_node()?; - - // Send the request and receive the response. - let request_id = self.increment_request_id(); - - let rsp = self - .transport - .write_noise_session(request_id, self.session.id, buffer, method, vec![node]) - .await - .map_err(|_| RpcClientError::Transport)?; - - // Process the response. - let msg = self - .session - .inner - .process_data(rsp.data, vec![]) - .await? - .expect("message must be decoded if there is no error"); - - match msg { - types::Message::Response(rsp) => Ok(rsp), - msg => Err(RpcClientError::ExpectedResponseMessage(msg)), - } - } - - async fn insecure_call_raw( - &mut self, - request: types::Request, - nodes: Vec, - ) -> Result { - let request_id = self.increment_request_id(); - - let rsp = self - .transport - .write_insecure_query(request_id, cbor::to_vec(request), nodes) - .await - .map_err(|_| RpcClientError::Transport)?; - - cbor::from_slice(&rsp.data).map_err(RpcClientError::DecodeError) - } - - async fn reset(&mut self) { - // Notify the other end (if any) of session closure. - let _ = self.close_notify().await; - // Reset the session. - self.session.reset(); - } - - async fn close_notify(&mut self) -> Result, RpcClientError> { - let node = self.session.inner.get_node()?; - - let mut buffer = vec![]; - self.session - .inner - .write_message(types::Message::Close, &mut buffer) - .map_err(|_| RpcClientError::Transport)?; - - let request_id = self.increment_request_id(); - - self.transport - .write_noise_session( - request_id, - self.session.id, - buffer, - String::new(), - vec![node], - ) - .await - .map_err(|_| RpcClientError::Transport) - .map(|rsp| rsp.data) - - // Skipping peer feedback, as the request was sent only to inform - // the other side of a graceful session close. - } - - async fn close(&mut self) -> Result<(), RpcClientError> { - if !self.session.inner.is_connected() { - return Ok(()); - } - - let data = self.close_notify().await?; - - // Close the session and check the received message. - let msg = self - .session - .inner - .process_data(data, vec![]) - .await? - .expect("message must be decoded if there is no error"); - self.session.inner.close(); - - match msg { - types::Message::Close => Ok(()), - msg => Err(RpcClientError::ExpectedCloseMessage(msg)), - } - } - - fn get_request_id(&self) -> u64 { - ((self.client_id as u64) << 32) + (self.sent_request_count as u64) - } - - fn increment_request_id(&mut self) -> u64 { - self.sent_request_count = self.sent_request_count.wrapping_add(1); - self.get_request_id() - } -} - /// An EnclaveRPC response that can be used to provide peer feedback. -pub struct Response { - inner: Result, - kind: types::Kind, - cmdq: mpsc::WeakSender, +pub struct Response<'a, T> { + transport: &'a dyn Transport, request_id: Option, + inner: Result, } -impl Response { +impl<'a, T> Response<'a, T> { /// Report success if result was `Ok(_)` and failure if result was `Err(_)`, then return the /// inner result consuming the response instance. pub async fn into_result_with_feedback(mut self) -> Result { @@ -456,48 +108,101 @@ impl Response { } /// Send peer feedback. - async fn send_peer_feedback(&mut self, pf: types::PeerFeedback) { + async fn send_peer_feedback(&mut self, feedback: types::PeerFeedback) { if let Some(request_id) = self.request_id.take() { // Only count feedback once. - if let Some(cmdq) = self.cmdq.upgrade() { - let _ = cmdq - .send(Command::PeerFeedback(request_id, pf, self.kind)) - .await; - } + let _ = self + .transport + .submit_peer_feedback(request_id, feedback) + .await; // Ignore error. } } } /// RPC client. pub struct RpcClient { - /// Internal command queue (sender part). - cmdq: mpsc::Sender, + /// Used transport. + transport: Box, + /// Multiplexed sessions. + sessions: tokio::sync::Mutex>, + /// The ID of the client. + client_id: u32, + /// The ID of the next transport request. + next_request_id: AtomicU32, } impl RpcClient { - fn new(transport: Box, builder: Builder) -> Self { - // Create the command channel. - let (tx, rx) = mpsc::channel(CMDQ_BACKLOG); - - // Ensure every client has a unique ID. + fn new( + transport: Box, + builder: Builder, + max_sessions: usize, + max_sessions_per_peer: usize, + stale_session_timeout: i64, + ) -> Self { + // Assign a unique ID to each client to avoid overlapping request IDs. let client_id = NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst); // Wraps if overflows. + let next_request_id = AtomicU32::new(1); - // Create the controller task and start it. - let controller = Controller { - session: MultiplexedSession::new(builder), + let sessions = tokio::sync::Mutex::new(Sessions::new( + builder, + max_sessions, + max_sessions_per_peer, + stale_session_timeout, + )); + + Self { transport, - cmdq: rx, + sessions, client_id, - sent_request_count: 0, + next_request_id, + } + } + + /// Construct an unconnected RPC client with runtime-internal transport. + pub fn new_runtime( + protocol: Arc, + endpoint: &str, + builder: Builder, + max_sessions: usize, + max_sessions_per_peer: usize, + stale_session_timeout: i64, + ) -> Self { + let transport = Box::new(RuntimeTransport::new(protocol, endpoint)); + + Self::new( + transport, + builder, + max_sessions, + max_sessions_per_peer, + stale_session_timeout, + ) + } + + /// Update allowed remote enclave identities. + pub async fn update_enclaves(&self, enclaves: Option>) { + let sessions = { + let mut sessions = self.sessions.lock().await; + sessions.update_enclaves(enclaves) }; - tokio::spawn(controller.run()); + self.close_all(sessions).await; + } - Self { cmdq: tx } + /// Update remote end's quote policy. + pub async fn update_quote_policy(&self, policy: QuotePolicy) { + let sessions = { + let mut sessions = self.sessions.lock().await; + sessions.update_quote_policy(policy) + }; + self.close_all(sessions).await; } - /// Construct an unconnected RPC client with runtime-internal transport. - pub fn new_runtime(builder: Builder, protocol: Arc, endpoint: &str) -> Self { - Self::new(Box::new(RuntimeTransport::new(protocol, endpoint)), builder) + /// Update remote runtime id. + pub async fn update_runtime_id(&self, id: Option) { + let sessions = { + let mut sessions = self.sessions.lock().await; + sessions.update_runtime_id(id) + }; + self.close_all(sessions).await; } /// Call a remote method using an encrypted and authenticated Noise session. @@ -572,10 +277,9 @@ impl RpcClient { }; Response { - inner, - kind, - cmdq: self.cmdq.downgrade(), + transport: &*self.transport, request_id, + inner, } } @@ -585,48 +289,328 @@ impl RpcClient { kind: types::Kind, nodes: Vec, ) -> Result<(u64, types::Response), RpcClientError> { - let (tx, rx) = oneshot::channel(); - self.cmdq - .send(Command::Call(request, kind, nodes, tx)) + match kind { + types::Kind::NoiseSession => { + // Attempt to establish a connection. This will not do anything in case the + // session has already been established. + let session = self.connect(nodes).await?; + let mut session = session.lock_owned().await; + + // Perform the call. + let result = self.secure_call_raw(request, &mut session).await; + + // In case there was a transport error we need to remove the session immediately + // as no progress is possible. The next call should select another peer or + // the same peer but another session. + if result.is_err() { + let mut sessions = self.sessions.lock().await; + sessions.remove(&session); + } + + result + } + types::Kind::InsecureQuery => { + // Perform the call. + self.insecure_call_raw(request, nodes).await + } + _ => Err(RpcClientError::UnsupportedRpcKind), + } + } + + async fn connect( + &self, + nodes: Vec, + ) -> Result, RpcClientError> { + // Create a new session. + let mut session = { + let mut sessions = self.sessions.lock().await; + + // No need to create a new session if we are connected to one of the nodes. + if let Some(session) = sessions.find(&nodes) { + return Ok(session); + } + + // Since the peer ID is not yet known, use the default value and set it later. + let peer_id = Default::default(); + sessions.create_initiator(peer_id) + }; + + // Copy session ID to avoid moved value errors. + let session_id = *session.get_session_id(); + + // Prepare buffers upfront. + let mut buffer1 = vec![]; + let mut buffer2 = vec![]; + + // Session Handshake1: prepare initialization request. + session + .process_data(&[], &mut buffer1) .await - .map_err(|_| RpcClientError::Dropped)?; + .expect("initiation must always succeed"); - rx.await.map_err(|_| RpcClientError::Dropped)? + let request_id = self.next_request_id(); + let result: Result<_, RpcClientError> = async { + // Transport: send initialization request and receive a response. + let rsp = self + .transport + .write_noise_session(request_id, session_id, buffer1, String::new(), nodes) + .await + .map_err(|_| RpcClientError::Transport)?; + + // Update the session with unverified identity of the remote node. + // The identity will be verified in Handshake2 using the RAK from + // the consensus layer. + session.set_peer_id(rsp.node); + session + .set_remote_node(rsp.node) + .expect("remote node should not be set"); + + // Session Handshake2: process initialization response, verify + // remote node, and prepare the next request containing RAK binding. + let _ = session + .process_data(&rsp.data, &mut buffer2) + .await + .map_err(|_| RpcClientError::Transport)?; + + Ok(rsp) + } + .await; + + // Submit peer feedback for the last transport and the received + // initialization response. + let feedback = match result { + Ok(_) => types::PeerFeedback::Success, + Err(_) => types::PeerFeedback::Failure, + }; + let _ = self + .transport + .submit_peer_feedback(request_id, feedback) + .await; // Ignore error. + + // Forward error after peer feedback is sent. + let rsp = result?; + + let request_id = self.next_request_id(); + let result = async { + // Transport: send RAK binding request. + let rsp = self + .transport + .write_noise_session( + request_id, + session_id, + buffer2, + String::new(), + vec![rsp.node], + ) + .await + .map_err(|_| RpcClientError::Transport)?; + + if session.is_unauthenticated() { + return Err(RpcClientError::Transport); + } + + Ok(rsp) + } + .await; + + // Submit peer feedback for the last transport and session + // authentication. + let feedback = match result { + Ok(_) => types::PeerFeedback::Success, + Err(_) => types::PeerFeedback::Failure, + }; + let _ = self + .transport + .submit_peer_feedback(request_id, feedback) + .await; // Ignore error. + + // Forward error after peer feedback is sent. + if let Err(err) = result { + // Failed to complete handshake. Gracefully close the session. + let session = Arc::new(tokio::sync::Mutex::new(session)) + .lock_owned() + .await; + let _ = self.close(session).await; // Ignore error. + + return Err(err); + } + + // The connection has been successfully established. The session can + // be added to the set of active sessions if there is space available, + // or if we can make space by removing a stale session. + let now = insecure_posix_time(); + let mut sessions = self.sessions.lock().await; + let maybe_removed_session = match sessions.remove_for(&rsp.node, now) { + Ok(maybe_removed_session) => maybe_removed_session, + Err(err) => { + // Unable to make space. Gracefully close the session. + drop(sessions); // Unlock. + + let session = Arc::new(tokio::sync::Mutex::new(session)) + .lock_owned() + .await; + let _ = self.close(session).await; // Ignore error. + + return Err(err.into()); + } + }; + let session = sessions + .add(session, now) + .expect("there should be space for the new session"); + + if let Some(removed_session) = maybe_removed_session { + // A stale session was removed. Gracefully close the removed session. + drop(sessions); // Unlock. + + let _ = self.close(removed_session).await; // Ignore error. + } + + Ok(session) } - /// Update allowed remote enclave identities. - /// - /// Useful if the key manager's policy has changed. - /// - /// # Panics - /// - /// This function panics if called within an asynchronous execution context. - pub fn update_enclaves(&self, enclaves: Option>) { - self.cmdq - .blocking_send(Command::UpdateEnclaves(enclaves)) - .unwrap(); + async fn secure_call_raw( + &self, + request: types::Request, + session: &mut OwnedMutexGuard>, + ) -> Result<(u64, types::Response), RpcClientError> { + let method = request.method.clone(); + let msg = types::Message::Request(request); + let session_id = *session.get_session_id(); + + // Session Transport: prepare the request message. + let mut buffer = vec![]; + session + .write_message(msg, &mut buffer) + .map_err(|_| RpcClientError::Transport)?; + let node = session.get_remote_node()?; + + let request_id = self.next_request_id(); + let result = async { + // Transport: send the request and receive a response. + let rsp = self + .transport + .write_noise_session(request_id, session_id, buffer, method, vec![node]) + .await + .map_err(|_| RpcClientError::Transport)?; + + // Session Transport: process the response. + session.process_data(&rsp.data, vec![]).await + } + .await; + + // Submit negative peer feedback for the last transport + // and the received response immediately. + if result.is_err() { + let _ = self + .transport + .submit_peer_feedback(request_id, types::PeerFeedback::Failure) + .await; // Ignore error. + } + + // Forward error after peer feedback is sent. + let maybe_msg = result?; + + // Unwrap response. + let msg = maybe_msg.expect("message must be decoded if there is no error"); + let rsp = match msg { + types::Message::Response(rsp) => rsp, + msg => return Err(RpcClientError::ExpectedResponseMessage(msg)), + }; + + Ok((request_id, rsp)) } - /// Update key manager's quote policy. - /// - /// # Panics - /// - /// This function panics if called within an asynchronous execution context. - pub fn update_quote_policy(&self, policy: QuotePolicy) { - self.cmdq - .blocking_send(Command::UpdateQuotePolicy(policy)) - .unwrap(); + async fn insecure_call_raw( + &self, + request: types::Request, + nodes: Vec, + ) -> Result<(u64, types::Response), RpcClientError> { + // Transport: send the request. + let request_id = self.next_request_id(); + let result = self + .transport + .write_insecure_query(request_id, cbor::to_vec(request), nodes) + .await + .map_err(|_| RpcClientError::Transport); + + // Submit negative peer feedback for the last transport immediately. + if result.is_err() { + let _ = self + .transport + .submit_peer_feedback(request_id, types::PeerFeedback::Failure) + .await; // Ignore error. + } + + // Forward error after peer feedback is sent. + let rsp = result?; + + // Unwrap response. + let rsp = cbor::from_slice(&rsp.data).map_err(RpcClientError::DecodeError)?; + + Ok((request_id, rsp)) } - /// Update remote runtime id. - /// - /// # Panics - /// - /// This function panics if called within an asynchronous execution context. - pub fn update_runtime_id(&self, id: Option) { - self.cmdq - .blocking_send(Command::UpdateRuntimeID(id)) - .unwrap(); + /// Close the session. + async fn close( + &self, + mut session: OwnedMutexGuard>, + ) -> Result<(), RpcClientError> { + if !session.is_connected() && !session.is_unauthenticated() { + return Ok(()); + } + + let session_id = *session.get_session_id(); + let node = session.get_remote_node()?; + + // Session Transport: prepare close request. + let mut buffer = vec![]; + session + .write_message(types::Message::Close, &mut buffer) + .map_err(|_| RpcClientError::Transport)?; + + // Transport: send close request. + let request_id = self.next_request_id(); + let rsp = self + .transport + .write_noise_session(request_id, session_id, buffer, String::new(), vec![node]) + .await + .map_err(|_| RpcClientError::Transport)?; + + // Skipping peer feedback, as the request was sent only to inform + // the other side of a graceful session close. + + // Session Transport: process the response. + let msg = session + .process_data(&rsp.data, vec![]) + .await? + .expect("message must be decoded if there is no error"); + + // Close the session. + session.close(); + + match msg { + types::Message::Close => Ok(()), + msg => Err(RpcClientError::ExpectedCloseMessage(msg)), + } + } + + /// Close all sessions. + async fn close_all(&self, sessions: Vec>) { + let futures = FuturesUnordered::new(); + for session in sessions { + let future = async { + let locked_session = session.lock_owned().await; + let _ = self.close(locked_session).await; // Ignore errors. + }; + futures.push(future); + } + futures.collect::<()>().await; + } + + /// Return the ID of the next transport request. + fn next_request_id(&self) -> u64 { + let next_request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst); // Wraps if overflows. + ((self.client_id as u64) << 32) + (next_request_id as u64) } /// Generate a random client ID. @@ -637,17 +621,18 @@ impl RpcClient { #[cfg(not(test))] OsRng.next_u32() } +} - /// Wait for the controller to process all queued messages. - #[cfg(test)] - async fn flush_cmd_queue(&self) -> Result<(), RpcClientError> { - let (tx, rx) = oneshot::channel(); - self.cmdq - .send(Command::Ping(tx)) - .await - .map_err(|_| RpcClientError::Dropped)?; - - rx.await.map_err(|_| RpcClientError::Dropped) +impl Drop for RpcClient { + fn drop(&mut self) { + // Close all sessions after the client is dropped. + block_on(async { + let sessions = { + let mut sessions = self.sessions.lock().await; + sessions.drain() + }; + self.close_all(sessions).await; + }); } } @@ -727,18 +712,23 @@ mod test { match message { Some(message) => { + let mut buffer = Vec::new(); + // Message, process and write reply. - let body = match message { + match message { types::Message::Request(rq) => { // Just echo back what was given. - types::Body::Success(rq.args) + let response = types::Message::Response(types::Response { + body: types::Body::Success(rq.args), + }); + + session.write_message(response, &mut buffer)?; + } + types::Message::Close => { + self.demux.close(session, &mut buffer)?; } _ => panic!("unhandled message type"), }; - let response = types::Message::Response(types::Response { body }); - - let mut buffer = Vec::new(); - session.write_message(response, &mut buffer)?; let rsp = EnclaveResponse { data: buffer, @@ -793,7 +783,7 @@ mod test { let _guard = rt.enter(); // Ensure Tokio runtime is available. let transport = MockTransport::new(); let builder = session::Builder::default(); - let client = RpcClient::new(Box::new(transport.clone()), builder); + let client = RpcClient::new(Box::new(transport.clone()), builder, 8, 2, 60); // Basic secure call. let result: u64 = rt @@ -805,7 +795,6 @@ mod test { .await }) .unwrap(); - rt.block_on(client.flush_cmd_queue()).unwrap(); // Flush cmd queue to get peer feedback. assert_eq!(result, 42, "secure call should work"); assert_eq!( transport.take_peer_feedback_history(), @@ -828,7 +817,6 @@ mod test { .await }) .unwrap(); - rt.block_on(client.flush_cmd_queue()).unwrap(); // Flush cmd queue to get peer feedback. assert_eq!(result, 43, "secure call should work"); assert_eq!( transport.take_peer_feedback_history(), @@ -853,16 +841,15 @@ mod test { .await }) .unwrap(); - rt.block_on(client.flush_cmd_queue()).unwrap(); // Flush cmd queue to get peer feedback. assert_eq!(result, 44, "secure call should work"); assert_eq!( transport.take_peer_feedback_history(), vec![ (8, types::PeerFeedback::Failure), // Handshake failed due to induced error. // (9, types::PeerFeedback::Failure), // Session close failed due to decrypt error (handshake not completed). [skipped] + (9, types::PeerFeedback::Success), // New handshake. (10, types::PeerFeedback::Success), // New handshake. - (11, types::PeerFeedback::Success), // New handshake. - (12, types::PeerFeedback::Success), // Handled call. + (11, types::PeerFeedback::Success), // Handled call. ] ); @@ -876,12 +863,11 @@ mod test { .await }) .unwrap(); - rt.block_on(client.flush_cmd_queue()).unwrap(); // Flush cmd queue to get peer feedback. assert_eq!(result, 45, "insecure call should work"); assert_eq!( transport.take_peer_feedback_history(), vec![ - (13, types::PeerFeedback::Success), // Handled call. + (12, types::PeerFeedback::Success), // Handled call. ] ); @@ -897,13 +883,12 @@ mod test { .await }) .unwrap(); - rt.block_on(client.flush_cmd_queue()).unwrap(); // Flush cmd queue to get peer feedback. assert_eq!(result, 46, "insecure call should work"); assert_eq!( transport.take_peer_feedback_history(), vec![ - (14, types::PeerFeedback::Failure), // Failed call due to induced error. - (15, types::PeerFeedback::Success), // Handled call. + (13, types::PeerFeedback::Failure), // Failed call due to induced error. + (14, types::PeerFeedback::Success), // Handled call. ] ); } diff --git a/runtime/src/enclave_rpc/demux.rs b/runtime/src/enclave_rpc/demux.rs index 4c10eaf08de..7dc7eb4c0bc 100644 --- a/runtime/src/enclave_rpc/demux.rs +++ b/runtime/src/enclave_rpc/demux.rs @@ -1,15 +1,12 @@ //! Session demultiplexer. -use std::{ - collections::{BTreeSet, HashMap}, - io::Write, - sync::{Arc, Mutex}, -}; +use std::{io::Write, sync::Mutex}; use thiserror::Error; use tokio::sync::OwnedMutexGuard; use super::{ - session::{Builder, Session, SessionInfo}, + session::Builder, + sessions::{self, MultiplexedSession, Sessions}, types::{Frame, Message, SessionID}, }; use crate::common::time::insecure_posix_time; @@ -21,8 +18,8 @@ pub enum Error { MalformedPayload(#[from] cbor::DecodeError), #[error("malformed request method")] MalformedRequestMethod, - #[error("max concurrent sessions reached")] - MaxConcurrentSessions, + #[error("sessions error: {0}")] + SessionsError(#[from] sessions::Error), #[error("{0}")] Other(#[from] anyhow::Error), } @@ -32,7 +29,7 @@ impl Error { match self { Error::MalformedPayload(_) => 1, Error::MalformedRequestMethod => 2, - Error::MaxConcurrentSessions => 3, + Error::SessionsError(_) => 3, Error::Other(_) => 4, } } @@ -48,252 +45,9 @@ impl From for crate::types::Error { } } -/// Peer identifier. -type PeerID = Vec; - -/// Shared pointer to a multiplexed session. -type SharedSession = Arc>; - -/// Key for use in the by-idle-time index. -type SessionByTimeKey = (i64, PeerID, SessionID); - -/// Structure used for session accounting. -struct SessionMeta { - /// Peer identifier. - peer_id: PeerID, - /// Session identifier. - session_id: SessionID, - /// Timestamp when the session was last accessed. - last_access_time: i64, - /// The shared session pointer that needs to be locked for access. - inner: SharedSession, -} - -impl SessionMeta { - /// Key for ordering in the by-idle-time index. - fn by_time_key(&self) -> SessionByTimeKey { - (self.last_access_time, self.peer_id.clone(), self.session_id) - } -} - -/// Session indices and management operations. -struct Sessions { - /// Session builder. - builder: Builder, - /// Maximum number of sessions. - max_sessions: usize, - /// Maximum number of sessions per peer. - max_sessions_per_peer: usize, - /// Stale session timeout (in seconds). - stale_session_timeout: i64, - - /// A map of sessions for each peer. - by_peer: HashMap>, - /// A set of all sessions, ordered by idle time. - by_idle_time: BTreeSet, -} - -impl Sessions { - /// Create a new session management instance. - fn new( - builder: Builder, - max_sessions: usize, - max_sessions_per_peer: usize, - stale_session_timeout: i64, - ) -> Self { - Self { - builder, - max_sessions, - max_sessions_per_peer, - stale_session_timeout, - by_peer: HashMap::new(), - by_idle_time: BTreeSet::new(), - } - } - - /// Create a new multiplexed session. - fn create_session( - mut builder: Builder, - peer_id: PeerID, - session_id: SessionID, - now: i64, - ) -> SessionMeta { - // If no quote policy is set, use the local one. - if builder.get_quote_policy().is_none() { - let policy = builder - .get_local_identity() - .as_ref() - .and_then(|id| id.quote_policy()); - builder = builder.quote_policy(policy); - } - - SessionMeta { - inner: Arc::new(tokio::sync::Mutex::new(MultiplexedSession { - peer_id: peer_id.clone(), - session_id, - inner: builder.build_responder(), - })), - peer_id, - session_id, - last_access_time: now, - } - } - - /// Fetch an existing session given its identifier or create a new one. - fn get_or_create( - &mut self, - peer_id: PeerID, - session_id: SessionID, - ) -> Result<(SharedSession, bool), Error> { - let now = insecure_posix_time(); - - // Check if peer exists. - if let Some(sessions) = self.by_peer.get_mut(&peer_id) { - // Check if the session exists. If so, return it. - if let Some(session) = sessions.get_mut(&session_id) { - // Remove old idle time. - self.by_idle_time.remove(&session.by_time_key()); - // Update idle time. - session.last_access_time = now; - self.by_idle_time.insert(session.by_time_key()); - - return Ok((session.inner.clone(), false)); - } - - // Check if the peer has max sessions or if no more sessions are available globally. If - // so, remove the oldest or return an error. - if sessions.len() >= self.max_sessions_per_peer - || self.by_idle_time.len() >= self.max_sessions - { - // Force close the oldest idle session so we can start a new one. - let inner = sessions - .iter() - .min_by_key(|(_, s)| { - if let Ok(_inner) = s.inner.try_lock() { - s.last_access_time - } else { - i64::MAX // Session is currently in use. - } - }) - .map(|(_, s)| s.inner.clone()) - .ok_or(Error::MaxConcurrentSessions)?; - - if let Ok(inner) = inner.try_lock_owned() { - self.remove(&inner); - } else { - // All sessions are in use. - return Err(Error::MaxConcurrentSessions); - } - } - } - - // Check if there are too many sessions. If so, remove one or return an error. - if self.by_idle_time.len() >= self.max_sessions { - // Attempt to prune stale sessions, starting with the oldest ones. - let mut remove_session: Option> = None; - for (last_process_frame_time, peer_id, session_id) in self.by_idle_time.iter() { - if now.saturating_sub(*last_process_frame_time) < self.stale_session_timeout { - // This is the oldest session, all next ones will be more fresh. - return Err(Error::MaxConcurrentSessions); - } - - // Fetch session and attempt to lock it. - if let Some(sessions) = self.by_peer.get(peer_id) { - if let Some(session) = sessions.get(session_id) { - if let Ok(session) = session.inner.clone().try_lock_owned() { - remove_session = Some(session); - break; - } - } - } - } - - if let Some(session) = remove_session { - // We found a session that can be removed. - self.remove(&session); - } else { - // All stale sessions are in use. - return Err(Error::MaxConcurrentSessions); - } - } - - // Create a new session. - let sessions = self.by_peer.entry(peer_id.clone()).or_default(); - let session = Self::create_session(self.builder.clone(), peer_id.clone(), session_id, now); - let inner = session.inner.clone(); - sessions.insert(session_id, session); - self.by_idle_time.insert((now, peer_id, session_id)); - - Ok((inner, true)) - } - - /// Remove a session that must be currently owned by the caller. - fn remove(&mut self, session: &OwnedMutexGuard) { - let sessions = self.by_peer.get_mut(&session.peer_id).unwrap(); - let session_meta = sessions.get(&session.session_id).unwrap(); - let key = session_meta.by_time_key(); - sessions.remove(&session.session_id); - self.by_idle_time.remove(&key); - - // If peer doesn't have any more sessions, remove the peer. - if sessions.is_empty() { - self.by_peer.remove(&session.peer_id); - } - } - - /// Clear all sessions. - fn clear(&mut self) { - self.by_peer.clear(); - self.by_idle_time.clear(); - } - - /// Number of all sessions. - #[cfg(test)] - fn session_count(&self) -> usize { - self.by_idle_time.len() - } - - /// Number of all peers. - #[cfg(test)] - fn peer_count(&self) -> usize { - self.by_peer.len() - } -} - /// Session demultiplexer. pub struct Demux { - sessions: Mutex, -} - -/// A multiplexed session. -pub struct MultiplexedSession { - /// Peer identifier (needed for resolution when only given the shared pointer). - peer_id: PeerID, - /// Session identifier (needed for resolution when only given the shared pointer). - session_id: SessionID, - /// The actual session. - inner: Session, -} - -impl MultiplexedSession { - /// Session information. - pub fn info(&self) -> Option> { - self.inner.session_info() - } - - /// Process incoming session data. - async fn process_data( - &mut self, - data: Vec, - writer: W, - ) -> Result, Error> { - Ok(self.inner.process_data(data, writer).await?) - } - - /// Write message to session and generate a response. - pub fn write_message(&mut self, msg: Message, mut writer: W) -> Result<(), Error> { - Ok(self.inner.write_message(msg, &mut writer)?) - } + sessions: Mutex>>, } impl Demux { @@ -316,12 +70,22 @@ impl Demux { async fn get_or_create_session( &self, - peer_id: PeerID, + peer_id: Vec, session_id: SessionID, - ) -> Result, Error> { - let (session, _) = { + ) -> Result>>, Error> { + let session = { let mut sessions = self.sessions.lock().unwrap(); - sessions.get_or_create(peer_id, session_id)? + match sessions.get(&peer_id, &session_id) { + Some(session) => session, + None => { + let now = insecure_posix_time(); + let _ = sessions.remove_for(&peer_id, now)?; + let session = sessions.create_responder(peer_id, session_id); + sessions + .add(session, now) + .expect("there should be space for the new session") + } + } }; Ok(session.lock_owned().await) @@ -332,16 +96,22 @@ impl Demux { /// Any data that needs to be transmitted back to the peer is written to the passed writer. pub async fn process_frame( &self, - peer_id: PeerID, + peer_id: Vec, data: Vec, writer: W, - ) -> Result<(OwnedMutexGuard, Option), Error> { + ) -> Result< + ( + OwnedMutexGuard>>, + Option, + ), + Error, + > { // Decode frame. let frame: Frame = cbor::from_slice(&data)?; // Get the existing session or create a new one. let mut session = self.get_or_create_session(peer_id, frame.session).await?; // Process session data. - match session.process_data(frame.payload, writer).await { + match session.process_data(&frame.payload, writer).await { Ok(msg) => { if let Some(Message::Request(ref req)) = msg { // Make sure that the untrusted_plaintext matches the request's method. @@ -354,11 +124,11 @@ impl Demux { } Err(err) => { // In case the session was closed, remove the session. - if session.inner.is_closed() { + if session.is_closed() { let mut sessions = self.sessions.lock().unwrap(); sessions.remove(&session); } - Err(err) + Err(Error::Other(err)) } } } @@ -368,7 +138,7 @@ impl Demux { /// Any data that needs to be transmitted back to the peer is written to the passed writer. pub fn close( &self, - mut session: OwnedMutexGuard, + mut session: OwnedMutexGuard>>, writer: W, ) -> Result<(), Error> { let mut sessions = self.sessions.lock().unwrap(); @@ -381,251 +151,6 @@ impl Demux { /// Resets all open sessions. pub fn reset(&self) { let mut sessions = self.sessions.lock().unwrap(); - sessions.clear(); - } -} - -#[cfg(test)] -mod test { - use crate::enclave_rpc::{session::Builder, types::SessionID}; - - use super::{Error, Sessions}; - - fn ids() -> (Vec>, Vec) { - let peer_ids: Vec> = (1..16).map(|x| vec![x]).collect(); - let session_ids: Vec = (1..16).map(|_| SessionID::random()).collect(); - - (peer_ids, session_ids) - } - - #[test] - fn test_namespacing() { - let (peer_ids, session_ids) = ids(); - let mut sessions = Sessions::new(Builder::default(), 16, 4, 60); - - let (s1, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[0]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let s1_owned = s1.try_lock().unwrap(); - assert_eq!(&s1_owned.peer_id, &peer_ids[0]); - assert_eq!(&s1_owned.session_id, &session_ids[0]); - drop(s1_owned); - assert_eq!(sessions.session_count(), 1); - assert_eq!(sessions.peer_count(), 1); - - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[1]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[2]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[3]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 1); - - // Requesting an existing session for an existing peer should return it. - let (s1r, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[0]) - .expect("get_or_create should succeed"); - assert!(!created, "session should be reused"); - let s1r_owned = s1r.try_lock().unwrap(); - assert_eq!(&s1r_owned.peer_id, &peer_ids[0]); - assert_eq!(&s1r_owned.session_id, &session_ids[0]); - drop(s1r_owned); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 1); - - // Sessions should be properly namespaced by peer. - let (s5, created) = sessions - .get_or_create(peer_ids[1].clone(), session_ids[0]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created due to namespacing"); - let s5_owned = s5.try_lock().unwrap(); - assert_eq!(&s5_owned.peer_id, &peer_ids[1]); - assert_eq!(&s5_owned.session_id, &session_ids[0]); - drop(s5_owned); - assert_eq!(sessions.session_count(), 5); - assert_eq!(sessions.peer_count(), 2); - } - - #[test] - fn test_max_sessions_per_peer() { - let (peer_ids, session_ids) = ids(); - let mut sessions = Sessions::new(Builder::default(), 16, 4, 60); // Stale timeout is ignored. - - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[0]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - - // Sleep to make sure the first session is the oldest. - std::thread::sleep(std::time::Duration::from_millis(1100)); - - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[1]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[2]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[3]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 1); - - // Creating more sessions for the same peer should result in the oldest session being - // closed. - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[4]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 1); - - // Only the oldest session should be closed. - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[1]) - .expect("get_or_create should succeed"); - assert!(!created, "session should be reused"); - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[2]) - .expect("get_or_create should succeed"); - assert!(!created, "session should be reused"); - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[3]) - .expect("get_or_create should succeed"); - assert!(!created, "session should be reused"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 1); - - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[0]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 1); - } - - #[test] - fn test_max_sessions() { - let (peer_ids, session_ids) = ids(); - let mut sessions = Sessions::new(Builder::default(), 4, 4, 60); - - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[0]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[1].clone(), session_ids[1]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[2].clone(), session_ids[2]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[3].clone(), session_ids[3]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 4); - - // Creating more sessions for a different peer should fail as no sessions are available and - // none are stale. - let res = sessions.get_or_create(peer_ids[4].clone(), session_ids[4]); - assert!( - matches!(res, Err(Error::MaxConcurrentSessions)), - "get_or_create should fail" - ); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 4); - - // Creating more sessions for one of the existing peers should still work as it should force - // evict an old session. Note that each peer has 4 available slots, but globally there are - // only 4 slots so if global slots are full this should still trigger peer session eviction. - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[5]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 4); - } - - #[test] - fn test_max_sessions_prune_stale() { - let (peer_ids, session_ids) = ids(); - let mut sessions = Sessions::new(Builder::default(), 4, 4, 0); // Stale timeout is zero. - - let (_, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[0]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[1].clone(), session_ids[1]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[2].clone(), session_ids[2]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[3].clone(), session_ids[3]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 4); - - // Creating more sessions for a different peer should succeed as one of the stale sessions - // should be removed to make room for a new session. - let (_, created) = sessions - .get_or_create(peer_ids[4].clone(), session_ids[4]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 4); - } - - #[test] - fn test_remove() { - let (peer_ids, session_ids) = ids(); - let mut sessions = Sessions::new(Builder::default(), 16, 4, 0); // Stale timeout is zero. - - let (s1, created) = sessions - .get_or_create(peer_ids[0].clone(), session_ids[0]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (s2, created) = sessions - .get_or_create(peer_ids[1].clone(), session_ids[1]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[1].clone(), session_ids[2]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - let (_, created) = sessions - .get_or_create(peer_ids[2].clone(), session_ids[3]) - .expect("get_or_create should succeed"); - assert!(created, "new session should be created"); - assert_eq!(sessions.session_count(), 4); - assert_eq!(sessions.peer_count(), 3); - - let s1r = s1.try_lock_owned().unwrap(); - sessions.remove(&s1r); - assert_eq!(sessions.session_count(), 3); - assert_eq!(sessions.peer_count(), 2); - - let s2r = s2.try_lock_owned().unwrap(); - sessions.remove(&s2r); - assert_eq!(sessions.session_count(), 2); - assert_eq!(sessions.peer_count(), 2); + let _ = sessions.drain(); } } diff --git a/runtime/src/enclave_rpc/mod.rs b/runtime/src/enclave_rpc/mod.rs index e4d3685043d..2a0e7a4bf2f 100644 --- a/runtime/src/enclave_rpc/mod.rs +++ b/runtime/src/enclave_rpc/mod.rs @@ -5,6 +5,7 @@ pub mod context; pub mod demux; pub mod dispatcher; pub mod session; +pub mod sessions; mod transport; pub mod types; diff --git a/runtime/src/enclave_rpc/session.rs b/runtime/src/enclave_rpc/session.rs index 6f406a2215c..d6a1c158573 100644 --- a/runtime/src/enclave_rpc/session.rs +++ b/runtime/src/enclave_rpc/session.rs @@ -116,7 +116,7 @@ impl Session { /// protocol replies need to be generated. pub async fn process_data( &mut self, - data: Vec, + data: &[u8], mut writer: W, ) -> Result> { // Replace the state with a closed state. In case processing fails for whatever @@ -134,7 +134,7 @@ impl Session { writer.write_all(&self.buf[..len])?; } else { // <- e - state.read_message(&data, &mut self.buf)?; + state.read_message(data, &mut self.buf)?; // -> e, ee, s, es let len = state.write_message(&self.get_rak_binding(), &mut self.buf)?; @@ -145,7 +145,7 @@ impl Session { } State::Handshake2(mut state) => { // Process data sent during Handshake1 phase. - let len = state.read_message(&data, &mut self.buf)?; + let len = state.read_message(data, &mut self.buf)?; let remote_static = state .get_remote_static() .expect("dh exchange just happened"); @@ -178,7 +178,7 @@ impl Session { } State::Transport(mut state) => { // TODO: Restore session in case of errors. - let len = state.read_message(&data, &mut self.buf)?; + let len = state.read_message(data, &mut self.buf)?; let msg = cbor::from_slice(&self.buf[..len])?; self.state = State::Transport(state); @@ -314,7 +314,7 @@ impl Session { } /// Return remote node identifier. - pub fn get_node(&self) -> Result { + pub fn get_remote_node(&self) -> Result { self.remote_node.ok_or(SessionError::NodeNotSet.into()) } diff --git a/runtime/src/enclave_rpc/sessions.rs b/runtime/src/enclave_rpc/sessions.rs new file mode 100644 index 00000000000..bddf80ba0d3 --- /dev/null +++ b/runtime/src/enclave_rpc/sessions.rs @@ -0,0 +1,945 @@ +//! Session demultiplexer. +use std::{ + collections::{BTreeSet, HashMap, HashSet}, + hash::Hash, + io::Write, + mem, + sync::Arc, +}; + +use anyhow::Result; +use rand::{rngs::OsRng, Rng}; +use tokio::sync::OwnedMutexGuard; + +use super::{ + session::{Builder, Session, SessionInfo}, + types::{Message, SessionID}, +}; +use crate::common::{ + crypto::signature, + namespace::Namespace, + sgx::{EnclaveIdentity, QuotePolicy}, + time::insecure_posix_time, +}; + +/// Shared pointer to a multiplexed session. +pub type SharedSession = Arc>>; + +/// Key for use in the by-idle-time index. +pub type SessionByTimeKey = (i64, PeerID, SessionID); + +/// Sessions error. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("max concurrent sessions reached")] + MaxConcurrentSessions, +} + +/// A multiplexed session. +pub struct MultiplexedSession { + /// Peer identifier (needed for resolution when only given the shared pointer). + peer_id: PeerID, + /// Session identifier (needed for resolution when only given the shared pointer). + session_id: SessionID, + /// The actual session. + inner: Session, +} + +impl MultiplexedSession { + /// Return the session's peer ID. + pub fn get_peer_id(&self) -> &PeerID { + &self.peer_id + } + + /// Set the session's peer ID. + pub fn set_peer_id(&mut self, peer_id: PeerID) { + self.peer_id = peer_id; + } + + /// Return the session ID. + pub fn get_session_id(&self) -> &SessionID { + &self.session_id + } + + /// Session information. + pub fn info(&self) -> Option> { + self.inner.session_info() + } + + /// Whether the session is in closed state. + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } + + /// Process incoming session data. + pub async fn process_data( + &mut self, + data: &[u8], + writer: W, + ) -> Result> { + self.inner.process_data(data, writer).await + } + + /// Write message to session and generate a response. + pub fn write_message(&mut self, msg: Message, mut writer: W) -> Result<()> { + self.inner.write_message(msg, &mut writer) + } + + /// Return remote node identifier. + pub fn get_remote_node(&self) -> Result { + self.inner.get_remote_node() + } + + /// Set the remote node identifier. + pub fn set_remote_node(&mut self, node: signature::PublicKey) -> Result<()> { + self.inner.set_remote_node(node) + } + + /// Whether the session handshake has completed and the session + /// is in transport mode. + pub fn is_connected(&self) -> bool { + self.inner.is_connected() + } + + /// Whether the session is in unauthenticated transport state. In this state the session can + /// only be used to transmit a close notification. + pub fn is_unauthenticated(&self) -> bool { + self.inner.is_unauthenticated() + } + + /// Mark the session as closed. + /// + /// After the session is closed it can no longer be used to transmit + /// or receive messages and any such use will result in an error. + pub fn close(&mut self) { + self.inner.close() + } +} + +/// Structure used for session accounting. +pub struct SessionMeta { + /// Peer identifier. + peer_id: PeerID, + /// Session identifier. + session_id: SessionID, + /// Timestamp when the session was last accessed. + last_access_time: i64, + /// The shared session pointer that needs to be locked for access. + inner: SharedSession, +} + +impl SessionMeta +where + PeerID: Clone + Ord + Hash, +{ + /// Key for ordering in the by-idle-time index. + fn by_time_key(&self) -> SessionByTimeKey { + (self.last_access_time, self.peer_id.clone(), self.session_id) + } +} + +/// Session indices and management operations. +pub struct Sessions { + /// Session builder. + builder: Builder, + /// Maximum number of sessions. + max_sessions: usize, + /// Maximum number of sessions per peer. + max_sessions_per_peer: usize, + /// Stale session timeout (in seconds). + stale_session_timeout: i64, + + /// A map of sessions for each peer. + by_peer: HashMap>>, + /// A set of all sessions, ordered by idle time. + by_idle_time: BTreeSet>, +} + +impl Sessions +where + PeerID: Clone + Ord + Hash, +{ + /// Create a new session management instance. + pub fn new( + builder: Builder, + max_sessions: usize, + max_sessions_per_peer: usize, + stale_session_timeout: i64, + ) -> Self { + Self { + builder, + max_sessions, + max_sessions_per_peer, + stale_session_timeout, + by_peer: HashMap::new(), + by_idle_time: BTreeSet::new(), + } + } + + /// Update remote enclave identity verification in the session builder + /// and clear all sessions if the identity has changed. + pub fn update_enclaves( + &mut self, + enclaves: Option>, + ) -> Vec> { + if self.builder.get_remote_enclaves() == &enclaves { + return vec![]; + } + + self.builder = mem::take(&mut self.builder).remote_enclaves(enclaves); + self.drain() + } + + /// Update quote policy used for remote quote verification in the session builder + /// and clear all sessions if the policy has changed. + pub fn update_quote_policy(&mut self, policy: QuotePolicy) -> Vec> { + let policy = Some(Arc::new(policy)); + if self.builder.get_quote_policy() == &policy { + return vec![]; + } + + self.builder = mem::take(&mut self.builder).quote_policy(policy); + self.drain() + } + + /// Update remote runtime ID for node identity verification in the session builder + /// and clear all sessions if the runtime ID has changed. + pub fn update_runtime_id(&mut self, id: Option) -> Vec> { + if self.builder.get_remote_runtime_id() == &id { + return vec![]; + } + + self.builder = mem::take(&mut self.builder).remote_runtime_id(id); + self.drain() + } + + /// Create a new multiplexed responder session. + pub fn create_responder( + &mut self, + peer_id: PeerID, + session_id: SessionID, + ) -> MultiplexedSession { + // If no quote policy is set, use the local one. + if self.builder.get_quote_policy().is_none() { + let policy = self + .builder + .get_local_identity() + .as_ref() + .and_then(|id| id.quote_policy()); + + self.builder = mem::take(&mut self.builder).quote_policy(policy); + } + + MultiplexedSession { + peer_id: peer_id.clone(), + session_id, + inner: self.builder.clone().build_responder(), + } + } + + /// Create a new multiplexed initiator session. + pub fn create_initiator(&self, peer_id: PeerID) -> MultiplexedSession { + let session_id = SessionID::random(); + + MultiplexedSession { + peer_id: peer_id.clone(), + session_id, + inner: self.builder.clone().build_initiator(), + } + } + + /// Fetch an existing session given its identifier. + pub fn get( + &mut self, + peer_id: &PeerID, + session_id: &SessionID, + ) -> Option> { + // Check if peer exists. + let sessions = match self.by_peer.get_mut(peer_id) { + Some(sessions) => sessions, + None => return None, + }; + + // Check if the session exists. If so, return it. + let session = match sessions.get_mut(session_id) { + Some(session) => session, + None => return None, + }; + + Self::update_access_time(session, &mut self.by_idle_time); + + Some(session.inner.clone()) + } + + /// Fetch an existing session from one of the given peers. If no peers + /// are provided, a session from any peer will be returned. + pub fn find(&mut self, peer_ids: &[PeerID]) -> Option> { + match peer_ids.is_empty() { + true => self.find_any(), + false => self.find_one(peer_ids), + } + } + + /// Fetch an existing session from any peer. + pub fn find_any(&mut self) -> Option> { + if self.by_idle_time.is_empty() { + return None; + } + + // Check if there is a session that is not currently in use. + for (_, peer_id, session_id) in self.by_idle_time.iter() { + let session = self + .by_peer + .get_mut(peer_id) + .unwrap() + .get_mut(session_id) + .unwrap(); + + if session.inner.clone().try_lock_owned().is_ok() { + Self::update_access_time(session, &mut self.by_idle_time); + return Some(session.inner.clone()); + } + } + + // If all sessions are in use, return a random one. + let n = OsRng.gen_range(0..self.by_idle_time.len()); + let (_, peer_id, session_id) = self.by_idle_time.iter().nth(n).unwrap(); + let session = self + .by_peer + .get_mut(peer_id) + .unwrap() + .get_mut(session_id) + .unwrap(); + + Self::update_access_time(session, &mut self.by_idle_time); + + Some(session.inner.clone()) + } + + /// Fetch an existing session from one of the given peers. + pub fn find_one(&mut self, peer_ids: &[PeerID]) -> Option> { + let mut all_sessions = vec![]; + + for peer_id in peer_ids.iter() { + let sessions = match self.by_peer.get_mut(peer_id) { + Some(sessions) => sessions, + None => return None, + }; + + // Check if peer has a session that is not currently in use. + let session = sessions + .values_mut() + .filter(|s| s.inner.clone().try_lock_owned().is_ok()) + .min_by_key(|s| s.last_access_time); + + if let Some(session) = session { + Self::update_access_time(session, &mut self.by_idle_time); + return Some(session.inner.clone()); + } + + for session in sessions.values() { + all_sessions.push((session.peer_id.clone(), session.session_id)); + } + } + + if all_sessions.is_empty() { + return None; + } + + // If all sessions are in use, return a random one. + let n = OsRng.gen_range(0..all_sessions.len()); + let (peer_id, session_id) = all_sessions.get(n).unwrap(); + let session = self + .by_peer + .get_mut(peer_id) + .unwrap() + .get_mut(session_id) + .unwrap(); + + Self::update_access_time(session, &mut self.by_idle_time); + + Some(session.inner.clone()) + } + + /// Remove one session to free up a slot for the given peer. + pub fn remove_for( + &mut self, + peer_id: &PeerID, + now: i64, + ) -> Result>>, Error> { + if let Some(session) = self.remove_from(peer_id)? { + return Ok(Some(session)); + } + self.remove_one(now) + } + + /// Remove one existing session from the given peer if the peer has reached + /// the maximum number of sessions or if the total number of sessions exceeds + /// the global session limit. + pub fn remove_from( + &mut self, + peer_id: &PeerID, + ) -> Result>>, Error> { + // Check if peer exists. + let sessions = match self.by_peer.get_mut(peer_id) { + Some(sessions) => sessions, + None => return Ok(None), + }; + + // Check if the peer has max sessions or if no more sessions are available globally. + // If so, remove the oldest or return an error. + if sessions.len() < self.max_sessions_per_peer + && self.by_idle_time.len() < self.max_sessions + { + return Ok(None); + } + + // Force close the oldest idle session. + let remove_session = sessions + .iter() + .min_by_key(|(_, s)| { + if let Ok(_inner) = s.inner.try_lock() { + s.last_access_time + } else { + i64::MAX // Session is currently in use. + } + }) + .map(|(_, s)| s.inner.clone()) + .ok_or(Error::MaxConcurrentSessions)?; + + let session = match remove_session.try_lock_owned() { + Ok(inner) => inner, + Err(_) => return Err(Error::MaxConcurrentSessions), // All sessions are in use. + }; + + self.remove(&session); + + Ok(Some(session)) + } + + /// Remove one stale session if the total number of sessions exceeds + /// the global session limit. + pub fn remove_one( + &mut self, + now: i64, + ) -> Result>>, Error> { + // Check if there are too many sessions. If so, remove one or return an error. + if self.by_idle_time.len() < self.max_sessions { + return Ok(None); + } + + // Attempt to prune stale sessions, starting with the oldest ones. + let mut remove_session: Option>> = None; + + for (last_process_frame_time, peer_id, session_id) in self.by_idle_time.iter() { + if now.saturating_sub(*last_process_frame_time) < self.stale_session_timeout { + // This is the oldest session, all next ones will be more fresh. + return Err(Error::MaxConcurrentSessions); + } + + // Fetch session and attempt to lock it. + if let Some(sessions) = self.by_peer.get(peer_id) { + if let Some(session) = sessions.get(session_id) { + if let Ok(session) = session.inner.clone().try_lock_owned() { + remove_session = Some(session); + break; + } + } + } + } + + // Check if we found a session that can be removed. + let session = match remove_session { + Some(session) => session, + None => return Err(Error::MaxConcurrentSessions), // All stale sessions are in use. + }; + + self.remove(&session); + + Ok(Some(session)) + } + + /// Add a session if there is an available spot. + pub fn add( + &mut self, + session: MultiplexedSession, + now: i64, + ) -> Result, Error> { + if self.by_idle_time.len() >= self.max_sessions { + return Err(Error::MaxConcurrentSessions); + } + + let sessions = self.by_peer.entry(session.peer_id.clone()).or_default(); + if sessions.len() >= self.max_sessions_per_peer { + return Err(Error::MaxConcurrentSessions); + } + + let peer_id = session.peer_id.clone(); + let session_id = session.session_id; + + let session = SessionMeta { + inner: Arc::new(tokio::sync::Mutex::new(session)), + peer_id, + session_id, + last_access_time: now, + }; + let inner = session.inner.clone(); + + self.by_idle_time.insert(session.by_time_key()); + sessions.insert(session.session_id, session); + + Ok(inner) + } + + /// Remove a session that must be currently owned by the caller. + pub fn remove(&mut self, session: &OwnedMutexGuard>) { + let sessions = self.by_peer.get_mut(&session.peer_id).unwrap(); + let session_meta = sessions.get(&session.session_id).unwrap(); + let key = session_meta.by_time_key(); + sessions.remove(&session.session_id); + self.by_idle_time.remove(&key); + + // If peer doesn't have any more sessions, remove the peer. + if sessions.is_empty() { + self.by_peer.remove(&session.peer_id); + } + } + + /// Removes and returns all sessions. + pub fn drain(&mut self) -> Vec> { + self.by_idle_time.clear(); + + let mut all_sessions = vec![]; + for (_, mut sessions) in self.by_peer.drain() { + for (_, session) in sessions.drain() { + all_sessions.push(session.inner); + } + } + + all_sessions + } + + fn update_access_time( + session: &mut SessionMeta, + by_idle_time: &mut BTreeSet>, + ) { + // Remove old idle time. + by_idle_time.remove(&session.by_time_key()); + + // Update idle time. + session.last_access_time = insecure_posix_time(); + by_idle_time.insert(session.by_time_key()); + } + + /// Number of all sessions. + #[cfg(test)] + fn session_count(&self) -> usize { + self.by_idle_time.len() + } + + /// Number of all peers. + #[cfg(test)] + fn peer_count(&self) -> usize { + self.by_peer.len() + } +} + +#[cfg(test)] +mod test { + use crate::enclave_rpc::{session::Builder, types::SessionID}; + + use super::{Error, Sessions}; + + fn ids() -> (Vec>, Vec) { + let peer_ids: Vec> = (1..8).map(|x| vec![x]).collect(); + let session_ids: Vec = (1..8).map(|_| SessionID::random()).collect(); + + (peer_ids, session_ids) + } + + #[test] + fn test_add() { + let (peer_ids, session_ids) = ids(); + let mut sessions = Sessions::new(Builder::default(), 4, 2, 60); + + let test_vector = vec![ + (&peer_ids[0], &session_ids[0], 1, 1, true), + (&peer_ids[0], &session_ids[1], 2, 1, true), // Different session ID. + (&peer_ids[0], &session_ids[2], 2, 1, false), // Too many sessions per peer. + (&peer_ids[1], &session_ids[0], 3, 2, true), // Different peer ID. + (&peer_ids[2], &session_ids[2], 4, 3, true), // Different peer ID and session ID. + (&peer_ids[3], &session_ids[3], 4, 3, false), // Too many sessions. + ]; + + let now = 0; + for (peer_id, session_id, num_sessions, num_peers, created) in test_vector { + let session = sessions.create_responder(peer_id.clone(), session_id.clone()); + let res = sessions.add(session, now); + match created { + true => { + assert!(res.is_ok(), "session should be created"); + let s = res.unwrap(); + let s_owned = s.try_lock().unwrap(); + assert_eq!(&s_owned.peer_id, peer_id); + assert_eq!(&s_owned.session_id, session_id); + } + false => { + assert!(res.is_err(), "session should not be created"); + assert!(matches!(res, Err(Error::MaxConcurrentSessions))); + } + }; + assert_eq!(sessions.session_count(), num_sessions); + assert_eq!(sessions.peer_count(), num_peers); + } + } + + #[test] + fn test_get() { + let (peer_ids, session_ids) = ids(); + let mut sessions = Sessions::new(Builder::default(), 8, 2, 60); + + let test_vector = vec![ + (&peer_ids[0], &session_ids[0], true), + (&peer_ids[0], &session_ids[1], false), // Different peer ID. + (&peer_ids[1], &session_ids[0], false), // Different session ID. + (&peer_ids[1], &session_ids[1], false), // Different peer ID and session ID. + ]; + + let now = 0; + for (peer_id, session_id, create) in test_vector { + if create { + let session = sessions.create_responder(peer_id.clone(), session_id.clone()); + let _ = sessions.add(session, now); + } + + let maybe_s = sessions.get(peer_id, session_id); + match create { + true => { + assert!(maybe_s.is_some(), "session should exist"); + let s = maybe_s.unwrap(); + let s_owned = s.try_lock_owned().unwrap(); + assert_eq!(&s_owned.peer_id, peer_id); + assert_eq!(&s_owned.session_id, session_id); + } + false => assert!(maybe_s.is_none(), "session should not exist"), + } + } + } + + #[test] + fn test_find_any() { + let (peer_ids, session_ids) = ids(); + let mut sessions = Sessions::new(Builder::default(), 8, 2, 60); + + let test_vector = vec![ + (&peer_ids[0], &session_ids[0]), + (&peer_ids[0], &session_ids[1]), + (&peer_ids[1], &session_ids[2]), + ]; + + // No sessions. + let maybe_s = sessions.find_any(); + assert!(maybe_s.is_none(), "session should not be found"); + + let mut now = 0; + for (peer_id, session_id) in test_vector { + let session = sessions.create_responder(peer_id.clone(), session_id.clone()); + let _ = sessions.add(session, now); + now += 1 + } + + // No sessions in use. + let maybe_s = sessions.find_any(); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let s1_owned = s.try_lock_owned().unwrap(); // Session now in use. + assert_eq!(&s1_owned.peer_id, &peer_ids[0]); + assert_eq!(&s1_owned.session_id, &session_ids[0]); + + // One session in use. + let maybe_s = sessions.find_any(); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let s2_owned = s.try_lock_owned().unwrap(); // Session now in use. + assert_eq!(&s2_owned.peer_id, &peer_ids[0]); + assert_eq!(&s2_owned.session_id, &session_ids[1]); // Different session found. + + // Two sessions in use. + let maybe_s = sessions.find_any(); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let s3_owned = s.try_lock_owned().unwrap(); // Session now in use. + assert_eq!(&s3_owned.peer_id, &peer_ids[1]); + assert_eq!(&s3_owned.session_id, &session_ids[2]); // Different session found. + + // All sessions in use. + let maybe_s = sessions.find_any(); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let res = s.try_lock_owned(); // Session now in use. + assert!(res.is_err(), "session should be in use"); + + // Free one session. + drop(s2_owned); + + // Two sessions in use. + let maybe_s = sessions.find_any(); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let s_owned = s.try_lock_owned().unwrap(); // Session now in use. + assert_eq!(&s_owned.peer_id, &peer_ids[0]); + assert_eq!(&s_owned.session_id, &session_ids[1]); + } + + #[test] + fn test_find_one() { + let (peer_ids, session_ids) = ids(); + let mut sessions = Sessions::new(Builder::default(), 8, 2, 60); + + let test_vector = vec![ + (&peer_ids[2], &session_ids[0]), // Incorrect peer. + (&peer_ids[0], &session_ids[0]), + (&peer_ids[3], &session_ids[1]), // Incorrect peer. + (&peer_ids[0], &session_ids[1]), + (&peer_ids[3], &session_ids[2]), // Incorrect peer. + (&peer_ids[1], &session_ids[2]), + (&peer_ids[2], &session_ids[2]), // Incorrect peer. + ]; + + // No sessions. + let maybe_s = sessions.find_one(&peer_ids[0..2]); + assert!(maybe_s.is_none(), "session should not be found"); + + let mut now = 0; + for (peer_id, session_id) in test_vector { + let session = sessions.create_responder(peer_id.clone(), session_id.clone()); + let _ = sessions.add(session, now); + now += 1 + } + + // Peers without sessions. + let maybe_s = sessions.find_one(&peer_ids[4..]); + assert!(maybe_s.is_none(), "session should not be found"); + + // No sessions in use. + let maybe_s = sessions.find_one(&peer_ids[0..2]); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let s1_owned = s.try_lock_owned().unwrap(); // Session now in use. + assert_eq!(&s1_owned.peer_id, &peer_ids[0]); + assert_eq!(&s1_owned.session_id, &session_ids[0]); + + // One session in use. + let maybe_s = sessions.find_one(&peer_ids[0..2]); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let s2_owned = s.try_lock_owned().unwrap(); // Session now in use. + assert_eq!(&s2_owned.peer_id, &peer_ids[0]); + assert_eq!(&s2_owned.session_id, &session_ids[1]); // Different session found. + + // Two sessions in use. + let maybe_s = sessions.find_one(&peer_ids[0..2]); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let s3_owned = s.try_lock_owned().unwrap(); // Session now in use. + assert_eq!(&s3_owned.peer_id, &peer_ids[1]); + assert_eq!(&s3_owned.session_id, &session_ids[2]); // Different session found. + + // All sessions in use. + let maybe_s = sessions.find_one(&peer_ids[0..2]); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let res = s.try_lock_owned(); // Session now in use. + assert!(res.is_err(), "session should be in use"); + + // Free one session. + drop(s2_owned); + + // Two sessions in use. + let maybe_s = sessions.find_one(&peer_ids[0..2]); + assert!(maybe_s.is_some(), "session should be found"); + let s = maybe_s.unwrap(); + let s_owned = s.try_lock_owned().unwrap(); // Session now in use. + assert_eq!(&s_owned.peer_id, &peer_ids[0]); + assert_eq!(&s_owned.session_id, &session_ids[1]); + } + + #[test] + fn test_remove_from() { + let (peer_ids, session_ids) = ids(); + let mut sessions = Sessions::new(Builder::default(), 4, 2, 60); + + let test_vector = vec![ + (&peer_ids[0], &session_ids[0]), + (&peer_ids[1], &session_ids[1]), + (&peer_ids[2], &session_ids[2]), + (&peer_ids[2], &session_ids[3]), // Max sessions per peer reached. + // Max sessions reached. + ]; + + let mut now = 0; + for (peer_id, session_id) in test_vector.clone() { + let session = sessions.create_responder(peer_id.clone(), session_id.clone()); + let _ = sessions.add(session, now); + now += 1; + } + + // Removing one session from an unknown peer should have no effect, + // even if all global session slots are occupied. + let res = sessions.remove_from(&peer_ids[3]); + assert!(res.is_ok(), "remove_from should succeed"); + let maybe_s_owned = res.unwrap(); + assert!(maybe_s_owned.is_none(), "no sessions should be removed"); + assert_eq!(sessions.session_count(), 4); + assert_eq!(sessions.peer_count(), 3); + + // Removing one session for one of the existing peers should work + // as it should force evict an old session. + // Note that each peer has 2 available slots, but globally there are + // only 4 slots so if global slots are full this should trigger peer + // session eviction. + let res = sessions.remove_from(&peer_ids[0]); + assert!(res.is_ok(), "remove_from should succeed"); + let maybe_s_owned = res.unwrap(); + assert!(maybe_s_owned.is_some(), "one session should be removed"); + let s_owned = maybe_s_owned.unwrap(); + assert_eq!(&s_owned.peer_id, &peer_ids[0]); + assert_eq!(&s_owned.session_id, &session_ids[0]); + assert_eq!(sessions.session_count(), 3); + assert_eq!(sessions.peer_count(), 2); + + // Removing another session should fail as one global session slot + // is available. + for peer_id in vec![&peer_ids[0], &peer_ids[1]] { + let res = sessions.remove_from(peer_id); + assert!(res.is_ok(), "remove_from should succeed"); + let maybe_s_owned = res.unwrap(); + assert!(maybe_s_owned.is_none(), "no sessions should be removed"); + assert_eq!(sessions.session_count(), 3); + assert_eq!(sessions.peer_count(), 2); + } + + // Removing one session from a peer with max sessions should succeed + // even if one global slot is available. + let res = sessions.remove_from(&peer_ids[2]); + assert!(res.is_ok(), "remove_from should succeed"); + let maybe_s_owned = res.unwrap(); + assert!(maybe_s_owned.is_some(), "one session should be removed"); + let s_owned = maybe_s_owned.unwrap(); + assert_eq!(&s_owned.peer_id, &peer_ids[2]); + assert_eq!(&s_owned.session_id, &session_ids[2]); + assert_eq!(sessions.session_count(), 2); + assert_eq!(sessions.peer_count(), 2); + } + + #[test] + fn test_remove_one() { + let (peer_ids, session_ids) = ids(); + let mut sessions = Sessions::new(Builder::default(), 4, 2, 60); + + let test_vector = vec![ + (&peer_ids[0], &session_ids[0]), + (&peer_ids[1], &session_ids[1]), + (&peer_ids[2], &session_ids[2]), + (&peer_ids[2], &session_ids[3]), // Max sessions reached. + ]; + + let mut now = 0; + for (peer_id, session_id) in test_vector.clone() { + let session = sessions.create_responder(peer_id.clone(), session_id.clone()); + let _ = sessions.add(session, now); + now += 1; + } + + // Forward time (stale_session_timeout - test_vector.len() - 1). + now += 60 - 4 - 1; + + // Removing one session should fail as there are none stale sessions. + let res = sessions.remove_one(now); + assert!(res.is_err(), "remove_one should fail"); + assert!(matches!(res, Err(Error::MaxConcurrentSessions))); + assert_eq!(sessions.session_count(), 4); + assert_eq!(sessions.peer_count(), 3); + + // Forward time. + now += 1; + + // Removing one session should succeed as no session slots + // are available and there is one stale session. + let res = sessions.remove_one(now); + assert!(res.is_ok(), "remove_one should succeed"); + let maybe_s_owned = res.unwrap(); + assert!(maybe_s_owned.is_some(), "one session should be removed"); + let s_owned = maybe_s_owned.unwrap(); + assert_eq!(&s_owned.peer_id, &peer_ids[0]); + assert_eq!(&s_owned.session_id, &session_ids[0]); + assert_eq!(sessions.session_count(), 3); + assert_eq!(sessions.peer_count(), 2); + + // Forward time. + now += 100; + + // Removing one session should fail even though there are stale sessions + // because there is one session slot available. + let res = sessions.remove_one(now); + assert!(res.is_ok(), "remove_one should succeed"); + let maybe_s_owned = res.unwrap(); + assert!(maybe_s_owned.is_none(), "no sessions should be removed"); + assert_eq!(sessions.session_count(), 3); + assert_eq!(sessions.peer_count(), 2); + } + + #[test] + fn test_remove() { + let (peer_ids, session_ids) = ids(); + let mut sessions = Sessions::new(Builder::default(), 8, 2, 60); + + let test_vector = vec![ + (&peer_ids[0], &session_ids[0], 3, 2), + (&peer_ids[1], &session_ids[1], 2, 1), + (&peer_ids[2], &session_ids[2], 1, 1), + (&peer_ids[2], &session_ids[3], 0, 0), + ]; + + let now = 0; + for (peer_id, session_id, _, _) in test_vector.clone() { + let session = sessions.create_responder(peer_id.clone(), session_id.clone()); + let _ = sessions.add(session, now); + } + + for (peer_id, session_id, num_sessions, num_peers) in test_vector { + let maybe_s = sessions.get(peer_id, session_id); + assert!(maybe_s.is_some(), "session should exist"); + let s = maybe_s.unwrap(); + let s_owned = s.try_lock_owned().unwrap(); + + sessions.remove(&s_owned); + assert_eq!(sessions.session_count(), num_sessions); + assert_eq!(sessions.peer_count(), num_peers); + } + } + + #[test] + fn test_clear() { + let (peer_ids, session_ids) = ids(); + let mut sessions = Sessions::new(Builder::default(), 8, 2, 60); + + let test_vector = vec![ + (&peer_ids[0], &session_ids[0]), + (&peer_ids[1], &session_ids[1]), + (&peer_ids[2], &session_ids[2]), + (&peer_ids[2], &session_ids[3]), + ]; + + let now = 0; + for (peer_id, session_id) in test_vector.clone() { + let session = sessions.create_responder(peer_id.clone(), session_id.clone()); + let _ = sessions.add(session, now); + } + + let removed_sessions = sessions.drain(); + assert_eq!(removed_sessions.len(), 4); + assert_eq!(sessions.session_count(), 0); + assert_eq!(sessions.peer_count(), 0); + } +} diff --git a/runtime/src/enclave_rpc/transport.rs b/runtime/src/enclave_rpc/transport.rs index 8f91bc679e1..3d0586a7a72 100644 --- a/runtime/src/enclave_rpc/transport.rs +++ b/runtime/src/enclave_rpc/transport.rs @@ -62,7 +62,7 @@ pub trait Transport: Send + Sync { async fn submit_peer_feedback( &self, request_id: u64, - peer_feedback: types::PeerFeedback, + feedback: types::PeerFeedback, ) -> Result<(), AnyError>; } diff --git a/tests/runtimes/simple-keyvalue/src/main.rs b/tests/runtimes/simple-keyvalue/src/main.rs index 4c46322ba2b..1fc08bcfe23 100644 --- a/tests/runtimes/simple-keyvalue/src/main.rs +++ b/tests/runtimes/simple-keyvalue/src/main.rs @@ -394,8 +394,7 @@ pub fn main_with_version(version: Version) { state .rpc_dispatcher .set_keymanager_status_update_handler(Some(Box::new(move |status| { - key_manager - .set_status(status) + block_on(key_manager.set_status(status)) .expect("failed to update km client status"); }))); @@ -403,7 +402,7 @@ pub fn main_with_version(version: Version) { state .rpc_dispatcher .set_keymanager_quote_policy_update_handler(Some(Box::new(move |policy| { - key_manager.set_quote_policy(policy); + block_on(key_manager.set_quote_policy(policy)); }))); let dispatcher = Dispatcher::new(hi, km_client, state.consensus_verifier.clone());