Skip to content

Commit

Permalink
Use entry to access/mutate ActiveRequests
Browse files Browse the repository at this point in the history
  • Loading branch information
njgheorghita committed Jul 19, 2023
1 parent f1d4d80 commit 15b5ed7
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 123 deletions.
174 changes: 72 additions & 102 deletions src/handler/active_requests.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
use super::*;
use delay_map::HashMapDelay;
use std::fmt;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ActiveRequestsError {
InvalidState,
}

impl fmt::Display for ActiveRequestsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ActiveRequestsError::InvalidState => {
write!(f, "Invalid state: active requests mappings are not in sync")
}
}
}
}

impl std::error::Error for ActiveRequestsError {}
use more_asserts::debug_unreachable;
use std::collections::hash_map::Entry;

pub(super) struct ActiveRequests {
/// A list of raw messages we are awaiting a response from the remote.
Expand All @@ -40,94 +24,86 @@ impl ActiveRequests {
// Insert a new request into the active requests mapping.
pub fn insert(&mut self, node_address: NodeAddress, request_call: RequestCall) {
let nonce = *request_call.packet().message_nonce();
let mut request_calls = self
.active_requests_mapping
.remove(&node_address)
.unwrap_or_default();
request_calls.push(request_call);
self.active_requests_mapping
.insert(node_address.clone(), request_calls);
.entry(node_address.clone())
.or_insert_with(Vec::new)
.push(request_call);
self.active_requests_nonce_mapping
.insert(nonce, node_address);
}

// Remove a single request identified by its nonce.
pub fn remove_by_nonce(
&mut self,
nonce: &MessageNonce,
) -> Result<(NodeAddress, RequestCall), ActiveRequestsError> {
let node_address = self
.active_requests_nonce_mapping
.remove(nonce)
.ok_or_else(|| ActiveRequestsError::InvalidState)?;
let mut requests = self
.active_requests_mapping
.remove(&node_address)
.ok_or_else(|| ActiveRequestsError::InvalidState)?;
let index = match requests
.iter()
.position(|req| req.packet().message_nonce() == nonce)
{
Some(index) => index,
None => {
// if nonce req is missing, reinsert remaining requests into mapping
if !requests.is_empty() {
self.active_requests_mapping
.insert(node_address.clone(), requests);
}
return Err(ActiveRequestsError::InvalidState);
}
pub fn remove_by_nonce(&mut self, nonce: &MessageNonce) -> Option<(NodeAddress, RequestCall)> {
let node_address = match self.active_requests_nonce_mapping.remove(nonce) {
Some(val) => val,
None => return None,
};
let req = requests.remove(index);
if !requests.is_empty() {
self.active_requests_mapping
.insert(node_address.clone(), requests);
match self.active_requests_mapping.entry(node_address.clone()) {
Entry::Vacant(_) => {
debug_unreachable!("expected to find node address in active_requests_mapping");
error!("expected to find node address in active_requests_mapping");
None
}
Entry::Occupied(mut requests) => {
let index = requests
.get()
.iter()
.position(|req| req.packet().message_nonce() == nonce)
.expect("to find request call by nonce");
Some((node_address, requests.get_mut().remove(index)))
}
}
Ok((node_address, req))
}

// Remove all requests associated with a node.
pub fn remove_requests(
&mut self,
node_address: &NodeAddress,
) -> Result<Vec<RequestCall>, ActiveRequestsError> {
let requests = self
.active_requests_mapping
.remove(&node_address)
.ok_or_else(|| ActiveRequestsError::InvalidState)?;
pub fn remove_requests(&mut self, node_address: &NodeAddress) -> Option<Vec<RequestCall>> {
let requests = match self.active_requests_mapping.remove(node_address) {
Some(val) => val,
None => return None,
};
// Account for node addresses in `active_requests_nonce_mapping` with an empty list
if requests.is_empty() {
return None;
}
for req in &requests {
self.active_requests_nonce_mapping
.remove(req.packet().message_nonce());
if self
.active_requests_nonce_mapping
.remove(req.packet().message_nonce())
.is_none()
{
debug_unreachable!("expected to find req with nonce");
error!("expected to find req with nonce");
}
}
Ok(requests)
Some(requests)
}

// Remove a single request identified by its id.
pub fn remove_request(
&mut self,
node_address: &NodeAddress,
id: &RequestId,
) -> Result<RequestCall, ActiveRequestsError> {
let reqs = self
.active_requests_mapping
.get(node_address)
.ok_or_else(|| ActiveRequestsError::InvalidState)?;
let index = reqs
.iter()
.position(|req| {
let req_id: RequestId = req.id().into();
&req_id == id
})
.ok_or_else(|| ActiveRequestsError::InvalidState)?;
let nonce = reqs
.get(index)
.ok_or_else(|| ActiveRequestsError::InvalidState)?
.packet()
.message_nonce()
.clone();
// Remove the associated nonce mapping.
let (_, request_call) = self.remove_by_nonce(&nonce)?;
Ok(request_call)
) -> Option<RequestCall> {
match self.active_requests_mapping.entry(node_address.to_owned()) {
Entry::Vacant(_) => None,
Entry::Occupied(mut requests) => {
let index = requests.get().iter().position(|req| {
let req_id: RequestId = req.id().into();
&req_id == id
});
let index = match index {
Some(index) => index,
// Node address existence in active requests mapping does not guarantee request
// id existence.
None => return None,
};
let request_call = requests.get_mut().remove(index);
// Remove the associated nonce mapping.
self.active_requests_nonce_mapping
.remove(request_call.packet().message_nonce());
Some(request_call)
}
}
}

/// Checks that `active_requests_mapping` and `active_requests_nonce_mapping` are in sync.
Expand Down Expand Up @@ -160,23 +136,17 @@ impl Stream for ActiveRequests {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.active_requests_nonce_mapping.poll_next_unpin(cx) {
Poll::Ready(Some(Ok((nonce, node_address)))) => {
// remove the associated mapping
let mut reqs = self
.active_requests_mapping
.remove(&node_address)
.ok_or_else(|| ActiveRequestsError::InvalidState)
.unwrap();
let index = reqs
.iter()
.position(|req| req.packet().message_nonce() == &nonce)
.ok_or_else(|| ActiveRequestsError::InvalidState)
.unwrap();
let req = reqs.remove(index);
if reqs.len() > 0 {
self.active_requests_mapping
.insert(node_address.clone(), reqs);
match self.active_requests_mapping.entry(node_address.clone()) {
Entry::Vacant(_) => panic!("invalid ActiveRequests state"),
Entry::Occupied(mut requests) => {
let index = requests
.get()
.iter()
.position(|req| req.packet().message_nonce() == &nonce)
.expect("to find request call by nonce");
Poll::Ready(Some(Ok((node_address, requests.get_mut().remove(index)))))
}
}
Poll::Ready(Some(Ok((node_address, req))))
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(None) => Poll::Ready(None),
Expand Down
94 changes: 76 additions & 18 deletions src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,16 @@ enum HandlerReqId {
External(RequestId),
}

impl Into<RequestId> for &HandlerReqId {
fn into(self) -> RequestId {
match self {
/// A request queued for sending.
struct PendingRequest {
contact: NodeContact,
request_id: HandlerReqId,
request: RequestBody,
}

impl From<&HandlerReqId> for RequestId {
fn from(id: &HandlerReqId) -> Self {
match id {
HandlerReqId::Internal(id) => id.clone(),
HandlerReqId::External(id) => id.clone(),
}
Expand All @@ -193,6 +200,8 @@ pub struct Handler {
active_requests: ActiveRequests,
/// The expected responses by SocketAddr which allows packets to pass the underlying filter.
filter_expected_responses: Arc<RwLock<HashMap<SocketAddr, usize>>>,
/// Requests awaiting a handshake completion.
pending_requests: HashMap<NodeAddress, Vec<PendingRequest>>,
/// Currently in-progress outbound handshakes (WHOAREYOU packets) with peers.
active_challenges: HashMapDelay<NodeAddress, Challenge>,
/// Established sessions with peers.
Expand Down Expand Up @@ -283,6 +292,7 @@ impl Handler {
enr,
key,
active_requests: ActiveRequests::new(config.request_timeout),
pending_requests: HashMap::new(),
filter_expected_responses,
sessions: LruTimeCache::new(
config.session_timeout,
Expand Down Expand Up @@ -463,6 +473,20 @@ impl Handler {
return Err(RequestError::SelfRequest);
}

// If there is already an active request or an active challenge (WHOAREYOU sent) for this node, add to pending requests
if self.active_challenges.get(&node_address).is_some() {
trace!("Request queued for node: {}", node_address);
self.pending_requests
.entry(node_address)
.or_insert_with(Vec::new)
.push(PendingRequest {
contact,
request_id,
request,
});
return Ok(());
}

let (packet, initiating_session) = {
if let Some(session) = self.sessions.get_mut(&node_address) {
// Encrypt the message and send
Expand All @@ -477,14 +501,23 @@ impl Handler {
.map_err(|e| RequestError::EncryptionFailed(format!("{e:?}")))?;
(packet, false)
} else {
// No session exists, start a new handshake
// No session exists, start a new handshake initiating a new session
trace!(
"Starting session. Sending random packet to: {}",
node_address
);
let packet =
Packet::new_random(&self.node_id).map_err(RequestError::EntropyFailure)?;
// We are initiating a new session
// Queue the request for sending after the handshake completes
self.pending_requests
.entry(node_address.clone())
.or_insert_with(Vec::new)
.push(PendingRequest {
contact: contact.clone(),
request_id: request_id.clone(),
request: request.clone(),
});

(packet, true)
}
};
Expand Down Expand Up @@ -587,7 +620,7 @@ impl Handler {
// Check that this challenge matches a known active request.
// If this message passes all the requisite checks, a request call is returned.
let mut request_call = match self.active_requests.remove_by_nonce(&request_nonce) {
Ok((node_address, request_call)) => {
Some((node_address, request_call)) => {
// Verify that the src_addresses match
if node_address.socket_addr != src_address {
debug!("Received a WHOAREYOU packet for a message with a non-expected source. Source {}, expected_source: {} message_nonce {}", src_address, node_address.socket_addr, hex::encode(request_nonce));
Expand All @@ -597,7 +630,7 @@ impl Handler {
}
request_call
}
Err(_) => {
None => {
trace!("Received a WHOAREYOU packet that references an unknown or expired request. Source {}, message_nonce {}", src_address, hex::encode(request_nonce));
return;
}
Expand Down Expand Up @@ -804,7 +837,7 @@ impl Handler {
.await;
// We could have pending messages that were awaiting this session to be
// established. If so process them.
self.replay_active_requests::<P>(&node_address).await;
self.send_pending_requests::<P>(&node_address).await;
} else {
// IP's or NodeAddress don't match. Drop the session.
warn!(
Expand Down Expand Up @@ -874,24 +907,49 @@ impl Handler {
}
}

async fn send_pending_requests<P: ProtocolIdentity>(&mut self, node_address: &NodeAddress) {
let pending_requests = self
.pending_requests
.remove(node_address)
.unwrap_or_default();
for req in pending_requests {
if let Err(request_error) = self
.send_request::<P>(req.contact, req.request_id.clone(), req.request)
.await
{
warn!("Failed to send next pending request {request_error}");
// Inform the service that the request failed
match req.request_id {
HandlerReqId::Internal(_) => {
// An internal request could not be sent. For now we do nothing about
// this.
}
HandlerReqId::External(id) => {
if let Err(e) = self
.service_send
.send(HandlerOut::RequestFailed(id, request_error))
.await
{
warn!("Failed to inform that request failed {e}");
}
}
}
}
}
}

async fn replay_active_requests<P: ProtocolIdentity>(&mut self, node_address: &NodeAddress) {
let active_requests = self
.active_requests
.remove_requests(node_address)
.unwrap_or_default();
for req in active_requests {
let request_id = req.id().clone();
if let Err(request_error) = self
.send_request::<P>(
req.contact().clone(),
request_id.clone(),
req.body().clone(),
)
.await
let (req_id, contact, body) = req.into_request_parts();
if let Err(request_error) = self.send_request::<P>(contact, req_id.clone(), body).await
{
warn!("Failed to send next awaiting request {request_error}");
// Inform the service that the request failed
match request_id {
match req_id {
HandlerReqId::Internal(_) => {
// An internal request could not be sent. For now we do nothing about
// this.
Expand Down Expand Up @@ -1044,7 +1102,7 @@ impl Handler {
response: Response,
) {
// Find a matching request, if any
if let Ok(mut request_call) = self
if let Some(mut request_call) = self
.active_requests
.remove_request(&node_address, &response.id)
{
Expand Down
Loading

0 comments on commit 15b5ed7

Please sign in to comment.