diff --git a/neptun/src/device/mod.rs b/neptun/src/device/mod.rs index bc2bf4a..180198a 100644 --- a/neptun/src/device/mod.rs +++ b/neptun/src/device/mod.rs @@ -27,7 +27,7 @@ pub mod tun; pub mod tun; use std::collections::HashMap; -use std::io::{self, BufReader, BufWriter, Write}; +use std::io::{self, BufReader, BufWriter}; use std::mem::MaybeUninit; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::os::fd::RawFd; @@ -1051,12 +1051,15 @@ impl Device { // * Send encapsulated packet to the peer's endpoint let mtu = d.mtu.load(Ordering::Relaxed); + let udp4 = d.udp4.as_ref().expect("Not connected"); + let udp6 = d.udp6.as_ref().expect("Not connected"); + let peers = &d.peers_by_ip; for _ in 0..MAX_ITR { let (element, iter) = unsafe { TX_RING_BUFFER.get_next() }; if element.is_element_free.load(Ordering::Relaxed) { const DATA_OFFSET: usize = 16; - let len = match iface.read(&mut element.data[DATA_OFFSET..mtu]) { + let len = match iface.read(&mut element.data[DATA_OFFSET..mtu + DATA_OFFSET]) { Ok(src) => src.len(), Err(Error::IfaceRead(e)) => { let ek = e.kind(); @@ -1101,7 +1104,61 @@ impl Device { let res = { let mut tun = peer.tunnel.lock(); - tun.queue_encapsulate(len, element, iter, peer.endpoint_ref()) + tun.queue_encapsulate(len, element, iter, peer.endpoint_ref(), &mut t.dst_buf[..]) + }; + + match res { + TunnResult::Done => {} + TunnResult::Err(e) => { + tracing::error!(message = "Encapsulate error", + error = ?e, + public_key = peer.public_key.1) + } + TunnResult::WriteToNetwork(packet) => { + let endpoint = peer.endpoint(); + if let Some(conn) = endpoint.conn.as_ref() { + // Prefer to send using the connected socket + if let Err(err) = conn.send(packet) { + tracing::debug!(message = "Failed to send packet with the connected socket", error = ?err); + drop(endpoint); + peer.shutdown_endpoint(); + } else { + tracing::trace!( + "Pkt -> ConnSock ({:?}), len: {}, dst_addr: {}", + endpoint.addr, + packet.len(), + dst_addr + ); + } + } else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr { + if let Err(err) = udp4.send_to(packet, &addr.into()) { + tracing::warn!(message = "Failed to write packet to network v4", error = ?err, dst = ?addr); + } else { + tracing::trace!( + message = "Writing packet to network v4", + interface = ?t.iface.name(), + packet_length = packet.len(), + src_addr = ?addr, + public_key = peer.public_key.1 + ); + } + } else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr { + if let Err(err) = udp6.send_to(packet, &addr.into()) { + tracing::warn!(message = "Failed to write packet to network v6", error = ?err, dst = ?addr); + } else { + tracing::trace!( + message = "Writing packet to network v6", + interface = ?t.iface.name(), + packet_length = packet.len(), + src_addr = ?addr, + public_key = peer.public_key.1 + ); + } + } else { + tracing::error!("No endpoint"); + } + } + _ => panic!("Unexpected result from encapsulate"), }; } } @@ -1121,10 +1178,6 @@ fn send_to_network( udp4: Arc, udp6: Arc, ) { - // Check whether udp4/6 are there - // let udp4 = udp4.as_ref().expect("Not connected"); - // let udp6 = udp6.as_ref().expect("Not connected"); - while let Ok(msg) = network_rx.recv() { match &msg.res { NeptunResult::Done => {} @@ -1136,11 +1189,38 @@ fn send_to_network( let packet = &msg.data.as_slice()[..*len]; if let Some(conn) = endpoint.conn.as_mut() { // Prefer to send using the connected socket - let _: Result<_, _> = conn.write(packet); + if let Err(err) = conn.send(packet) { + tracing::debug!(message = "Failed to send packet with the connected socket", error = ?err); + drop(endpoint); + // TODO: shutting endpoint here + // peer.shutdown_endpoint(); + } else { + tracing::trace!( + "Pkt -> ConnSock ({:?}), len: {}", + endpoint.addr, + packet.len(), + ); + } } else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr { - let _: Result<_, _> = udp4.send_to(packet, &addr.into()); + if let Err(err) = udp4.send_to(packet, &addr.into()) { + tracing::warn!(message = "Failed to write packet to network v4", error = ?err, dst = ?addr); + } else { + tracing::trace!( + message = "Writing packet to network v4", + packet_length = packet.len(), + src_addr = ?addr, + ); + } } else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr { - let _: Result<_, _> = udp6.send_to(packet, &addr.into()); + if let Err(err) = udp6.send_to(packet, &addr.into()) { + tracing::warn!(message = "Failed to write packet to network v6", error = ?err, dst = ?addr); + } else { + tracing::trace!( + message = "Writing packet to network v6", + packet_length = packet.len(), + src_addr = ?addr, + ); + } } else { tracing::error!("No endpoint"); } diff --git a/neptun/src/noise/mod.rs b/neptun/src/noise/mod.rs index 42a9098..83e1ec7 100644 --- a/neptun/src/noise/mod.rs +++ b/neptun/src/noise/mod.rs @@ -14,7 +14,7 @@ mod session; mod timers; use crossbeam::channel::{Receiver, Sender}; -use ring_buffers::{EncryptionTaskData, TX_RING_BUFFER}; +use ring_buffers::EncryptionTaskData; use session::{Session, AEAD_SIZE, DATA_OFFSET}; use crate::noise::errors::WireGuardError; @@ -107,7 +107,6 @@ pub struct Tunn { pub peer_static_public: x25519_dalek::PublicKey, encrypt_tx_chan: Option>, - network_tx_chan: Option>, } type MessageType = u32; @@ -274,7 +273,6 @@ impl Tunn { Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT)) }), encrypt_tx_chan: encrypt_tx_chan, - network_tx_chan: network_tx_chan, }; Ok(tunn) @@ -353,6 +351,7 @@ impl Tunn { element: &'static mut EncryptionTaskData, idx: usize, endpoint: Arc>, + dst: &'a mut [u8], ) -> TunnResult<'a> { let current = self.current; if let Some(ref session) = self.sessions[current % N_SESSIONS] { @@ -384,34 +383,7 @@ impl Tunn { } // Initiate a new handshake if none is in progress - self.initiate_handshake(endpoint, false); - TunnResult::Done - } - - pub fn initiate_handshake<'a>( - &mut self, - endpoint: Arc>, - force_resend: bool, - ) { - // TODO: Have to fix this. This can't be a hardcoded 0th iter - let (dst, _) = unsafe { TX_RING_BUFFER.get_next() }; - { - dst.endpoint = endpoint; - let res = self.format_handshake_initiation(dst.data.as_mut_slice(), force_resend); - match res { - TunnResult::Done => return, - TunnResult::Err(e) => { - tracing::error!(message = "Handshake initiation error", error = ?e); - return; - } - TunnResult::WriteToNetwork(buf) => { - dst.res = NeptunResult::WriteToNetwork(buf.len()) - } - _ => panic!("Unexpected result from handshake initiation"), - } - }; - dst.is_element_free.store(false, Ordering::Relaxed); - let _ = self.network_tx_chan.as_ref().unwrap().send(dst); + self.format_handshake_initiation(dst, false) } /// Receives a UDP datagram from the network and parses it.