=
RwLock::new(crate::PermitBanList::default());
}
-mod test;
+pub(crate) mod test;
/// Events that can be produced by the `Discv5` event stream.
#[derive(Debug)]
-pub enum Discv5Event {
+pub enum Event {
/// A node has been discovered from a FINDNODES request.
///
/// The ENR of the node is returned. Various properties can be derived from the ENR.
@@ -71,8 +77,11 @@ pub enum Discv5Event {
/// The main Discv5 Service struct. This provides the user-level API for performing queries and
/// interacting with the underlying service.
-pub struct Discv5 {
- config: Discv5Config,
+pub struct Discv5
+where
+ P: ProtocolIdentity,
+{
+ config: Config,
/// The channel to make requests from the main service.
service_channel: Option>,
/// The exit channel to shutdown the underlying service.
@@ -83,13 +92,17 @@ pub struct Discv5 {
local_enr: Arc>,
/// The key associated with the local ENR, required for updating the local ENR.
enr_key: Arc>,
+ // Type of socket we are using
+ ip_mode: IpMode,
+ /// Phantom for the protocol id.
+ _phantom: PhantomData,
}
-impl Discv5 {
+impl Discv5 {
pub fn new(
local_enr: Enr,
enr_key: CombinedKey,
- mut config: Discv5Config,
+ mut config: Config,
) -> Result {
// ensure the keypair matches the one that signed the enr.
if local_enr.public_key() != enr_key.public() {
@@ -126,6 +139,8 @@ impl Discv5 {
// Update the PermitBan list based on initial configuration
*PERMIT_BAN_LIST.write() = config.permit_ban_list.clone();
+ let ip_mode = IpMode::new_from_listen_config(&config.listen_config);
+
Ok(Discv5 {
config,
service_channel: None,
@@ -133,23 +148,24 @@ impl Discv5 {
kbuckets,
local_enr,
enr_key,
+ ip_mode,
+ _phantom: Default::default(),
})
}
/// Starts the required tasks and begins listening on a given UDP SocketAddr.
- pub async fn start(&mut self, listen_socket: SocketAddr) -> Result<(), Discv5Error> {
+ pub async fn start(&mut self) -> Result<(), Error> {
if self.service_channel.is_some() {
warn!("Service is already started");
- return Err(Discv5Error::ServiceAlreadyStarted);
+ return Err(Error::ServiceAlreadyStarted);
}
// create the main service
- let (service_exit, service_channel) = Service::spawn(
+ let (service_exit, service_channel) = Service::spawn::(
self.local_enr.clone(),
self.enr_key.clone(),
self.kbuckets.clone(),
self.config.clone(),
- listen_socket,
)
.await?;
self.service_exit = Some(service_exit);
@@ -178,7 +194,7 @@ impl Discv5 {
/// them upfront.
pub fn add_enr(&self, enr: Enr) -> Result<(), &'static str> {
// only add ENR's that have a valid udp socket.
- if self.config.ip_mode.get_contactable_addr(&enr).is_none() {
+ if self.ip_mode.get_contactable_addr(&enr).is_none() {
warn!("ENR attempted to be added without an UDP socket compatible with configured IpMode has been ignored.");
return Err("ENR has no compatible UDP socket to connect to");
}
@@ -285,6 +301,13 @@ impl Discv5 {
self.local_enr.read().clone()
}
+ /// Identical to `Discv5::local_enr` except that this exposes the `Arc` itself.
+ ///
+ /// This is useful for synchronising views of the local ENR outside of `Discv5`.
+ pub fn external_enr(&self) -> Arc> {
+ self.local_enr.clone()
+ }
+
/// Returns the routing table of the discv5 service
pub fn kbuckets(&self) -> KBucketsTable {
self.kbuckets.read().clone()
@@ -300,6 +323,31 @@ impl Discv5 {
None
}
+ /// Sends a PING request to a node.
+ pub fn send_ping(
+ &self,
+ enr: Enr,
+ ) -> impl Future> + 'static {
+ let (callback_send, callback_recv) = oneshot::channel();
+ let channel = self.clone_channel();
+
+ async move {
+ let channel = channel.map_err(|_| RequestError::ServiceNotStarted)?;
+
+ let event = ServiceRequest::Ping(enr, Some(callback_send));
+
+ // send the request
+ channel
+ .send(event)
+ .await
+ .map_err(|_| RequestError::ChannelFailed("Service channel closed".into()))?;
+ // await the response
+ callback_recv
+ .await
+ .map_err(|e| RequestError::ChannelFailed(e.to_string()))?
+ }
+ }
+
/// Bans a node from the server. This will remove the node from the routing table if it exists
/// and block all incoming packets from the node until the timeout specified. Setting the
/// timeout to `None` creates a permanent ban.
@@ -351,39 +399,45 @@ impl Discv5 {
/// Updates the local ENR TCP/UDP socket.
pub fn update_local_enr_socket(&self, socket_addr: SocketAddr, is_tcp: bool) -> bool {
let mut local_enr = self.local_enr.write();
- let update_socket: Option = match socket_addr {
- SocketAddr::V4(socket_addr) => {
- if Some(socket_addr) != local_enr.udp4_socket() {
- Some(socket_addr.into())
- } else {
- None
+ match (is_tcp, socket_addr) {
+ (false, SocketAddr::V4(specific_socket_addr)) => {
+ if Some(specific_socket_addr) != local_enr.udp4_socket() {
+ return local_enr
+ .set_udp_socket(socket_addr, &self.enr_key.read())
+ .is_ok();
}
}
- SocketAddr::V6(socket_addr) => {
- if Some(socket_addr) != local_enr.udp6_socket() {
- Some(socket_addr.into())
- } else {
- None
+ (true, SocketAddr::V4(specific_socket_addr)) => {
+ if Some(specific_socket_addr) != local_enr.tcp4_socket() {
+ return local_enr
+ .set_tcp_socket(socket_addr, &self.enr_key.read())
+ .is_ok();
}
}
- };
- if let Some(new_socket_addr) = update_socket {
- if is_tcp {
- local_enr
- .set_tcp_socket(new_socket_addr, &self.enr_key.read())
- .is_ok()
- } else {
- local_enr
- .set_udp_socket(new_socket_addr, &self.enr_key.read())
- .is_ok()
+ (false, SocketAddr::V6(specific_socket_addr)) => {
+ if Some(specific_socket_addr) != local_enr.udp6_socket() {
+ return local_enr
+ .set_udp_socket(socket_addr, &self.enr_key.read())
+ .is_ok();
+ }
+ }
+ (true, SocketAddr::V6(specific_socket_addr)) => {
+ if Some(specific_socket_addr) != local_enr.tcp6_socket() {
+ return local_enr
+ .set_tcp_socket(socket_addr, &self.enr_key.read())
+ .is_ok();
+ }
}
- } else {
- false
}
+ false
}
/// Allows application layer to insert an arbitrary field into the local ENR.
- pub fn enr_insert(&self, key: &str, value: &[u8]) -> Result>, EnrError> {
+ pub fn enr_insert(
+ &self,
+ key: &str,
+ value: &T,
+ ) -> Result>, EnrError> {
self.local_enr
.write()
.insert(key, value, &self.enr_key.read())
@@ -452,14 +506,33 @@ impl Discv5 {
let (callback_send, callback_recv) = oneshot::channel();
- let event = ServiceRequest::FindEnr(node_contact, callback_send);
+ let event =
+ ServiceRequest::FindNodeDesignated(node_contact.clone(), vec![0], callback_send);
+
+ // send the request
channel
.send(event)
.await
.map_err(|_| RequestError::ChannelFailed("Service channel closed".into()))?;
- callback_recv
+ // await the response
+ match callback_recv
.await
.map_err(|e| RequestError::ChannelFailed(e.to_string()))?
+ {
+ Ok(mut nodes) => {
+ // This must be for asking for an ENR
+ if nodes.len() > 1 {
+ warn!(
+ "Peer returned more than one ENR for itself. {}",
+ node_contact
+ );
+ }
+ nodes
+ .pop()
+ .ok_or(RequestError::InvalidEnr("Peer did not return an ENR"))
+ }
+ Err(err) => Err(err),
+ }
}
}
@@ -474,7 +547,7 @@ impl Discv5 {
let (callback_send, callback_recv) = oneshot::channel();
let channel = self.clone_channel();
- let ip_mode = self.config.ip_mode;
+ let ip_mode = self.ip_mode;
async move {
let node_contact = NodeContact::try_from_enr(enr, ip_mode)?;
@@ -494,6 +567,35 @@ impl Discv5 {
}
}
+ /// Send a FINDNODE request for nodes that fall within the given set of distances,
+ /// to the designated peer and wait for a response.
+ pub fn find_node_designated_peer(
+ &self,
+ enr: Enr,
+ distances: Vec,
+ ) -> impl Future, RequestError>> + 'static {
+ let (callback_send, callback_recv) = oneshot::channel();
+ let channel = self.clone_channel();
+ let ip_mode = self.ip_mode;
+
+ async move {
+ let node_contact = NodeContact::try_from_enr(enr, ip_mode)?;
+ let channel = channel.map_err(|_| RequestError::ServiceNotStarted)?;
+
+ let event = ServiceRequest::FindNodeDesignated(node_contact, distances, callback_send);
+
+ // send the request
+ channel
+ .send(event)
+ .await
+ .map_err(|_| RequestError::ChannelFailed("Service channel closed".into()))?;
+ // await the response
+ callback_recv
+ .await
+ .map_err(|e| RequestError::ChannelFailed(e.to_string()))?
+ }
+ }
+
/// Runs an iterative `FIND_NODE` request.
///
/// This will return peers containing contactable nodes of the DHT closest to the
@@ -575,7 +677,7 @@ impl Discv5 {
/// Creates an event stream channel which can be polled to receive Discv5 events.
pub fn event_stream(
&self,
- ) -> impl Future, Discv5Error>> + 'static {
+ ) -> impl Future, Error>> + 'static {
let channel = self.clone_channel();
async move {
@@ -587,25 +689,23 @@ impl Discv5 {
channel
.send(event)
.await
- .map_err(|_| Discv5Error::ServiceChannelClosed)?;
+ .map_err(|_| Error::ServiceChannelClosed)?;
- callback_recv
- .await
- .map_err(|_| Discv5Error::ServiceChannelClosed)
+ callback_recv.await.map_err(|_| Error::ServiceChannelClosed)
}
}
/// Internal helper function to send events to the Service.
- fn clone_channel(&self) -> Result, Discv5Error> {
+ fn clone_channel(&self) -> Result, Error> {
if let Some(channel) = self.service_channel.as_ref() {
Ok(channel.clone())
} else {
- Err(Discv5Error::ServiceNotStarted)
+ Err(Error::ServiceNotStarted)
}
}
}
-impl Drop for Discv5 {
+impl Drop for Discv5 {
fn drop(&mut self) {
self.shutdown();
}
diff --git a/src/discv5/test.rs b/src/discv5/test.rs
index c5a432d16..f6cd70225 100644
--- a/src/discv5/test.rs
+++ b/src/discv5/test.rs
@@ -1,9 +1,12 @@
#![cfg(test)]
-use crate::{Discv5, *};
-use enr::{k256, CombinedKey, Enr, EnrBuilder, EnrKey, NodeId};
+use crate::{socket::ListenConfig, Discv5, *};
+use enr::{k256, CombinedKey, Enr, EnrKey, NodeId};
use rand_core::{RngCore, SeedableRng};
-use std::{collections::HashMap, net::Ipv4Addr};
+use std::{
+ collections::HashMap,
+ net::{Ipv4Addr, Ipv6Addr},
+};
fn init() {
let _ = tracing_subscriber::fmt()
@@ -11,7 +14,7 @@ fn init() {
.try_init();
}
-fn update_enr(discv5: &mut Discv5, key: &str, value: &[u8]) -> bool {
+fn update_enr(discv5: &mut Discv5, key: &str, value: &T) -> bool {
discv5.enr_insert(key, value).is_ok()
}
@@ -22,17 +25,13 @@ async fn build_nodes(n: usize, base_port: u16) -> Vec {
for port in base_port..base_port + n as u16 {
let enr_key = CombinedKey::generate_secp256k1();
- let config = Discv5Config::default();
+ let listen_config = ListenConfig::Ipv4 { ip, port };
+ let config = ConfigBuilder::new(listen_config).build();
- let enr = EnrBuilder::new("v4")
- .ip4(ip)
- .udp4(port)
- .build(&enr_key)
- .unwrap();
+ let enr = Enr::builder().ip4(ip).udp4(port).build(&enr_key).unwrap();
// transport for building a swarm
- let socket_addr = enr.udp4_socket().unwrap();
let mut discv5 = Discv5::new(enr, enr_key, config).unwrap();
- discv5.start(socket_addr.into()).await.unwrap();
+ discv5.start().await.unwrap();
nodes.push(discv5);
}
nodes
@@ -46,24 +45,78 @@ async fn build_nodes_from_keypairs(keys: Vec, base_port: u16) -> Ve
for (i, enr_key) in keys.into_iter().enumerate() {
let port = base_port + i as u16;
- let config = Discv5ConfigBuilder::new().build();
+ let listen_config = ListenConfig::Ipv4 { ip, port };
+ let config = ConfigBuilder::new(listen_config).build();
- let enr = EnrBuilder::new("v4")
- .ip4(ip)
- .udp4(port)
+ let enr = Enr::builder().ip4(ip).udp4(port).build(&enr_key).unwrap();
+
+ let mut discv5 = Discv5::new(enr, enr_key, config).unwrap();
+ discv5.start().await.unwrap();
+ nodes.push(discv5);
+ }
+ nodes
+}
+
+async fn build_nodes_from_keypairs_ipv6(keys: Vec, base_port: u16) -> Vec {
+ let mut nodes = Vec::new();
+
+ for (i, enr_key) in keys.into_iter().enumerate() {
+ let port = base_port + i as u16;
+
+ let listen_config = ListenConfig::Ipv6 {
+ ip: Ipv6Addr::LOCALHOST,
+ port,
+ };
+ let config = ConfigBuilder::new(listen_config).build();
+
+ let enr = Enr::builder()
+ .ip6(Ipv6Addr::LOCALHOST)
+ .udp6(port)
.build(&enr_key)
.unwrap();
- let socket_addr = enr.udp4_socket().unwrap();
let mut discv5 = Discv5::new(enr, enr_key, config).unwrap();
- discv5.start(socket_addr.into()).await.unwrap();
+ discv5.start().await.unwrap();
+ nodes.push(discv5);
+ }
+ nodes
+}
+
+async fn build_nodes_from_keypairs_dual_stack(
+ keys: Vec,
+ base_port: u16,
+) -> Vec {
+ let mut nodes = Vec::new();
+
+ for (i, enr_key) in keys.into_iter().enumerate() {
+ let ipv4_port = base_port + i as u16;
+ let ipv6_port = ipv4_port + 1000;
+
+ let listen_config = ListenConfig::DualStack {
+ ipv4: Ipv4Addr::LOCALHOST,
+ ipv4_port,
+ ipv6: Ipv6Addr::LOCALHOST,
+ ipv6_port,
+ };
+ let config = ConfigBuilder::new(listen_config).build();
+
+ let enr = Enr::builder()
+ .ip4(Ipv4Addr::LOCALHOST)
+ .udp4(ipv4_port)
+ .ip6(Ipv6Addr::LOCALHOST)
+ .udp6(ipv6_port)
+ .build(&enr_key)
+ .unwrap();
+
+ let mut discv5 = Discv5::new(enr, enr_key, config).unwrap();
+ discv5.start().await.unwrap();
nodes.push(discv5);
}
nodes
}
/// Generate `n` deterministic keypairs from a given seed.
-fn generate_deterministic_keypair(n: usize, seed: u64) -> Vec {
+pub(crate) fn generate_deterministic_keypair(n: usize, seed: u64) -> Vec {
let mut keypairs = Vec::new();
for i in 0..n {
let sk = {
@@ -72,7 +125,7 @@ fn generate_deterministic_keypair(n: usize, seed: u64) -> Vec {
loop {
// until a value is given within the curve order
rng.fill_bytes(&mut b);
- if let Ok(k) = k256::ecdsa::SigningKey::from_bytes(&b) {
+ if let Ok(k) = k256::ecdsa::SigningKey::from_slice(&b) {
break k;
}
}
@@ -88,6 +141,24 @@ fn get_distance(node1: NodeId, node2: NodeId) -> Option {
node1.log2_distance(&node2.into())
}
+#[macro_export]
+macro_rules! return_if_ipv6_is_not_supported {
+ () => {
+ let mut is_ipv6_supported = false;
+ for i in if_addrs::get_if_addrs().expect("network interfaces").iter() {
+ if !i.is_loopback() && i.addr.ip().is_ipv6() {
+ is_ipv6_supported = true;
+ break;
+ }
+ }
+
+ if !is_ipv6_supported {
+ tracing::error!("Seems Ipv6 is not supported. Test won't be run.");
+ return;
+ }
+ };
+}
+
// Simple searching function to find seeds that give node ids for a range of testing and different
// topologies
#[allow(dead_code)]
@@ -252,16 +323,137 @@ fn find_seed_linear_topology() {
}
}
-/// This is a smaller version of the star topology test designed to debug issues with queries.
+/// Test for running a simple query test for a topology consisting of IPv4 nodes.
+#[tokio::test]
+async fn test_discovery_three_peers_ipv4() {
+ init();
+ let total_nodes = 3;
+ // Seed is chosen such that all nodes are in the 256th bucket of bootstrap
+ let seed = 1652;
+ // Generate `num_nodes` + bootstrap_node and target_node keypairs from given seed
+ let keypairs = generate_deterministic_keypair(total_nodes + 2, seed);
+ // IPv4
+ let nodes = build_nodes_from_keypairs(keypairs, 10000).await;
+
+ assert_eq!(
+ total_nodes,
+ test_discovery_three_peers(nodes, total_nodes).await
+ );
+}
+
+/// Test for running a simple query test for a topology consisting of IPv6 nodes.
#[tokio::test]
-async fn test_discovery_three_peers() {
+async fn test_discovery_three_peers_ipv6() {
+ return_if_ipv6_is_not_supported!();
+
init();
let total_nodes = 3;
// Seed is chosen such that all nodes are in the 256th bucket of bootstrap
let seed = 1652;
// Generate `num_nodes` + bootstrap_node and target_node keypairs from given seed
let keypairs = generate_deterministic_keypair(total_nodes + 2, seed);
- let mut nodes = build_nodes_from_keypairs(keypairs, 11200).await;
+ // IPv6
+ let nodes = build_nodes_from_keypairs_ipv6(keypairs, 10010).await;
+
+ assert_eq!(
+ total_nodes,
+ test_discovery_three_peers(nodes, total_nodes).await
+ );
+}
+
+/// Test for running a simple query test for a topology consisting of dual stack nodes.
+#[tokio::test]
+async fn test_discovery_three_peers_dual_stack() {
+ return_if_ipv6_is_not_supported!();
+
+ init();
+ let total_nodes = 3;
+ // Seed is chosen such that all nodes are in the 256th bucket of bootstrap
+ let seed = 1652;
+ // Generate `num_nodes` + bootstrap_node and target_node keypairs from given seed
+ let keypairs = generate_deterministic_keypair(total_nodes + 2, seed);
+ // DualStack
+ let nodes = build_nodes_from_keypairs_dual_stack(keypairs, 10020).await;
+
+ assert_eq!(
+ total_nodes,
+ test_discovery_three_peers(nodes, total_nodes).await
+ );
+}
+
+/// Test for running a simple query test for a mixed topology of IPv4, IPv6 and dual stack nodes.
+/// The node to run the query is DualStack.
+#[tokio::test]
+async fn test_discovery_three_peers_mixed() {
+ return_if_ipv6_is_not_supported!();
+
+ init();
+ let total_nodes = 3;
+ // Seed is chosen such that all nodes are in the 256th bucket of bootstrap
+ let seed = 1652;
+ // Generate `num_nodes` + bootstrap_node and target_node keypairs from given seed
+ let mut keypairs = generate_deterministic_keypair(total_nodes + 2, seed);
+
+ let mut nodes = vec![];
+ // Bootstrap node (DualStack)
+ nodes.append(&mut build_nodes_from_keypairs_dual_stack(vec![keypairs.remove(0)], 10030).await);
+ // A node to run query (DualStack)
+ nodes.append(&mut build_nodes_from_keypairs_dual_stack(vec![keypairs.remove(0)], 10031).await);
+ // IPv4 node
+ nodes.append(&mut build_nodes_from_keypairs(vec![keypairs.remove(0)], 10032).await);
+ // IPv6 node
+ nodes.append(&mut build_nodes_from_keypairs_ipv6(vec![keypairs.remove(0)], 10033).await);
+ // Target node (DualStack)
+ nodes.append(&mut build_nodes_from_keypairs_dual_stack(vec![keypairs.remove(0)], 10034).await);
+
+ assert!(keypairs.is_empty());
+ assert_eq!(5, nodes.len());
+ assert_eq!(
+ total_nodes,
+ test_discovery_three_peers(nodes, total_nodes).await
+ );
+}
+
+/// Test for running a simple query test for a mixed topology of IPv4, IPv6 and dual stack nodes.
+/// The node to run the query is IPv4.
+// NOTE: This test emits the error log below because the node to run a query is in IPv4 mode so
+// IPv6 address included in the response is non-contactable.
+// `ERROR discv5::service: Query 0 has a non contactable enr: ENR: NodeId: 0xe030..dcbe, IpV4 Socket: None IpV6 Socket: Some([::1]:10043)`
+#[tokio::test]
+async fn test_discovery_three_peers_mixed_query_from_ipv4() {
+ return_if_ipv6_is_not_supported!();
+
+ init();
+ let total_nodes = 3;
+ // Seed is chosen such that all nodes are in the 256th bucket of bootstrap
+ let seed = 1652;
+ // Generate `num_nodes` + bootstrap_node and target_node keypairs from given seed
+ let mut keypairs = generate_deterministic_keypair(total_nodes + 2, seed);
+
+ let mut nodes = vec![];
+ // Bootstrap node (DualStack)
+ nodes.append(&mut build_nodes_from_keypairs_dual_stack(vec![keypairs.remove(0)], 10040).await);
+ // A node to run query (** IPv4 **)
+ nodes.append(&mut build_nodes_from_keypairs(vec![keypairs.remove(0)], 10041).await);
+ // IPv4 node
+ nodes.append(&mut build_nodes_from_keypairs(vec![keypairs.remove(0)], 10042).await);
+ // IPv6 node
+ nodes.append(&mut build_nodes_from_keypairs_ipv6(vec![keypairs.remove(0)], 10043).await);
+ // Target node (DualStack)
+ nodes.append(&mut build_nodes_from_keypairs_dual_stack(vec![keypairs.remove(0)], 10044).await);
+
+ assert!(keypairs.is_empty());
+ assert_eq!(5, nodes.len());
+
+ // `2` is expected here since the node that runs the query is IPv4.
+ // The response from Bootstrap node will include the IPv6 node but that will be ignored due to
+ // non-contactable.
+ assert_eq!(2, test_discovery_three_peers(nodes, total_nodes).await);
+}
+
+/// This is a smaller version of the star topology test designed to debug issues with queries.
+async fn test_discovery_three_peers(mut nodes: Vec, total_nodes: usize) -> usize {
+ init();
// Last node is bootstrap node in a star topology
let bootstrap_node = nodes.remove(0);
// target_node is not polled.
@@ -307,7 +499,7 @@ async fn test_discovery_three_peers() {
result_nodes.len(),
total_nodes
);
- assert_eq!(result_nodes.len(), total_nodes);
+ result_nodes.len()
}
/// Test for a star topology with `num_nodes` connected to a `bootstrap_node`
@@ -544,12 +736,12 @@ async fn test_table_limits() {
let mut keypairs = generate_deterministic_keypair(12, 9487);
let ip: Ipv4Addr = "127.0.0.1".parse().unwrap();
let enr_key: CombinedKey = keypairs.remove(0);
- let config = Discv5ConfigBuilder::new().ip_limit().build();
- let enr = EnrBuilder::new("v4")
- .ip4(ip)
- .udp4(9050)
- .build(&enr_key)
- .unwrap();
+ let enr = Enr::builder().ip4(ip).udp4(9050).build(&enr_key).unwrap();
+ let listen_config = ListenConfig::Ipv4 {
+ ip: enr.ip4().unwrap(),
+ port: enr.udp4().unwrap(),
+ };
+ let config = ConfigBuilder::new(listen_config).ip_limit().build();
// let socket_addr = enr.udp_socket().unwrap();
let discv5: Discv5 = Discv5::new(enr, enr_key, config).unwrap();
@@ -559,7 +751,7 @@ async fn test_table_limits() {
.map(|i| {
let ip: Ipv4Addr = Ipv4Addr::new(192, 168, 1, i as u8);
let enr_key: CombinedKey = keypairs.remove(0);
- EnrBuilder::new("v4")
+ Enr::builder()
.ip4(ip)
.udp4(9050 + i as u16)
.build(&enr_key)
@@ -578,11 +770,7 @@ async fn test_table_limits() {
async fn test_bucket_limits() {
let enr_key = CombinedKey::generate_secp256k1();
let ip: Ipv4Addr = "127.0.0.1".parse().unwrap();
- let enr = EnrBuilder::new("v4")
- .ip4(ip)
- .udp4(9500)
- .build(&enr_key)
- .unwrap();
+ let enr = Enr::builder().ip4(ip).udp4(9500).build(&enr_key).unwrap();
let bucket_limit: usize = 2;
// Generate `bucket_limit + 1` keypairs that go in `enr` node's 256th bucket.
let keys = {
@@ -590,7 +778,7 @@ async fn test_bucket_limits() {
for _ in 0..bucket_limit + 1 {
loop {
let key = CombinedKey::generate_secp256k1();
- let enr_new = EnrBuilder::new("v4").build(&key).unwrap();
+ let enr_new = Enr::empty(&key).unwrap();
let node_key: Key = enr.node_id().into();
let distance = node_key.log2_distance(&enr_new.node_id().into()).unwrap();
if distance == 256 {
@@ -606,7 +794,7 @@ async fn test_bucket_limits() {
.map(|i| {
let kp = &keys[i - 1];
let ip: Ipv4Addr = Ipv4Addr::new(192, 168, 1, i as u8);
- EnrBuilder::new("v4")
+ Enr::builder()
.ip4(ip)
.udp4(9500 + i as u16)
.build(kp)
@@ -614,8 +802,13 @@ async fn test_bucket_limits() {
})
.collect();
- let config = Discv5ConfigBuilder::new().ip_limit().build();
- let discv5 = Discv5::new(enr, enr_key, config).unwrap();
+ let listen_config = ListenConfig::Ipv4 {
+ ip: enr.ip4().unwrap(),
+ port: enr.udp4().unwrap(),
+ };
+ let config = ConfigBuilder::new(listen_config).ip_limit().build();
+
+ let discv5: Discv5 = Discv5::new(enr, enr_key, config).unwrap();
for enr in enrs {
let _ = discv5.add_enr(enr.clone()); // we expect some of these to fail based on the filter.
}
diff --git a/src/error.rs b/src/error.rs
index d35ba90d4..35b2b4768 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -4,7 +4,7 @@ use std::fmt;
#[derive(Debug)]
/// A general error that is used throughout the Discv5 library.
-pub enum Discv5Error {
+pub enum Error {
/// An invalid ENR was received.
InvalidEnr,
/// The public key type is known.
@@ -41,9 +41,9 @@ pub enum Discv5Error {
Io(std::io::Error),
}
-impl From for Discv5Error {
- fn from(err: std::io::Error) -> Discv5Error {
- Discv5Error::Io(err)
+impl From for Error {
+ fn from(err: std::io::Error) -> Error {
+ Error::Io(err)
}
}
@@ -127,7 +127,7 @@ pub enum QueryError {
InvalidMultiaddr(String),
}
-impl fmt::Display for Discv5Error {
+impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:?}")
}
diff --git a/src/executor.rs b/src/executor.rs
index f3b6532e4..04dac7e64 100644
--- a/src/executor.rs
+++ b/src/executor.rs
@@ -1,6 +1,5 @@
-///! A simple trait to allow generic executors or wrappers for spawning the discv5 tasks.
-use std::future::Future;
-use std::pin::Pin;
+//! A simple trait to allow generic executors or wrappers for spawning the discv5 tasks.
+use std::{future::Future, pin::Pin};
pub trait Executor: ExecutorClone {
/// Run the given future in the background until it ends.
diff --git a/src/handler/active_requests.rs b/src/handler/active_requests.rs
index e46ccee83..8cfa2d602 100644
--- a/src/handler/active_requests.rs
+++ b/src/handler/active_requests.rs
@@ -1,68 +1,142 @@
use super::*;
use delay_map::HashMapDelay;
use more_asserts::debug_unreachable;
+use std::collections::hash_map::Entry;
pub(super) struct ActiveRequests {
/// A list of raw messages we are awaiting a response from the remote.
- active_requests_mapping: HashMapDelay,
+ active_requests_mapping: HashMap>,
// WHOAREYOU messages do not include the source node id. We therefore maintain another
// mapping of active_requests via message_nonce. This allows us to match WHOAREYOU
// requests with active requests sent.
- /// A mapping of all pending active raw requests message nonces to their NodeAddress.
- active_requests_nonce_mapping: HashMap,
+ /// A mapping of all active raw requests message nonces to their NodeAddress.
+ active_requests_nonce_mapping: HashMapDelay,
}
impl ActiveRequests {
pub fn new(request_timeout: Duration) -> Self {
ActiveRequests {
- active_requests_mapping: HashMapDelay::new(request_timeout),
- active_requests_nonce_mapping: HashMap::new(),
+ active_requests_mapping: HashMap::new(),
+ active_requests_nonce_mapping: HashMapDelay::new(request_timeout),
}
}
+ /// Insert a new request into the active requests mapping.
pub fn insert(&mut self, node_address: NodeAddress, request_call: RequestCall) {
let nonce = *request_call.packet().message_nonce();
self.active_requests_mapping
- .insert(node_address.clone(), request_call);
+ .entry(node_address.clone())
+ .or_default()
+ .push(request_call);
self.active_requests_nonce_mapping
.insert(nonce, node_address);
}
- pub fn get(&self, node_address: &NodeAddress) -> Option<&RequestCall> {
+ /// Update the underlying packet for the request via message nonce.
+ pub fn update_packet(&mut self, old_nonce: MessageNonce, new_packet: Packet) {
+ let node_address =
+ if let Some(node_address) = self.active_requests_nonce_mapping.remove(&old_nonce) {
+ node_address
+ } else {
+ debug_unreachable!("expected to find nonce in active_requests_nonce_mapping");
+ error!("expected to find nonce in active_requests_nonce_mapping");
+ return;
+ };
+
+ self.active_requests_nonce_mapping
+ .insert(new_packet.header.message_nonce, node_address.clone());
+
+ match self.active_requests_mapping.entry(node_address) {
+ Entry::Occupied(mut requests) => {
+ let maybe_request_call = requests
+ .get_mut()
+ .iter_mut()
+ .find(|req| req.packet().message_nonce() == &old_nonce);
+
+ if let Some(request_call) = maybe_request_call {
+ request_call.update_packet(new_packet);
+ } else {
+ debug_unreachable!("expected to find request call in active_requests_mapping");
+ error!("expected to find request call in active_requests_mapping");
+ }
+ }
+ Entry::Vacant(_) => {
+ debug_unreachable!("expected to find node address in active_requests_mapping");
+ error!("expected to find node address in active_requests_mapping");
+ }
+ }
+ }
+
+ pub fn get(&self, node_address: &NodeAddress) -> Option<&Vec> {
self.active_requests_mapping.get(node_address)
}
+ /// Remove a single request identified by its nonce.
pub fn remove_by_nonce(&mut self, nonce: &MessageNonce) -> Option<(NodeAddress, RequestCall)> {
- match self.active_requests_nonce_mapping.remove(nonce) {
- Some(node_address) => match self.active_requests_mapping.remove(&node_address) {
- Some(request_call) => Some((node_address, request_call)),
- None => {
- debug_unreachable!("A matching request call doesn't exist");
- error!("A matching request call doesn't exist");
- None
+ let node_address = self.active_requests_nonce_mapping.remove(nonce)?;
+ match self.active_requests_mapping.entry(node_address.clone()) {
+ Entry::Vacant(_) => {
+ debug_unreachable!("expected to find node address in active_requests_mapping");
+ error!("expected to find node address in active_requests_mapping");
+ None
+ }
+ Entry::Occupied(mut requests) => {
+ let result = requests
+ .get()
+ .iter()
+ .position(|req| req.packet().message_nonce() == nonce)
+ .map(|index| (node_address, requests.get_mut().remove(index)));
+ if requests.get().is_empty() {
+ requests.remove();
}
- },
- None => None,
+ result
+ }
}
}
- pub fn remove(&mut self, node_address: &NodeAddress) -> Option {
- match self.active_requests_mapping.remove(node_address) {
- Some(request_call) => {
- // Remove the associated nonce mapping.
- match self
- .active_requests_nonce_mapping
- .remove(request_call.packet().message_nonce())
- {
- Some(_) => Some(request_call),
- None => {
- debug_unreachable!("A matching nonce mapping doesn't exist");
- error!("A matching nonce mapping doesn't exist");
- None
- }
+ /// Remove all requests associated with a node.
+ pub fn remove_requests(&mut self, node_address: &NodeAddress) -> Option> {
+ let requests = self.active_requests_mapping.remove(node_address)?;
+ // Account for node addresses in `active_requests_nonce_mapping` with an empty list
+ if requests.is_empty() {
+ debug_unreachable!("expected to find requests in active_requests_mapping");
+ return None;
+ }
+ for req in &requests {
+ if self
+ .active_requests_nonce_mapping
+ .remove(req.packet().message_nonce())
+ .is_none()
+ {
+ debug_unreachable!("expected to find req with nonce");
+ error!("expected to find req with nonce");
+ }
+ }
+ Some(requests)
+ }
+
+ /// Remove a single request identified by its id.
+ pub fn remove_request(
+ &mut self,
+ node_address: &NodeAddress,
+ id: &RequestId,
+ ) -> Option {
+ match self.active_requests_mapping.entry(node_address.clone()) {
+ Entry::Vacant(_) => None,
+ Entry::Occupied(mut requests) => {
+ let index = requests.get().iter().position(|req| {
+ let req_id: RequestId = req.id().into();
+ &req_id == id
+ })?;
+ let request_call = requests.get_mut().remove(index);
+ if requests.get().is_empty() {
+ requests.remove();
}
+ // Remove the associated nonce mapping.
+ self.active_requests_nonce_mapping
+ .remove(request_call.packet().message_nonce());
+ Some(request_call)
}
- None => None,
}
}
@@ -80,10 +154,12 @@ impl ActiveRequests {
}
}
- for (address, request) in self.active_requests_mapping.iter() {
- let nonce = request.packet().message_nonce();
- if !self.active_requests_nonce_mapping.contains_key(nonce) {
- panic!("Address {} maps to request with nonce {:?}, which does not exist in `active_requests_nonce_mapping`", address, nonce);
+ for (address, requests) in self.active_requests_mapping.iter() {
+ for req in requests {
+ let nonce = req.packet().message_nonce();
+ if !self.active_requests_nonce_mapping.contains_key(nonce) {
+ panic!("Address {} maps to request with nonce {:?}, which does not exist in `active_requests_nonce_mapping`", address, nonce);
+ }
}
}
}
@@ -92,12 +168,27 @@ impl ActiveRequests {
impl Stream for ActiveRequests {
type Item = Result<(NodeAddress, RequestCall), String>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
- match self.active_requests_mapping.poll_next_unpin(cx) {
- Poll::Ready(Some(Ok((node_address, request_call)))) => {
- // Remove the associated nonce mapping.
- self.active_requests_nonce_mapping
- .remove(request_call.packet().message_nonce());
- Poll::Ready(Some(Ok((node_address, request_call))))
+ match self.active_requests_nonce_mapping.poll_next_unpin(cx) {
+ Poll::Ready(Some(Ok((nonce, node_address)))) => {
+ match self.active_requests_mapping.entry(node_address.clone()) {
+ Entry::Vacant(_) => Poll::Ready(None),
+ Entry::Occupied(mut requests) => {
+ match requests
+ .get()
+ .iter()
+ .position(|req| req.packet().message_nonce() == &nonce)
+ {
+ Some(index) => {
+ let result = (node_address, requests.get_mut().remove(index));
+ if requests.get().is_empty() {
+ requests.remove();
+ }
+ Poll::Ready(Some(Ok(result)))
+ }
+ None => Poll::Ready(None),
+ }
+ }
+ }
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(None) => Poll::Ready(None),
diff --git a/src/handler/crypto/ecdh.rs b/src/handler/crypto/ecdh.rs
index 350db7d0c..ce12f7da7 100644
--- a/src/handler/crypto/ecdh.rs
+++ b/src/handler/crypto/ecdh.rs
@@ -7,10 +7,10 @@ use super::k256::{
pub fn ecdh(public_key: &VerifyingKey, secret_key: &SigningKey) -> Vec {
k256::PublicKey::from_affine(
- (k256::PublicKey::from_sec1_bytes(public_key.to_bytes().as_ref())
+ (k256::PublicKey::from_sec1_bytes(public_key.to_sec1_bytes().as_ref())
.unwrap()
.to_projective()
- * k256::SecretKey::from_be_bytes(&secret_key.to_bytes())
+ * k256::SecretKey::from_slice(&secret_key.to_bytes())
.unwrap()
.to_nonzero_scalar()
.as_ref())
diff --git a/src/handler/crypto/mod.rs b/src/handler/crypto/mod.rs
index dbeb10991..6b609d3a8 100644
--- a/src/handler/crypto/mod.rs
+++ b/src/handler/crypto/mod.rs
@@ -6,7 +6,7 @@
//! encryption and key-derivation algorithms. Future versions may abstract some of these to allow
//! for different algorithms.
use crate::{
- error::Discv5Error,
+ error::Error,
node_info::NodeContact,
packet::{ChallengeData, MessageNonce},
};
@@ -19,8 +19,7 @@ use enr::{
k256::{
self,
ecdsa::{
- digest::Update,
- signature::{DigestSigner, DigestVerifier, Signature as _},
+ signature::{DigestSigner, DigestVerifier},
Signature,
},
sha2::{Digest, Sha256},
@@ -49,18 +48,16 @@ pub(crate) fn generate_session_keys(
local_id: &NodeId,
contact: &NodeContact,
challenge_data: &ChallengeData,
-) -> Result<(Key, Key, Vec), Discv5Error> {
+) -> Result<(Key, Key, Vec), Error> {
let (secret, ephem_pk) = {
match contact.public_key() {
CombinedPublicKey::Secp256k1(remote_pk) => {
- let ephem_sk = k256::ecdsa::SigningKey::random(rand::thread_rng());
+ let ephem_sk = k256::ecdsa::SigningKey::random(&mut rand::thread_rng());
let secret = ecdh(&remote_pk, &ephem_sk);
let ephem_pk = ephem_sk.verifying_key();
- (secret, ephem_pk.to_bytes().to_vec())
- }
- CombinedPublicKey::Ed25519(_) => {
- return Err(Discv5Error::KeyTypeNotSupported("Ed25519"))
+ (secret, ephem_pk.to_sec1_bytes().to_vec())
}
+ CombinedPublicKey::Ed25519(_) => return Err(Error::KeyTypeNotSupported("Ed25519")),
}
};
@@ -75,7 +72,7 @@ fn derive_key(
first_id: &NodeId,
second_id: &NodeId,
challenge_data: &ChallengeData,
-) -> Result<(Key, Key), Discv5Error> {
+) -> Result<(Key, Key), Error> {
let mut info = [0u8; INFO_LENGTH];
info[0..26].copy_from_slice(KEY_AGREEMENT_STRING.as_bytes());
info[26..26 + NODE_ID_LENGTH].copy_from_slice(&first_id.raw());
@@ -85,7 +82,7 @@ fn derive_key(
let mut okm = [0u8; 2 * KEY_LENGTH];
hk.expand(&info, &mut okm)
- .map_err(|_| Discv5Error::KeyDerivationFailed)?;
+ .map_err(|_| Error::KeyDerivationFailed)?;
let mut initiator_key: Key = Default::default();
let mut recipient_key: Key = Default::default();
@@ -102,17 +99,17 @@ pub(crate) fn derive_keys_from_pubkey(
remote_id: &NodeId,
challenge_data: &ChallengeData,
ephem_pubkey: &[u8],
-) -> Result<(Key, Key), Discv5Error> {
+) -> Result<(Key, Key), Error> {
let secret = {
match local_key {
CombinedKey::Secp256k1(key) => {
// convert remote pubkey into secp256k1 public key
// the key type should match our own node record
let remote_pubkey = k256::ecdsa::VerifyingKey::from_sec1_bytes(ephem_pubkey)
- .map_err(|_| Discv5Error::InvalidRemotePublicKey)?;
+ .map_err(|_| Error::InvalidRemotePublicKey)?;
ecdh(&remote_pubkey, key)
}
- CombinedKey::Ed25519(_) => return Err(Discv5Error::KeyTypeNotSupported("Ed25519")),
+ CombinedKey::Ed25519(_) => return Err(Error::KeyTypeNotSupported("Ed25519")),
}
};
@@ -128,18 +125,18 @@ pub(crate) fn sign_nonce(
challenge_data: &ChallengeData,
ephem_pubkey: &[u8],
dst_id: &NodeId,
-) -> Result, Discv5Error> {
+) -> Result, Error> {
let signing_message = generate_signing_nonce(challenge_data, ephem_pubkey, dst_id);
match signing_key {
CombinedKey::Secp256k1(key) => {
- let message = Sha256::new().chain(signing_message);
+ let message = Sha256::new().chain_update(signing_message);
let signature: Signature = key
.try_sign_digest(message)
- .map_err(|e| Discv5Error::Error(format!("Failed to sign message: {e}")))?;
- Ok(signature.as_bytes().to_vec())
+ .map_err(|e| Error::Error(format!("Failed to sign message: {e}")))?;
+ Ok(signature.to_vec())
}
- CombinedKey::Ed25519(_) => Err(Discv5Error::KeyTypeNotSupported("Ed25519")),
+ CombinedKey::Ed25519(_) => Err(Error::KeyTypeNotSupported("Ed25519")),
}
}
@@ -157,7 +154,7 @@ pub(crate) fn verify_authentication_nonce(
CombinedPublicKey::Secp256k1(key) => {
if let Ok(sig) = k256::ecdsa::Signature::try_from(sig) {
return key
- .verify_digest(Sha256::new().chain(signing_nonce), &sig)
+ .verify_digest(Sha256::new().chain_update(signing_nonce), &sig)
.is_ok();
}
false
@@ -192,9 +189,9 @@ pub(crate) fn decrypt_message(
message_nonce: MessageNonce,
msg: &[u8],
aad: &[u8],
-) -> Result, Discv5Error> {
+) -> Result, Error> {
if msg.len() < 16 {
- return Err(Discv5Error::DecryptionFailed(
+ return Err(Error::DecryptionFailed(
"Message not long enough to contain a MAC".into(),
));
}
@@ -202,7 +199,7 @@ pub(crate) fn decrypt_message(
let aead = Aes128Gcm::new(GenericArray::from_slice(key));
let payload = Payload { msg, aad };
aead.decrypt(GenericArray::from_slice(&message_nonce), payload)
- .map_err(|e| Discv5Error::DecryptionFailed(e.to_string()))
+ .map_err(|e| Error::DecryptionFailed(e.to_string()))
}
/* Encryption related functions */
@@ -214,17 +211,19 @@ pub(crate) fn encrypt_message(
message_nonce: MessageNonce,
msg: &[u8],
aad: &[u8],
-) -> Result, Discv5Error> {
+) -> Result, Error> {
let aead = Aes128Gcm::new(GenericArray::from_slice(key));
let payload = Payload { msg, aad };
aead.encrypt(GenericArray::from_slice(&message_nonce), payload)
- .map_err(|e| Discv5Error::DecryptionFailed(e.to_string()))
+ .map_err(|e| Error::DecryptionFailed(e.to_string()))
}
#[cfg(test)]
mod tests {
+ use crate::packet::DefaultProtocolId;
+
use super::*;
- use enr::{CombinedKey, EnrBuilder, EnrKey};
+ use enr::{CombinedKey, Enr, EnrKey};
use std::convert::TryInto;
fn hex_decode(x: &'static str) -> Vec {
@@ -260,7 +259,7 @@ mod tests {
.unwrap();
let remote_pk = k256::ecdsa::VerifyingKey::from_sec1_bytes(&remote_pubkey).unwrap();
- let local_sk = k256::ecdsa::SigningKey::from_bytes(&local_secret_key).unwrap();
+ let local_sk = k256::ecdsa::SigningKey::from_slice(&local_secret_key).unwrap();
let secret = ecdh(&remote_pk, &local_sk);
assert_eq!(secret, expected_secret);
@@ -276,7 +275,7 @@ mod tests {
.unwrap();
let remote_pk = k256::ecdsa::VerifyingKey::from_sec1_bytes(&dest_pubkey).unwrap();
- let local_sk = k256::ecdsa::SigningKey::from_bytes(&ephem_key).unwrap();
+ let local_sk = k256::ecdsa::SigningKey::from_slice(&ephem_key).unwrap();
let secret = ecdh(&remote_pk, &local_sk);
@@ -310,7 +309,7 @@ mod tests {
let expected_sig = hex::decode("94852a1e2318c4e5e9d422c98eaf19d1d90d876b29cd06ca7cb7546d0fff7b484fe86c09a064fe72bdbef73ba8e9c34df0cd2b53e9d65528c2c7f336d5dfc6e6").unwrap();
let challenge_data = ChallengeData::try_from(hex::decode("000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000").unwrap().as_slice()).unwrap();
- let key = k256::ecdsa::SigningKey::from_bytes(&local_secret_key).unwrap();
+ let key = k256::ecdsa::SigningKey::from_slice(&local_secret_key).unwrap();
let sig = sign_nonce(&key.into(), &challenge_data, &ephemeral_pubkey, &dst_id).unwrap();
assert_eq!(sig, expected_sig);
@@ -342,12 +341,12 @@ mod tests {
let node1_key = CombinedKey::generate_secp256k1();
let node2_key = CombinedKey::generate_secp256k1();
- let node1_enr = EnrBuilder::new("v4")
+ let node1_enr = Enr::builder()
.ip("127.0.0.1".parse().unwrap())
.udp4(9000)
.build(&node1_key)
.unwrap();
- let node2_enr = EnrBuilder::new("v4")
+ let node2_enr = Enr::builder()
.ip("127.0.0.1".parse().unwrap())
.udp4(9000)
.build(&node2_key)
@@ -394,7 +393,8 @@ mod tests {
let dst_id: NodeId = node_key_2().public().into();
let encoded_ref_packet = hex::decode("00000000000000000000000000000000088b3d4342774649325f313964a39e55ea96c005ad52be8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d34c4f53245d08dab84102ed931f66d1492acb308fa1c6715b9d139b81acbdcc").unwrap();
let (_packet, auth_data) =
- crate::packet::Packet::decode(&dst_id, &encoded_ref_packet).unwrap();
+ crate::packet::Packet::decode::(&dst_id, &encoded_ref_packet)
+ .unwrap();
let ciphertext = hex::decode("b84102ed931f66d1492acb308fa1c6715b9d139b81acbdcc").unwrap();
let read_key = hex::decode("00000000000000000000000000000000").unwrap();
diff --git a/src/handler/mod.rs b/src/handler/mod.rs
index fac44c9da..d1e13b022 100644
--- a/src/handler/mod.rs
+++ b/src/handler/mod.rs
@@ -27,10 +27,10 @@
//! Messages from a node on the network come by [`Socket`] and get the form of a [`HandlerOut`]
//! and can be forwarded to the application layer via the send channel.
use crate::{
- config::Discv5Config,
+ config::Config,
discv5::PERMIT_BAN_LIST,
- error::{Discv5Error, RequestError},
- packet::{ChallengeData, IdNonce, MessageNonce, Packet, PacketKind},
+ error::{Error, RequestError},
+ packet::{ChallengeData, IdNonce, MessageNonce, Packet, PacketKind, ProtocolIdentity},
rpc::{Message, Request, RequestBody, RequestId, Response, ResponseBody},
socket,
socket::{FilterConfig, Socket},
@@ -39,7 +39,9 @@ use crate::{
use delay_map::HashMapDelay;
use enr::{CombinedKey, NodeId};
use futures::prelude::*;
+use more_asserts::debug_unreachable;
use parking_lot::RwLock;
+use smallvec::SmallVec;
use std::{
collections::HashMap,
convert::TryFrom,
@@ -63,7 +65,7 @@ pub use crate::node_info::{NodeAddress, NodeContact};
use crate::metrics::METRICS;
-use crate::lru_time_cache::LruTimeCache;
+use crate::{lru_time_cache::LruTimeCache, socket::ListenConfig};
use active_requests::ActiveRequests;
use request_call::RequestCall;
use session::Session;
@@ -72,6 +74,12 @@ use session::Session;
// seconds).
const BANNED_NODES_CHECK: u64 = 300; // Check every 5 minutes.
+// The one-time session timeout.
+const ONE_TIME_SESSION_TIMEOUT: u64 = 30;
+
+// The maximum number of established one-time sessions to maintain.
+const ONE_TIME_SESSION_CACHE_CAPACITY: usize = 100;
+
/// Messages sent from the application layer to `Handler`.
#[derive(Debug, Clone, PartialEq)]
#[allow(clippy::large_enum_variant)]
@@ -172,6 +180,15 @@ struct PendingRequest {
request: RequestBody,
}
+impl From<&HandlerReqId> for RequestId {
+ fn from(id: &HandlerReqId) -> Self {
+ match id {
+ HandlerReqId::Internal(id) => id.clone(),
+ HandlerReqId::External(id) => id.clone(),
+ }
+ }
+}
+
/// Process to handle handshakes and sessions established from raw RPC communications between nodes.
pub struct Handler {
/// Configuration for the discv5 service.
@@ -183,7 +200,7 @@ pub struct Handler {
enr: Arc>,
/// The key to sign the ENR and set up encrypted communication with peers.
key: Arc>,
- /// Pending raw requests.
+ /// Active requests that are awaiting a response.
active_requests: ActiveRequests,
/// The expected responses by SocketAddr which allows packets to pass the underlying filter.
filter_expected_responses: Arc>>,
@@ -193,12 +210,14 @@ pub struct Handler {
active_challenges: HashMapDelay,
/// Established sessions with peers.
sessions: LruTimeCache,
+ /// Established sessions with peers for a specific request, stored just one per node.
+ one_time_sessions: LruTimeCache,
/// The channel to receive messages from the application layer.
service_recv: mpsc::UnboundedReceiver,
/// The channel to send messages to the application layer.
service_send: mpsc::Sender,
- /// The listening socket to filter out any attempted requests to self.
- listen_socket: SocketAddr,
+ /// The listening sockets to filter out any attempted requests to self.
+ listen_sockets: SmallVec<[SocketAddr; 2]>,
/// The discovery v5 UDP socket tasks.
socket: Socket,
/// Exit channel to shutdown the handler.
@@ -213,11 +232,10 @@ type HandlerReturn = (
impl Handler {
/// A new Session service which instantiates the UDP socket send/recv tasks.
- pub async fn spawn(
+ pub async fn spawn(
enr: Arc>,
key: Arc>,
- listen_socket: SocketAddr,
- config: Discv5Config,
+ config: Config,
) -> Result {
let (exit_sender, exit) = oneshot::channel();
// create the channels to send/receive messages from the application
@@ -233,7 +251,6 @@ impl Handler {
let node_id = enr.read().node_id();
// enable the packet filter if required
-
let filter_config = FilterConfig {
enabled: config.enable_packet_filter,
rate_limiter: config.filter_rate_limiter.clone(),
@@ -241,18 +258,32 @@ impl Handler {
max_bans_per_ip: config.filter_max_bans_per_ip,
};
+ let mut listen_sockets = SmallVec::default();
+ match config.listen_config {
+ ListenConfig::Ipv4 { ip, port } => listen_sockets.push((ip, port).into()),
+ ListenConfig::Ipv6 { ip, port } => listen_sockets.push((ip, port).into()),
+ ListenConfig::DualStack {
+ ipv4,
+ ipv4_port,
+ ipv6,
+ ipv6_port,
+ } => {
+ listen_sockets.push((ipv4, ipv4_port).into());
+ listen_sockets.push((ipv6, ipv6_port).into());
+ }
+ };
+
let socket_config = socket::SocketConfig {
executor: config.executor.clone().expect("Executor must exist"),
- socket_addr: listen_socket,
filter_config,
+ listen_config: config.listen_config.clone(),
local_node_id: node_id,
expected_responses: filter_expected_responses.clone(),
ban_duration: config.ban_duration,
- ip_mode: config.ip_mode,
};
// Attempt to bind to the socket before spinning up the send/recv tasks.
- let socket = Socket::new(socket_config).await?;
+ let socket = Socket::new::(socket_config).await?;
config
.executor
@@ -271,22 +302,26 @@ impl Handler {
config.session_timeout,
Some(config.session_cache_capacity),
),
+ one_time_sessions: LruTimeCache::new(
+ Duration::from_secs(ONE_TIME_SESSION_TIMEOUT),
+ Some(ONE_TIME_SESSION_CACHE_CAPACITY),
+ ),
active_challenges: HashMapDelay::new(config.request_timeout),
service_recv,
service_send,
- listen_socket,
+ listen_sockets,
socket,
exit,
};
debug!("Handler Starting");
- handler.start().await;
+ handler.start::
().await;
}));
Ok((exit_sender, handler_send, handler_recv))
}
/// The main execution loop for the handler.
- async fn start(&mut self) {
+ async fn start(&mut self) {
let mut banned_nodes_check = tokio::time::interval(Duration::from_secs(BANNED_NODES_CHECK));
loop {
@@ -295,7 +330,7 @@ impl Handler {
match handler_request {
HandlerIn::Request(contact, request) => {
let Request { id, body: request } = *request;
- if let Err(request_error) = self.send_request(contact, HandlerReqId::External(id.clone()), request).await {
+ if let Err(request_error) = self.send_request::(contact, HandlerReqId::External(id.clone()), request).await {
// If the sending failed report to the application
if let Err(e) = self.service_send.send(HandlerOut::RequestFailed(id, request_error)).await {
warn!("Failed to inform that request failed {}", e)
@@ -304,25 +339,25 @@ impl Handler {
}
HandlerIn::RequestNoPending(contact, request) => {
let Request { id, body: request } = *request;
- if let Err(request_error) = self.send_request_no_pending(contact, HandlerReqId::External(id.clone()), request).await {
+ if let Err(request_error) = self.send_request_no_pending::
(contact, HandlerReqId::External(id.clone()), request).await {
// If the sending failed report to the application
let _ = self.service_send.send(HandlerOut::RequestFailed(id, request_error)).await;
}
}
- HandlerIn::Response(dst, response) => self.send_response(dst, *response).await,
- HandlerIn::WhoAreYou(wru_ref, enr) => self.send_challenge(wru_ref, enr).await,
+ HandlerIn::Response(dst, response) => self.send_response::
(dst, *response).await,
+ HandlerIn::WhoAreYou(wru_ref, enr) => self.send_challenge::
(wru_ref, enr).await,
}
}
Some(inbound_packet) = self.socket.recv.recv() => {
- self.process_inbound_packet(inbound_packet).await;
+ self.process_inbound_packet::
(inbound_packet).await;
}
- Some(Ok((node_address, pending_request))) = self.active_requests.next() => {
- self.handle_request_timeout(node_address, pending_request).await;
+ Some(Ok((node_address, active_request))) = self.active_requests.next() => {
+ self.handle_request_timeout(node_address, active_request).await;
}
Some(Ok((node_address, _challenge))) = self.active_challenges.next() => {
// A challenge has expired. There could be pending requests awaiting this
// challenge. We process them here
- self.send_next_request(node_address).await;
+ self.send_pending_requests::
(&node_address).await;
}
_ = banned_nodes_check.tick() => self.unban_nodes_check(), // Unban nodes that are past the timeout
_ = &mut self.exit => {
@@ -333,14 +368,17 @@ impl Handler {
}
/// Processes an inbound decoded packet.
- async fn process_inbound_packet(&mut self, inbound_packet: socket::InboundPacket) {
+ async fn process_inbound_packet(
+ &mut self,
+ inbound_packet: socket::InboundPacket,
+ ) {
let message_nonce = inbound_packet.header.message_nonce;
match inbound_packet.header.kind {
PacketKind::WhoAreYou { enr_seq, .. } => {
let challenge_data =
ChallengeData::try_from(inbound_packet.authenticated_data.as_slice())
.expect("Must be correct size");
- self.handle_challenge(
+ self.handle_challenge::(
inbound_packet.src_address,
message_nonce,
enr_seq,
@@ -358,7 +396,7 @@ impl Handler {
socket_addr: inbound_packet.src_address,
node_id: src_id,
};
- self.handle_auth_message(
+ self.handle_auth_message::
(
node_address,
message_nonce,
&id_nonce_sig,
@@ -433,7 +471,7 @@ impl Handler {
}
/// Sends a `Request` to a node.
- async fn send_request(
+ async fn send_request(
&mut self,
contact: NodeContact,
request_id: HandlerReqId,
@@ -441,19 +479,20 @@ impl Handler {
) -> Result<(), RequestError> {
let node_address = contact.node_address();
- if node_address.socket_addr == self.listen_socket {
+ if self.listen_sockets.contains(&node_address.socket_addr) {
debug!("Filtered request to self");
return Err(RequestError::SelfRequest);
}
- // If there is already an active request or an active challenge (WHOAREYOU sent) for this node, add to pending requests
- if self.active_requests.get(&node_address).is_some()
- || self.active_challenges.get(&node_address).is_some()
+ // If there is already an active challenge (WHOAREYOU sent) for this node, or if we are
+ // awaiting a session with this node to be established, add the request to pending requests.
+ if self.active_challenges.get(&node_address).is_some()
+ || self.is_awaiting_session_to_be_established(&node_address)
{
trace!("Request queued for node: {}", node_address);
self.pending_requests
.entry(node_address)
- .or_insert_with(Vec::new)
+ .or_default()
.push(PendingRequest {
contact,
request_id,
@@ -472,18 +511,17 @@ impl Handler {
},
};
let packet = session
- .encrypt_message(self.node_id, &request.encode())
+ .encrypt_message::(self.node_id, &request.encode())
.map_err(|e| RequestError::EncryptionFailed(format!("{e:?}")))?;
(packet, false)
} else {
- // No session exists, start a new handshake
+ // No session exists, start a new handshake initiating a new session
trace!(
"Starting session. Sending random packet to: {}",
node_address
);
let packet =
Packet::new_random(&self.node_id).map_err(RequestError::EntropyFailure)?;
- // We are initiating a new session
(packet, true)
}
};
@@ -505,7 +543,7 @@ impl Handler {
// Sends a request to a node if a session is established, otherwise revert to to normal send
// request. This unreliable sends which do not timeout or error.
- async fn send_request_no_pending(
+ async fn send_request_no_pending(
&mut self,
contact: NodeContact,
request_id: HandlerReqId,
@@ -513,7 +551,7 @@ impl Handler {
) -> Result<(), RequestError> {
let node_address = contact.node_address();
- if node_address.socket_addr == self.listen_socket {
+ if self.listen_sockets.contains(&node_address.socket_addr) {
debug!("Filtered request to self");
return Err(RequestError::SelfRequest);
}
@@ -528,11 +566,11 @@ impl Handler {
};
// Encrypt the message and send
let packet = session
- .encrypt_message(self.node_id, &request.clone().encode())
+ .encrypt_message::(self.node_id, &request.clone().encode())
.map_err(|e| RequestError::EncryptionFailed(format!("{:?}", e)))?;
packet
} else {
- return self.send_request(contact, request_id, request).await;
+ return self.send_request::
(contact, request_id, request).await;
}
};
@@ -543,31 +581,39 @@ impl Handler {
}
/// Sends an RPC Response.
- async fn send_response(&mut self, node_address: NodeAddress, response: Response) {
+ async fn send_response(
+ &mut self,
+ node_address: NodeAddress,
+ response: Response,
+ ) {
// Check for an established session
- if let Some(session) = self.sessions.get_mut(&node_address) {
- // Encrypt the message and send
- let packet = match session.encrypt_message(self.node_id, &response.encode()) {
- Ok(packet) => packet,
- Err(e) => {
- warn!("Could not encrypt response: {:?}", e);
- return;
- }
- };
- self.send(node_address, packet).await;
+ let packet = if let Some(session) = self.sessions.get_mut(&node_address) {
+ session.encrypt_message::(self.node_id, &response.encode())
+ } else if let Some(mut session) = self.remove_one_time_session(&node_address, &response.id)
+ {
+ session.encrypt_message::
(self.node_id, &response.encode())
} else {
// Either the session is being established or has expired. We simply drop the
// response in this case.
- warn!(
+ return warn!(
"Session is not established. Dropping response {} for node: {}",
response, node_address.node_id
);
+ };
+
+ match packet {
+ Ok(packet) => self.send(node_address, packet).await,
+ Err(e) => warn!("Could not encrypt response: {:?}", e),
}
}
/// This is called in response to a `HandlerOut::WhoAreYou` event. The applications finds the
/// highest known ENR for a node then we respond to the node with a WHOAREYOU packet.
- async fn send_challenge(&mut self, wru_ref: WhoAreYouRef, remote_enr: Option) {
+ async fn send_challenge(
+ &mut self,
+ wru_ref: WhoAreYouRef,
+ remote_enr: Option,
+ ) {
let node_address = wru_ref.0;
let message_nonce = wru_ref.1;
@@ -590,7 +636,7 @@ impl Handler {
let enr_seq = remote_enr.clone().map_or_else(|| 0, |enr| enr.seq());
let id_nonce: IdNonce = rand::random();
let packet = Packet::new_whoareyou(message_nonce, id_nonce, enr_seq);
- let challenge_data = ChallengeData::try_from(packet.authenticated_data().as_slice())
+ let challenge_data = ChallengeData::try_from(packet.authenticated_data::().as_slice())
.expect("Must be the correct challenge size");
debug!("Sending WHOAREYOU to {}", node_address);
self.add_expected_response(node_address.socket_addr);
@@ -607,7 +653,7 @@ impl Handler {
/* Packet Handling */
/// Handles a WHOAREYOU packet that was received from the network.
- async fn handle_challenge(
+ async fn handle_challenge(
&mut self,
src_address: SocketAddr,
request_nonce: MessageNonce,
@@ -620,7 +666,7 @@ impl Handler {
Some((node_address, request_call)) => {
// Verify that the src_addresses match
if node_address.socket_addr != src_address {
- trace!("Received a WHOAREYOU packet for a message with a non-expected source. Source {}, expected_source: {} message_nonce {}", src_address, node_address.socket_addr, hex::encode(request_nonce));
+ debug!("Received a WHOAREYOU packet for a message with a non-expected source. Source {}, expected_source: {} message_nonce {}", src_address, node_address.socket_addr, hex::encode(request_nonce));
// Add the request back if src_address doesn't match
self.active_requests.insert(node_address, request_call);
return;
@@ -670,7 +716,7 @@ impl Handler {
};
// Generate a new session and authentication packet
- let (auth_packet, mut session) = match Session::encrypt_with_header(
+ let (auth_packet, mut session) = match Session::encrypt_with_header::(
request_call.contact(),
self.key.clone(),
updated_enr,
@@ -701,22 +747,25 @@ impl Handler {
// All sent requests must have an associated node_id. Therefore the following
// must not panic.
let node_address = request_call.contact().node_address();
+ let auth_message_nonce = auth_packet.header.message_nonce;
match request_call.contact().enr() {
Some(enr) => {
// NOTE: Here we decide if the session is outgoing or ingoing. The condition for an
// outgoing session is that we originally sent a RANDOM packet (signifying we did
// not have a session for a request) and the packet is not a PING (we are not
// trying to update an old session that may have expired.
- let connection_direction = {
- match (request_call.initiating_session(), &request_call.body()) {
- (true, RequestBody::Ping { .. }) => ConnectionDirection::Incoming,
- (true, _) => ConnectionDirection::Outgoing,
- (false, _) => ConnectionDirection::Incoming,
- }
+ let connection_direction = if request_call.initiating_session() {
+ ConnectionDirection::Outgoing
+ } else {
+ ConnectionDirection::Incoming
};
// We already know the ENR. Send the handshake response packet
- trace!("Sending Authentication response to node: {}", node_address);
+ trace!(
+ "Sending Authentication response to node: {} ({:?})",
+ node_address,
+ request_call.id()
+ );
request_call.update_packet(auth_packet.clone());
request_call.set_handshake_sent();
request_call.set_initiating_session(false);
@@ -740,7 +789,11 @@ impl Handler {
// Send the Auth response
let contact = request_call.contact().clone();
- trace!("Sending Authentication response to node: {}", node_address);
+ trace!(
+ "Sending Authentication response to node: {} ({:?})",
+ node_address,
+ request_call.id()
+ );
request_call.update_packet(auth_packet.clone());
request_call.set_handshake_sent();
// Reinsert the request_call
@@ -751,14 +804,15 @@ impl Handler {
let request = RequestBody::FindNode { distances: vec![0] };
session.awaiting_enr = Some(id.clone());
if let Err(e) = self
- .send_request(contact, HandlerReqId::Internal(id), request)
+ .send_request::
(contact, HandlerReqId::Internal(id), request)
.await
{
warn!("Failed to send Enr request {}", e)
}
}
}
- self.new_session(node_address, session);
+ self.new_session::
(node_address.clone(), session, Some(auth_message_nonce))
+ .await;
}
/// Verifies a Node ENR to it's observed address. If it fails, any associated session is also
@@ -779,7 +833,7 @@ impl Handler {
/// Handle a message that contains an authentication header.
#[allow(clippy::too_many_arguments)]
- async fn handle_auth_message(
+ async fn handle_auth_message(
&mut self,
node_address: NodeAddress,
message_nonce: MessageNonce,
@@ -807,7 +861,9 @@ impl Handler {
ephem_pubkey,
enr_record,
) {
- Ok((session, enr)) => {
+ Ok((mut session, enr)) => {
+ // Remove the expected response for the challenge.
+ self.remove_expected_response(node_address.socket_addr);
// Receiving an AuthResponse must give us an up-to-date view of the node ENR.
// Verify the ENR is valid
if self.verify_enr(&enr, &node_address) {
@@ -826,7 +882,11 @@ impl Handler {
{
warn!("Failed to inform of established session {}", e)
}
- self.new_session(node_address.clone(), session);
+ // When (re-)establishing a session from an outgoing challenge, we do not need
+ // to filter out this request from active requests, so we do not pass
+ // the message nonce on to `new_session`.
+ self.new_session::(node_address.clone(), session, None)
+ .await;
self.handle_message(
node_address.clone(),
message_nonce,
@@ -834,9 +894,6 @@ impl Handler {
authenticated_data,
)
.await;
- // We could have pending messages that were awaiting this session to be
- // established. If so process them.
- self.send_next_request(node_address).await;
} else {
// IP's or NodeAddress don't match. Drop the session.
warn!(
@@ -847,9 +904,41 @@ impl Handler {
);
self.fail_session(&node_address, RequestError::InvalidRemoteEnr, true)
.await;
+
+ // Respond to PING request even if the ENR or NodeAddress don't match
+ // so that the source node can notice its external IP address has been changed.
+ let maybe_ping_request = match session.decrypt_message(
+ message_nonce,
+ message,
+ authenticated_data,
+ ) {
+ Ok(m) => match Message::decode(&m) {
+ Ok(Message::Request(request)) if request.msg_type() == 1 => {
+ Some(request)
+ }
+ _ => None,
+ },
+ _ => None,
+ };
+ if let Some(request) = maybe_ping_request {
+ debug!(
+ "Responding to a PING request using a one-time session. node_address: {}",
+ node_address
+ );
+ self.one_time_sessions
+ .insert(node_address.clone(), (request.id.clone(), session));
+ if let Err(e) = self
+ .service_send
+ .send(HandlerOut::Request(node_address.clone(), Box::new(request)))
+ .await
+ {
+ warn!("Failed to report request to application {}", e);
+ self.one_time_sessions.remove(&node_address);
+ }
+ }
}
}
- Err(Discv5Error::InvalidChallengeSignature(challenge)) => {
+ Err(Error::InvalidChallengeSignature(challenge)) => {
warn!(
"Authentication header contained invalid signature. Ignoring packet from: {}",
node_address
@@ -868,47 +957,43 @@ impl Handler {
}
} else {
warn!(
- "Received an authenticated header without a matching WHOAREYOU request. {}",
- node_address
+ node_id = %node_address.node_id, addr = %node_address.socket_addr,
+ "Received an authenticated header without a matching WHOAREYOU request",
);
}
}
- async fn send_next_request(&mut self, node_address: NodeAddress) {
- // ensure we are not over writing any existing requests
- if self.active_requests.get(&node_address).is_none() {
- if let std::collections::hash_map::Entry::Occupied(mut entry) =
- self.pending_requests.entry(node_address)
+ /// Send all pending requests corresponding to the given node address, that were waiting for a
+ /// new session to be established or when an active outgoing challenge has expired.
+ async fn send_pending_requests(&mut self, node_address: &NodeAddress) {
+ let pending_requests = self
+ .pending_requests
+ .remove(node_address)
+ .unwrap_or_default();
+ for req in pending_requests {
+ trace!(
+ "Sending pending request {} to {node_address}. {}",
+ RequestId::from(&req.request_id),
+ req.request,
+ );
+ if let Err(request_error) = self
+ .send_request::(req.contact, req.request_id.clone(), req.request)
+ .await
{
- // If it exists, there must be a request here
- let PendingRequest {
- contact,
- request_id,
- request,
- } = entry.get_mut().remove(0);
- if entry.get().is_empty() {
- entry.remove();
- }
- trace!("Sending next awaiting message. Node: {}", contact);
- if let Err(request_error) = self
- .send_request(contact, request_id.clone(), request)
- .await
- {
- warn!("Failed to send next awaiting request {}", request_error);
- // Inform the service that the request failed
- match request_id {
- HandlerReqId::Internal(_) => {
- // An internal request could not be sent. For now we do nothing about
- // this.
- }
- HandlerReqId::External(id) => {
- if let Err(e) = self
- .service_send
- .send(HandlerOut::RequestFailed(id, request_error))
- .await
- {
- warn!("Failed to inform that request failed {}", e);
- }
+ warn!("Failed to send next pending request {request_error}");
+ // Inform the service that the request failed
+ match req.request_id {
+ HandlerReqId::Internal(_) => {
+ // An internal request could not be sent. For now we do nothing about
+ // this.
+ }
+ HandlerReqId::External(id) => {
+ if let Err(e) = self
+ .service_send
+ .send(HandlerOut::RequestFailed(id, request_error))
+ .await
+ {
+ warn!("Failed to inform that request failed {e}");
}
}
}
@@ -916,6 +1001,64 @@ impl Handler {
}
}
+ /// Replays all active requests for the given node address, in the case that a new session has
+ /// been established. If an optional message nonce is provided, the corresponding request will
+ /// be skipped, eg. the request that established the new session.
+ async fn replay_active_requests(
+ &mut self,
+ node_address: &NodeAddress,
+ // Optional message nonce to filter out the request used to establish the session.
+ message_nonce: Option,
+ ) {
+ trace!(
+ "Replaying active requests. {}, {:?}",
+ node_address,
+ message_nonce
+ );
+
+ let packets = if let Some(session) = self.sessions.get_mut(node_address) {
+ let mut packets = vec![];
+ for request_call in self
+ .active_requests
+ .get(node_address)
+ .unwrap_or(&vec![])
+ .iter()
+ .filter(|req| {
+ // Except the active request that was used to establish the new session, as it has
+ // already been handled and shouldn't be replayed.
+ if let Some(nonce) = message_nonce.as_ref() {
+ req.packet().message_nonce() != nonce
+ } else {
+ true
+ }
+ })
+ {
+ if let Ok(new_packet) =
+ session.encrypt_message::(self.node_id, &request_call.encode())
+ {
+ packets.push((*request_call.packet().message_nonce(), new_packet));
+ } else {
+ error!(
+ "Failed to re-encrypt packet while replaying active request with id: {:?}",
+ request_call.id()
+ );
+ }
+ }
+
+ packets
+ } else {
+ debug_unreachable!("Attempted to replay active requests but session doesn't exist.");
+ error!("Attempted to replay active requests but session doesn't exist.");
+ return;
+ };
+
+ for (old_nonce, new_packet) in packets {
+ self.active_requests
+ .update_packet(old_nonce, new_packet.clone());
+ self.send(node_address.clone(), new_packet).await;
+ }
+ }
+
/// Handle a standard message that does not contain an authentication header.
#[allow(clippy::single_match)]
async fn handle_message(
@@ -1046,22 +1189,11 @@ impl Handler {
/// Nodes response.
async fn handle_response(&mut self, node_address: NodeAddress, response: Response) {
// Find a matching request, if any
- if let Some(mut request_call) = self.active_requests.remove(&node_address) {
- let id = match request_call.id() {
- HandlerReqId::Internal(id) | HandlerReqId::External(id) => id,
- };
- if id != &response.id {
- trace!(
- "Received an RPC Response to an unknown request. Likely late response. {}",
- node_address
- );
- // add the request back and reset the timer
- self.active_requests.insert(node_address, request_call);
- return;
- }
-
+ if let Some(mut request_call) = self
+ .active_requests
+ .remove_request(&node_address, &response.id)
+ {
// The response matches a request
-
// Check to see if this is a Nodes response, in which case we may require to wait for
// extra responses
if let ResponseBody::Nodes { total, .. } = response.body {
@@ -1115,7 +1247,6 @@ impl Handler {
{
warn!("Failed to inform of response {}", e)
}
- self.send_next_request(node_address).await;
} else {
// This is likely a late response and we have already failed the request. These get
// dropped here.
@@ -1131,14 +1262,49 @@ impl Handler {
self.active_requests.insert(node_address, request_call);
}
- fn new_session(&mut self, node_address: NodeAddress, session: Session) {
+ /// Establishes a new session with a peer, or re-establishes an existing session if a
+ /// new challenge was issued during an ongoing session.
+ async fn new_session(
+ &mut self,
+ node_address: NodeAddress,
+ session: Session,
+ // Optional message nonce is required to filter out the request that was used in the
+ // handshake to re-establish a session, if applicable.
+ message_nonce: Option,
+ ) {
if let Some(current_session) = self.sessions.get_mut(&node_address) {
current_session.update(session);
+ // If a session is re-established, due to a new handshake during an ongoing
+ // session, we need to replay any active requests from the prior session, excluding
+ // the request that was used to re-establish the session handshake.
+ self.replay_active_requests::(&node_address, message_nonce)
+ .await;
} else {
- self.sessions.insert(node_address, session);
+ self.sessions.insert(node_address.clone(), session);
METRICS
.active_sessions
.store(self.sessions.len(), Ordering::Relaxed);
+ // We could have pending messages that were awaiting this session to be
+ // established. If so process them.
+ self.send_pending_requests::
(&node_address).await;
+ }
+ }
+
+ /// Remove one-time session by the given NodeAddress and RequestId if exists.
+ fn remove_one_time_session(
+ &mut self,
+ node_address: &NodeAddress,
+ request_id: &RequestId,
+ ) -> Option {
+ match self.one_time_sessions.peek(node_address) {
+ Some((id, _)) if id == request_id => {
+ let (_, session) = self
+ .one_time_sessions
+ .remove(node_address)
+ .expect("one-time session must exist");
+ Some(session)
+ }
+ _ => None,
}
}
@@ -1171,7 +1337,7 @@ impl Handler {
.await;
}
- /// Removes a session and updates associated metrics and fields.
+ /// Removes a session, fails all of that session's active & pending requests, and updates associated metrics and fields.
async fn fail_session(
&mut self,
node_address: &NodeAddress,
@@ -1184,6 +1350,7 @@ impl Handler {
.active_sessions
.store(self.sessions.len(), Ordering::Relaxed);
}
+ // fail all pending requests
if let Some(to_remove) = self.pending_requests.remove(node_address) {
for PendingRequest { request_id, .. } in to_remove {
match request_id {
@@ -1202,6 +1369,28 @@ impl Handler {
}
}
}
+ // fail all active requests
+ for req in self
+ .active_requests
+ .remove_requests(node_address)
+ .unwrap_or_default()
+ {
+ match req.id() {
+ HandlerReqId::Internal(_) => {
+ // Do not report failures on requests belonging to the handler.
+ }
+ HandlerReqId::External(id) => {
+ if let Err(e) = self
+ .service_send
+ .send(HandlerOut::RequestFailed(id.clone(), error.clone()))
+ .await
+ {
+ warn!("Failed to inform request failure {e}")
+ }
+ }
+ }
+ self.remove_expected_response(node_address.socket_addr);
+ }
}
/// Sends a packet to the send handler to be encoded and sent.
@@ -1226,4 +1415,19 @@ impl Handler {
.ban_nodes
.retain(|_, time| time.is_none() || Some(Instant::now()) < *time);
}
+
+ /// Returns whether a session with this node does not exist and a request that initiates
+ /// a session has been sent.
+ fn is_awaiting_session_to_be_established(&mut self, node_address: &NodeAddress) -> bool {
+ if self.sessions.get(node_address).is_some() {
+ // session exists
+ return false;
+ }
+
+ if let Some(requests) = self.active_requests.get(node_address) {
+ requests.iter().any(|req| req.initiating_session())
+ } else {
+ false
+ }
+ }
}
diff --git a/src/handler/request_call.rs b/src/handler/request_call.rs
index b91c7643e..42506aa76 100644
--- a/src/handler/request_call.rs
+++ b/src/handler/request_call.rs
@@ -1,4 +1,4 @@
-pub use crate::node_info::{NodeAddress, NodeContact};
+pub use crate::node_info::NodeContact;
use crate::{
packet::Packet,
rpc::{Request, RequestBody},
diff --git a/src/handler/session.rs b/src/handler/session.rs
index 7a15e3f9f..9c1f96a8f 100644
--- a/src/handler/session.rs
+++ b/src/handler/session.rs
@@ -1,7 +1,9 @@
use super::*;
use crate::{
node_info::NodeContact,
- packet::{ChallengeData, Packet, PacketHeader, PacketKind, MESSAGE_NONCE_LENGTH},
+ packet::{
+ ChallengeData, Packet, PacketHeader, PacketKind, ProtocolIdentity, MESSAGE_NONCE_LENGTH,
+ },
};
use enr::{CombinedKey, NodeId};
use zeroize::Zeroize;
@@ -55,11 +57,11 @@ impl Session {
/// Uses the current `Session` to encrypt a message. Encrypt packets with the current session
/// key if we are awaiting a response from AuthMessage.
- pub(crate) fn encrypt_message(
+ pub(crate) fn encrypt_message(
&mut self,
src_id: NodeId,
message: &[u8],
- ) -> Result {
+ ) -> Result {
self.counter += 1;
// If the message nonce length is ever set below 4 bytes this will explode. The packet
@@ -77,7 +79,7 @@ impl Session {
};
let mut authenticated_data = iv.to_be_bytes().to_vec();
- authenticated_data.extend_from_slice(&header.encode());
+ authenticated_data.extend_from_slice(&header.encode::());
let cipher = crypto::encrypt_message(
&self.keys.encryption_key,
@@ -102,7 +104,7 @@ impl Session {
message_nonce: MessageNonce,
message: &[u8],
aad: &[u8],
- ) -> Result, Discv5Error> {
+ ) -> Result, Error> {
// First try with the canonical keys.
let result_canon =
crypto::decrypt_message(&self.keys.decryption_key, message_nonce, message, aad);
@@ -138,7 +140,7 @@ impl Session {
id_nonce_sig: &[u8],
ephem_pubkey: &[u8],
enr_record: Option,
- ) -> Result<(Session, Enr), Discv5Error> {
+ ) -> Result<(Session, Enr), Error> {
// check and verify a potential ENR update
// Duplicate code here to avoid cloning an ENR
@@ -158,7 +160,7 @@ impl Session {
"Peer did not respond with their ENR. Session could not be established. Node: {}",
remote_id
);
- return Err(Discv5Error::SessionNotEstablished);
+ return Err(Error::SessionNotEstablished);
}
};
enr.public_key()
@@ -172,7 +174,7 @@ impl Session {
local_id,
id_nonce_sig,
) {
- return Err(Discv5Error::InvalidChallengeSignature(challenge));
+ return Err(Error::InvalidChallengeSignature(challenge));
}
// The keys are derived after the message has been verified to prevent potential extra work
@@ -211,14 +213,14 @@ impl Session {
}
/// Encrypts a message and produces an AuthMessage.
- pub(crate) fn encrypt_with_header(
+ pub(crate) fn encrypt_with_header(
remote_contact: &NodeContact,
local_key: Arc>,
updated_enr: Option,
local_node_id: &NodeId,
challenge_data: &ChallengeData,
message: &[u8],
- ) -> Result<(Packet, Session), Discv5Error> {
+ ) -> Result<(Packet, Session), Error> {
// generate the session keys
let (encryption_key, decryption_key, ephem_pubkey) =
crypto::generate_session_keys(local_node_id, remote_contact, challenge_data)?;
@@ -235,7 +237,7 @@ impl Session {
&ephem_pubkey,
&remote_contact.node_id(),
)
- .map_err(|_| Discv5Error::Custom("Could not sign WHOAREYOU nonce"))?;
+ .map_err(|_| Error::Custom("Could not sign WHOAREYOU nonce"))?;
// build an authentication packet
let message_nonce: MessageNonce = rand::random();
@@ -250,7 +252,7 @@ impl Session {
// Create the authenticated data for the new packet.
let mut authenticated_data = packet.iv.to_be_bytes().to_vec();
- authenticated_data.extend_from_slice(&packet.header.encode());
+ authenticated_data.extend_from_slice(&packet.header.encode::());
// encrypt the message
let message_ciphertext =
@@ -263,3 +265,11 @@ impl Session {
Ok((packet, session))
}
}
+
+#[cfg(test)]
+pub(crate) fn build_dummy_session() -> Session {
+ Session::new(Keys {
+ encryption_key: [0; 16],
+ decryption_key: [0; 16],
+ })
+}
diff --git a/src/handler/tests.rs b/src/handler/tests.rs
index 5ed889984..a2bc95659 100644
--- a/src/handler/tests.rs
+++ b/src/handler/tests.rs
@@ -1,12 +1,25 @@
#![cfg(test)]
+
use super::*;
use crate::{
+ packet::DefaultProtocolId,
+ return_if_ipv6_is_not_supported,
rpc::{Request, Response},
- Discv5ConfigBuilder,
+ ConfigBuilder, IpMode,
+};
+use std::{
+ collections::HashSet,
+ convert::TryInto,
+ net::{Ipv4Addr, Ipv6Addr},
+ num::NonZeroU16,
+ ops::Add,
};
+use crate::{
+ handler::{session::build_dummy_session, HandlerOut::RequestFailed},
+ RequestError::SelfRequest,
+};
use active_requests::ActiveRequests;
-use enr::EnrBuilder;
use std::time::Duration;
use tokio::time::sleep;
@@ -16,6 +29,69 @@ fn init() {
.try_init();
}
+async fn build_handler(
+ enr: Enr,
+ key: CombinedKey,
+ config: Config,
+) -> (
+ oneshot::Sender<()>,
+ mpsc::UnboundedSender,
+ mpsc::Receiver,
+ Handler,
+) {
+ let mut listen_sockets = SmallVec::default();
+ listen_sockets.push((Ipv4Addr::LOCALHOST, 9000).into());
+ let node_id = enr.node_id();
+ let filter_expected_responses = Arc::new(RwLock::new(HashMap::new()));
+
+ let socket = {
+ let socket_config = {
+ let filter_config = FilterConfig {
+ enabled: config.enable_packet_filter,
+ rate_limiter: config.filter_rate_limiter.clone(),
+ max_nodes_per_ip: config.filter_max_nodes_per_ip,
+ max_bans_per_ip: config.filter_max_bans_per_ip,
+ };
+
+ socket::SocketConfig {
+ executor: config.executor.clone().expect("Executor must exist"),
+ filter_config,
+ listen_config: config.listen_config.clone(),
+ local_node_id: node_id,
+ expected_responses: filter_expected_responses.clone(),
+ ban_duration: config.ban_duration,
+ }
+ };
+
+ Socket::new::(socket_config).await.unwrap()
+ };
+ let (handler_send, service_recv) = mpsc::unbounded_channel();
+ let (service_send, handler_recv) = mpsc::channel(50);
+ let (exit_sender, exit) = oneshot::channel();
+
+ let handler = Handler {
+ request_retries: config.request_retries,
+ node_id,
+ enr: Arc::new(RwLock::new(enr)),
+ key: Arc::new(RwLock::new(key)),
+ active_requests: ActiveRequests::new(config.request_timeout),
+ pending_requests: HashMap::new(),
+ filter_expected_responses,
+ sessions: LruTimeCache::new(config.session_timeout, Some(config.session_cache_capacity)),
+ one_time_sessions: LruTimeCache::new(
+ Duration::from_secs(ONE_TIME_SESSION_TIMEOUT),
+ Some(ONE_TIME_SESSION_CACHE_CAPACITY),
+ ),
+ active_challenges: HashMapDelay::new(config.request_timeout),
+ service_recv,
+ service_send,
+ listen_sockets,
+ socket,
+ exit,
+ };
+ (exit_sender, handler_send, handler_recv, handler)
+}
+
macro_rules! arc_rw {
( $x: expr ) => {
Arc::new(RwLock::new($x))
@@ -34,33 +110,43 @@ async fn simple_session_message() {
let key1 = CombinedKey::generate_secp256k1();
let key2 = CombinedKey::generate_secp256k1();
- let config = Discv5ConfigBuilder::new().enable_packet_filter().build();
-
- let sender_enr = EnrBuilder::new("v4")
+ let sender_enr = Enr::builder()
.ip4(ip)
.udp4(sender_port)
.build(&key1)
.unwrap();
- let receiver_enr = EnrBuilder::new("v4")
+ let receiver_enr = Enr::builder()
.ip4(ip)
.udp4(receiver_port)
.build(&key2)
.unwrap();
- let (_exit_send, sender_send, _sender_recv) = Handler::spawn(
+ let sender_listen_config = ListenConfig::Ipv4 {
+ ip: sender_enr.ip4().unwrap(),
+ port: sender_enr.udp4().unwrap(),
+ };
+ let sender_config = ConfigBuilder::new(sender_listen_config)
+ .enable_packet_filter()
+ .build();
+ let (_exit_send, sender_send, _sender_recv) = Handler::spawn::(
arc_rw!(sender_enr.clone()),
arc_rw!(key1),
- sender_enr.udp4_socket().unwrap().into(),
- config.clone(),
+ sender_config,
)
.await
.unwrap();
- let (_exit_recv, recv_send, mut receiver_recv) = Handler::spawn(
+ let receiver_listen_config = ListenConfig::Ipv4 {
+ ip: receiver_enr.ip4().unwrap(),
+ port: receiver_enr.udp4().unwrap(),
+ };
+ let receiver_config = ConfigBuilder::new(receiver_listen_config)
+ .enable_packet_filter()
+ .build();
+ let (_exit_recv, recv_send, mut receiver_recv) = Handler::spawn::(
arc_rw!(receiver_enr.clone()),
arc_rw!(key2),
- receiver_enr.udp4_socket().unwrap().into(),
- config,
+ receiver_config,
)
.await
.unwrap();
@@ -111,35 +197,55 @@ async fn multiple_messages() {
let key1 = CombinedKey::generate_secp256k1();
let key2 = CombinedKey::generate_secp256k1();
- let config = Discv5ConfigBuilder::new().build();
- let sender_enr = EnrBuilder::new("v4")
+ let sender_enr = Enr::builder()
.ip4(ip)
.udp4(sender_port)
.build(&key1)
.unwrap();
- let receiver_enr = EnrBuilder::new("v4")
+
+ let receiver_enr = Enr::builder()
.ip4(ip)
.udp4(receiver_port)
.build(&key2)
.unwrap();
- let (_exit_send, sender_handler, mut sender_handler_recv) = Handler::spawn(
- arc_rw!(sender_enr.clone()),
- arc_rw!(key1),
- sender_enr.udp4_socket().unwrap().into(),
- config.clone(),
- )
- .await
- .unwrap();
+ // Build sender handler
+ let (sender_exit, sender_send, mut sender_recv, mut handler) = {
+ let sender_listen_config = ListenConfig::Ipv4 {
+ ip: sender_enr.ip4().unwrap(),
+ port: sender_enr.udp4().unwrap(),
+ };
+ let sender_config = ConfigBuilder::new(sender_listen_config).build();
+ build_handler::(sender_enr.clone(), key1, sender_config).await
+ };
+ let sender = async move {
+ // Start sender handler.
+ handler.start::().await;
+ // After the handler has been terminated test the handler's states.
+ assert!(handler.pending_requests.is_empty());
+ assert_eq!(0, handler.active_requests.count().await);
+ assert!(handler.active_challenges.is_empty());
+ assert!(handler.filter_expected_responses.read().is_empty());
+ };
- let (_exit_recv, recv_send, mut receiver_handler) = Handler::spawn(
- arc_rw!(receiver_enr.clone()),
- arc_rw!(key2),
- receiver_enr.udp4_socket().unwrap().into(),
- config,
- )
- .await
- .unwrap();
+ // Build receiver handler
+ let (receiver_exit, receiver_send, mut receiver_recv, mut handler) = {
+ let receiver_listen_config = ListenConfig::Ipv4 {
+ ip: receiver_enr.ip4().unwrap(),
+ port: receiver_enr.udp4().unwrap(),
+ };
+ let receiver_config = ConfigBuilder::new(receiver_listen_config).build();
+ build_handler::(receiver_enr.clone(), key2, receiver_config).await
+ };
+ let receiver = async move {
+ // Start receiver handler.
+ handler.start::().await;
+ // After the handler has been terminated test the handler's states.
+ assert!(handler.pending_requests.is_empty());
+ assert_eq!(0, handler.active_requests.count().await);
+ assert!(handler.active_challenges.is_empty());
+ assert!(handler.filter_expected_responses.read().is_empty());
+ };
let send_message = Box::new(Request {
id: RequestId(vec![1]),
@@ -147,7 +253,7 @@ async fn multiple_messages() {
});
// sender to send the first message then await for the session to be established
- let _ = sender_handler.send(HandlerIn::Request(
+ let _ = sender_send.send(HandlerIn::Request(
receiver_enr.clone().into(),
send_message.clone(),
));
@@ -157,7 +263,7 @@ async fn multiple_messages() {
body: ResponseBody::Pong {
enr_seq: 1,
ip: ip.into(),
- port: sender_port,
+ port: sender_port.try_into().unwrap(),
},
};
@@ -166,35 +272,46 @@ async fn multiple_messages() {
let mut message_count = 0usize;
let recv_send_message = send_message.clone();
- let sender = async move {
+ let sender_ops = async move {
+ let mut response_count = 0usize;
loop {
- match sender_handler_recv.recv().await {
+ match sender_recv.recv().await {
Some(HandlerOut::Established(_, _, _)) => {
// now the session is established, send the rest of the messages
for _ in 0..messages_to_send - 1 {
- let _ = sender_handler.send(HandlerIn::Request(
+ let _ = sender_send.send(HandlerIn::Request(
receiver_enr.clone().into(),
send_message.clone(),
));
}
}
+ Some(HandlerOut::Response(_, _)) => {
+ response_count += 1;
+ if response_count == messages_to_send {
+ // Notify the handlers that the message exchange has been completed.
+ sender_exit.send(()).unwrap();
+ receiver_exit.send(()).unwrap();
+ return;
+ }
+ }
_ => continue,
};
}
};
- let receiver = async move {
+ let receiver_ops = async move {
loop {
- match receiver_handler.recv().await {
+ match receiver_recv.recv().await {
Some(HandlerOut::WhoAreYou(wru_ref)) => {
- let _ = recv_send.send(HandlerIn::WhoAreYou(wru_ref, Some(sender_enr.clone())));
+ let _ =
+ receiver_send.send(HandlerIn::WhoAreYou(wru_ref, Some(sender_enr.clone())));
}
Some(HandlerOut::Request(addr, request)) => {
assert_eq!(request, recv_send_message);
message_count += 1;
// required to send a pong response to establish the session
- let _ =
- recv_send.send(HandlerIn::Response(addr, Box::new(pong_response.clone())));
+ let _ = receiver_send
+ .send(HandlerIn::Response(addr, Box::new(pong_response.clone())));
if message_count == messages_to_send {
return;
}
@@ -207,47 +324,716 @@ async fn multiple_messages() {
};
let sleep_future = sleep(Duration::from_millis(100));
+ let message_exchange = async move {
+ let _ = tokio::join!(sender, sender_ops, receiver, receiver_ops);
+ };
tokio::select! {
- _ = sender => {}
- _ = receiver => {}
+ _ = message_exchange => {}
_ = sleep_future => {
panic!("Test timed out");
}
}
}
+fn create_node() -> Enr {
+ let key = CombinedKey::generate_secp256k1();
+ let ip = "127.0.0.1".parse().unwrap();
+ let port = 8080 + rand::random::() % 1000;
+ Enr::builder().ip4(ip).udp4(port).build(&key).unwrap()
+}
+
+fn create_req_call(node: &Enr) -> (RequestCall, NodeAddress) {
+ let node_contact: NodeContact = node.clone().into();
+ let packet = Packet::new_random(&node.node_id()).unwrap();
+ let id = HandlerReqId::Internal(RequestId::random());
+ let request = RequestBody::Ping { enr_seq: 1 };
+ let initiating_session = true;
+ let node_addr = node_contact.node_address();
+ let req = RequestCall::new(node_contact, packet, id, request, initiating_session);
+ (req, node_addr)
+}
+
#[tokio::test]
async fn test_active_requests_insert() {
const EXPIRY: Duration = Duration::from_secs(5);
let mut active_requests = ActiveRequests::new(EXPIRY);
- // Create the test values needed
- let port = 5000;
- let ip = "127.0.0.1".parse().unwrap();
+ let node_1 = create_node();
+ let node_2 = create_node();
+ let (req_1, req_1_addr) = create_req_call(&node_1);
+ let (req_2, req_2_addr) = create_req_call(&node_2);
+ let (req_3, req_3_addr) = create_req_call(&node_2);
+
+ // insert the pair and verify the mapping remains in sync
+ active_requests.insert(req_1_addr, req_1);
+ active_requests.check_invariant();
+ active_requests.insert(req_2_addr, req_2);
+ active_requests.check_invariant();
+ active_requests.insert(req_3_addr, req_3);
+ active_requests.check_invariant();
+}
+
+#[tokio::test]
+async fn test_active_requests_remove_requests() {
+ const EXPIRY: Duration = Duration::from_secs(5);
+ let mut active_requests = ActiveRequests::new(EXPIRY);
+
+ let node_1 = create_node();
+ let node_2 = create_node();
+ let (req_1, req_1_addr) = create_req_call(&node_1);
+ let (req_2, req_2_addr) = create_req_call(&node_2);
+ let (req_3, req_3_addr) = create_req_call(&node_2);
+ active_requests.insert(req_1_addr.clone(), req_1);
+ active_requests.insert(req_2_addr.clone(), req_2);
+ active_requests.insert(req_3_addr.clone(), req_3);
+ active_requests.check_invariant();
+ let reqs = active_requests.remove_requests(&req_1_addr).unwrap();
+ assert_eq!(reqs.len(), 1);
+ active_requests.check_invariant();
+ let reqs = active_requests.remove_requests(&req_2_addr).unwrap();
+ assert_eq!(reqs.len(), 2);
+ active_requests.check_invariant();
+ assert!(active_requests.remove_requests(&req_3_addr).is_none());
+}
+
+#[tokio::test]
+async fn test_active_requests_remove_request() {
+ const EXPIRY: Duration = Duration::from_secs(5);
+ let mut active_requests = ActiveRequests::new(EXPIRY);
+
+ let node_1 = create_node();
+ let node_2 = create_node();
+ let (req_1, req_1_addr) = create_req_call(&node_1);
+ let (req_2, req_2_addr) = create_req_call(&node_2);
+ let (req_3, req_3_addr) = create_req_call(&node_2);
+ let req_1_id = req_1.id().into();
+ let req_2_id = req_2.id().into();
+ let req_3_id = req_3.id().into();
+
+ active_requests.insert(req_1_addr.clone(), req_1);
+ active_requests.insert(req_2_addr.clone(), req_2);
+ active_requests.insert(req_3_addr.clone(), req_3);
+ active_requests.check_invariant();
+ let req_id: RequestId = active_requests
+ .remove_request(&req_1_addr, &req_1_id)
+ .unwrap()
+ .id()
+ .into();
+ assert_eq!(req_id, req_1_id);
+ active_requests.check_invariant();
+ let req_id: RequestId = active_requests
+ .remove_request(&req_2_addr, &req_2_id)
+ .unwrap()
+ .id()
+ .into();
+ assert_eq!(req_id, req_2_id);
+ active_requests.check_invariant();
+ let req_id: RequestId = active_requests
+ .remove_request(&req_3_addr, &req_3_id)
+ .unwrap()
+ .id()
+ .into();
+ assert_eq!(req_id, req_3_id);
+ active_requests.check_invariant();
+ assert!(active_requests
+ .remove_request(&req_3_addr, &req_3_id)
+ .is_none());
+}
+
+#[tokio::test]
+async fn test_active_requests_remove_by_nonce() {
+ const EXPIRY: Duration = Duration::from_secs(5);
+ let mut active_requests = ActiveRequests::new(EXPIRY);
+
+ let node_1 = create_node();
+ let node_2 = create_node();
+ let (req_1, req_1_addr) = create_req_call(&node_1);
+ let (req_2, req_2_addr) = create_req_call(&node_2);
+ let (req_3, req_3_addr) = create_req_call(&node_2);
+ let req_1_nonce = *req_1.packet().message_nonce();
+ let req_2_nonce = *req_2.packet().message_nonce();
+ let req_3_nonce = *req_3.packet().message_nonce();
+
+ active_requests.insert(req_1_addr.clone(), req_1);
+ active_requests.insert(req_2_addr.clone(), req_2);
+ active_requests.insert(req_3_addr.clone(), req_3);
+ active_requests.check_invariant();
+
+ let req = active_requests.remove_by_nonce(&req_1_nonce).unwrap();
+ assert_eq!(req.0, req_1_addr);
+ active_requests.check_invariant();
+ let req = active_requests.remove_by_nonce(&req_2_nonce).unwrap();
+ assert_eq!(req.0, req_2_addr);
+ active_requests.check_invariant();
+ let req = active_requests.remove_by_nonce(&req_3_nonce).unwrap();
+ assert_eq!(req.0, req_3_addr);
+ active_requests.check_invariant();
+ let random_nonce = rand::random();
+ assert!(active_requests.remove_by_nonce(&random_nonce).is_none());
+}
+
+#[tokio::test]
+async fn test_active_requests_update_packet() {
+ const EXPIRY: Duration = Duration::from_secs(5);
+ let mut active_requests = ActiveRequests::new(EXPIRY);
+
+ let node_1 = create_node();
+ let node_2 = create_node();
+ let (req_1, req_1_addr) = create_req_call(&node_1);
+ let (req_2, req_2_addr) = create_req_call(&node_2);
+ let (req_3, req_3_addr) = create_req_call(&node_2);
+
+ let old_nonce = *req_2.packet().message_nonce();
+ active_requests.insert(req_1_addr, req_1);
+ active_requests.insert(req_2_addr.clone(), req_2);
+ active_requests.insert(req_3_addr, req_3);
+ active_requests.check_invariant();
+
+ let new_packet = Packet::new_random(&node_2.node_id()).unwrap();
+ let new_nonce = new_packet.message_nonce();
+ active_requests.update_packet(old_nonce, new_packet.clone());
+ active_requests.check_invariant();
+
+ assert_eq!(2, active_requests.get(&req_2_addr).unwrap().len());
+ assert!(active_requests.remove_by_nonce(&old_nonce).is_none());
+ let (addr, req) = active_requests.remove_by_nonce(new_nonce).unwrap();
+ assert_eq!(addr, req_2_addr);
+ assert_eq!(req.packet(), &new_packet);
+}
+
+#[tokio::test]
+async fn test_self_request_ipv4() {
+ init();
let key = CombinedKey::generate_secp256k1();
+ let enr = Enr::builder()
+ .ip4(Ipv4Addr::LOCALHOST)
+ .udp4(5004)
+ .build(&key)
+ .unwrap();
+ let listen_config = ListenConfig::Ipv4 {
+ ip: enr.ip4().unwrap(),
+ port: enr.udp4().unwrap(),
+ };
+ let config = ConfigBuilder::new(listen_config)
+ .enable_packet_filter()
+ .build();
- let enr = EnrBuilder::new("v4")
- .ip4(ip)
- .udp4(port)
+ let (_exit_send, send, mut recv) =
+ Handler::spawn::(arc_rw!(enr.clone()), arc_rw!(key), config)
+ .await
+ .unwrap();
+
+ // self request (IPv4)
+ let _ = send.send(HandlerIn::Request(
+ NodeContact::try_from_enr(enr.clone(), IpMode::Ip4).unwrap(),
+ Box::new(Request {
+ id: RequestId(vec![1]),
+ body: RequestBody::Ping { enr_seq: 1 },
+ }),
+ ));
+ let handler_out = recv.recv().await;
+ assert_eq!(
+ Some(RequestFailed(RequestId(vec![1]), SelfRequest)),
+ handler_out
+ );
+}
+
+#[tokio::test]
+async fn test_self_request_ipv6() {
+ return_if_ipv6_is_not_supported!();
+
+ init();
+
+ let key = CombinedKey::generate_secp256k1();
+ let enr = Enr::builder()
+ .ip6(Ipv6Addr::LOCALHOST)
+ .udp6(5005)
.build(&key)
.unwrap();
- let node_id = enr.node_id();
+ let listen_config = ListenConfig::Ipv6 {
+ ip: enr.ip6().unwrap(),
+ port: enr.udp6().unwrap(),
+ };
+ let config = ConfigBuilder::new(listen_config)
+ .enable_packet_filter()
+ .build();
- let contact: NodeContact = enr.into();
- let node_address = contact.node_address();
+ let (_exit_send, send, mut recv) =
+ Handler::spawn::(arc_rw!(enr.clone()), arc_rw!(key), config)
+ .await
+ .unwrap();
- let packet = Packet::new_random(&node_id).unwrap();
- let id = HandlerReqId::Internal(RequestId::random());
- let request = RequestBody::Ping { enr_seq: 1 };
- let initiating_session = true;
- let request_call = RequestCall::new(contact, packet, id, request, initiating_session);
+ // self request (IPv6)
+ let _ = send.send(HandlerIn::Request(
+ NodeContact::try_from_enr(enr, IpMode::Ip6).unwrap(),
+ Box::new(Request {
+ id: RequestId(vec![2]),
+ body: RequestBody::Ping { enr_seq: 1 },
+ }),
+ ));
+ let handler_out = recv.recv().await;
+ assert_eq!(
+ Some(RequestFailed(RequestId(vec![2]), SelfRequest)),
+ handler_out
+ );
+}
- // insert the pair and verify the mapping remains in sync
- let nonce = *request_call.packet().message_nonce();
- active_requests.insert(node_address, request_call);
- active_requests.check_invariant();
- active_requests.remove_by_nonce(&nonce);
- active_requests.check_invariant();
+#[tokio::test]
+async fn remove_one_time_session() {
+ let config = ConfigBuilder::new(ListenConfig::default()).build();
+ let key = CombinedKey::generate_secp256k1();
+ let enr = Enr::builder()
+ .ip4(Ipv4Addr::LOCALHOST)
+ .udp4(9000)
+ .build(&key)
+ .unwrap();
+ let (_, _, _, mut handler) = build_handler::(enr, key, config).await;
+
+ let enr = {
+ let key = CombinedKey::generate_secp256k1();
+ Enr::builder()
+ .ip4(Ipv4Addr::LOCALHOST)
+ .udp4(9000)
+ .build(&key)
+ .unwrap()
+ };
+ let node_address = NodeAddress::new("127.0.0.1:9000".parse().unwrap(), enr.node_id());
+ let request_id = RequestId::random();
+ let session = build_dummy_session();
+ handler
+ .one_time_sessions
+ .insert(node_address.clone(), (request_id.clone(), session));
+
+ let other_request_id = RequestId::random();
+ assert!(handler
+ .remove_one_time_session(&node_address, &other_request_id)
+ .is_none());
+ assert_eq!(1, handler.one_time_sessions.len());
+
+ let other_node_address = NodeAddress::new("127.0.0.1:9001".parse().unwrap(), enr.node_id());
+ assert!(handler
+ .remove_one_time_session(&other_node_address, &request_id)
+ .is_none());
+ assert_eq!(1, handler.one_time_sessions.len());
+
+ assert!(handler
+ .remove_one_time_session(&node_address, &request_id)
+ .is_some());
+ assert_eq!(0, handler.one_time_sessions.len());
+}
+
+// Tests replaying active requests.
+//
+// In this test, Receiver's session expires and Receiver returns WHOAREYOU.
+// Sender then creates a new session and resend active requests.
+//
+// ```mermaid
+// sequenceDiagram
+// participant Sender
+// participant Receiver
+// Note over Sender: Start discv5 server
+// Note over Receiver: Start discv5 server
+//
+// Note over Sender,Receiver: Session established
+//
+// rect rgb(100, 100, 0)
+// Note over Receiver: ** Session expired **
+// end
+//
+// rect rgb(10, 10, 10)
+// Note left of Sender: Sender sends requests **in parallel**.
+// par
+// Sender ->> Receiver: PING(id:2)
+// and
+// Sender -->> Receiver: PING(id:3)
+// and
+// Sender -->> Receiver: PING(id:4)
+// and
+// Sender -->> Receiver: PING(id:5)
+// end
+// end
+//
+// Note over Receiver: Send WHOAREYOU since the session has been expired
+// Receiver ->> Sender: WHOAREYOU
+//
+// rect rgb(100, 100, 0)
+// Note over Receiver: Drop PING(id:2,3,4,5) request since WHOAREYOU already sent.
+// end
+//
+// Note over Sender: New session established with Receiver
+//
+// Sender ->> Receiver: Handshake message (id:2)
+//
+// Note over Receiver: New session established with Sender
+//
+// rect rgb(10, 10, 10)
+// Note left of Sender: Handler::replay_active_requests()
+// Sender ->> Receiver: PING (id:3)
+// Sender ->> Receiver: PING (id:4)
+// Sender ->> Receiver: PING (id:5)
+// end
+//
+// Receiver ->> Sender: PONG (id:2)
+// Receiver ->> Sender: PONG (id:3)
+// Receiver ->> Sender: PONG (id:4)
+// Receiver ->> Sender: PONG (id:5)
+// ```
+#[tokio::test]
+async fn test_replay_active_requests() {
+ init();
+ let sender_port = 5006;
+ let receiver_port = 5007;
+ let ip = "127.0.0.1".parse().unwrap();
+ let key1 = CombinedKey::generate_secp256k1();
+ let key2 = CombinedKey::generate_secp256k1();
+
+ let sender_enr = Enr::builder()
+ .ip4(ip)
+ .udp4(sender_port)
+ .build(&key1)
+ .unwrap();
+
+ let receiver_enr = Enr::builder()
+ .ip4(ip)
+ .udp4(receiver_port)
+ .build(&key2)
+ .unwrap();
+
+ // Build sender handler
+ let (sender_exit, sender_send, mut sender_recv, mut handler) = {
+ let sender_listen_config = ListenConfig::Ipv4 {
+ ip: sender_enr.ip4().unwrap(),
+ port: sender_enr.udp4().unwrap(),
+ };
+ let sender_config = ConfigBuilder::new(sender_listen_config).build();
+ build_handler::(sender_enr.clone(), key1, sender_config).await
+ };
+ let sender = async move {
+ // Start sender handler.
+ handler.start::().await;
+ // After the handler has been terminated test the handler's states.
+ assert!(handler.pending_requests.is_empty());
+ assert_eq!(0, handler.active_requests.count().await);
+ assert!(handler.active_challenges.is_empty());
+ assert!(handler.filter_expected_responses.read().is_empty());
+ };
+
+ // Build receiver handler
+ // Shorten receiver's timeout to reproduce session expired.
+ let receiver_session_timeout = Duration::from_secs(1);
+ let (receiver_exit, receiver_send, mut receiver_recv, mut handler) = {
+ let receiver_listen_config = ListenConfig::Ipv4 {
+ ip: receiver_enr.ip4().unwrap(),
+ port: receiver_enr.udp4().unwrap(),
+ };
+ let receiver_config = ConfigBuilder::new(receiver_listen_config)
+ .session_timeout(receiver_session_timeout)
+ .build();
+ build_handler::(receiver_enr.clone(), key2, receiver_config).await
+ };
+ let receiver = async move {
+ // Start receiver handler.
+ handler.start::().await;
+ // After the handler has been terminated test the handler's states.
+ assert!(handler.pending_requests.is_empty());
+ assert_eq!(0, handler.active_requests.count().await);
+ assert!(handler.active_challenges.is_empty());
+ assert!(handler.filter_expected_responses.read().is_empty());
+ };
+
+ let messages_to_send = 5usize;
+
+ let sender_ops = async move {
+ let mut response_count = 0usize;
+ let mut expected_request_ids = HashSet::new();
+ expected_request_ids.insert(RequestId(vec![1]));
+
+ // sender to send the first message then await for the session to be established
+ let _ = sender_send.send(HandlerIn::Request(
+ receiver_enr.clone().into(),
+ Box::new(Request {
+ id: RequestId(vec![1]),
+ body: RequestBody::Ping { enr_seq: 1 },
+ }),
+ ));
+
+ match sender_recv.recv().await {
+ Some(HandlerOut::Established(_, _, _)) => {
+ // Sleep until receiver's session expired.
+ tokio::time::sleep(receiver_session_timeout.add(Duration::from_millis(500))).await;
+ // send the rest of the messages
+ for req_id in 2..=messages_to_send {
+ let request_id = RequestId(vec![req_id as u8]);
+ expected_request_ids.insert(request_id.clone());
+ let _ = sender_send.send(HandlerIn::Request(
+ receiver_enr.clone().into(),
+ Box::new(Request {
+ id: request_id,
+ body: RequestBody::Ping { enr_seq: 1 },
+ }),
+ ));
+ }
+ }
+ handler_out => panic!("Unexpected message: {:?}", handler_out),
+ }
+
+ loop {
+ match sender_recv.recv().await {
+ Some(HandlerOut::Response(_, response)) => {
+ assert!(expected_request_ids.remove(&response.id));
+ response_count += 1;
+ if response_count == messages_to_send {
+ // Notify the handlers that the message exchange has been completed.
+ assert!(expected_request_ids.is_empty());
+ sender_exit.send(()).unwrap();
+ receiver_exit.send(()).unwrap();
+ return;
+ }
+ }
+ _ => continue,
+ };
+ }
+ };
+
+ let receiver_ops = async move {
+ let mut message_count = 0usize;
+ loop {
+ match receiver_recv.recv().await {
+ Some(HandlerOut::WhoAreYou(wru_ref)) => {
+ receiver_send
+ .send(HandlerIn::WhoAreYou(wru_ref, Some(sender_enr.clone())))
+ .unwrap();
+ }
+ Some(HandlerOut::Request(addr, request)) => {
+ assert!(matches!(request.body, RequestBody::Ping { .. }));
+ let pong_response = Response {
+ id: request.id,
+ body: ResponseBody::Pong {
+ enr_seq: 1,
+ ip: ip.into(),
+ port: NonZeroU16::new(sender_port).unwrap(),
+ },
+ };
+ receiver_send
+ .send(HandlerIn::Response(addr, Box::new(pong_response)))
+ .unwrap();
+ message_count += 1;
+ if message_count == messages_to_send {
+ return;
+ }
+ }
+ _ => {
+ continue;
+ }
+ }
+ }
+ };
+
+ let sleep_future = sleep(Duration::from_secs(5));
+ let message_exchange = async move {
+ let _ = tokio::join!(sender, sender_ops, receiver, receiver_ops);
+ };
+
+ tokio::select! {
+ _ = message_exchange => {}
+ _ = sleep_future => {
+ panic!("Test timed out");
+ }
+ }
+}
+
+// Tests sending pending requests.
+//
+// Sender attempts to send multiple requests in parallel, but due to the absence of a session, only
+// one of the requests from Sender is sent and others are inserted into `pending_requests`.
+// The pending requests are sent once a session is established.
+//
+// ```mermaid
+// sequenceDiagram
+// participant Sender
+// participant Receiver
+//
+// Note over Sender: No session with Receiver
+//
+// rect rgb(10, 10, 10)
+// Note left of Sender: Sender attempts to send multiple requests in parallel but no session with Receiver. So Sender sends a random packet for the first request, and the rest of the requests are inserted into pending_requests.
+// par
+// Sender ->> Receiver: Random packet (id:1)
+// Note over Sender: Insert the request into `active_requests`
+// and
+// Note over Sender: Insert Request(id:2) into *pending_requests*
+// and
+// Note over Sender: Insert Request(id:3) into *pending_requests*
+// end
+// end
+//
+// Receiver ->> Sender: WHOAREYOU (id:1)
+//
+// Note over Sender: New session established with Receiver
+//
+// rect rgb(0, 100, 0)
+// Note over Sender: Send pending requests since a session has been established.
+// Sender ->> Receiver: Request (id:2)
+// Sender ->> Receiver: Request (id:3)
+// end
+//
+// Sender ->> Receiver: Handshake message (id:1)
+//
+// Note over Receiver: New session established with Sender
+//
+// Receiver ->> Sender: Response (id:2)
+// Receiver ->> Sender: Response (id:3)
+// Receiver ->> Sender: Response (id:1)
+//
+// Note over Sender: The request (id:2) completed.
+// Note over Sender: The request (id:3) completed.
+// Note over Sender: The request (id:1) completed.
+// ```
+#[tokio::test]
+async fn test_send_pending_request() {
+ init();
+ let sender_port = 5008;
+ let receiver_port = 5009;
+ let ip = "127.0.0.1".parse().unwrap();
+ let key1 = CombinedKey::generate_secp256k1();
+ let key2 = CombinedKey::generate_secp256k1();
+
+ let sender_enr = Enr::builder()
+ .ip4(ip)
+ .udp4(sender_port)
+ .build(&key1)
+ .unwrap();
+
+ let receiver_enr = Enr::builder()
+ .ip4(ip)
+ .udp4(receiver_port)
+ .build(&key2)
+ .unwrap();
+
+ // Build sender handler
+ let (sender_exit, sender_send, mut sender_recv, mut handler) = {
+ let sender_listen_config = ListenConfig::Ipv4 {
+ ip: sender_enr.ip4().unwrap(),
+ port: sender_enr.udp4().unwrap(),
+ };
+ let sender_config = ConfigBuilder::new(sender_listen_config).build();
+ build_handler::(sender_enr.clone(), key1, sender_config).await
+ };
+ let sender = async move {
+ // Start sender handler.
+ handler.start::().await;
+ // After the handler has been terminated test the handler's states.
+ assert!(handler.pending_requests.is_empty());
+ assert_eq!(0, handler.active_requests.count().await);
+ assert!(handler.active_challenges.is_empty());
+ assert!(handler.filter_expected_responses.read().is_empty());
+ };
+
+ // Build receiver handler
+ // Shorten receiver's timeout to reproduce session expired.
+ let receiver_session_timeout = Duration::from_secs(1);
+ let (receiver_exit, receiver_send, mut receiver_recv, mut handler) = {
+ let receiver_listen_config = ListenConfig::Ipv4 {
+ ip: receiver_enr.ip4().unwrap(),
+ port: receiver_enr.udp4().unwrap(),
+ };
+ let receiver_config = ConfigBuilder::new(receiver_listen_config)
+ .session_timeout(receiver_session_timeout)
+ .build();
+ build_handler::(receiver_enr.clone(), key2, receiver_config).await
+ };
+ let receiver = async move {
+ // Start receiver handler.
+ handler.start::().await;
+ // After the handler has been terminated test the handler's states.
+ assert!(handler.pending_requests.is_empty());
+ assert_eq!(0, handler.active_requests.count().await);
+ assert!(handler.active_challenges.is_empty());
+ assert!(handler.filter_expected_responses.read().is_empty());
+ };
+
+ let messages_to_send = 3usize;
+
+ let sender_ops = async move {
+ let mut response_count = 0usize;
+ let mut expected_request_ids = HashSet::new();
+
+ // send requests
+ for req_id in 1..=messages_to_send {
+ let request_id = RequestId(vec![req_id as u8]);
+ expected_request_ids.insert(request_id.clone());
+ let _ = sender_send.send(HandlerIn::Request(
+ receiver_enr.clone().into(),
+ Box::new(Request {
+ id: request_id,
+ body: RequestBody::Ping { enr_seq: 1 },
+ }),
+ ));
+ }
+
+ loop {
+ match sender_recv.recv().await {
+ Some(HandlerOut::Response(_, response)) => {
+ assert!(expected_request_ids.remove(&response.id));
+ response_count += 1;
+ if response_count == messages_to_send {
+ // Notify the handlers that the message exchange has been completed.
+ assert!(expected_request_ids.is_empty());
+ sender_exit.send(()).unwrap();
+ receiver_exit.send(()).unwrap();
+ return;
+ }
+ }
+ _ => continue,
+ };
+ }
+ };
+
+ let receiver_ops = async move {
+ let mut message_count = 0usize;
+ loop {
+ match receiver_recv.recv().await {
+ Some(HandlerOut::WhoAreYou(wru_ref)) => {
+ receiver_send
+ .send(HandlerIn::WhoAreYou(wru_ref, Some(sender_enr.clone())))
+ .unwrap();
+ }
+ Some(HandlerOut::Request(addr, request)) => {
+ assert!(matches!(request.body, RequestBody::Ping { .. }));
+ let pong_response = Response {
+ id: request.id,
+ body: ResponseBody::Pong {
+ enr_seq: 1,
+ ip: ip.into(),
+ port: NonZeroU16::new(sender_port).unwrap(),
+ },
+ };
+ receiver_send
+ .send(HandlerIn::Response(addr, Box::new(pong_response)))
+ .unwrap();
+ message_count += 1;
+ if message_count == messages_to_send {
+ return;
+ }
+ }
+ _ => {
+ continue;
+ }
+ }
+ }
+ };
+
+ let sleep_future = sleep(Duration::from_secs(5));
+ let message_exchange = async move {
+ let _ = tokio::join!(sender, sender_ops, receiver, receiver_ops);
+ };
+
+ tokio::select! {
+ _ = message_exchange => {}
+ _ = sleep_future => {
+ panic!("Test timed out");
+ }
+ }
}
diff --git a/src/ipmode.rs b/src/ipmode.rs
index 25f51757a..bc292d7f9 100644
--- a/src/ipmode.rs
+++ b/src/ipmode.rs
@@ -1,64 +1,85 @@
-use crate::Enr;
+//! A set of configuration parameters to tune the discovery protocol.
+use crate::{
+ socket::ListenConfig,
+ Enr,
+ IpMode::{DualStack, Ip4, Ip6},
+};
use std::net::SocketAddr;
-///! A set of configuration parameters to tune the discovery protocol.
/// Sets the socket type to be established and also determines the type of ENRs that we will store
/// in our routing table.
/// We store ENR's that have a `get_contractable_addr()` based on the `IpMode` set.
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IpMode {
/// IPv4 only. This creates an IPv4 only UDP socket and will only store ENRs in the local
/// routing table if they contain a contactable IPv4 address.
+ #[default]
Ip4,
- /// This enables IPv6 support. This creates an IPv6 socket. If `enable_mapped_addresses` is set
- /// to true, this creates a dual-stack socket capable of sending/receiving IPv4 and IPv6
- /// packets. If `enabled_mapped_addresses` is set to false, this is equivalent to running an
- /// IPv6-only node.
- Ip6 { enable_mapped_addresses: bool },
+ /// IPv6 only. This creates an IPv6 only UDP socket and will only store ENRs in the local
+ /// routing table if they contain a contactable IPv6 address. Mapped addresses will be
+ /// disabled.
+ Ip6,
+ /// Two UDP sockets are in use. One for Ipv4 and one for Ipv6.
+ DualStack,
}
-impl Default for IpMode {
- fn default() -> Self {
- IpMode::Ip4
+impl IpMode {
+ pub(crate) fn new_from_listen_config(listen_config: &ListenConfig) -> Self {
+ match listen_config {
+ ListenConfig::Ipv4 { .. } => Ip4,
+ ListenConfig::Ipv6 { .. } => Ip6,
+ ListenConfig::DualStack { .. } => DualStack,
+ }
}
-}
-impl IpMode {
pub fn is_ipv4(&self) -> bool {
- self == &IpMode::Ip4
+ self == &Ip4
}
- /// Get the contactable Socket address of an Enr under current configuration.
+ /// Get the contactable Socket address of an Enr under current configuration. When running in
+ /// dual stack, an Enr that advertises both an Ipv4 and a canonical Ipv6 address will be
+ /// contacted using their Ipv6 address.
pub fn get_contactable_addr(&self, enr: &Enr) -> Option {
- match self {
- IpMode::Ip4 => enr.udp4_socket().map(SocketAddr::V4),
- IpMode::Ip6 {
- enable_mapped_addresses,
- } => {
- // NOTE: general consensus is that ipv6 addresses should be preferred.
- let maybe_ipv6_addr = enr.udp6_socket().and_then(|socket_addr| {
- // NOTE: There is nothing in the spec preventing compat/mapped addresses from being
- // transmitted in the ENR. Here we choose to enforce canonical addresses since
- // it simplifies the logic of matching socket_addr verification. For this we prevent
- // communications with Ipv4 addresses advertized in the Ipv6 field.
- if to_ipv4_mapped(socket_addr.ip()).is_some() {
- None
- } else {
- Some(SocketAddr::V6(socket_addr))
- }
- });
- if *enable_mapped_addresses {
- // If mapped addresses are enabled we can use the Ipv4 address of the node in
- // case it doesn't have an ipv6 one
- maybe_ipv6_addr.or_else(|| enr.udp4_socket().map(SocketAddr::V4))
+ // A function to get a canonical ipv6 address from an Enr
+
+ /// NOTE: There is nothing in the spec preventing compat/mapped addresses from being
+ /// transmitted in the ENR. Here we choose to enforce canonical addresses since
+ /// it simplifies the logic of matching socket_addr verification. For this we prevent
+ /// communications with Ipv4 addresses advertised in the Ipv6 field.
+ fn canonical_ipv6_enr_addr(enr: &Enr) -> Option {
+ enr.udp6_socket().and_then(|socket_addr| {
+ if to_ipv4_mapped(socket_addr.ip()).is_some() {
+ None
} else {
- maybe_ipv6_addr
+ Some(socket_addr)
}
+ })
+ }
+
+ match self {
+ Ip4 => enr.udp4_socket().map(SocketAddr::V4),
+ Ip6 => canonical_ipv6_enr_addr(enr).map(SocketAddr::V6),
+ DualStack => {
+ canonical_ipv6_enr_addr(enr)
+ .map(SocketAddr::V6)
+ // NOTE: general consensus is that ipv6 addresses should be preferred.
+ .or_else(|| enr.udp4_socket().map(SocketAddr::V4))
}
}
}
}
+/// Copied from the standard library. See
+/// The current code is behind the `ip` feature.
+pub const fn to_ipv4_mapped(ip: &std::net::Ipv6Addr) -> Option {
+ match ip.octets() {
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, a, b, c, d] => {
+ Some(std::net::Ipv4Addr::new(a, b, c, d))
+ }
+ _ => None,
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -82,7 +103,7 @@ mod tests {
name,
enr_ip4: None,
enr_ip6: None,
- ip_mode: IpMode::Ip4,
+ ip_mode: Ip4,
expected_socket_addr: None,
}
}
@@ -115,7 +136,7 @@ mod tests {
fn test(&self) {
let test_enr = {
- let builder = &mut enr::EnrBuilder::new("v4");
+ let builder = &mut enr::Enr::builder();
if let Some(ip4) = self.enr_ip4 {
builder.ip4(ip4).udp4(IP4_TEST_PORT);
}
@@ -141,19 +162,15 @@ mod tests {
fn empty_enr_no_contactable_address() {
// Empty ENR
TestCase::new("Empty enr is non contactable by ip4 node")
- .ip_mode(IpMode::Ip4)
+ .ip_mode(Ip4)
.test();
TestCase::new("Empty enr is not contactable by ip6 only node")
- .ip_mode(IpMode::Ip6 {
- enable_mapped_addresses: false,
- })
+ .ip_mode(Ip6)
.test();
TestCase::new("Empty enr is not contactable by dual stack node")
- .ip_mode(IpMode::Ip6 {
- enable_mapped_addresses: true,
- })
+ .ip_mode(DualStack)
.test();
}
@@ -162,22 +179,18 @@ mod tests {
// Ip4 only ENR
TestCase::new("Ipv4 only enr is contactable by ip4 node")
.enr_ip4(Ipv4Addr::LOCALHOST)
- .ip_mode(IpMode::Ip4)
+ .ip_mode(Ip4)
.expect_ip4(Ipv4Addr::LOCALHOST)
.test();
TestCase::new("Ipv4 only enr is not contactable by ip6 only node")
.enr_ip4(Ipv4Addr::LOCALHOST)
- .ip_mode(IpMode::Ip6 {
- enable_mapped_addresses: false,
- })
+ .ip_mode(Ip6)
.test();
TestCase::new("Ipv4 only enr is contactable by dual stack node")
.enr_ip4(Ipv4Addr::LOCALHOST)
- .ip_mode(IpMode::Ip6 {
- enable_mapped_addresses: true,
- })
+ .ip_mode(DualStack)
.expect_ip4(Ipv4Addr::LOCALHOST)
.test();
}
@@ -187,22 +200,18 @@ mod tests {
// Ip4 only ENR
TestCase::new("Ipv6 only enr is not contactable by ip4 node")
.enr_ip6(Ipv6Addr::LOCALHOST)
- .ip_mode(IpMode::Ip4)
+ .ip_mode(Ip4)
.test();
TestCase::new("Ipv6 only enr is contactable by ip6 only node")
.enr_ip6(Ipv6Addr::LOCALHOST)
- .ip_mode(IpMode::Ip6 {
- enable_mapped_addresses: false,
- })
+ .ip_mode(Ip6)
.expect_ip6(Ipv6Addr::LOCALHOST)
.test();
TestCase::new("Ipv6 only enr is contactable by dual stack node")
.enr_ip6(Ipv6Addr::LOCALHOST)
- .ip_mode(IpMode::Ip6 {
- enable_mapped_addresses: true,
- })
+ .ip_mode(DualStack)
.expect_ip6(Ipv6Addr::LOCALHOST)
.test();
}
@@ -213,37 +222,22 @@ mod tests {
TestCase::new("Dual stack enr is contactable by ip4 node")
.enr_ip6(Ipv6Addr::LOCALHOST)
.enr_ip4(Ipv4Addr::LOCALHOST)
- .ip_mode(IpMode::Ip4)
+ .ip_mode(Ip4)
.expect_ip4(Ipv4Addr::LOCALHOST)
.test();
TestCase::new("Dual stack enr is contactable by ip6 only node")
.enr_ip6(Ipv6Addr::LOCALHOST)
.enr_ip4(Ipv4Addr::LOCALHOST)
- .ip_mode(IpMode::Ip6 {
- enable_mapped_addresses: false,
- })
+ .ip_mode(Ip6)
.expect_ip6(Ipv6Addr::LOCALHOST)
.test();
TestCase::new("Dual stack enr is contactable by dual stack node")
.enr_ip6(Ipv6Addr::LOCALHOST)
.enr_ip4(Ipv4Addr::LOCALHOST)
- .ip_mode(IpMode::Ip6 {
- enable_mapped_addresses: true,
- })
+ .ip_mode(Ip6)
.expect_ip6(Ipv6Addr::LOCALHOST)
.test();
}
}
-
-/// Copied from the standard library. See https://github.com/rust-lang/rust/issues/27709
-/// The current code is behind the `ip` feature.
-pub const fn to_ipv4_mapped(ip: &std::net::Ipv6Addr) -> Option {
- match ip.octets() {
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, a, b, c, d] => {
- Some(std::net::Ipv4Addr::new(a, b, c, d))
- }
- _ => None,
- }
-}
diff --git a/src/kbucket/entry.rs b/src/kbucket/entry.rs
index 1e0304048..97be95de1 100644
--- a/src/kbucket/entry.rs
+++ b/src/kbucket/entry.rs
@@ -25,9 +25,7 @@
//! representing the nodes participating in the Kademlia DHT.
pub use super::{
- bucket::{
- AppliedPending, ConnectionState, InsertResult, Node, NodeStatus, MAX_NODES_PER_BUCKET,
- },
+ bucket::{AppliedPending, ConnectionState, InsertResult, Node, NodeStatus},
key::*,
ConnectionDirection,
};
diff --git a/src/lib.rs b/src/lib.rs
index a2864b26d..99b2eb256 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,71 +1,65 @@
-#![warn(rust_2018_idioms)]
#![deny(rustdoc::broken_intra_doc_links)]
-#![cfg_attr(docsrs, feature(doc_cfg))]
-#![allow(clippy::needless_doctest_main)]
//! An implementation of [Discovery V5](https://github.com/ethereum/devp2p/blob/master/discv5/discv5.md).
//!
//! # Overview
//!
-//! Discovery v5 is a protocol designed for encrypted peer discovery and topic advertisement. Each peer/node
-//! on the network is identified via it's ENR ([Ethereum Name
+//! Discovery v5 is a protocol designed for encrypted peer discovery and topic advertisement. Each
+//! peer/node on the network is identified via it's ENR ([Ethereum Name
//! Record](https://eips.ethereum.org/EIPS/eip-778)), which is essentially a signed key-value store
//! containing the node's public key and optionally IP address and port.
//!
-//! Discv5 employs a kademlia-like routing table to store and manage discovered peers and topics. The
-//! protocol allows for external IP discovery in NAT environments through regular PING/PONG's with
+//! Discv5 employs a kademlia-like routing table to store and manage discovered peers. The protocol
+//! allows for external IP discovery in NAT environments through regular PING/PONG's with
//! discovered nodes. Nodes return the external IP address that they have received and a simple
//! majority is chosen as our external IP address. If an external IP address is updated, this is
-//! produced as an event to notify the swarm (if one is used for this behaviour).
+//! produced as an event.
//!
-//! For a simple CLI discovery service see [discv5-cli](https://github.com/AgeManning/discv5-cli)
+//! For a simple CLI discovery service see [discv5-cli](https://github.com/AgeManning/discv5-cli)
//!
-//! This protocol is split into four main sections/layers:
+//! This protocol is split into four main layers:
//!
-//! * Socket - The [`socket`] module is responsible for opening the underlying UDP socket. It
-//! creates individual tasks for sending/encoding and receiving/decoding packets from the UDP
-//! socket.
-//! * Handler - The protocol's communication is encrypted with `AES_GCM`. All node communication
-//! undergoes a handshake, which results in a [`Session`]. [`Session`]'s are established when
-//! needed and get dropped after a timeout. This section manages the creation and maintenance of
-//! sessions between nodes and the encryption/decryption of packets from the socket. It is realised by the [`handler::Handler`] struct and it runs in its own task.
-//! * Service - This section contains the protocol-level logic. In particular it manages the
-//! routing table of known ENR's, topic registration/advertisement and performs various queries
-//! such as peer discovery. This section is realised by the [`Service`] struct. This also runs in
-//! it's own thread.
-//! * Application - This section is the user-facing API which can start/stop the underlying
-//! tasks, initiate queries and obtain metrics about the underlying server.
+//! - [`socket`]: Responsible for opening the underlying UDP socket. It creates individual tasks
+//! for sending/encoding and receiving/decoding packets from the UDP socket.
+//! - [`handler`]: The protocol's communication is encrypted with `AES_GCM`. All node communication
+//! undergoes a handshake, which results in a `Session`. These are established when needed and get
+//! dropped after a timeout. The creation and maintenance of sessions between nodes and the
+//! encryption/decryption of packets from the socket is realised by the [`handler::Handler`] struct
+//! runnning in its own task.
+//! - [`service`]: Contains the protocol-level logic. The [`service::Service`] manages the routing
+//! table of known ENR's, and performs parallel queries for peer discovery. It also runs in it's
+//! own task.
+//! - [`Discv5`]: The application level. Manages the user-facing API. It starts/stops the underlying
+//! tasks, allows initiating queries and obtain metrics about the underlying server.
//!
-//! ## Event Stream
+//! ## Event Stream
//!
-//! The [`Discv5`] struct provides access to an event-stream which allows the user to listen to
-//! [`Discv5Event`] that get generated from the underlying server. The stream can be obtained
-//! from the [`Discv5::event_stream()`] function.
+//! The [`Discv5`] struct provides access to an event-stream which allows the user to listen to
+//! [`Event`] that get generated from the underlying server. The stream can be obtained from the
+//! [`Discv5::event_stream`] function.
//!
-//! ## Runtimes
+//! ## Runtimes
//!
-//! Discv5 requires a tokio runtime with timing and io enabled. An explicit runtime can be given
-//! via the configuration. See the [`Discv5ConfigBuilder`] for further details. Such a runtime
-//! must implement the [`Executor`] trait.
+//! Discv5 requires a tokio runtime with timing and io enabled. An explicit runtime can be given
+//! via the configuration. See the [`ConfigBuilder`] for further details. Such a runtime must
+//! implement the [`Executor`] trait.
//!
-//! If an explicit runtime is not provided via the configuration parameters, it is assumed that
-//! a tokio runtime is present when creating the [`Discv5`] struct. The struct will use the
-//! existing runtime for spawning the underlying server tasks. If a runtime is not present, the
-//! creation of the [`Discv5`] struct will panic.
+//! If an explicit runtime is not provided via the configuration parameters, it is assumed that a
+//! tokio runtime is present when creating the [`Discv5`] struct. The struct will use the existing
+//! runtime for spawning the underlying server tasks. If a runtime is not present, the creation of
+//! the [`Discv5`] struct will panic.
//!
//! # Usage
//!
//! A simple example of creating this service is as follows:
//!
//! ```rust
-//! use discv5::{enr, enr::{CombinedKey, NodeId}, TokioExecutor, Discv5, Discv5ConfigBuilder};
-//! use std::net::SocketAddr;
-//!
-//! // listening address and port
-//! let listen_addr = "0.0.0.0:9000".parse::().unwrap();
+//! use discv5::{enr, enr::{CombinedKey, NodeId}, TokioExecutor, Discv5, ConfigBuilder};
+//! use discv5::socket::ListenConfig;
+//! use std::net::{Ipv4Addr, SocketAddr};
//!
//! // construct a local ENR
//! let enr_key = CombinedKey::generate_secp256k1();
-//! let enr = enr::EnrBuilder::new("v4").build(&enr_key).unwrap();
+//! let enr = enr::Enr::empty(&enr_key).unwrap();
//!
//! // build the tokio executor
//! let mut runtime = tokio::runtime::Builder::new_multi_thread()
@@ -74,18 +68,24 @@
//! .build()
//! .unwrap();
//!
+//! // configuration for the sockets to listen on
+//! let listen_config = ListenConfig::Ipv4 {
+//! ip: Ipv4Addr::UNSPECIFIED,
+//! port: 9000,
+//! };
+//!
//! // default configuration
-//! let config = Discv5ConfigBuilder::new().build();
+//! let config = ConfigBuilder::new(listen_config).build();
//!
//! // construct the discv5 server
-//! let mut discv5 = Discv5::new(enr, enr_key, config).unwrap();
+//! let mut discv5: Discv5 = Discv5::new(enr, enr_key, config).unwrap();
//!
//! // In order to bootstrap the routing table an external ENR should be added
//! // This can be done via add_enr. I.e.:
//! // discv5.add_enr()
//!
//! // start the discv5 server
-//! runtime.block_on(discv5.start(listen_addr));
+//! runtime.block_on(discv5.start());
//!
//! // run a find_node query
//! runtime.block_on(async {
@@ -93,14 +93,6 @@
//! println!("Found nodes: {:?}", found_nodes);
//! });
//! ```
-//!
-//! [`Discv5`]: struct.Discv5.html
-//! [`Discv5Event`]: enum.Discv5Event.html
-//! [`Discv5Config`]: config/struct.Discv5Config.html
-//! [`Discv5ConfigBuilder`]: config/struct.Discv5ConfigBuilder.html
-//! [Packet]: packet/enum.Packet.html
-//! [`Service`]: service/struct.Service.html
-//! [`Session`]: session/struct.Session.html
mod config;
mod discv5;
@@ -124,14 +116,15 @@ extern crate lazy_static;
pub type Enr = enr::Enr;
-pub use crate::discv5::{Discv5, Discv5Event};
-pub use config::{Discv5Config, Discv5ConfigBuilder};
-pub use error::{Discv5Error, QueryError, RequestError, ResponseError};
+pub use crate::discv5::{Discv5, Event};
+pub use config::{Config, ConfigBuilder};
+pub use error::{Error, QueryError, RequestError, ResponseError};
pub use executor::{Executor, TokioExecutor};
pub use ipmode::IpMode;
pub use kbucket::{ConnectionDirection, ConnectionState, Key};
+pub use packet::{DefaultProtocolId, ProtocolIdentity};
pub use permit_ban::PermitBanList;
pub use service::TalkRequest;
-pub use socket::{RateLimiter, RateLimiterBuilder};
+pub use socket::{ListenConfig, RateLimiter, RateLimiterBuilder};
// re-export the ENR crate
pub use enr;
diff --git a/src/node_info.rs b/src/node_info.rs
index 1b299deb2..ff609ae87 100644
--- a/src/node_info.rs
+++ b/src/node_info.rs
@@ -4,7 +4,11 @@ use enr::{CombinedPublicKey, NodeId};
use std::net::SocketAddr;
#[cfg(feature = "libp2p")]
-use libp2p_core::{identity::PublicKey, multiaddr::Protocol, multihash, Multiaddr};
+use libp2p::{
+ identity::{KeyType, PublicKey},
+ multiaddr::Protocol,
+ Multiaddr,
+};
/// This type relaxes the requirement of having an ENR to connect to a node, to allow for unsigned
/// connection types, such as multiaddrs.
@@ -94,34 +98,34 @@ impl NodeContact {
Protocol::Udp(port) => udp_port = Some(port),
Protocol::Ip4(addr) => ip_addr = Some(addr.into()),
Protocol::Ip6(addr) => ip_addr = Some(addr.into()),
- Protocol::P2p(multihash) => p2p = Some(multihash),
+ Protocol::P2p(peer_id) => p2p = Some(peer_id),
_ => {}
}
}
let udp_port = udp_port.ok_or("A UDP port must be specified in the multiaddr")?;
let ip_addr = ip_addr.ok_or("An IP address must be specified in the multiaddr")?;
- let multihash = p2p.ok_or("The p2p protocol must be specified in the multiaddr")?;
-
- // verify the correct key type
- if multihash.code() != u64::from(multihash::Code::Identity) {
- return Err("The key type is unsupported");
- }
-
- let public_key: CombinedPublicKey =
- match PublicKey::from_protobuf_encoding(&multihash.to_bytes()[2..])
- .map_err(|_| "Invalid public key")?
- {
- PublicKey::Secp256k1(pk) => {
- enr::k256::ecdsa::VerifyingKey::from_sec1_bytes(&pk.encode_uncompressed())
- .expect("Libp2p key conversion, always valid")
- .into()
- }
- PublicKey::Ed25519(pk) => enr::ed25519_dalek::PublicKey::from_bytes(&pk.encode())
- .expect("Libp2p key conversion, always valid")
- .into(),
+ let peer_id = p2p.ok_or("The p2p protocol must be specified in the multiaddr")?;
+
+ let public_key: CombinedPublicKey = {
+ let pk = PublicKey::try_decode_protobuf(&peer_id.to_bytes()[2..])
+ .map_err(|_| "Invalid public key")?;
+ match pk.key_type() {
+ KeyType::Secp256k1 => enr::k256::ecdsa::VerifyingKey::from_sec1_bytes(
+ &pk.try_into_secp256k1()
+ .expect("Must be secp256k1")
+ .to_bytes_uncompressed(),
+ )
+ .expect("Libp2p key conversion, always valid")
+ .into(),
+ KeyType::Ed25519 => enr::ed25519_dalek::VerifyingKey::from_bytes(
+ &pk.try_into_ed25519().expect("Must be ed25519").to_bytes(),
+ )
+ .expect("Libp2p key conversion, always valid")
+ .into(),
_ => return Err("The key type is not supported"),
- };
+ }
+ };
Ok(NodeContact {
public_key,
diff --git a/src/packet/mod.rs b/src/packet/mod.rs
index a7c7191ad..f071263bd 100644
--- a/src/packet/mod.rs
+++ b/src/packet/mod.rs
@@ -29,10 +29,17 @@ pub const MESSAGE_NONCE_LENGTH: usize = 12;
/// The Id nonce length (in bytes).
pub const ID_NONCE_LENGTH: usize = 16;
-/// Protocol ID sent with each message.
-const PROTOCOL_ID: &str = "discv5";
-/// The version sent with each handshake.
-const VERSION: u16 = 0x0001;
+pub struct DefaultProtocolId {}
+
+impl ProtocolIdentity for DefaultProtocolId {
+ const PROTOCOL_ID_BYTES: [u8; 6] = *b"discv5";
+ const PROTOCOL_VERSION_BYTES: [u8; 2] = 0x0001_u16.to_be_bytes();
+}
+
+pub trait ProtocolIdentity {
+ const PROTOCOL_ID_BYTES: [u8; 6];
+ const PROTOCOL_VERSION_BYTES: [u8; 2];
+}
pub(crate) const MAX_PACKET_SIZE: usize = 1280;
// The smallest packet must be at least this large
@@ -92,11 +99,14 @@ pub struct PacketHeader {
impl PacketHeader {
// Encodes the header to bytes to be included into the `masked-header` of the Packet Encoding.
- pub fn encode(&self) -> Vec {
+ pub fn encode(&self) -> Vec
+ where
+ P: ProtocolIdentity,
+ {
let auth_data = self.kind.encode();
let mut buf = Vec::with_capacity(auth_data.len() + STATIC_HEADER_LENGTH);
- buf.extend_from_slice(PROTOCOL_ID.as_bytes());
- buf.extend_from_slice(&VERSION.to_be_bytes());
+ buf.extend_from_slice(&P::PROTOCOL_ID_BYTES);
+ buf.extend_from_slice(&P::PROTOCOL_VERSION_BYTES);
let kind: u8 = (&self.kind).into();
buf.extend_from_slice(&kind.to_be_bytes());
buf.extend_from_slice(&self.message_nonce);
@@ -372,15 +382,15 @@ impl Packet {
}
/// Generates the authenticated data for this packet.
- pub fn authenticated_data(&self) -> Vec {
+ pub fn authenticated_data(&self) -> Vec {
let mut authenticated_data = self.iv.to_be_bytes().to_vec();
- authenticated_data.extend_from_slice(&self.header.encode());
+ authenticated_data.extend_from_slice(&self.header.encode::());
authenticated_data
}
/// Encodes a packet to bytes and performs the AES-CTR encryption.
- pub fn encode(self, dst_id: &NodeId) -> Vec {
- let header = self.encrypt_header(dst_id);
+ pub fn encode(self, dst_id: &NodeId) -> Vec {
+ let header = self.encrypt_header::(dst_id);
let mut buf = Vec::with_capacity(IV_LENGTH + header.len() + self.message.len());
buf.extend_from_slice(&self.iv.to_be_bytes());
buf.extend_from_slice(&header);
@@ -389,8 +399,8 @@ impl Packet {
}
/// Creates the masked header of a packet performing the required AES-CTR encryption.
- fn encrypt_header(&self, dst_id: &NodeId) -> Vec {
- let mut header_bytes = self.header.encode();
+ fn encrypt_header(&self, dst_id: &NodeId) -> Vec {
+ let mut header_bytes = self.header.encode::();
/* Encryption is done inline
*
@@ -410,7 +420,10 @@ impl Packet {
/// Decodes a packet (data) given our local source id (src_key).
///
/// This also returns the authenticated data for further decryption in the handler.
- pub fn decode(src_id: &NodeId, data: &[u8]) -> Result<(Self, Vec), PacketError> {
+ pub fn decode(
+ src_id: &NodeId,
+ data: &[u8],
+ ) -> Result<(Self, Vec), PacketError> {
if data.len() > MAX_PACKET_SIZE {
return Err(PacketError::TooLarge);
}
@@ -440,17 +453,15 @@ impl Packet {
}
// Check the protocol id
- if &static_header[..6] != PROTOCOL_ID.as_bytes() {
+ if static_header[..6] != P::PROTOCOL_ID_BYTES {
return Err(PacketError::HeaderDecryptionFailed);
}
+ let version_bytes = &static_header[6..8];
// Check the version matches
- let version = u16::from_be_bytes(
- static_header[6..8]
- .try_into()
- .expect("Must be correct size"),
- );
- if version != VERSION {
+ if version_bytes != P::PROTOCOL_VERSION_BYTES {
+ let version =
+ u16::from_be_bytes(version_bytes.try_into().expect("Must be correct size"));
return Err(PacketError::InvalidVersion(version));
}
@@ -607,7 +618,7 @@ mod tests {
message,
};
- let encoded = packet.encode(&node_id_b);
+ let encoded = packet.encode::(&node_id_b);
dbg!(hex::encode(&encoded));
assert_eq!(expected_result, encoded);
}
@@ -640,7 +651,7 @@ mod tests {
message: Vec::new(),
};
- assert_eq!(packet.encode(&dst_id), expected_output);
+ assert_eq!(packet.encode::(&dst_id), expected_output);
}
#[test]
@@ -672,7 +683,7 @@ mod tests {
header,
message: Vec::new(),
};
- let encoded = packet.encode(&dst_id);
+ let encoded = packet.encode::(&dst_id);
assert_eq!(encoded, expected_output);
}
@@ -705,7 +716,7 @@ mod tests {
header,
message: Vec::new(),
};
- let encoded = packet.encode(&dst_id);
+ let encoded = packet.encode::(&dst_id);
assert_eq!(encoded, expected_output);
}
@@ -730,7 +741,7 @@ mod tests {
header,
message: ciphertext,
};
- let encoded = packet.encode(&dst_id);
+ let encoded = packet.encode::(&dst_id);
assert_eq!(encoded, expected_output);
}
@@ -742,9 +753,9 @@ mod tests {
let packet = Packet::new_random(&src_id).unwrap();
- let encoded_packet = packet.clone().encode(&dst_id);
+ let encoded_packet = packet.clone().encode::(&dst_id);
let (decoded_packet, _authenticated_data) =
- Packet::decode(&dst_id, &encoded_packet).unwrap();
+ Packet::decode::(&dst_id, &encoded_packet).unwrap();
assert_eq!(decoded_packet, packet);
}
@@ -759,9 +770,9 @@ mod tests {
let packet = Packet::new_whoareyou(message_nonce, id_nonce, enr_seq);
- let encoded_packet = packet.clone().encode(&dst_id);
+ let encoded_packet = packet.clone().encode::(&dst_id);
let (decoded_packet, _authenticated_data) =
- Packet::decode(&dst_id, &encoded_packet).unwrap();
+ Packet::decode::(&dst_id, &encoded_packet).unwrap();
assert_eq!(decoded_packet, packet);
}
@@ -779,9 +790,9 @@ mod tests {
let packet =
Packet::new_authheader(src_id, message_nonce, id_nonce_sig, pubkey, enr_record);
- let encoded_packet = packet.clone().encode(&dst_id);
+ let encoded_packet = packet.clone().encode::(&dst_id);
let (decoded_packet, _authenticated_data) =
- Packet::decode(&dst_id, &encoded_packet).unwrap();
+ Packet::decode::(&dst_id, &encoded_packet).unwrap();
assert_eq!(decoded_packet, packet);
}
@@ -808,7 +819,8 @@ mod tests {
let encoded_ref_packet = hex::decode("00000000000000000000000000000000088b3d4342774649325f313964a39e55ea96c005ad52be8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d34c4f53245d08dab84102ed931f66d1492acb308fa1c6715b9d139b81acbdcc").unwrap();
- let (packet, _auth_data) = Packet::decode(&dst_id, &encoded_ref_packet).unwrap();
+ let (packet, _auth_data) =
+ Packet::decode::(&dst_id, &encoded_ref_packet).unwrap();
assert_eq!(packet, expected_packet);
}
@@ -844,7 +856,8 @@ mod tests {
let decoded_ref_packet = hex::decode("00000000000000000000000000000000088b3d4342774649305f313964a39e55ea96c005ad521d8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d34c4f53245d08da4bb252012b2cba3f4f374a90a75cff91f142fa9be3e0a5f3ef268ccb9065aeecfd67a999e7fdc137e062b2ec4a0eb92947f0d9a74bfbf44dfba776b21301f8b65efd5796706adff216ab862a9186875f9494150c4ae06fa4d1f0396c93f215fa4ef524f1eadf5f0f4126b79336671cbcf7a885b1f8bd2a5d839cf8").unwrap();
- let (packet, _auth_data) = Packet::decode(&dst_id, &decoded_ref_packet).unwrap();
+ let (packet, _auth_data) =
+ Packet::decode::(&dst_id, &decoded_ref_packet).unwrap();
assert_eq!(packet, expected_packet);
}
@@ -880,7 +893,8 @@ mod tests {
let encoded_ref_packet = hex::decode("00000000000000000000000000000000088b3d4342774649305f313964a39e55ea96c005ad539c8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d34c4f53245d08da4bb23698868350aaad22e3ab8dd034f548a1c43cd246be98562fafa0a1fa86d8e7a3b95ae78cc2b988ded6a5b59eb83ad58097252188b902b21481e30e5e285f19735796706adff216ab862a9186875f9494150c4ae06fa4d1f0396c93f215fa4ef524e0ed04c3c21e39b1868e1ca8105e585ec17315e755e6cfc4dd6cb7fd8e1a1f55e49b4b5eb024221482105346f3c82b15fdaae36a3bb12a494683b4a3c7f2ae41306252fed84785e2bbff3b022812d0882f06978df84a80d443972213342d04b9048fc3b1d5fcb1df0f822152eced6da4d3f6df27e70e4539717307a0208cd208d65093ccab5aa596a34d7511401987662d8cf62b139471").unwrap();
- let (packet, _auth_data) = Packet::decode(&dst_id, &encoded_ref_packet).unwrap();
+ let (packet, _auth_data) =
+ Packet::decode::(&dst_id, &encoded_ref_packet).unwrap();
assert_eq!(packet, expected_packet);
}
@@ -889,11 +903,11 @@ mod tests {
let src_id: NodeId = node_key_1().public().into();
let data = [0; MAX_PACKET_SIZE + 1];
- let result = Packet::decode(&src_id, &data);
+ let result = Packet::decode::(&src_id, &data);
assert_eq!(result, Err(PacketError::TooLarge));
let data = [0; MIN_PACKET_SIZE - 1];
- let result = Packet::decode(&src_id, &data);
+ let result = Packet::decode::(&src_id, &data);
assert_eq!(result, Err(PacketError::TooSmall));
}
}
diff --git a/src/query_pool/peers/closest.rs b/src/query_pool/peers/closest.rs
index 936cf2c9d..334d8a39f 100644
--- a/src/query_pool/peers/closest.rs
+++ b/src/query_pool/peers/closest.rs
@@ -23,7 +23,7 @@
//
use super::*;
use crate::{
- config::Discv5Config,
+ config::Config,
kbucket::{Distance, Key, MAX_NODES_PER_BUCKET},
};
use std::{
@@ -76,7 +76,7 @@ pub struct FindNodeQueryConfig {
}
impl FindNodeQueryConfig {
- pub fn new_from_config(config: &Discv5Config) -> Self {
+ pub fn new_from_config(config: &Config) -> Self {
Self {
parallelism: config.query_parallelism,
num_results: MAX_NODES_PER_BUCKET,
diff --git a/src/query_pool/peers/predicate.rs b/src/query_pool/peers/predicate.rs
index 4768a1c35..3b4442019 100644
--- a/src/query_pool/peers/predicate.rs
+++ b/src/query_pool/peers/predicate.rs
@@ -1,6 +1,6 @@
use super::*;
use crate::{
- config::Discv5Config,
+ config::Config,
kbucket::{Distance, Key, PredicateKey, MAX_NODES_PER_BUCKET},
};
use std::{
@@ -55,7 +55,7 @@ pub(crate) struct PredicateQueryConfig {
}
impl PredicateQueryConfig {
- pub(crate) fn new_from_config(config: &Discv5Config) -> Self {
+ pub(crate) fn new_from_config(config: &Config) -> Self {
Self {
parallelism: config.query_parallelism,
num_results: MAX_NODES_PER_BUCKET,
diff --git a/src/rpc.rs b/src/rpc.rs
index 0040a660c..f7152cba5 100644
--- a/src/rpc.rs
+++ b/src/rpc.rs
@@ -1,10 +1,12 @@
use enr::{CombinedKey, Enr};
use rlp::{DecoderError, RlpStream};
-use std::net::{IpAddr, Ipv6Addr};
+use std::{
+ convert::TryInto,
+ net::{IpAddr, Ipv6Addr},
+ num::NonZeroU16,
+};
use tracing::{debug, warn};
-type TopicHash = [u8; 32];
-
/// Type to manage the request IDs.
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
pub struct RequestId(pub Vec);
@@ -80,14 +82,6 @@ pub enum RequestBody {
/// The request.
request: Vec,
},
- /// A REGISTERTOPIC request.
- RegisterTopic {
- topic: Vec,
- enr: crate::Enr,
- ticket: Vec,
- },
- /// A TOPICQUERY request.
- TopicQuery { topic: TopicHash },
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -99,7 +93,7 @@ pub enum ResponseBody {
/// Our external IP address as observed by the responder.
ip: IpAddr,
/// Our external UDP port as observed by the responder.
- port: u16,
+ port: NonZeroU16,
},
/// A NODES response.
Nodes {
@@ -117,13 +111,6 @@ pub enum ResponseBody {
/// The response for the talk.
response: Vec,
},
- Ticket {
- ticket: Vec,
- wait_time: u64,
- },
- RegisterConfirmation {
- topic: Vec,
- },
}
impl Request {
@@ -132,8 +119,6 @@ impl Request {
RequestBody::Ping { .. } => 1,
RequestBody::FindNode { .. } => 3,
RequestBody::Talk { .. } => 5,
- RequestBody::RegisterTopic { .. } => 7,
- RequestBody::TopicQuery { .. } => 10,
}
}
@@ -156,10 +141,7 @@ impl Request {
let mut s = RlpStream::new();
s.begin_list(2);
s.append(&id.as_bytes());
- s.begin_list(distances.len());
- for distance in distances {
- s.append(&distance);
- }
+ s.append_list(&distances);
buf.extend_from_slice(&s.out());
buf
}
@@ -172,24 +154,6 @@ impl Request {
buf.extend_from_slice(&s.out());
buf
}
- RequestBody::RegisterTopic { topic, enr, ticket } => {
- let mut s = RlpStream::new();
- s.begin_list(4);
- s.append(&id.as_bytes());
- s.append(&topic);
- s.append(&enr);
- s.append(&ticket);
- buf.extend_from_slice(&s.out());
- buf
- }
- RequestBody::TopicQuery { topic } => {
- let mut s = RlpStream::new();
- s.begin_list(2);
- s.append(&id.as_bytes());
- s.append(&(&topic as &[u8]));
- buf.extend_from_slice(&s.out());
- buf
- }
}
}
}
@@ -201,8 +165,6 @@ impl Response {
ResponseBody::Nodes { .. } => 4,
ResponseBody::NodesRaw { .. } => 4,
ResponseBody::Talk { .. } => 6,
- ResponseBody::Ticket { .. } => 8,
- ResponseBody::RegisterConfirmation { .. } => 9,
}
}
@@ -211,22 +173,12 @@ impl Response {
match self.body {
ResponseBody::Pong { .. } => matches!(req, RequestBody::Ping { .. }),
ResponseBody::Nodes { .. } => {
- matches!(
- req,
- RequestBody::FindNode { .. } | RequestBody::TopicQuery { .. }
- )
+ matches!(req, RequestBody::FindNode { .. })
}
ResponseBody::NodesRaw { .. } => {
- matches!(
- req,
- RequestBody::FindNode { .. } | RequestBody::TopicQuery { .. }
- )
+ matches!(req, RequestBody::FindNode { .. })
}
ResponseBody::Talk { .. } => matches!(req, RequestBody::Talk { .. }),
- ResponseBody::Ticket { .. } => matches!(req, RequestBody::RegisterTopic { .. }),
- ResponseBody::RegisterConfirmation { .. } => {
- matches!(req, RequestBody::RegisterTopic { .. })
- }
}
}
@@ -246,7 +198,7 @@ impl Response {
IpAddr::V4(addr) => s.append(&(&addr.octets() as &[u8])),
IpAddr::V6(addr) => s.append(&(&addr.octets() as &[u8])),
};
- s.append(&port);
+ s.append(&port.get());
buf.extend_from_slice(&s.out());
buf
}
@@ -292,23 +244,6 @@ impl Response {
buf.extend_from_slice(&s.out());
buf
}
- ResponseBody::Ticket { ticket, wait_time } => {
- let mut s = RlpStream::new();
- s.begin_list(3);
- s.append(&id.as_bytes());
- s.append(&ticket);
- s.append(&wait_time);
- buf.extend_from_slice(&s.out());
- buf
- }
- ResponseBody::RegisterConfirmation { topic } => {
- let mut s = RlpStream::new();
- s.begin_list(2);
- s.append(&id.as_bytes());
- s.append(&topic);
- buf.extend_from_slice(&s.out());
- buf
- }
}
}
}
@@ -371,12 +306,6 @@ impl std::fmt::Display for ResponseBody {
ResponseBody::Talk { response } => {
write!(f, "Response: Response {}", hex::encode(response))
}
- ResponseBody::Ticket { ticket, wait_time } => {
- write!(f, "TICKET: Ticket: {ticket:?}, Wait time: {wait_time}")
- }
- ResponseBody::RegisterConfirmation { topic } => {
- write!(f, "REGTOPIC: Registered: {}", hex::encode(topic))
- }
}
}
}
@@ -400,14 +329,6 @@ impl std::fmt::Display for RequestBody {
hex::encode(protocol),
hex::encode(request)
),
- RequestBody::TopicQuery { topic } => write!(f, "TOPICQUERY: topic: {topic:?}"),
- RequestBody::RegisterTopic { topic, enr, ticket } => write!(
- f,
- "RegisterTopic: topic: {}, enr: {}, ticket: {}",
- hex::encode(topic),
- enr.to_base64(),
- hex::encode(ticket)
- ),
}
}
}
@@ -426,8 +347,9 @@ impl Message {
}
let msg_type = data[0];
+ let data = &data[1..];
- let rlp = rlp::Rlp::new(&data[1..]);
+ let rlp = rlp::Rlp::new(data);
let list_len = rlp.item_count().and_then(|size| {
if size < 2 {
@@ -437,6 +359,12 @@ impl Message {
}
})?;
+ // verify there is no extra data
+ let payload_info = rlp.payload_info()?;
+ if data.len() != payload_info.header_len + payload_info.value_len {
+ return Err(DecoderError::RlpInconsistentLengthAndData);
+ }
+
let id = RequestId::decode(rlp.val_at::>(0)?)?;
let message = match msg_type {
@@ -476,8 +404,13 @@ impl Message {
let mut ip = [0u8; 16];
ip.copy_from_slice(&ip_bytes);
let ipv6 = Ipv6Addr::from(ip);
- // If the ipv6 is ipv4 compatible/mapped, simply return the ipv4.
- if let Some(ipv4) = ipv6.to_ipv4() {
+
+ if ipv6.is_loopback() {
+ // Checking if loopback address since IPv6Addr::to_ipv4 returns
+ // IPv4 address for IPv6 loopback address.
+ IpAddr::V6(ipv6)
+ } else if let Some(ipv4) = ipv6.to_ipv4() {
+ // If the ipv6 is ipv4 compatible/mapped, simply return the ipv4.
IpAddr::V4(ipv4)
} else {
IpAddr::V6(ipv6)
@@ -488,15 +421,20 @@ impl Message {
return Err(DecoderError::RlpIncorrectListLen);
}
};
- let port = rlp.val_at::(3)?;
- Message::Response(Response {
- id,
- body: ResponseBody::Pong {
- enr_seq: rlp.val_at::(1)?,
- ip,
- port,
- },
- })
+ let raw_port = rlp.val_at::(3)?;
+ if let Ok(port) = raw_port.try_into() {
+ Message::Response(Response {
+ id,
+ body: ResponseBody::Pong {
+ enr_seq: rlp.val_at::(1)?,
+ ip,
+ port,
+ },
+ })
+ } else {
+ debug!("The port number should be non zero: {raw_port}");
+ return Err(DecoderError::Custom("PONG response port number invalid"));
+ }
}
3 => {
// FindNodeRequest
@@ -584,59 +522,7 @@ impl Message {
}
_ => {
return Err(DecoderError::Custom("Unknown RPC message type"));
- } /*
- * All other RPC messages are currently not supported as per the 5.1 specification.
-
- 7 => {
- // RegisterTopicRequest
- if list_len != 2 {
- debug!("RegisterTopic Request has an invalid RLP list length. Expected 2, found {}", list_len);
- return Err(DecoderError::RlpIncorrectListLen);
- }
- let ticket = rlp.val_at::>(1)?;
- Message::Request(Request {
- id,
- body: RequestBody::RegisterTopic { ticket },
- })
- }
- 8 => {
- // RegisterTopicResponse
- if list_len != 2 {
- debug!("RegisterTopic Response has an invalid RLP list length. Expected 2, found {}", list_len);
- return Err(DecoderError::RlpIncorrectListLen);
- }
- Message::Response(Response {
- id,
- body: ResponseBody::RegisterTopic {
- registered: rlp.val_at::(1)?,
- },
- })
- }
- 9 => {
- // TopicQueryRequest
- if list_len != 2 {
- debug!(
- "TopicQuery Request has an invalid RLP list length. Expected 2, found {}",
- list_len
- );
- return Err(DecoderError::RlpIncorrectListLen);
- }
- let topic = {
- let topic_bytes = rlp.val_at::>(1)?;
- if topic_bytes.len() > 32 {
- debug!("Ticket Request has a topic greater than 32 bytes");
- return Err(DecoderError::RlpIsTooBig);
- }
- let mut topic = [0u8; 32];
- topic[32 - topic_bytes.len()..].copy_from_slice(&topic_bytes);
- topic
- };
- Message::Request(Request {
- id,
- body: RequestBody::TopicQuery { topic },
- })
- }
- */
+ }
};
Ok(message)
@@ -646,7 +532,7 @@ impl Message {
#[cfg(test)]
mod tests {
use super::*;
- use enr::EnrBuilder;
+ use std::net::Ipv4Addr;
#[test]
fn ref_test_encode_request_ping() {
@@ -691,7 +577,11 @@ mod tests {
let port = 5000;
let message = Message::Response(Response {
id,
- body: ResponseBody::Pong { enr_seq, ip, port },
+ body: ResponseBody::Pong {
+ enr_seq,
+ ip,
+ port: port.try_into().unwrap(),
+ },
});
// expected hex output
@@ -808,7 +698,51 @@ mod tests {
body: ResponseBody::Pong {
enr_seq: 15,
ip: "127.0.0.1".parse().unwrap(),
- port: 80,
+ port: 80.try_into().unwrap(),
+ },
+ });
+
+ let encoded = request.clone().encode();
+ let decoded = Message::decode(&encoded).unwrap();
+
+ assert_eq!(request, decoded);
+ }
+
+ #[test]
+ fn encode_decode_ping_response_ipv4_mapped() {
+ let id = RequestId(vec![1]);
+ let request = Message::Response(Response {
+ id: id.clone(),
+ body: ResponseBody::Pong {
+ enr_seq: 15,
+ ip: IpAddr::V6(Ipv4Addr::new(192, 0, 2, 1).to_ipv6_mapped()),
+ port: 80.try_into().unwrap(),
+ },
+ });
+
+ let encoded = request.encode();
+ let decoded = Message::decode(&encoded).unwrap();
+ let expected = Message::Response(Response {
+ id,
+ body: ResponseBody::Pong {
+ enr_seq: 15,
+ ip: IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1)),
+ port: 80.try_into().unwrap(),
+ },
+ });
+
+ assert_eq!(expected, decoded);
+ }
+
+ #[test]
+ fn encode_decode_ping_response_ipv6_loopback() {
+ let id = RequestId(vec![1]);
+ let request = Message::Response(Response {
+ id,
+ body: ResponseBody::Pong {
+ enr_seq: 15,
+ ip: IpAddr::V6(Ipv6Addr::LOCALHOST),
+ port: 80.try_into().unwrap(),
},
});
@@ -837,17 +771,17 @@ mod tests {
#[test]
fn encode_decode_nodes_response() {
let key = CombinedKey::generate_secp256k1();
- let enr1 = EnrBuilder::new("v4")
+ let enr1 = Enr::builder()
.ip4("127.0.0.1".parse().unwrap())
.udp4(500)
.build(&key)
.unwrap();
- let enr2 = EnrBuilder::new("v4")
+ let enr2 = Enr::builder()
.ip4("10.0.0.1".parse().unwrap())
.tcp4(8080)
.build(&key)
.unwrap();
- let enr3 = EnrBuilder::new("v4")
+ let enr3 = Enr::builder()
.ip("10.4.5.6".parse().unwrap())
.build(&key)
.unwrap();
@@ -869,183 +803,21 @@ mod tests {
}
#[test]
- fn encode_decode_ticket_request() {
- let id = RequestId(vec![1]);
- let request = Message::Request(Request {
- id,
- body: RequestBody::Talk {
- protocol: vec![17u8; 32],
- request: vec![1, 2, 3],
- },
- });
-
- let encoded = request.clone().encode();
- let decoded = Message::decode(&encoded).unwrap();
-
- assert_eq!(request, decoded);
+ fn reject_extra_data() {
+ let data = [6, 194, 0, 75];
+ let msg = Message::decode(&data).unwrap();
+ assert_eq!(
+ msg,
+ Message::Response(Response {
+ id: RequestId(vec![0]),
+ body: ResponseBody::Talk { response: vec![75] }
+ })
+ );
+
+ let data2 = [6, 193, 0, 75, 252];
+ Message::decode(&data2).expect_err("should reject extra data");
+
+ let data3 = [6, 194, 0, 75, 252];
+ Message::decode(&data3).expect_err("should reject extra data");
}
-
- /*
- * These RPC messages are not in use yet
- *
- #[test]
- fn ref_test_encode_request_ticket() {
- // reference input
- let id = 1;
- let hash_bytes =
- hex::decode("fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736")
- .unwrap();
-
- // expected hex output
- let expected_output =
- hex::decode("05e201a0fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736")
- .unwrap();
-
- let mut topic_hash = [0; 32];
- topic_hash.copy_from_slice(&hash_bytes);
-
- let message = Message::Request(Request {
- id,
- body: RequestBody::Ticket { topic: topic_hash },
- });
- assert_eq!(message.encode(), expected_output);
- }
-
- #[test]
- fn ref_test_encode_request_register_topic() {
- // reference input
- let id = 1;
- let ticket =
- hex::decode("fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736")
- .unwrap();
-
- // expected hex output
- let expected_output =
- hex::decode("07e201a0fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736")
- .unwrap();
-
- let message = Message::Request(Request {
- id,
- body: RequestBody::RegisterTopic { ticket },
- });
- assert_eq!(message.encode(), expected_output);
- }
-
- #[test]
- fn ref_test_encode_request_topic_query() {
- // reference input
- let id = 1;
- let hash_bytes =
- hex::decode("fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736")
- .unwrap();
-
- // expected hex output
- let expected_output =
- hex::decode("09e201a0fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736")
- .unwrap();
-
- let mut topic_hash = [0; 32];
- topic_hash.copy_from_slice(&hash_bytes);
-
- let message = Message::Request(Request {
- id,
- body: RequestBody::TopicQuery { topic: topic_hash },
- });
- assert_eq!(message.encode(), expected_output);
- }
-
- #[test]
- fn ref_test_encode_response_register_topic() {
- // reference input
- let id = 1;
- let registered = true;
-
- // expected hex output
- let expected_output = hex::decode("08c20101").unwrap();
- let message = Message::Response(Response {
- id,
- body: ResponseBody::RegisterTopic { registered },
- });
- assert_eq!(message.encode(), expected_output);
- }
-
- #[test]
- fn encode_decode_register_topic_request() {
- let request = Message::Request(Request {
- id: 1,
- body: RequestBody::RegisterTopic {
- topic: vec![1,2,3],
- ticket: vec![1, 2, 3, 4, 5],
- },
- });
-
- let encoded = request.clone().encode();
- let decoded = Message::decode(encoded).unwrap();
-
- assert_eq!(request, decoded);
- }
-
- #[test]
- fn encode_decode_register_topic_response() {
- let request = Message::Response(Response {
- id: 0,
- body: ResponseBody::RegisterTopic { registered: true },
- });
-
- let encoded = request.clone().encode();
- let decoded = Message::decode(encoded).unwrap();
-
- assert_eq!(request, decoded);
- }
-
- #[test]
- fn encode_decode_topic_query_request() {
- let request = Message::Request(Request {
- id: 1,
- body: RequestBody::TopicQuery { topic: [17u8; 32] },
- });
-
- let encoded = request.clone().encode();
- let decoded = Message::decode(encoded).unwrap();
-
- assert_eq!(request, decoded);
- }
-
- #[test]
- fn ref_test_encode_response_ticket() {
- // reference input
- let id = 1;
- let ticket = [0; 32].to_vec(); // all 0's
- let wait_time = 5;
-
- // expected hex output
- let expected_output = hex::decode(
- "06e301a0000000000000000000000000000000000000000000000000000000000000000005",
- )
- .unwrap();
-
- let message = Message::Response(Response {
- id,
- body: ResponseBody::Ticket { ticket, wait_time },
- });
- assert_eq!(message.encode(), expected_output);
- }
-
- #[test]
- fn encode_decode_ticket_response() {
- let request = Message::Response(Response {
- id: 0,
- body: ResponseBody::Ticket {
- ticket: vec![1, 2, 3, 4, 5],
- wait_time: 5,
- },
- });
-
- let encoded = request.clone().encode();
- let decoded = Message::decode(encoded).unwrap();
-
- assert_eq!(request, decoded);
- }
-
- */
}
diff --git a/src/service.rs b/src/service.rs
index 80a6f7dba..224aa09b4 100644
--- a/src/service.rs
+++ b/src/service.rs
@@ -25,11 +25,11 @@ use crate::{
NodeStatus, UpdateResult, MAX_NODES_PER_BUCKET,
},
node_info::{NodeAddress, NodeContact, NonContactable},
- packet::MAX_PACKET_SIZE,
+ packet::{ProtocolIdentity, MAX_PACKET_SIZE},
query_pool::{
FindNodeQueryConfig, PredicateQueryConfig, QueryId, QueryPool, QueryPoolState, TargetKey,
},
- rpc, Discv5Config, Discv5Event, Enr,
+ rpc, Config, Enr, Event, IpMode,
};
use delay_map::HashSetDelay;
use enr::{CombinedKey, NodeId};
@@ -38,7 +38,14 @@ use futures::prelude::*;
use more_asserts::debug_unreachable;
use parking_lot::RwLock;
use rpc::*;
-use std::{collections::HashMap, net::SocketAddr, sync::Arc, task::Poll, time::Instant};
+use std::{
+ collections::HashMap,
+ convert::TryInto,
+ net::{IpAddr, SocketAddr},
+ sync::Arc,
+ task::Poll,
+ time::Instant,
+};
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, info, trace, warn};
@@ -137,8 +144,13 @@ pub enum ServiceRequest {
/// - A Predicate Query - Searches for peers closest to a random target that match a specified
/// predicate.
StartQuery(QueryKind, oneshot::Sender>),
- /// Find the ENR of a node given its multiaddr.
- FindEnr(NodeContact, oneshot::Sender>),
+ /// Send a FINDNODE request for nodes that fall within the given set of distances,
+ /// to the designated peer and wait for a response.
+ FindNodeDesignated(
+ NodeContact,
+ Vec,
+ oneshot::Sender, RequestError>>,
+ ),
/// The TALK discv5 RPC function.
Talk(
NodeContact,
@@ -146,16 +158,18 @@ pub enum ServiceRequest {
Vec,
oneshot::Sender, RequestError>>,
),
+ /// The PING discv5 RPC function.
+ Ping(Enr, Option>>),
/// Sets up an event stream where the discv5 server will return various events such as
/// discovered nodes as it traverses the DHT.
- RequestEventStream(oneshot::Sender>),
+ RequestEventStream(oneshot::Sender>),
}
use crate::discv5::PERMIT_BAN_LIST;
pub struct Service {
/// Configuration parameters.
- config: Discv5Config,
+ config: Config,
/// The local ENR of the server.
local_enr: Arc>,
@@ -174,7 +188,7 @@ pub struct Service {
active_requests: FnvHashMap,
/// Keeps track of the number of responses received from a NODES response.
- active_nodes_responses: HashMap,
+ active_nodes_responses: HashMap,
/// A map of votes nodes have made about our external IP address. We accept the majority.
ip_votes: Option,
@@ -198,7 +212,10 @@ pub struct Service {
peers_to_ping: HashSetDelay,
/// A channel that the service emits events on.
- event_stream: Option>,
+ event_stream: Option>,
+
+ // Type of socket we are using
+ ip_mode: IpMode,
}
/// Active RPC request awaiting a response from the handler.
@@ -213,12 +230,24 @@ struct ActiveRequest {
pub callback: Option,
}
+#[derive(Debug)]
+pub struct Pong {
+ /// The current ENR sequence number of the responder.
+ pub enr_seq: u64,
+ /// Our external IP address as observed by the responder.
+ pub ip: IpAddr,
+ /// Our external UDP port as observed by the responder.
+ pub port: u16,
+}
+
/// The kinds of responses we can send back to the discv5 layer.
pub enum CallbackResponse {
- /// A response to a requested ENR.
- Enr(oneshot::Sender>),
+ /// A response to a requested Nodes.
+ Nodes(oneshot::Sender, RequestError>>),
/// A response from a TALK request
Talk(oneshot::Sender, RequestError>>),
+ /// A response from a Pong request
+ Pong(oneshot::Sender>),
}
/// For multiple responses to a FindNodes request, this keeps track of the request count
@@ -245,12 +274,11 @@ impl Service {
/// `local_enr` is the `ENR` representing the local node. This contains node identifying information, such
/// as IP addresses and ports which we wish to broadcast to other nodes via this discovery
/// mechanism.
- pub async fn spawn(
+ pub async fn spawn(
local_enr: Arc>,
enr_key: Arc>,
kbuckets: Arc>>,
- config: Discv5Config,
- listen_socket: SocketAddr,
+ config: Config,
) -> Result<(oneshot::Sender<()>, mpsc::Sender), std::io::Error> {
// process behaviour-level configuration parameters
let ip_votes = if config.enr_update {
@@ -262,14 +290,11 @@ impl Service {
None
};
+ let ip_mode = IpMode::new_from_listen_config(&config.listen_config);
+
// build the session service
- let (handler_exit, handler_send, handler_recv) = Handler::spawn(
- local_enr.clone(),
- enr_key.clone(),
- listen_socket,
- config.clone(),
- )
- .await?;
+ let (handler_exit, handler_send, handler_recv) =
+ Handler::spawn::(local_enr.clone(), enr_key.clone(), config.clone()).await?;
// create the required channels
let (discv5_send, discv5_recv) = mpsc::channel(30);
@@ -296,9 +321,10 @@ impl Service {
event_stream: None,
exit,
config: config.clone(),
+ ip_mode,
};
- info!("Discv5 Service started");
+ info!(mode = ?service.ip_mode, "Discv5 Service started");
service.start().await;
}));
@@ -307,7 +333,6 @@ impl Service {
/// The main execution loop of the discv5 serviced.
async fn start(&mut self) {
- info!("{:?}", self.config.ip_mode);
loop {
tokio::select! {
_ = &mut self.exit => {
@@ -329,12 +354,15 @@ impl Service {
}
}
}
- ServiceRequest::FindEnr(node_contact, callback) => {
- self.request_enr(node_contact, Some(callback));
+ ServiceRequest::FindNodeDesignated(node_contact, distance, callback) => {
+ self.request_find_node_designated_peer(node_contact, distance, Some(callback));
}
ServiceRequest::Talk(node_contact, protocol, request, callback) => {
self.talk_request(node_contact, protocol, request, callback);
}
+ ServiceRequest::Ping(enr, callback) => {
+ self.send_ping(enr, callback);
+ }
ServiceRequest::RequestEventStream(callback) => {
// the channel size needs to be large to handle many discovered peers
// if we are reporting them on the event stream.
@@ -350,7 +378,7 @@ impl Service {
Some(event) = self.handler_recv.recv() => {
match event {
HandlerOut::Established(enr, socket_addr, direction) => {
- self.send_event(Discv5Event::SessionEstablished(enr.clone(), socket_addr));
+ self.send_event(Event::SessionEstablished(enr.clone(), socket_addr));
self.inject_session_established(enr, direction);
}
HandlerOut::Request(node_address, request) => {
@@ -428,7 +456,7 @@ impl Service {
};
if let Some(enr) = enr {
- self.send_ping(enr);
+ self.send_ping(enr, None);
}
}
}
@@ -558,13 +586,13 @@ impl Service {
to_request_enr = Some(enr);
}
}
- // don't know of the ENR, request the update
+ // don't know the peer, don't request its most recent ENR
_ => {}
}
if let Some(enr) = to_request_enr {
- match NodeContact::try_from_enr(enr, self.config.ip_mode) {
+ match NodeContact::try_from_enr(enr, self.ip_mode) {
Ok(contact) => {
- self.request_enr(contact, None);
+ self.request_find_node_designated_peer(contact, vec![0], None);
}
Err(NonContactable { enr }) => {
debug_unreachable!("Stored ENR is not contactable. {}", enr);
@@ -578,20 +606,24 @@ impl Service {
// build the PONG response
let src = node_address.socket_addr;
- let response = Response {
- id,
- body: ResponseBody::Pong {
- enr_seq: self.local_enr.read().seq(),
- ip: src.ip(),
- port: src.port(),
- },
- };
- debug!("Sending PONG response to {}", node_address);
- if let Err(e) = self
- .handler_send
- .send(HandlerIn::Response(node_address, Box::new(response)))
- {
- warn!("Failed to send response {}", e)
+ if let Ok(port) = src.port().try_into() {
+ let response = Response {
+ id,
+ body: ResponseBody::Pong {
+ enr_seq: self.local_enr.read().seq(),
+ ip: src.ip(),
+ port,
+ },
+ };
+ debug!("Sending PONG response to {}", node_address);
+ if let Err(e) = self
+ .handler_send
+ .send(HandlerIn::Response(node_address, Box::new(response)))
+ {
+ warn!("Failed to send response {}", e);
+ }
+ } else {
+ warn!("The src port number should be non zero. {src}");
}
}
RequestBody::Talk { protocol, request } => {
@@ -603,13 +635,7 @@ impl Service {
sender: Some(self.handler_send.clone()),
};
- self.send_event(Discv5Event::TalkRequest(req));
- }
- RequestBody::RegisterTopic { .. } => {
- debug!("Received RegisterTopic request which is unimplemented");
- }
- RequestBody::TopicQuery { .. } => {
- debug!("Received TopicQuery request which is unimplemented");
+ self.send_event(Event::TalkRequest(req));
}
}
}
@@ -659,25 +685,9 @@ impl Service {
_ => unreachable!(),
};
- // This could be an ENR request from the outer service. If so respond to the
- // callback and End.
- if let Some(CallbackResponse::Enr(callback)) = active_request.callback.take() {
- // Currently only support requesting for ENR's. Verify this is the case.
- if !distances_requested.is_empty() && distances_requested[0] != 0 {
- error!("Retrieved a callback request that wasn't for a peer's ENR");
- return;
- }
- // This must be for asking for an ENR
- if nodes.len() > 1 {
- warn!(
- "Peer returned more than one ENR for itself. {}",
- active_request.contact
- );
- }
- let response = nodes
- .pop()
- .ok_or(RequestError::InvalidEnr("Peer did not return an ENR"));
- if let Err(e) = callback.send(response) {
+ if let Some(CallbackResponse::Nodes(callback)) = active_request.callback.take()
+ {
+ if let Err(e) = callback.send(Ok(nodes)) {
warn!("Failed to send response in callback {:?}", e)
}
return;
@@ -712,10 +722,9 @@ impl Service {
if nodes.len() < before_len {
// Peer sent invalid ENRs. Blacklist the Node
- warn!(
- "Peer sent invalid ENR. Blacklisting {}",
- active_request.contact
- );
+ let node_id = active_request.contact.node_id();
+ let addr = active_request.contact.socket_addr();
+ warn!(%node_id, %addr, "ENRs received of unsolicited distances. Blacklisting");
let ban_timeout = self.config.ban_duration.map(|v| Instant::now() + v);
PERMIT_BAN_LIST.write().ban(node_address, ban_timeout);
}
@@ -723,10 +732,8 @@ impl Service {
// handle the case that there is more than one response
if total > 1 {
- let mut current_response = self
- .active_nodes_responses
- .remove(&node_id)
- .unwrap_or_default();
+ let mut current_response =
+ self.active_nodes_responses.remove(&id).unwrap_or_default();
debug!(
"Nodes Response: {} of {} received",
@@ -744,7 +751,7 @@ impl Service {
current_response.received_nodes.append(&mut nodes);
self.active_nodes_responses
- .insert(node_id, current_response);
+ .insert(id.clone(), current_response);
self.active_requests.insert(id, active_request);
return;
}
@@ -766,107 +773,122 @@ impl Service {
// in a later response sends a response with a total of 1, all previous nodes
// will be ignored.
// ensure any mapping is removed in this rare case
- self.active_nodes_responses.remove(&node_id);
+ self.active_nodes_responses.remove(&id);
self.discovered(&node_id, nodes, active_request.query_id);
}
ResponseBody::Pong { enr_seq, ip, port } => {
- let socket = SocketAddr::new(ip, port);
- // perform ENR majority-based update if required.
+ // Send the response to the user, if they are who asked
+ if let Some(CallbackResponse::Pong(callback)) = active_request.callback {
+ let response = Pong {
+ enr_seq,
+ ip,
+ port: port.get(),
+ };
+ if let Err(e) = callback.send(Ok(response)) {
+ warn!("Failed to send callback response {:?}", e)
+ };
+ } else {
+ let socket = SocketAddr::new(ip, port.get());
+ // perform ENR majority-based update if required.
- // Only count votes that from peers we have contacted.
- let key: kbucket::Key = node_id.into();
- let should_count = matches!(
+ // Only count votes that from peers we have contacted.
+ let key: kbucket::Key = node_id.into();
+ let should_count = matches!(
self.kbuckets.write().entry(&key),
kbucket::Entry::Present(_, status)
if status.is_connected() && !status.is_incoming());
- if should_count {
- // get the advertised local addresses
- let (local_ip4_socket, local_ip6_socket) = {
- let local_enr = self.local_enr.read();
- (local_enr.udp4_socket(), local_enr.udp6_socket())
- };
-
- if let Some(ref mut ip_votes) = self.ip_votes {
- ip_votes.insert(node_id, socket);
- let (maybe_ip4_majority, maybe_ip6_majority) = ip_votes.majority();
+ if should_count {
+ // get the advertised local addresses
+ let (local_ip4_socket, local_ip6_socket) = {
+ let local_enr = self.local_enr.read();
+ (local_enr.udp4_socket(), local_enr.udp6_socket())
+ };
- let new_ip4 = maybe_ip4_majority.and_then(|majority| {
- if Some(majority) != local_ip4_socket {
- Some(majority)
- } else {
- None
- }
- });
- let new_ip6 = maybe_ip6_majority.and_then(|majority| {
- if Some(majority) != local_ip6_socket {
- Some(majority)
- } else {
- None
- }
- });
+ if let Some(ref mut ip_votes) = self.ip_votes {
+ ip_votes.insert(node_id, socket);
+ let (maybe_ip4_majority, maybe_ip6_majority) = ip_votes.majority();
- if new_ip4.is_some() || new_ip6.is_some() {
- let mut updated = false;
-
- // Check if our advertised IPV6 address needs to be updated.
- if let Some(new_ip6) = new_ip6 {
- let new_ip6: SocketAddr = new_ip6.into();
- let result = self
- .local_enr
- .write()
- .set_udp_socket(new_ip6, &self.enr_key.read());
- match result {
- Ok(_) => {
- updated = true;
- info!("Local UDP ip6 socket updated to: {}", new_ip6);
- self.send_event(Discv5Event::SocketUpdated(new_ip6));
- }
- Err(e) => {
- warn!("Failed to update local UDP ip6 socket. ip6: {}, error: {:?}", new_ip6, e);
- }
+ let new_ip4 = maybe_ip4_majority.and_then(|majority| {
+ if Some(majority) != local_ip4_socket {
+ Some(majority)
+ } else {
+ None
}
- }
- if let Some(new_ip4) = new_ip4 {
- let new_ip4: SocketAddr = new_ip4.into();
- let result = self
- .local_enr
- .write()
- .set_udp_socket(new_ip4, &self.enr_key.read());
- match result {
- Ok(_) => {
- updated = true;
- info!("Local UDP socket updated to: {}", new_ip4);
- self.send_event(Discv5Event::SocketUpdated(new_ip4));
+ });
+ let new_ip6 = maybe_ip6_majority.and_then(|majority| {
+ if Some(majority) != local_ip6_socket {
+ Some(majority)
+ } else {
+ None
+ }
+ });
+
+ if new_ip4.is_some() || new_ip6.is_some() {
+ let mut updated = false;
+
+ // Check if our advertised IPV6 address needs to be updated.
+ if let Some(new_ip6) = new_ip6 {
+ let new_ip6: SocketAddr = new_ip6.into();
+ let result = self
+ .local_enr
+ .write()
+ .set_udp_socket(new_ip6, &self.enr_key.read());
+ match result {
+ Ok(_) => {
+ updated = true;
+ info!(
+ "Local UDP ip6 socket updated to: {}",
+ new_ip6,
+ );
+ self.send_event(Event::SocketUpdated(new_ip6));
+ }
+ Err(e) => {
+ warn!("Failed to update local UDP ip6 socket. ip6: {}, error: {:?}", new_ip6, e);
+ }
}
- Err(e) => {
- warn!("Failed to update local UDP socket. ip: {}, error: {:?}", new_ip4, e);
+ }
+ if let Some(new_ip4) = new_ip4 {
+ let new_ip4: SocketAddr = new_ip4.into();
+ let result = self
+ .local_enr
+ .write()
+ .set_udp_socket(new_ip4, &self.enr_key.read());
+ match result {
+ Ok(_) => {
+ updated = true;
+ info!("Local UDP socket updated to: {}", new_ip4);
+ self.send_event(Event::SocketUpdated(new_ip4));
+ }
+ Err(e) => {
+ warn!("Failed to update local UDP socket. ip: {}, error: {:?}", new_ip4, e);
+ }
}
}
- }
- if updated {
- self.ping_connected_peers();
+ if updated {
+ self.ping_connected_peers();
+ }
}
}
}
- }
- // check if we need to request a new ENR
- if let Some(enr) = self.find_enr(&node_id) {
- if enr.seq() < enr_seq {
- // request an ENR update
- debug!("Requesting an ENR update from: {}", active_request.contact);
- let request_body = RequestBody::FindNode { distances: vec![0] };
- let active_request = ActiveRequest {
- contact: active_request.contact,
- request_body,
- query_id: None,
- callback: None,
- };
- self.send_rpc_request(active_request);
+ // check if we need to request a new ENR
+ if let Some(enr) = self.find_enr(&node_id) {
+ if enr.seq() < enr_seq {
+ // request an ENR update
+ debug!("Requesting an ENR update from: {}", active_request.contact);
+ let request_body = RequestBody::FindNode { distances: vec![0] };
+ let active_request = ActiveRequest {
+ contact: active_request.contact,
+ request_body,
+ query_id: None,
+ callback: None,
+ };
+ self.send_rpc_request(active_request);
+ }
+ self.connection_updated(node_id, ConnectionStatus::PongReceived(enr));
}
- self.connection_updated(node_id, ConnectionStatus::PongReceived(enr));
}
}
ResponseBody::Talk { response } => {
@@ -880,12 +902,6 @@ impl Service {
_ => error!("Invalid callback for response"),
}
}
- ResponseBody::Ticket { .. } => {
- error!("Received a TICKET response. This is unimplemented and should be unreachable.");
- }
- ResponseBody::RegisterConfirmation { .. } => {
- error!("Received a RegisterConfirmation response. This is unimplemented and should be unreachable.");
- }
}
} else {
warn!(
@@ -898,8 +914,12 @@ impl Service {
// Send RPC Requests //
/// Sends a PING request to a node.
- fn send_ping(&mut self, enr: Enr) {
- match NodeContact::try_from_enr(enr, self.config.ip_mode) {
+ fn send_ping(
+ &mut self,
+ enr: Enr,
+ callback: Option>>,
+ ) {
+ match NodeContact::try_from_enr(enr, self.ip_mode) {
Ok(contact) => {
let request_body = RequestBody::Ping {
enr_seq: self.local_enr.read().seq(),
@@ -908,7 +928,7 @@ impl Service {
contact,
request_body,
query_id: None,
- callback: None,
+ callback: callback.map(CallbackResponse::Pong),
};
self.send_rpc_request(active_request);
}
@@ -934,22 +954,23 @@ impl Service {
};
for enr in connected_peers {
- self.send_ping(enr.clone());
+ self.send_ping(enr.clone(), None);
}
}
/// Request an external node's ENR.
- fn request_enr(
+ fn request_find_node_designated_peer(
&mut self,
contact: NodeContact,
- callback: Option>>,
+ distances: Vec,
+ callback: Option, RequestError>>>,
) {
- let request_body = RequestBody::FindNode { distances: vec![0] };
+ let request_body = RequestBody::FindNode { distances };
let active_request = ActiveRequest {
contact,
request_body,
query_id: None,
- callback: callback.map(CallbackResponse::Enr),
+ callback: callback.map(CallbackResponse::Nodes),
};
self.send_rpc_request(active_request);
}
@@ -1101,7 +1122,7 @@ impl Service {
) {
// find the ENR associated with the query
if let Some(enr) = self.find_enr(&return_peer) {
- match NodeContact::try_from_enr(enr, self.config.ip_mode) {
+ match NodeContact::try_from_enr(enr, self.ip_mode) {
Ok(contact) => {
let active_request = ActiveRequest {
contact,
@@ -1114,7 +1135,8 @@ impl Service {
return;
}
Err(NonContactable { enr }) => {
- error!("Query {} has a non contactable enr: {}", *query_id, enr);
+ // This can happen quite often in ipv6 only nodes
+ debug!("Query {} has a non contactable enr: {}", *query_id, enr);
}
}
} else {
@@ -1150,7 +1172,7 @@ impl Service {
}
}
- fn send_event(&mut self, event: Discv5Event) {
+ fn send_event(&mut self, event: Event) {
if let Some(stream) = self.event_stream.as_mut() {
if let Err(mpsc::error::TrySendError::Closed(_)) = stream.try_send(event) {
// If the stream has been dropped prevent future attempts to send events
@@ -1170,7 +1192,7 @@ impl Service {
// If any of the discovered nodes are in the routing table, and there contains an older ENR, update it.
// If there is an event stream send the Discovered event
if self.config.report_discovered_peers {
- self.send_event(Discv5Event::Discovered(enr.clone()));
+ self.send_event(Event::Discovered(enr.clone()));
}
// ignore peers that don't pass the table filter
@@ -1249,12 +1271,24 @@ impl Service {
state: ConnectionState::Connected,
direction,
};
- match self.kbuckets.write().insert_or_update(&key, enr, status) {
+
+ let insert_result =
+ self.kbuckets
+ .write()
+ .insert_or_update(&key, enr.clone(), status);
+ match insert_result {
InsertResult::Inserted => {
// We added this peer to the table
debug!("New connected node added to routing table: {}", node_id);
self.peers_to_ping.insert(node_id);
- let event = Discv5Event::NodeInserted {
+
+ // PING immediately if the direction is outgoing. This allows us to receive
+ // a PONG without waiting for the ping_interval, making ENR updates faster.
+ if direction == ConnectionDirection::Outgoing {
+ self.send_ping(enr, None);
+ }
+
+ let event = Event::NodeInserted {
node_id,
replaced: None,
};
@@ -1342,20 +1376,34 @@ impl Service {
}
};
if let Some(enr) = optional_enr {
- self.send_ping(enr)
+ self.send_ping(enr, None)
}
}
}
/// The equivalent of libp2p `inject_connected()` for a udp session. We have no stream, but a
/// session key-pair has been negotiated.
- fn inject_session_established(&mut self, enr: Enr, direction: ConnectionDirection) {
+ fn inject_session_established(&mut self, enr: Enr, connection_direction: ConnectionDirection) {
// Ignore sessions with non-contactable ENRs
- if self.config.ip_mode.get_contactable_addr(&enr).is_none() {
+ if self.ip_mode.get_contactable_addr(&enr).is_none() {
return;
}
let node_id = enr.node_id();
+
+ // We never update connection direction if a node already exists in the routing table as we
+ // don't want to promote the direction from incoming to outgoing.
+ let key = kbucket::Key::from(node_id);
+ let direction = match self
+ .kbuckets
+ .read()
+ .get_bucket(&key)
+ .map(|bucket| bucket.get(&key))
+ {
+ Some(Some(node)) => node.status.direction,
+ _ => connection_direction,
+ };
+
debug!(
"Session established with Node: {}, direction: {}",
node_id, direction
@@ -1371,10 +1419,10 @@ impl Service {
// If this is initiated by the user, return an error on the callback. All callbacks
// support a request error.
match active_request.callback {
- Some(CallbackResponse::Enr(callback)) => {
+ Some(CallbackResponse::Nodes(callback)) => {
callback
.send(Err(error))
- .unwrap_or_else(|_| debug!("Couldn't send TALK error response to user"));
+ .unwrap_or_else(|_| debug!("Couldn't send Nodes error response to user"));
return;
}
Some(CallbackResponse::Talk(callback)) => {
@@ -1384,6 +1432,13 @@ impl Service {
.unwrap_or_else(|_| debug!("Couldn't send TALK error response to user"));
return;
}
+ Some(CallbackResponse::Pong(callback)) => {
+ // return the error
+ callback
+ .send(Err(error))
+ .unwrap_or_else(|_| debug!("Couldn't send Pong error response to user"));
+ return;
+ }
None => {
// no callback to send too
}
@@ -1393,13 +1448,14 @@ impl Service {
match active_request.request_body {
// if a failed FindNodes request, ensure we haven't partially received packets. If
// so, process the partially found nodes
- RequestBody::FindNode { .. } => {
- if let Some(nodes_response) = self.active_nodes_responses.remove(&node_id) {
+ RequestBody::FindNode { ref distances } => {
+ if let Some(nodes_response) = self.active_nodes_responses.remove(&id) {
if !nodes_response.received_nodes.is_empty() {
- warn!(
- "NODES Response failed, but was partially processed from: {}",
- active_request.contact
- );
+ let node_id = active_request.contact.node_id();
+ let addr = active_request.contact.socket_addr();
+ let received = nodes_response.received_nodes.len();
+ let expected = distances.len();
+ warn!(%node_id, %addr, %error, %received, %expected, "FINDNODE request failed with partial results");
// if it's a query mark it as success, to process the partial
// collection of peers
self.discovered(
@@ -1447,14 +1503,12 @@ impl Service {
}
/// A future that maintains the routing table and inserts nodes when required. This returns the
- /// `Discv5Event::NodeInserted` variant if a new node has been inserted into the routing table.
- async fn bucket_maintenance_poll(
- kbuckets: &Arc>>,
- ) -> Discv5Event {
+ /// [`Event::NodeInserted`] variant if a new node has been inserted into the routing table.
+ async fn bucket_maintenance_poll(kbuckets: &Arc>>) -> Event {
future::poll_fn(move |_cx| {
// Drain applied pending entries from the routing table.
if let Some(entry) = kbuckets.write().take_applied_pending() {
- let event = Discv5Event::NodeInserted {
+ let event = Event::NodeInserted {
node_id: entry.inserted.into_preimage(),
replaced: entry.evicted.map(|n| n.key.into_preimage()),
};
@@ -1475,10 +1529,8 @@ impl Service {
let request_body = query.target().rpc_request(return_peer);
Poll::Ready(QueryEvent::Waiting(query.id(), node_id, request_body))
}
- QueryPoolState::Timeout(query) => {
- warn!("Query id: {:?} timed out", query.id());
- Poll::Ready(QueryEvent::TimedOut(Box::new(query)))
- }
+
+ QueryPoolState::Timeout(query) => Poll::Ready(QueryEvent::TimedOut(Box::new(query))),
QueryPoolState::Waiting(None) | QueryPoolState::Idle => Poll::Pending,
})
.await
diff --git a/src/service/test.rs b/src/service/test.rs
index 160b483d8..c85f6a97a 100644
--- a/src/service/test.rs
+++ b/src/service/test.rs
@@ -3,20 +3,26 @@
use super::*;
use crate::{
+ discv5::test::generate_deterministic_keypair,
handler::Handler,
kbucket,
kbucket::{BucketInsertResult, KBucketsTable, NodeStatus},
node_info::NodeContact,
+ packet::{DefaultProtocolId, ProtocolIdentity},
query_pool::{QueryId, QueryPool},
rpc::RequestId,
service::{ActiveRequest, Service},
- Discv5ConfigBuilder, Enr,
+ socket::ListenConfig,
+ ConfigBuilder, Enr,
};
-use enr::{CombinedKey, EnrBuilder};
+use enr::CombinedKey;
use parking_lot::RwLock;
-use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
+use std::{collections::HashMap, net::Ipv4Addr, sync::Arc, time::Duration};
use tokio::sync::{mpsc, oneshot};
+/// Default UDP port number to use for tests requiring UDP exposure
+pub const DEFAULT_UDP_PORT: u16 = 0;
+
fn _connected_state() -> NodeStatus {
NodeStatus {
state: ConnectionState::Connected,
@@ -37,24 +43,23 @@ fn init() {
.try_init();
}
-async fn build_service(
+async fn build_service(
local_enr: Arc>,
enr_key: Arc>,
- listen_socket: SocketAddr,
filters: bool,
) -> Service {
- let config = Discv5ConfigBuilder::new()
+ let listen_config = ListenConfig::Ipv4 {
+ ip: local_enr.read().ip4().unwrap(),
+ port: local_enr.read().udp4().unwrap(),
+ };
+ let config = ConfigBuilder::new(listen_config)
.executor(Box::::default())
.build();
// build the session service
- let (_handler_exit, handler_send, handler_recv) = Handler::spawn(
- local_enr.clone(),
- enr_key.clone(),
- listen_socket,
- config.clone(),
- )
- .await
- .unwrap();
+ let (_handler_exit, handler_send, handler_recv) =
+ Handler::spawn::(local_enr.clone(), enr_key.clone(), config.clone())
+ .await
+ .unwrap();
let (table_filter, bucket_filter) = if filters {
(
@@ -93,6 +98,7 @@ async fn build_service(
event_stream: None,
exit,
config,
+ ip_mode: Default::default(),
}
}
@@ -101,25 +107,22 @@ async fn test_updating_connection_on_ping() {
init();
let enr_key1 = CombinedKey::generate_secp256k1();
let ip = "127.0.0.1".parse().unwrap();
- let enr = EnrBuilder::new("v4")
+ let enr = Enr::builder()
.ip4(ip)
- .udp4(10001)
+ .udp4(DEFAULT_UDP_PORT)
.build(&enr_key1)
.unwrap();
let ip2 = "127.0.0.1".parse().unwrap();
let enr_key2 = CombinedKey::generate_secp256k1();
- let enr2 = EnrBuilder::new("v4")
+ let enr2 = Enr::builder()
.ip4(ip2)
- .udp4(10002)
+ .udp4(DEFAULT_UDP_PORT)
.build(&enr_key2)
.unwrap();
- let socket_addr = enr.udp4_socket().unwrap();
-
- let mut service = build_service(
+ let mut service = build_service::(
Arc::new(RwLock::new(enr)),
Arc::new(RwLock::new(enr_key1)),
- socket_addr.into(),
false,
)
.await;
@@ -142,7 +145,7 @@ async fn test_updating_connection_on_ping() {
body: ResponseBody::Pong {
enr_seq: 2,
ip: ip2.into(),
- port: 10002,
+ port: 9000.try_into().unwrap(),
},
};
@@ -165,3 +168,176 @@ async fn test_updating_connection_on_ping() {
let node = buckets.iter_ref().next().unwrap();
assert!(node.status.is_connected())
}
+
+#[tokio::test]
+async fn test_connection_direction_on_inject_session_established() {
+ init();
+
+ let enr_key1 = CombinedKey::generate_secp256k1();
+ let ip = std::net::Ipv4Addr::LOCALHOST;
+ let enr = Enr::builder()
+ .ip4(ip)
+ .udp4(DEFAULT_UDP_PORT)
+ .build(&enr_key1)
+ .unwrap();
+
+ let enr_key2 = CombinedKey::generate_secp256k1();
+ let ip2 = std::net::Ipv4Addr::LOCALHOST;
+ let enr2 = Enr::builder()
+ .ip4(ip2)
+ .udp4(DEFAULT_UDP_PORT)
+ .build(&enr_key2)
+ .unwrap();
+
+ let mut service = build_service::(
+ Arc::new(RwLock::new(enr)),
+ Arc::new(RwLock::new(enr_key1)),
+ false,
+ )
+ .await;
+
+ let key = &kbucket::Key::from(enr2.node_id());
+
+ // Test that the existing connection direction is not updated.
+ // Incoming
+ service.inject_session_established(enr2.clone(), ConnectionDirection::Incoming);
+ let status = service.kbuckets.read().iter_ref().next().unwrap().status;
+ assert!(status.is_connected());
+ assert_eq!(ConnectionDirection::Incoming, status.direction);
+
+ service.inject_session_established(enr2.clone(), ConnectionDirection::Outgoing);
+ let status = service.kbuckets.read().iter_ref().next().unwrap().status;
+ assert!(status.is_connected());
+ assert_eq!(ConnectionDirection::Incoming, status.direction);
+
+ // (disconnected) Outgoing
+ let result = service.kbuckets.write().update_node_status(
+ key,
+ ConnectionState::Disconnected,
+ Some(ConnectionDirection::Outgoing),
+ );
+ assert!(matches!(result, UpdateResult::Updated));
+ service.inject_session_established(enr2.clone(), ConnectionDirection::Incoming);
+ let status = service.kbuckets.read().iter_ref().next().unwrap().status;
+ assert!(status.is_connected());
+ assert_eq!(ConnectionDirection::Outgoing, status.direction);
+}
+
+#[tokio::test]
+async fn test_handling_concurrent_responses() {
+ init();
+
+ // Seed is chosen such that all nodes are in the 256th distance of the first node.
+ let seed = 1652;
+ let mut keypairs = generate_deterministic_keypair(5, seed);
+
+ let mut service = {
+ let enr_key = keypairs.pop().unwrap();
+ let enr = Enr::builder()
+ .ip4(Ipv4Addr::LOCALHOST)
+ .udp4(10005)
+ .build(&enr_key)
+ .unwrap();
+ build_service::(
+ Arc::new(RwLock::new(enr)),
+ Arc::new(RwLock::new(enr_key)),
+ false,
+ )
+ .await
+ };
+
+ let node_contact: NodeContact = Enr::builder()
+ .ip4(Ipv4Addr::LOCALHOST)
+ .udp4(10006)
+ .build(&keypairs.remove(0))
+ .unwrap()
+ .into();
+ let node_address = node_contact.node_address();
+
+ // Add fake requests
+ // Request1
+ service.active_requests.insert(
+ RequestId(vec![1]),
+ ActiveRequest {
+ contact: node_contact.clone(),
+ request_body: RequestBody::FindNode {
+ distances: vec![254, 255, 256],
+ },
+ query_id: Some(QueryId(1)),
+ callback: None,
+ },
+ );
+ // Request2
+ service.active_requests.insert(
+ RequestId(vec![2]),
+ ActiveRequest {
+ contact: node_contact,
+ request_body: RequestBody::FindNode {
+ distances: vec![254, 255, 256],
+ },
+ query_id: Some(QueryId(2)),
+ callback: None,
+ },
+ );
+
+ assert_eq!(3, keypairs.len());
+ let mut enrs_for_response = keypairs
+ .iter()
+ .enumerate()
+ .map(|(i, key)| {
+ Enr::builder()
+ .ip4(Ipv4Addr::LOCALHOST)
+ .udp4(10007 + i as u16)
+ .build(key)
+ .unwrap()
+ })
+ .collect::>();
+
+ // Response to `Request1` is sent as two separate messages in total. Handle the first one of the
+ // messages here.
+ service.handle_rpc_response(
+ node_address.clone(),
+ Response {
+ id: RequestId(vec![1]),
+ body: ResponseBody::Nodes {
+ total: 2,
+ nodes: vec![enrs_for_response.pop().unwrap()],
+ },
+ },
+ );
+ // Service has still two active requests since we are waiting for the second NODE response to
+ // `Request1`.
+ assert_eq!(2, service.active_requests.len());
+ // Service stores the first response to `Request1` into `active_nodes_responses`.
+ assert!(!service.active_nodes_responses.is_empty());
+
+ // Second, handle a response to *`Request2`* before the second response to `Request1`.
+ service.handle_rpc_response(
+ node_address.clone(),
+ Response {
+ id: RequestId(vec![2]),
+ body: ResponseBody::Nodes {
+ total: 1,
+ nodes: vec![enrs_for_response.pop().unwrap()],
+ },
+ },
+ );
+ // `Request2` is completed so now the number of active requests should be one.
+ assert_eq!(1, service.active_requests.len());
+ // Service still keeps the first response in `active_nodes_responses`.
+ assert!(!service.active_nodes_responses.is_empty());
+
+ // Finally, handle the second response to `Request1`.
+ service.handle_rpc_response(
+ node_address,
+ Response {
+ id: RequestId(vec![1]),
+ body: ResponseBody::Nodes {
+ total: 2,
+ nodes: vec![enrs_for_response.pop().unwrap()],
+ },
+ },
+ );
+ assert!(service.active_requests.is_empty());
+ assert!(service.active_nodes_responses.is_empty());
+}
diff --git a/src/socket/filter/mod.rs b/src/socket/filter/mod.rs
index d7a7da2e4..7bb54ad77 100644
--- a/src/socket/filter/mod.rs
+++ b/src/socket/filter/mod.rs
@@ -7,6 +7,7 @@ use lru::LruCache;
use std::{
collections::HashSet,
net::{IpAddr, SocketAddr},
+ num::NonZeroUsize,
sync::atomic::Ordering,
time::{Duration, Instant},
};
@@ -19,9 +20,16 @@ pub use config::FilterConfig;
use rate_limiter::{LimitKind, RateLimiter};
/// The maximum number of IPs to retain when calculating the number of nodes per IP.
-const KNOWN_ADDRS_SIZE: usize = 500;
+const KNOWN_ADDRS_SIZE: NonZeroUsize = match NonZeroUsize::new(500) {
+ Some(non_zero) => non_zero,
+ None => unreachable!(),
+};
/// The number of IPs to retain at any given time that have banned nodes.
-const BANNED_NODES_SIZE: usize = 50;
+const BANNED_NODES_SIZE: NonZeroUsize = match NonZeroUsize::new(50) {
+ Some(non_zero) => non_zero,
+ None => unreachable!(),
+};
+
/// The maximum number of packets to keep record of for metrics if the rate limiter is not
/// specified.
const DEFAULT_PACKETS_PER_SECOND: usize = 20;
diff --git a/src/socket/mod.rs b/src/socket/mod.rs
index e2fec681d..209348efc 100644
--- a/src/socket/mod.rs
+++ b/src/socket/mod.rs
@@ -1,16 +1,19 @@
-use crate::{Executor, IpMode};
+use crate::{packet::ProtocolIdentity, Executor};
use parking_lot::RwLock;
use recv::*;
use send::*;
use socket2::{Domain, Protocol, Socket as Socket2, Type};
use std::{
collections::HashMap,
- io::{Error, ErrorKind},
- net::SocketAddr,
+ io::Error,
+ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
sync::Arc,
time::Duration,
};
-use tokio::sync::{mpsc, oneshot};
+use tokio::{
+ net::UdpSocket,
+ sync::{mpsc, oneshot},
+};
mod filter;
mod recv;
@@ -23,16 +26,35 @@ pub use filter::{
pub use recv::InboundPacket;
pub use send::OutboundPacket;
+/// Configuration for the sockets to listen on.
+///
+/// Default implementation is the UNSPECIFIED ipv4 address with port 9000.
+#[derive(Clone, Debug)]
+pub enum ListenConfig {
+ Ipv4 {
+ ip: Ipv4Addr,
+ port: u16,
+ },
+ Ipv6 {
+ ip: Ipv6Addr,
+ port: u16,
+ },
+ DualStack {
+ ipv4: Ipv4Addr,
+ ipv4_port: u16,
+ ipv6: Ipv6Addr,
+ ipv6_port: u16,
+ },
+}
+
/// Convenience objects for setting up the recv handler.
pub struct SocketConfig {
/// The executor to spawn the tasks.
pub executor: Box,
- /// The listening socket.
- pub socket_addr: SocketAddr,
/// Configuration details for the packet filter.
pub filter_config: FilterConfig,
/// Type of socket to create.
- pub ip_mode: IpMode,
+ pub listen_config: ListenConfig,
/// If the filter is enabled this sets the default timeout for bans enacted by the filter.
pub ban_duration: Option,
/// The expected responses reference.
@@ -52,34 +74,15 @@ pub struct Socket {
impl Socket {
/// This creates and binds a new UDP socket.
// In general this function can be expanded to handle more advanced socket creation.
- async fn new_socket(
- socket_addr: &SocketAddr,
- ip_mode: IpMode,
- ) -> Result {
- match ip_mode {
- IpMode::Ip4 => match socket_addr {
- SocketAddr::V6(_) => Err(Error::new(
- ErrorKind::InvalidInput,
- "Cannot create an ipv4 socket from an ipv6 address",
- )),
- ip4 => tokio::net::UdpSocket::bind(ip4).await,
- },
- IpMode::Ip6 {
- enable_mapped_addresses,
- } => {
- let addr = match socket_addr {
- SocketAddr::V4(_) => Err(Error::new(
- ErrorKind::InvalidInput,
- "Cannot create an ipv6 socket from an ipv4 address",
- )),
- SocketAddr::V6(ip6) => Ok((*ip6).into()),
- }?;
+ async fn new_socket(socket_addr: &SocketAddr) -> Result