diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 2c539392f..625582402 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -73,6 +73,12 @@ use session::Session; // seconds). const BANNED_NODES_CHECK: u64 = 300; // Check every 5 minutes. +// The one-time session timeout. +const ONE_TIME_SESSION_TIMEOUT: u64 = 30; + +// The maximum number of established one-time sessions to maintain. +const ONE_TIME_SESSION_CACHE_CAPACITY: usize = 100; + /// Messages sent from the application layer to `Handler`. #[derive(Debug, Clone, PartialEq)] #[allow(clippy::large_enum_variant)] @@ -191,6 +197,8 @@ pub struct Handler { active_challenges: HashMapDelay, /// Established sessions with peers. sessions: LruTimeCache, + /// Established sessions with peers for a specific request, stored just one per node. + one_time_sessions: LruTimeCache, /// The channel to receive messages from the application layer. service_recv: mpsc::UnboundedReceiver, /// The channel to send messages to the application layer. @@ -281,6 +289,10 @@ impl Handler { config.session_timeout, Some(config.session_cache_capacity), ), + one_time_sessions: LruTimeCache::new( + Duration::from_secs(ONE_TIME_SESSION_TIMEOUT), + Some(ONE_TIME_SESSION_CACHE_CAPACITY), + ), active_challenges: HashMapDelay::new(config.request_timeout), service_recv, service_send, @@ -516,23 +528,23 @@ impl Handler { response: Response, ) { // Check for an established session - if let Some(session) = self.sessions.get_mut(&node_address) { - // Encrypt the message and send - let packet = match session.encrypt_message::

(self.node_id, &response.encode()) { - Ok(packet) => packet, - Err(e) => { - warn!("Could not encrypt response: {:?}", e); - return; - } - }; - self.send(node_address, packet).await; + let packet = if let Some(session) = self.sessions.get_mut(&node_address) { + session.encrypt_message::

(self.node_id, &response.encode()) + } else if let Some(mut session) = self.remove_one_time_session(&node_address, &response.id) + { + session.encrypt_message::

(self.node_id, &response.encode()) } else { // Either the session is being established or has expired. We simply drop the // response in this case. - warn!( + return warn!( "Session is not established. Dropping response {} for node: {}", response, node_address.node_id ); + }; + + match packet { + Ok(packet) => self.send(node_address, packet).await, + Err(e) => warn!("Could not encrypt response: {:?}", e), } } @@ -780,7 +792,7 @@ impl Handler { ephem_pubkey, enr_record, ) { - Ok((session, enr)) => { + Ok((mut session, enr)) => { // Receiving an AuthResponse must give us an up-to-date view of the node ENR. // Verify the ENR is valid if self.verify_enr(&enr, &node_address) { @@ -820,6 +832,38 @@ impl Handler { ); self.fail_session(&node_address, RequestError::InvalidRemoteEnr, true) .await; + + // Respond to PING request even if the ENR or NodeAddress don't match + // so that the source node can notice its external IP address has been changed. + let maybe_ping_request = match session.decrypt_message( + message_nonce, + message, + authenticated_data, + ) { + Ok(m) => match Message::decode(&m) { + Ok(Message::Request(request)) if request.msg_type() == 1 => { + Some(request) + } + _ => None, + }, + _ => None, + }; + if let Some(request) = maybe_ping_request { + debug!( + "Responding to a PING request using a one-time session. node_address: {}", + node_address + ); + self.one_time_sessions + .insert(node_address.clone(), (request.id.clone(), session)); + if let Err(e) = self + .service_send + .send(HandlerOut::Request(node_address.clone(), Box::new(request))) + .await + { + warn!("Failed to report request to application {}", e); + self.one_time_sessions.remove(&node_address); + } + } } } Err(Discv5Error::InvalidChallengeSignature(challenge)) => { @@ -1119,6 +1163,24 @@ impl Handler { } } + /// Remove one-time session by the given NodeAddress and RequestId if exists. + fn remove_one_time_session( + &mut self, + node_address: &NodeAddress, + request_id: &RequestId, + ) -> Option { + match self.one_time_sessions.peek(node_address) { + Some((id, _)) if id == request_id => { + let (_, session) = self + .one_time_sessions + .remove(node_address) + .expect("one-time session must exist"); + Some(session) + } + _ => None, + } + } + /// A request has failed. async fn fail_request( &mut self, diff --git a/src/handler/session.rs b/src/handler/session.rs index fedeb031d..9f79f9017 100644 --- a/src/handler/session.rs +++ b/src/handler/session.rs @@ -265,3 +265,11 @@ impl Session { Ok((packet, session)) } } + +#[cfg(test)] +pub(crate) fn build_dummy_session() -> Session { + Session::new(Keys { + encryption_key: [0; 16], + decryption_key: [0; 16], + }) +} diff --git a/src/handler/tests.rs b/src/handler/tests.rs index d64fab792..db9c7d3a8 100644 --- a/src/handler/tests.rs +++ b/src/handler/tests.rs @@ -9,7 +9,10 @@ use crate::{ }; use std::net::{Ipv4Addr, Ipv6Addr}; -use crate::{handler::HandlerOut::RequestFailed, RequestError::SelfRequest}; +use crate::{ + handler::{session::build_dummy_session, HandlerOut::RequestFailed}, + RequestError::SelfRequest, +}; use active_requests::ActiveRequests; use enr::EnrBuilder; use std::time::Duration; @@ -21,6 +24,66 @@ fn init() { .try_init(); } +async fn build_handler() -> Handler { + let config = Discv5ConfigBuilder::new(ListenConfig::default()).build(); + let key = CombinedKey::generate_secp256k1(); + let enr = EnrBuilder::new("v4") + .ip4(Ipv4Addr::LOCALHOST) + .udp4(9000) + .build(&key) + .unwrap(); + let mut listen_sockets = SmallVec::default(); + listen_sockets.push((Ipv4Addr::LOCALHOST, 9000).into()); + let node_id = enr.node_id(); + let filter_expected_responses = Arc::new(RwLock::new(HashMap::new())); + + let socket = { + let socket_config = { + let filter_config = FilterConfig { + enabled: config.enable_packet_filter, + rate_limiter: config.filter_rate_limiter.clone(), + max_nodes_per_ip: config.filter_max_nodes_per_ip, + max_bans_per_ip: config.filter_max_bans_per_ip, + }; + + socket::SocketConfig { + executor: config.executor.clone().expect("Executor must exist"), + filter_config, + listen_config: config.listen_config.clone(), + local_node_id: node_id, + expected_responses: filter_expected_responses.clone(), + ban_duration: config.ban_duration, + } + }; + + Socket::new::

(socket_config).await.unwrap() + }; + let (_, service_recv) = mpsc::unbounded_channel(); + let (service_send, _) = mpsc::channel(50); + let (_, exit) = oneshot::channel(); + + Handler { + request_retries: config.request_retries, + node_id, + enr: Arc::new(RwLock::new(enr)), + key: Arc::new(RwLock::new(key)), + active_requests: ActiveRequests::new(config.request_timeout), + pending_requests: HashMap::new(), + filter_expected_responses, + sessions: LruTimeCache::new(config.session_timeout, Some(config.session_cache_capacity)), + one_time_sessions: LruTimeCache::new( + Duration::from_secs(ONE_TIME_SESSION_TIMEOUT), + Some(ONE_TIME_SESSION_CACHE_CAPACITY), + ), + active_challenges: HashMapDelay::new(config.request_timeout), + service_recv, + service_send, + listen_sockets, + socket, + exit, + } +} + macro_rules! arc_rw { ( $x: expr ) => { Arc::new(RwLock::new($x)) @@ -353,3 +416,40 @@ async fn test_self_request_ipv6() { handler_out ); } + +#[tokio::test] +async fn remove_one_time_session() { + let mut handler = build_handler::().await; + + let enr = { + let key = CombinedKey::generate_secp256k1(); + EnrBuilder::new("v4") + .ip4(Ipv4Addr::LOCALHOST) + .udp4(9000) + .build(&key) + .unwrap() + }; + let node_address = NodeAddress::new("127.0.0.1:9000".parse().unwrap(), enr.node_id()); + let request_id = RequestId::random(); + let session = build_dummy_session(); + handler + .one_time_sessions + .insert(node_address.clone(), (request_id.clone(), session)); + + let other_request_id = RequestId::random(); + assert!(handler + .remove_one_time_session(&node_address, &other_request_id) + .is_none()); + assert_eq!(1, handler.one_time_sessions.len()); + + let other_node_address = NodeAddress::new("127.0.0.1:9001".parse().unwrap(), enr.node_id()); + assert!(handler + .remove_one_time_session(&other_node_address, &request_id) + .is_none()); + assert_eq!(1, handler.one_time_sessions.len()); + + assert!(handler + .remove_one_time_session(&node_address, &request_id) + .is_some()); + assert_eq!(0, handler.one_time_sessions.len()); +}