diff --git a/src/discv5.rs b/src/discv5.rs index 17a463a45..0acf1587a 100644 --- a/src/discv5.rs +++ b/src/discv5.rs @@ -49,7 +49,7 @@ lazy_static! { RwLock::new(crate::PermitBanList::default()); } -mod test; +pub(crate) mod test; /// Events that can be produced by the `Discv5` event stream. #[derive(Debug)] diff --git a/src/discv5/test.rs b/src/discv5/test.rs index 47ecad9b8..f6cd70225 100644 --- a/src/discv5/test.rs +++ b/src/discv5/test.rs @@ -116,7 +116,7 @@ async fn build_nodes_from_keypairs_dual_stack( } /// Generate `n` deterministic keypairs from a given seed. -fn generate_deterministic_keypair(n: usize, seed: u64) -> Vec { +pub(crate) fn generate_deterministic_keypair(n: usize, seed: u64) -> Vec { let mut keypairs = Vec::new(); for i in 0..n { let sk = { diff --git a/src/handler/active_requests.rs b/src/handler/active_requests.rs index e46ccee83..8cfa2d602 100644 --- a/src/handler/active_requests.rs +++ b/src/handler/active_requests.rs @@ -1,68 +1,142 @@ use super::*; use delay_map::HashMapDelay; use more_asserts::debug_unreachable; +use std::collections::hash_map::Entry; pub(super) struct ActiveRequests { /// A list of raw messages we are awaiting a response from the remote. - active_requests_mapping: HashMapDelay, + active_requests_mapping: HashMap>, // WHOAREYOU messages do not include the source node id. We therefore maintain another // mapping of active_requests via message_nonce. This allows us to match WHOAREYOU // requests with active requests sent. - /// A mapping of all pending active raw requests message nonces to their NodeAddress. - active_requests_nonce_mapping: HashMap, + /// A mapping of all active raw requests message nonces to their NodeAddress. + active_requests_nonce_mapping: HashMapDelay, } impl ActiveRequests { pub fn new(request_timeout: Duration) -> Self { ActiveRequests { - active_requests_mapping: HashMapDelay::new(request_timeout), - active_requests_nonce_mapping: HashMap::new(), + active_requests_mapping: HashMap::new(), + active_requests_nonce_mapping: HashMapDelay::new(request_timeout), } } + /// Insert a new request into the active requests mapping. pub fn insert(&mut self, node_address: NodeAddress, request_call: RequestCall) { let nonce = *request_call.packet().message_nonce(); self.active_requests_mapping - .insert(node_address.clone(), request_call); + .entry(node_address.clone()) + .or_default() + .push(request_call); self.active_requests_nonce_mapping .insert(nonce, node_address); } - pub fn get(&self, node_address: &NodeAddress) -> Option<&RequestCall> { + /// Update the underlying packet for the request via message nonce. + pub fn update_packet(&mut self, old_nonce: MessageNonce, new_packet: Packet) { + let node_address = + if let Some(node_address) = self.active_requests_nonce_mapping.remove(&old_nonce) { + node_address + } else { + debug_unreachable!("expected to find nonce in active_requests_nonce_mapping"); + error!("expected to find nonce in active_requests_nonce_mapping"); + return; + }; + + self.active_requests_nonce_mapping + .insert(new_packet.header.message_nonce, node_address.clone()); + + match self.active_requests_mapping.entry(node_address) { + Entry::Occupied(mut requests) => { + let maybe_request_call = requests + .get_mut() + .iter_mut() + .find(|req| req.packet().message_nonce() == &old_nonce); + + if let Some(request_call) = maybe_request_call { + request_call.update_packet(new_packet); + } else { + debug_unreachable!("expected to find request call in active_requests_mapping"); + error!("expected to find request call in active_requests_mapping"); + } + } + Entry::Vacant(_) => { + debug_unreachable!("expected to find node address in active_requests_mapping"); + error!("expected to find node address in active_requests_mapping"); + } + } + } + + pub fn get(&self, node_address: &NodeAddress) -> Option<&Vec> { self.active_requests_mapping.get(node_address) } + /// Remove a single request identified by its nonce. pub fn remove_by_nonce(&mut self, nonce: &MessageNonce) -> Option<(NodeAddress, RequestCall)> { - match self.active_requests_nonce_mapping.remove(nonce) { - Some(node_address) => match self.active_requests_mapping.remove(&node_address) { - Some(request_call) => Some((node_address, request_call)), - None => { - debug_unreachable!("A matching request call doesn't exist"); - error!("A matching request call doesn't exist"); - None + let node_address = self.active_requests_nonce_mapping.remove(nonce)?; + match self.active_requests_mapping.entry(node_address.clone()) { + Entry::Vacant(_) => { + debug_unreachable!("expected to find node address in active_requests_mapping"); + error!("expected to find node address in active_requests_mapping"); + None + } + Entry::Occupied(mut requests) => { + let result = requests + .get() + .iter() + .position(|req| req.packet().message_nonce() == nonce) + .map(|index| (node_address, requests.get_mut().remove(index))); + if requests.get().is_empty() { + requests.remove(); } - }, - None => None, + result + } } } - pub fn remove(&mut self, node_address: &NodeAddress) -> Option { - match self.active_requests_mapping.remove(node_address) { - Some(request_call) => { - // Remove the associated nonce mapping. - match self - .active_requests_nonce_mapping - .remove(request_call.packet().message_nonce()) - { - Some(_) => Some(request_call), - None => { - debug_unreachable!("A matching nonce mapping doesn't exist"); - error!("A matching nonce mapping doesn't exist"); - None - } + /// Remove all requests associated with a node. + pub fn remove_requests(&mut self, node_address: &NodeAddress) -> Option> { + let requests = self.active_requests_mapping.remove(node_address)?; + // Account for node addresses in `active_requests_nonce_mapping` with an empty list + if requests.is_empty() { + debug_unreachable!("expected to find requests in active_requests_mapping"); + return None; + } + for req in &requests { + if self + .active_requests_nonce_mapping + .remove(req.packet().message_nonce()) + .is_none() + { + debug_unreachable!("expected to find req with nonce"); + error!("expected to find req with nonce"); + } + } + Some(requests) + } + + /// Remove a single request identified by its id. + pub fn remove_request( + &mut self, + node_address: &NodeAddress, + id: &RequestId, + ) -> Option { + match self.active_requests_mapping.entry(node_address.clone()) { + Entry::Vacant(_) => None, + Entry::Occupied(mut requests) => { + let index = requests.get().iter().position(|req| { + let req_id: RequestId = req.id().into(); + &req_id == id + })?; + let request_call = requests.get_mut().remove(index); + if requests.get().is_empty() { + requests.remove(); } + // Remove the associated nonce mapping. + self.active_requests_nonce_mapping + .remove(request_call.packet().message_nonce()); + Some(request_call) } - None => None, } } @@ -80,10 +154,12 @@ impl ActiveRequests { } } - for (address, request) in self.active_requests_mapping.iter() { - let nonce = request.packet().message_nonce(); - if !self.active_requests_nonce_mapping.contains_key(nonce) { - panic!("Address {} maps to request with nonce {:?}, which does not exist in `active_requests_nonce_mapping`", address, nonce); + for (address, requests) in self.active_requests_mapping.iter() { + for req in requests { + let nonce = req.packet().message_nonce(); + if !self.active_requests_nonce_mapping.contains_key(nonce) { + panic!("Address {} maps to request with nonce {:?}, which does not exist in `active_requests_nonce_mapping`", address, nonce); + } } } } @@ -92,12 +168,27 @@ impl ActiveRequests { impl Stream for ActiveRequests { type Item = Result<(NodeAddress, RequestCall), String>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.active_requests_mapping.poll_next_unpin(cx) { - Poll::Ready(Some(Ok((node_address, request_call)))) => { - // Remove the associated nonce mapping. - self.active_requests_nonce_mapping - .remove(request_call.packet().message_nonce()); - Poll::Ready(Some(Ok((node_address, request_call)))) + match self.active_requests_nonce_mapping.poll_next_unpin(cx) { + Poll::Ready(Some(Ok((nonce, node_address)))) => { + match self.active_requests_mapping.entry(node_address.clone()) { + Entry::Vacant(_) => Poll::Ready(None), + Entry::Occupied(mut requests) => { + match requests + .get() + .iter() + .position(|req| req.packet().message_nonce() == &nonce) + { + Some(index) => { + let result = (node_address, requests.get_mut().remove(index)); + if requests.get().is_empty() { + requests.remove(); + } + Poll::Ready(Some(Ok(result))) + } + None => Poll::Ready(None), + } + } + } } Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(None) => Poll::Ready(None), diff --git a/src/handler/mod.rs b/src/handler/mod.rs index c3c8f28f7..4547078e8 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -39,6 +39,7 @@ use crate::{ use delay_map::HashMapDelay; use enr::{CombinedKey, NodeId}; use futures::prelude::*; +use more_asserts::debug_unreachable; use parking_lot::RwLock; use smallvec::SmallVec; use std::{ @@ -176,6 +177,15 @@ struct PendingRequest { request: RequestBody, } +impl From<&HandlerReqId> for RequestId { + fn from(id: &HandlerReqId) -> Self { + match id { + HandlerReqId::Internal(id) => id.clone(), + HandlerReqId::External(id) => id.clone(), + } + } +} + /// Process to handle handshakes and sessions established from raw RPC communications between nodes. pub struct Handler { /// Configuration for the discv5 service. @@ -187,7 +197,7 @@ pub struct Handler { enr: Arc>, /// The key to sign the ENR and set up encrypted communication with peers. key: Arc>, - /// Pending raw requests. + /// Active requests that are awaiting a response. active_requests: ActiveRequests, /// The expected responses by SocketAddr which allows packets to pass the underlying filter. filter_expected_responses: Arc>>, @@ -331,13 +341,13 @@ impl Handler { Some(inbound_packet) = self.socket.recv.recv() => { self.process_inbound_packet::

(inbound_packet).await; } - Some(Ok((node_address, pending_request))) = self.active_requests.next() => { - self.handle_request_timeout(node_address, pending_request).await; + Some(Ok((node_address, active_request))) = self.active_requests.next() => { + self.handle_request_timeout(node_address, active_request).await; } Some(Ok((node_address, _challenge))) = self.active_challenges.next() => { // A challenge has expired. There could be pending requests awaiting this // challenge. We process them here - self.send_next_request::

(node_address).await; + self.send_pending_requests::

(&node_address).await; } _ = banned_nodes_check.tick() => self.unban_nodes_check(), // Unban nodes that are past the timeout _ = &mut self.exit => { @@ -392,7 +402,7 @@ impl Handler { socket_addr: inbound_packet.src_address, node_id: src_id, }; - self.handle_message::

( + self.handle_message( node_address, message_nonce, &inbound_packet.message, @@ -464,9 +474,10 @@ impl Handler { return Err(RequestError::SelfRequest); } - // If there is already an active request or an active challenge (WHOAREYOU sent) for this node, add to pending requests - if self.active_requests.get(&node_address).is_some() - || self.active_challenges.get(&node_address).is_some() + // If there is already an active challenge (WHOAREYOU sent) for this node, or if we are + // awaiting a session with this node to be established, add the request to pending requests. + if self.active_challenges.get(&node_address).is_some() + || self.is_awaiting_session_to_be_established(&node_address) { trace!("Request queued for node: {}", node_address); self.pending_requests @@ -494,14 +505,13 @@ impl Handler { .map_err(|e| RequestError::EncryptionFailed(format!("{e:?}")))?; (packet, false) } else { - // No session exists, start a new handshake + // No session exists, start a new handshake initiating a new session trace!( "Starting session. Sending random packet to: {}", node_address ); let packet = Packet::new_random(&self.node_id).map_err(RequestError::EntropyFailure)?; - // We are initiating a new session (packet, true) } }; @@ -688,6 +698,7 @@ impl Handler { // All sent requests must have an associated node_id. Therefore the following // must not panic. let node_address = request_call.contact().node_address(); + let auth_message_nonce = auth_packet.header.message_nonce; match request_call.contact().enr() { Some(enr) => { // NOTE: Here we decide if the session is outgoing or ingoing. The condition for an @@ -701,7 +712,11 @@ impl Handler { }; // We already know the ENR. Send the handshake response packet - trace!("Sending Authentication response to node: {}", node_address); + trace!( + "Sending Authentication response to node: {} ({:?})", + node_address, + request_call.id() + ); request_call.update_packet(auth_packet.clone()); request_call.set_handshake_sent(); request_call.set_initiating_session(false); @@ -725,7 +740,11 @@ impl Handler { // Send the Auth response let contact = request_call.contact().clone(); - trace!("Sending Authentication response to node: {}", node_address); + trace!( + "Sending Authentication response to node: {} ({:?})", + node_address, + request_call.id() + ); request_call.update_packet(auth_packet.clone()); request_call.set_handshake_sent(); // Reinsert the request_call @@ -743,7 +762,8 @@ impl Handler { } } } - self.new_session(node_address, session); + self.new_session::

(node_address.clone(), session, Some(auth_message_nonce)) + .await; } /// Verifies a Node ENR to it's observed address. If it fails, any associated session is also @@ -813,17 +833,18 @@ impl Handler { { warn!("Failed to inform of established session {}", e) } - self.new_session(node_address.clone(), session); - self.handle_message::

( + // When (re-)establishing a session from an outgoing challenge, we do not need + // to filter out this request from active requests, so we do not pass + // the message nonce on to `new_session`. + self.new_session::

(node_address.clone(), session, None) + .await; + self.handle_message( node_address.clone(), message_nonce, message, authenticated_data, ) .await; - // We could have pending messages that were awaiting this session to be - // established. If so process them. - self.send_next_request::

(node_address).await; } else { // IP's or NodeAddress don't match. Drop the session. warn!( @@ -893,41 +914,37 @@ impl Handler { } } - async fn send_next_request(&mut self, node_address: NodeAddress) { - // ensure we are not over writing any existing requests - if self.active_requests.get(&node_address).is_none() { - if let std::collections::hash_map::Entry::Occupied(mut entry) = - self.pending_requests.entry(node_address) + /// Send all pending requests corresponding to the given node address, that were waiting for a + /// new session to be established or when an active outgoing challenge has expired. + async fn send_pending_requests(&mut self, node_address: &NodeAddress) { + let pending_requests = self + .pending_requests + .remove(node_address) + .unwrap_or_default(); + for req in pending_requests { + trace!( + "Sending pending request {} to {node_address}. {}", + RequestId::from(&req.request_id), + req.request, + ); + if let Err(request_error) = self + .send_request::

(req.contact, req.request_id.clone(), req.request) + .await { - // If it exists, there must be a request here - let PendingRequest { - contact, - request_id, - request, - } = entry.get_mut().remove(0); - if entry.get().is_empty() { - entry.remove(); - } - trace!("Sending next awaiting message. Node: {}", contact); - if let Err(request_error) = self - .send_request::

(contact, request_id.clone(), request) - .await - { - warn!("Failed to send next awaiting request {}", request_error); - // Inform the service that the request failed - match request_id { - HandlerReqId::Internal(_) => { - // An internal request could not be sent. For now we do nothing about - // this. - } - HandlerReqId::External(id) => { - if let Err(e) = self - .service_send - .send(HandlerOut::RequestFailed(id, request_error)) - .await - { - warn!("Failed to inform that request failed {}", e); - } + warn!("Failed to send next pending request {request_error}"); + // Inform the service that the request failed + match req.request_id { + HandlerReqId::Internal(_) => { + // An internal request could not be sent. For now we do nothing about + // this. + } + HandlerReqId::External(id) => { + if let Err(e) = self + .service_send + .send(HandlerOut::RequestFailed(id, request_error)) + .await + { + warn!("Failed to inform that request failed {e}"); } } } @@ -935,9 +952,67 @@ impl Handler { } } + /// Replays all active requests for the given node address, in the case that a new session has + /// been established. If an optional message nonce is provided, the corresponding request will + /// be skipped, eg. the request that established the new session. + async fn replay_active_requests( + &mut self, + node_address: &NodeAddress, + // Optional message nonce to filter out the request used to establish the session. + message_nonce: Option, + ) { + trace!( + "Replaying active requests. {}, {:?}", + node_address, + message_nonce + ); + + let packets = if let Some(session) = self.sessions.get_mut(node_address) { + let mut packets = vec![]; + for request_call in self + .active_requests + .get(node_address) + .unwrap_or(&vec![]) + .iter() + .filter(|req| { + // Except the active request that was used to establish the new session, as it has + // already been handled and shouldn't be replayed. + if let Some(nonce) = message_nonce.as_ref() { + req.packet().message_nonce() != nonce + } else { + true + } + }) + { + if let Ok(new_packet) = + session.encrypt_message::

(self.node_id, &request_call.encode()) + { + packets.push((*request_call.packet().message_nonce(), new_packet)); + } else { + error!( + "Failed to re-encrypt packet while replaying active request with id: {:?}", + request_call.id() + ); + } + } + + packets + } else { + debug_unreachable!("Attempted to replay active requests but session doesn't exist."); + error!("Attempted to replay active requests but session doesn't exist."); + return; + }; + + for (old_nonce, new_packet) in packets { + self.active_requests + .update_packet(old_nonce, new_packet.clone()); + self.send(node_address.clone(), new_packet).await; + } + } + /// Handle a standard message that does not contain an authentication header. #[allow(clippy::single_match)] - async fn handle_message( + async fn handle_message( &mut self, node_address: NodeAddress, message_nonce: MessageNonce, @@ -1039,7 +1114,7 @@ impl Handler { } } // Handle standard responses - self.handle_response::

(node_address, response).await; + self.handle_response(node_address, response).await; } } } else { @@ -1063,28 +1138,13 @@ impl Handler { /// Handles a response to a request. Re-inserts the request call if the response is a multiple /// Nodes response. - async fn handle_response( - &mut self, - node_address: NodeAddress, - response: Response, - ) { + async fn handle_response(&mut self, node_address: NodeAddress, response: Response) { // Find a matching request, if any - if let Some(mut request_call) = self.active_requests.remove(&node_address) { - let id = match request_call.id() { - HandlerReqId::Internal(id) | HandlerReqId::External(id) => id, - }; - if id != &response.id { - trace!( - "Received an RPC Response to an unknown request. Likely late response. {}", - node_address - ); - // add the request back and reset the timer - self.active_requests.insert(node_address, request_call); - return; - } - + if let Some(mut request_call) = self + .active_requests + .remove_request(&node_address, &response.id) + { // The response matches a request - // Check to see if this is a Nodes response, in which case we may require to wait for // extra responses if let ResponseBody::Nodes { total, .. } = response.body { @@ -1138,7 +1198,6 @@ impl Handler { { warn!("Failed to inform of response {}", e) } - self.send_next_request::

(node_address).await; } else { // This is likely a late response and we have already failed the request. These get // dropped here. @@ -1154,14 +1213,31 @@ impl Handler { self.active_requests.insert(node_address, request_call); } - fn new_session(&mut self, node_address: NodeAddress, session: Session) { + /// Establishes a new session with a peer, or re-establishes an existing session if a + /// new challenge was issued during an ongoing session. + async fn new_session( + &mut self, + node_address: NodeAddress, + session: Session, + // Optional message nonce is required to filter out the request that was used in the + // handshake to re-establish a session, if applicable. + message_nonce: Option, + ) { if let Some(current_session) = self.sessions.get_mut(&node_address) { current_session.update(session); + // If a session is re-established, due to a new handshake during an ongoing + // session, we need to replay any active requests from the prior session, excluding + // the request that was used to re-establish the session handshake. + self.replay_active_requests::

(&node_address, message_nonce) + .await; } else { - self.sessions.insert(node_address, session); + self.sessions.insert(node_address.clone(), session); METRICS .active_sessions .store(self.sessions.len(), Ordering::Relaxed); + // We could have pending messages that were awaiting this session to be + // established. If so process them. + self.send_pending_requests::

(&node_address).await; } } @@ -1212,7 +1288,7 @@ impl Handler { .await; } - /// Removes a session and updates associated metrics and fields. + /// Removes a session, fails all of that session's active & pending requests, and updates associated metrics and fields. async fn fail_session( &mut self, node_address: &NodeAddress, @@ -1225,6 +1301,7 @@ impl Handler { .active_sessions .store(self.sessions.len(), Ordering::Relaxed); } + // fail all pending requests if let Some(to_remove) = self.pending_requests.remove(node_address) { for PendingRequest { request_id, .. } in to_remove { match request_id { @@ -1243,6 +1320,28 @@ impl Handler { } } } + // fail all active requests + for req in self + .active_requests + .remove_requests(node_address) + .unwrap_or_default() + { + match req.id() { + HandlerReqId::Internal(_) => { + // Do not report failures on requests belonging to the handler. + } + HandlerReqId::External(id) => { + if let Err(e) = self + .service_send + .send(HandlerOut::RequestFailed(id.clone(), error.clone())) + .await + { + warn!("Failed to inform request failure {e}") + } + } + } + self.remove_expected_response(node_address.socket_addr); + } } /// Sends a packet to the send handler to be encoded and sent. @@ -1267,4 +1366,19 @@ impl Handler { .ban_nodes .retain(|_, time| time.is_none() || Some(Instant::now()) < *time); } + + /// Returns whether a session with this node does not exist and a request that initiates + /// a session has been sent. + fn is_awaiting_session_to_be_established(&mut self, node_address: &NodeAddress) -> bool { + if self.sessions.get(node_address).is_some() { + // session exists + return false; + } + + if let Some(requests) = self.active_requests.get(node_address) { + requests.iter().any(|req| req.initiating_session()) + } else { + false + } + } } diff --git a/src/handler/tests.rs b/src/handler/tests.rs index 187ed5aba..a2bc95659 100644 --- a/src/handler/tests.rs +++ b/src/handler/tests.rs @@ -8,8 +8,11 @@ use crate::{ ConfigBuilder, IpMode, }; use std::{ + collections::HashSet, convert::TryInto, net::{Ipv4Addr, Ipv6Addr}, + num::NonZeroU16, + ops::Add, }; use crate::{ @@ -333,35 +336,170 @@ async fn multiple_messages() { } } +fn create_node() -> Enr { + let key = CombinedKey::generate_secp256k1(); + let ip = "127.0.0.1".parse().unwrap(); + let port = 8080 + rand::random::() % 1000; + Enr::builder().ip4(ip).udp4(port).build(&key).unwrap() +} + +fn create_req_call(node: &Enr) -> (RequestCall, NodeAddress) { + let node_contact: NodeContact = node.clone().into(); + let packet = Packet::new_random(&node.node_id()).unwrap(); + let id = HandlerReqId::Internal(RequestId::random()); + let request = RequestBody::Ping { enr_seq: 1 }; + let initiating_session = true; + let node_addr = node_contact.node_address(); + let req = RequestCall::new(node_contact, packet, id, request, initiating_session); + (req, node_addr) +} + #[tokio::test] async fn test_active_requests_insert() { const EXPIRY: Duration = Duration::from_secs(5); let mut active_requests = ActiveRequests::new(EXPIRY); - // Create the test values needed - let port = 5000; - let ip = "127.0.0.1".parse().unwrap(); + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); - let key = CombinedKey::generate_secp256k1(); + // insert the pair and verify the mapping remains in sync + active_requests.insert(req_1_addr, req_1); + active_requests.check_invariant(); + active_requests.insert(req_2_addr, req_2); + active_requests.check_invariant(); + active_requests.insert(req_3_addr, req_3); + active_requests.check_invariant(); +} - let enr = Enr::builder().ip4(ip).udp4(port).build(&key).unwrap(); - let node_id = enr.node_id(); +#[tokio::test] +async fn test_active_requests_remove_requests() { + const EXPIRY: Duration = Duration::from_secs(5); + let mut active_requests = ActiveRequests::new(EXPIRY); - let contact: NodeContact = enr.into(); - let node_address = contact.node_address(); + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); + active_requests.insert(req_1_addr.clone(), req_1); + active_requests.insert(req_2_addr.clone(), req_2); + active_requests.insert(req_3_addr.clone(), req_3); + active_requests.check_invariant(); + let reqs = active_requests.remove_requests(&req_1_addr).unwrap(); + assert_eq!(reqs.len(), 1); + active_requests.check_invariant(); + let reqs = active_requests.remove_requests(&req_2_addr).unwrap(); + assert_eq!(reqs.len(), 2); + active_requests.check_invariant(); + assert!(active_requests.remove_requests(&req_3_addr).is_none()); +} - let packet = Packet::new_random(&node_id).unwrap(); - let id = HandlerReqId::Internal(RequestId::random()); - let request = RequestBody::Ping { enr_seq: 1 }; - let initiating_session = true; - let request_call = RequestCall::new(contact, packet, id, request, initiating_session); +#[tokio::test] +async fn test_active_requests_remove_request() { + const EXPIRY: Duration = Duration::from_secs(5); + let mut active_requests = ActiveRequests::new(EXPIRY); - // insert the pair and verify the mapping remains in sync - let nonce = *request_call.packet().message_nonce(); - active_requests.insert(node_address, request_call); + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); + let req_1_id = req_1.id().into(); + let req_2_id = req_2.id().into(); + let req_3_id = req_3.id().into(); + + active_requests.insert(req_1_addr.clone(), req_1); + active_requests.insert(req_2_addr.clone(), req_2); + active_requests.insert(req_3_addr.clone(), req_3); + active_requests.check_invariant(); + let req_id: RequestId = active_requests + .remove_request(&req_1_addr, &req_1_id) + .unwrap() + .id() + .into(); + assert_eq!(req_id, req_1_id); active_requests.check_invariant(); - active_requests.remove_by_nonce(&nonce); + let req_id: RequestId = active_requests + .remove_request(&req_2_addr, &req_2_id) + .unwrap() + .id() + .into(); + assert_eq!(req_id, req_2_id); active_requests.check_invariant(); + let req_id: RequestId = active_requests + .remove_request(&req_3_addr, &req_3_id) + .unwrap() + .id() + .into(); + assert_eq!(req_id, req_3_id); + active_requests.check_invariant(); + assert!(active_requests + .remove_request(&req_3_addr, &req_3_id) + .is_none()); +} + +#[tokio::test] +async fn test_active_requests_remove_by_nonce() { + const EXPIRY: Duration = Duration::from_secs(5); + let mut active_requests = ActiveRequests::new(EXPIRY); + + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); + let req_1_nonce = *req_1.packet().message_nonce(); + let req_2_nonce = *req_2.packet().message_nonce(); + let req_3_nonce = *req_3.packet().message_nonce(); + + active_requests.insert(req_1_addr.clone(), req_1); + active_requests.insert(req_2_addr.clone(), req_2); + active_requests.insert(req_3_addr.clone(), req_3); + active_requests.check_invariant(); + + let req = active_requests.remove_by_nonce(&req_1_nonce).unwrap(); + assert_eq!(req.0, req_1_addr); + active_requests.check_invariant(); + let req = active_requests.remove_by_nonce(&req_2_nonce).unwrap(); + assert_eq!(req.0, req_2_addr); + active_requests.check_invariant(); + let req = active_requests.remove_by_nonce(&req_3_nonce).unwrap(); + assert_eq!(req.0, req_3_addr); + active_requests.check_invariant(); + let random_nonce = rand::random(); + assert!(active_requests.remove_by_nonce(&random_nonce).is_none()); +} + +#[tokio::test] +async fn test_active_requests_update_packet() { + const EXPIRY: Duration = Duration::from_secs(5); + let mut active_requests = ActiveRequests::new(EXPIRY); + + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); + + let old_nonce = *req_2.packet().message_nonce(); + active_requests.insert(req_1_addr, req_1); + active_requests.insert(req_2_addr.clone(), req_2); + active_requests.insert(req_3_addr, req_3); + active_requests.check_invariant(); + + let new_packet = Packet::new_random(&node_2.node_id()).unwrap(); + let new_nonce = new_packet.message_nonce(); + active_requests.update_packet(old_nonce, new_packet.clone()); + active_requests.check_invariant(); + + assert_eq!(2, active_requests.get(&req_2_addr).unwrap().len()); + assert!(active_requests.remove_by_nonce(&old_nonce).is_none()); + let (addr, req) = active_requests.remove_by_nonce(new_nonce).unwrap(); + assert_eq!(addr, req_2_addr); + assert_eq!(req.packet(), &new_packet); } #[tokio::test] @@ -485,3 +623,417 @@ async fn remove_one_time_session() { .is_some()); assert_eq!(0, handler.one_time_sessions.len()); } + +// Tests replaying active requests. +// +// In this test, Receiver's session expires and Receiver returns WHOAREYOU. +// Sender then creates a new session and resend active requests. +// +// ```mermaid +// sequenceDiagram +// participant Sender +// participant Receiver +// Note over Sender: Start discv5 server +// Note over Receiver: Start discv5 server +// +// Note over Sender,Receiver: Session established +// +// rect rgb(100, 100, 0) +// Note over Receiver: ** Session expired ** +// end +// +// rect rgb(10, 10, 10) +// Note left of Sender: Sender sends requests
**in parallel**. +// par +// Sender ->> Receiver: PING(id:2) +// and +// Sender -->> Receiver: PING(id:3) +// and +// Sender -->> Receiver: PING(id:4) +// and +// Sender -->> Receiver: PING(id:5) +// end +// end +// +// Note over Receiver: Send WHOAREYOU
since the session has been expired +// Receiver ->> Sender: WHOAREYOU +// +// rect rgb(100, 100, 0) +// Note over Receiver: Drop PING(id:2,3,4,5) request
since WHOAREYOU already sent. +// end +// +// Note over Sender: New session established with Receiver +// +// Sender ->> Receiver: Handshake message (id:2) +// +// Note over Receiver: New session established with Sender +// +// rect rgb(10, 10, 10) +// Note left of Sender: Handler::replay_active_requests() +// Sender ->> Receiver: PING (id:3) +// Sender ->> Receiver: PING (id:4) +// Sender ->> Receiver: PING (id:5) +// end +// +// Receiver ->> Sender: PONG (id:2) +// Receiver ->> Sender: PONG (id:3) +// Receiver ->> Sender: PONG (id:4) +// Receiver ->> Sender: PONG (id:5) +// ``` +#[tokio::test] +async fn test_replay_active_requests() { + init(); + let sender_port = 5006; + let receiver_port = 5007; + let ip = "127.0.0.1".parse().unwrap(); + let key1 = CombinedKey::generate_secp256k1(); + let key2 = CombinedKey::generate_secp256k1(); + + let sender_enr = Enr::builder() + .ip4(ip) + .udp4(sender_port) + .build(&key1) + .unwrap(); + + let receiver_enr = Enr::builder() + .ip4(ip) + .udp4(receiver_port) + .build(&key2) + .unwrap(); + + // Build sender handler + let (sender_exit, sender_send, mut sender_recv, mut handler) = { + let sender_listen_config = ListenConfig::Ipv4 { + ip: sender_enr.ip4().unwrap(), + port: sender_enr.udp4().unwrap(), + }; + let sender_config = ConfigBuilder::new(sender_listen_config).build(); + build_handler::(sender_enr.clone(), key1, sender_config).await + }; + let sender = async move { + // Start sender handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; + + // Build receiver handler + // Shorten receiver's timeout to reproduce session expired. + let receiver_session_timeout = Duration::from_secs(1); + let (receiver_exit, receiver_send, mut receiver_recv, mut handler) = { + let receiver_listen_config = ListenConfig::Ipv4 { + ip: receiver_enr.ip4().unwrap(), + port: receiver_enr.udp4().unwrap(), + }; + let receiver_config = ConfigBuilder::new(receiver_listen_config) + .session_timeout(receiver_session_timeout) + .build(); + build_handler::(receiver_enr.clone(), key2, receiver_config).await + }; + let receiver = async move { + // Start receiver handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; + + let messages_to_send = 5usize; + + let sender_ops = async move { + let mut response_count = 0usize; + let mut expected_request_ids = HashSet::new(); + expected_request_ids.insert(RequestId(vec![1])); + + // sender to send the first message then await for the session to be established + let _ = sender_send.send(HandlerIn::Request( + receiver_enr.clone().into(), + Box::new(Request { + id: RequestId(vec![1]), + body: RequestBody::Ping { enr_seq: 1 }, + }), + )); + + match sender_recv.recv().await { + Some(HandlerOut::Established(_, _, _)) => { + // Sleep until receiver's session expired. + tokio::time::sleep(receiver_session_timeout.add(Duration::from_millis(500))).await; + // send the rest of the messages + for req_id in 2..=messages_to_send { + let request_id = RequestId(vec![req_id as u8]); + expected_request_ids.insert(request_id.clone()); + let _ = sender_send.send(HandlerIn::Request( + receiver_enr.clone().into(), + Box::new(Request { + id: request_id, + body: RequestBody::Ping { enr_seq: 1 }, + }), + )); + } + } + handler_out => panic!("Unexpected message: {:?}", handler_out), + } + + loop { + match sender_recv.recv().await { + Some(HandlerOut::Response(_, response)) => { + assert!(expected_request_ids.remove(&response.id)); + response_count += 1; + if response_count == messages_to_send { + // Notify the handlers that the message exchange has been completed. + assert!(expected_request_ids.is_empty()); + sender_exit.send(()).unwrap(); + receiver_exit.send(()).unwrap(); + return; + } + } + _ => continue, + }; + } + }; + + let receiver_ops = async move { + let mut message_count = 0usize; + loop { + match receiver_recv.recv().await { + Some(HandlerOut::WhoAreYou(wru_ref)) => { + receiver_send + .send(HandlerIn::WhoAreYou(wru_ref, Some(sender_enr.clone()))) + .unwrap(); + } + Some(HandlerOut::Request(addr, request)) => { + assert!(matches!(request.body, RequestBody::Ping { .. })); + let pong_response = Response { + id: request.id, + body: ResponseBody::Pong { + enr_seq: 1, + ip: ip.into(), + port: NonZeroU16::new(sender_port).unwrap(), + }, + }; + receiver_send + .send(HandlerIn::Response(addr, Box::new(pong_response))) + .unwrap(); + message_count += 1; + if message_count == messages_to_send { + return; + } + } + _ => { + continue; + } + } + } + }; + + let sleep_future = sleep(Duration::from_secs(5)); + let message_exchange = async move { + let _ = tokio::join!(sender, sender_ops, receiver, receiver_ops); + }; + + tokio::select! { + _ = message_exchange => {} + _ = sleep_future => { + panic!("Test timed out"); + } + } +} + +// Tests sending pending requests. +// +// Sender attempts to send multiple requests in parallel, but due to the absence of a session, only +// one of the requests from Sender is sent and others are inserted into `pending_requests`. +// The pending requests are sent once a session is established. +// +// ```mermaid +// sequenceDiagram +// participant Sender +// participant Receiver +// +// Note over Sender: No session with Receiver +// +// rect rgb(10, 10, 10) +// Note left of Sender: Sender attempts to send multiple requests in parallel
but no session with Receiver.
So Sender sends a random packet for the first request,
and the rest of the requests are inserted into pending_requests. +// par +// Sender ->> Receiver: Random packet (id:1) +// Note over Sender: Insert the request into `active_requests` +// and +// Note over Sender: Insert Request(id:2) into *pending_requests* +// and +// Note over Sender: Insert Request(id:3) into *pending_requests* +// end +// end +// +// Receiver ->> Sender: WHOAREYOU (id:1) +// +// Note over Sender: New session established with Receiver +// +// rect rgb(0, 100, 0) +// Note over Sender: Send pending requests since a session has been established. +// Sender ->> Receiver: Request (id:2) +// Sender ->> Receiver: Request (id:3) +// end +// +// Sender ->> Receiver: Handshake message (id:1) +// +// Note over Receiver: New session established with Sender +// +// Receiver ->> Sender: Response (id:2) +// Receiver ->> Sender: Response (id:3) +// Receiver ->> Sender: Response (id:1) +// +// Note over Sender: The request (id:2) completed. +// Note over Sender: The request (id:3) completed. +// Note over Sender: The request (id:1) completed. +// ``` +#[tokio::test] +async fn test_send_pending_request() { + init(); + let sender_port = 5008; + let receiver_port = 5009; + let ip = "127.0.0.1".parse().unwrap(); + let key1 = CombinedKey::generate_secp256k1(); + let key2 = CombinedKey::generate_secp256k1(); + + let sender_enr = Enr::builder() + .ip4(ip) + .udp4(sender_port) + .build(&key1) + .unwrap(); + + let receiver_enr = Enr::builder() + .ip4(ip) + .udp4(receiver_port) + .build(&key2) + .unwrap(); + + // Build sender handler + let (sender_exit, sender_send, mut sender_recv, mut handler) = { + let sender_listen_config = ListenConfig::Ipv4 { + ip: sender_enr.ip4().unwrap(), + port: sender_enr.udp4().unwrap(), + }; + let sender_config = ConfigBuilder::new(sender_listen_config).build(); + build_handler::(sender_enr.clone(), key1, sender_config).await + }; + let sender = async move { + // Start sender handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; + + // Build receiver handler + // Shorten receiver's timeout to reproduce session expired. + let receiver_session_timeout = Duration::from_secs(1); + let (receiver_exit, receiver_send, mut receiver_recv, mut handler) = { + let receiver_listen_config = ListenConfig::Ipv4 { + ip: receiver_enr.ip4().unwrap(), + port: receiver_enr.udp4().unwrap(), + }; + let receiver_config = ConfigBuilder::new(receiver_listen_config) + .session_timeout(receiver_session_timeout) + .build(); + build_handler::(receiver_enr.clone(), key2, receiver_config).await + }; + let receiver = async move { + // Start receiver handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; + + let messages_to_send = 3usize; + + let sender_ops = async move { + let mut response_count = 0usize; + let mut expected_request_ids = HashSet::new(); + + // send requests + for req_id in 1..=messages_to_send { + let request_id = RequestId(vec![req_id as u8]); + expected_request_ids.insert(request_id.clone()); + let _ = sender_send.send(HandlerIn::Request( + receiver_enr.clone().into(), + Box::new(Request { + id: request_id, + body: RequestBody::Ping { enr_seq: 1 }, + }), + )); + } + + loop { + match sender_recv.recv().await { + Some(HandlerOut::Response(_, response)) => { + assert!(expected_request_ids.remove(&response.id)); + response_count += 1; + if response_count == messages_to_send { + // Notify the handlers that the message exchange has been completed. + assert!(expected_request_ids.is_empty()); + sender_exit.send(()).unwrap(); + receiver_exit.send(()).unwrap(); + return; + } + } + _ => continue, + }; + } + }; + + let receiver_ops = async move { + let mut message_count = 0usize; + loop { + match receiver_recv.recv().await { + Some(HandlerOut::WhoAreYou(wru_ref)) => { + receiver_send + .send(HandlerIn::WhoAreYou(wru_ref, Some(sender_enr.clone()))) + .unwrap(); + } + Some(HandlerOut::Request(addr, request)) => { + assert!(matches!(request.body, RequestBody::Ping { .. })); + let pong_response = Response { + id: request.id, + body: ResponseBody::Pong { + enr_seq: 1, + ip: ip.into(), + port: NonZeroU16::new(sender_port).unwrap(), + }, + }; + receiver_send + .send(HandlerIn::Response(addr, Box::new(pong_response))) + .unwrap(); + message_count += 1; + if message_count == messages_to_send { + return; + } + } + _ => { + continue; + } + } + } + }; + + let sleep_future = sleep(Duration::from_secs(5)); + let message_exchange = async move { + let _ = tokio::join!(sender, sender_ops, receiver, receiver_ops); + }; + + tokio::select! { + _ = message_exchange => {} + _ = sleep_future => { + panic!("Test timed out"); + } + } +} diff --git a/src/service.rs b/src/service.rs index 3b3713ce3..1c99e2903 100644 --- a/src/service.rs +++ b/src/service.rs @@ -188,7 +188,7 @@ pub struct Service { active_requests: FnvHashMap, /// Keeps track of the number of responses received from a NODES response. - active_nodes_responses: HashMap, + active_nodes_responses: HashMap, /// A map of votes nodes have made about our external IP address. We accept the majority. ip_votes: Option, @@ -733,10 +733,8 @@ impl Service { // handle the case that there is more than one response if total > 1 { - let mut current_response = self - .active_nodes_responses - .remove(&node_id) - .unwrap_or_default(); + let mut current_response = + self.active_nodes_responses.remove(&id).unwrap_or_default(); debug!( "Nodes Response: {} of {} received", @@ -754,7 +752,7 @@ impl Service { current_response.received_nodes.append(&mut nodes); self.active_nodes_responses - .insert(node_id, current_response); + .insert(id.clone(), current_response); self.active_requests.insert(id, active_request); return; } @@ -776,7 +774,7 @@ impl Service { // in a later response sends a response with a total of 1, all previous nodes // will be ignored. // ensure any mapping is removed in this rare case - self.active_nodes_responses.remove(&node_id); + self.active_nodes_responses.remove(&id); self.discovered(&node_id, nodes, active_request.query_id); } @@ -1452,7 +1450,7 @@ impl Service { // if a failed FindNodes request, ensure we haven't partially received packets. If // so, process the partially found nodes RequestBody::FindNode { .. } => { - if let Some(nodes_response) = self.active_nodes_responses.remove(&node_id) { + if let Some(nodes_response) = self.active_nodes_responses.remove(&id) { if !nodes_response.received_nodes.is_empty() { warn!( "NODES Response failed, but was partially processed from: {}", diff --git a/src/service/test.rs b/src/service/test.rs index bf1f10946..c85f6a97a 100644 --- a/src/service/test.rs +++ b/src/service/test.rs @@ -3,6 +3,7 @@ use super::*; use crate::{ + discv5::test::generate_deterministic_keypair, handler::Handler, kbucket, kbucket::{BucketInsertResult, KBucketsTable, NodeStatus}, @@ -16,7 +17,7 @@ use crate::{ }; use enr::CombinedKey; use parking_lot::RwLock; -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{collections::HashMap, net::Ipv4Addr, sync::Arc, time::Duration}; use tokio::sync::{mpsc, oneshot}; /// Default UDP port number to use for tests requiring UDP exposure @@ -221,3 +222,122 @@ async fn test_connection_direction_on_inject_session_established() { assert!(status.is_connected()); assert_eq!(ConnectionDirection::Outgoing, status.direction); } + +#[tokio::test] +async fn test_handling_concurrent_responses() { + init(); + + // Seed is chosen such that all nodes are in the 256th distance of the first node. + let seed = 1652; + let mut keypairs = generate_deterministic_keypair(5, seed); + + let mut service = { + let enr_key = keypairs.pop().unwrap(); + let enr = Enr::builder() + .ip4(Ipv4Addr::LOCALHOST) + .udp4(10005) + .build(&enr_key) + .unwrap(); + build_service::( + Arc::new(RwLock::new(enr)), + Arc::new(RwLock::new(enr_key)), + false, + ) + .await + }; + + let node_contact: NodeContact = Enr::builder() + .ip4(Ipv4Addr::LOCALHOST) + .udp4(10006) + .build(&keypairs.remove(0)) + .unwrap() + .into(); + let node_address = node_contact.node_address(); + + // Add fake requests + // Request1 + service.active_requests.insert( + RequestId(vec![1]), + ActiveRequest { + contact: node_contact.clone(), + request_body: RequestBody::FindNode { + distances: vec![254, 255, 256], + }, + query_id: Some(QueryId(1)), + callback: None, + }, + ); + // Request2 + service.active_requests.insert( + RequestId(vec![2]), + ActiveRequest { + contact: node_contact, + request_body: RequestBody::FindNode { + distances: vec![254, 255, 256], + }, + query_id: Some(QueryId(2)), + callback: None, + }, + ); + + assert_eq!(3, keypairs.len()); + let mut enrs_for_response = keypairs + .iter() + .enumerate() + .map(|(i, key)| { + Enr::builder() + .ip4(Ipv4Addr::LOCALHOST) + .udp4(10007 + i as u16) + .build(key) + .unwrap() + }) + .collect::>(); + + // Response to `Request1` is sent as two separate messages in total. Handle the first one of the + // messages here. + service.handle_rpc_response( + node_address.clone(), + Response { + id: RequestId(vec![1]), + body: ResponseBody::Nodes { + total: 2, + nodes: vec![enrs_for_response.pop().unwrap()], + }, + }, + ); + // Service has still two active requests since we are waiting for the second NODE response to + // `Request1`. + assert_eq!(2, service.active_requests.len()); + // Service stores the first response to `Request1` into `active_nodes_responses`. + assert!(!service.active_nodes_responses.is_empty()); + + // Second, handle a response to *`Request2`* before the second response to `Request1`. + service.handle_rpc_response( + node_address.clone(), + Response { + id: RequestId(vec![2]), + body: ResponseBody::Nodes { + total: 1, + nodes: vec![enrs_for_response.pop().unwrap()], + }, + }, + ); + // `Request2` is completed so now the number of active requests should be one. + assert_eq!(1, service.active_requests.len()); + // Service still keeps the first response in `active_nodes_responses`. + assert!(!service.active_nodes_responses.is_empty()); + + // Finally, handle the second response to `Request1`. + service.handle_rpc_response( + node_address, + Response { + id: RequestId(vec![1]), + body: ResponseBody::Nodes { + total: 2, + nodes: vec![enrs_for_response.pop().unwrap()], + }, + }, + ); + assert!(service.active_requests.is_empty()); + assert!(service.active_nodes_responses.is_empty()); +}