Skip to content

Commit

Permalink
Handle handshake not through separate thread
Browse files Browse the repository at this point in the history
  • Loading branch information
Hasan6979 committed Jan 17, 2025
1 parent edb6d3c commit 139cc26
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 41 deletions.
100 changes: 90 additions & 10 deletions neptun/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub mod tun;
pub mod tun;

use std::collections::HashMap;
use std::io::{self, BufReader, BufWriter, Write};
use std::io::{self, BufReader, BufWriter};
use std::mem::MaybeUninit;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::fd::RawFd;
Expand Down Expand Up @@ -1051,12 +1051,15 @@ impl Device {
// * Send encapsulated packet to the peer's endpoint
let mtu = d.mtu.load(Ordering::Relaxed);

let udp4 = d.udp4.as_ref().expect("Not connected");
let udp6 = d.udp6.as_ref().expect("Not connected");

let peers = &d.peers_by_ip;
for _ in 0..MAX_ITR {
let (element, iter) = unsafe { TX_RING_BUFFER.get_next() };
if element.is_element_free.load(Ordering::Relaxed) {
const DATA_OFFSET: usize = 16;
let len = match iface.read(&mut element.data[DATA_OFFSET..mtu]) {
let len = match iface.read(&mut element.data[DATA_OFFSET..mtu + DATA_OFFSET]) {
Ok(src) => src.len(),
Err(Error::IfaceRead(e)) => {
let ek = e.kind();
Expand Down Expand Up @@ -1101,7 +1104,61 @@ impl Device {

let res = {
let mut tun = peer.tunnel.lock();
tun.queue_encapsulate(len, element, iter, peer.endpoint_ref())
tun.queue_encapsulate(len, element, iter, peer.endpoint_ref(), &mut t.dst_buf[..])
};

match res {
TunnResult::Done => {}
TunnResult::Err(e) => {
tracing::error!(message = "Encapsulate error",
error = ?e,
public_key = peer.public_key.1)
}
TunnResult::WriteToNetwork(packet) => {
let endpoint = peer.endpoint();
if let Some(conn) = endpoint.conn.as_ref() {
// Prefer to send using the connected socket
if let Err(err) = conn.send(packet) {
tracing::debug!(message = "Failed to send packet with the connected socket", error = ?err);
drop(endpoint);
peer.shutdown_endpoint();
} else {
tracing::trace!(
"Pkt -> ConnSock ({:?}), len: {}, dst_addr: {}",
endpoint.addr,
packet.len(),
dst_addr
);
}
} else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr {
if let Err(err) = udp4.send_to(packet, &addr.into()) {
tracing::warn!(message = "Failed to write packet to network v4", error = ?err, dst = ?addr);
} else {
tracing::trace!(
message = "Writing packet to network v4",
interface = ?t.iface.name(),
packet_length = packet.len(),
src_addr = ?addr,
public_key = peer.public_key.1
);
}
} else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr {
if let Err(err) = udp6.send_to(packet, &addr.into()) {
tracing::warn!(message = "Failed to write packet to network v6", error = ?err, dst = ?addr);
} else {
tracing::trace!(
message = "Writing packet to network v6",
interface = ?t.iface.name(),
packet_length = packet.len(),
src_addr = ?addr,
public_key = peer.public_key.1
);
}
} else {
tracing::error!("No endpoint");
}
}
_ => panic!("Unexpected result from encapsulate"),
};
}
}
Expand All @@ -1121,10 +1178,6 @@ fn send_to_network(
udp4: Arc<Socket>,
udp6: Arc<Socket>,
) {
// Check whether udp4/6 are there
// let udp4 = udp4.as_ref().expect("Not connected");
// let udp6 = udp6.as_ref().expect("Not connected");

while let Ok(msg) = network_rx.recv() {
match &msg.res {
NeptunResult::Done => {}
Expand All @@ -1136,11 +1189,38 @@ fn send_to_network(
let packet = &msg.data.as_slice()[..*len];
if let Some(conn) = endpoint.conn.as_mut() {
// Prefer to send using the connected socket
let _: Result<_, _> = conn.write(packet);
if let Err(err) = conn.send(packet) {
tracing::debug!(message = "Failed to send packet with the connected socket", error = ?err);
drop(endpoint);
// TODO: shutting endpoint here
// peer.shutdown_endpoint();
} else {
tracing::trace!(
"Pkt -> ConnSock ({:?}), len: {}",
endpoint.addr,
packet.len(),
);
}
} else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr {
let _: Result<_, _> = udp4.send_to(packet, &addr.into());
if let Err(err) = udp4.send_to(packet, &addr.into()) {
tracing::warn!(message = "Failed to write packet to network v4", error = ?err, dst = ?addr);
} else {
tracing::trace!(
message = "Writing packet to network v4",
packet_length = packet.len(),
src_addr = ?addr,
);
}
} else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr {
let _: Result<_, _> = udp6.send_to(packet, &addr.into());
if let Err(err) = udp6.send_to(packet, &addr.into()) {
tracing::warn!(message = "Failed to write packet to network v6", error = ?err, dst = ?addr);
} else {
tracing::trace!(
message = "Writing packet to network v6",
packet_length = packet.len(),
src_addr = ?addr,
);
}
} else {
tracing::error!("No endpoint");
}
Expand Down
34 changes: 3 additions & 31 deletions neptun/src/noise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod session;
mod timers;

use crossbeam::channel::{Receiver, Sender};
use ring_buffers::{EncryptionTaskData, TX_RING_BUFFER};
use ring_buffers::EncryptionTaskData;
use session::{Session, AEAD_SIZE, DATA_OFFSET};

use crate::noise::errors::WireGuardError;
Expand Down Expand Up @@ -107,7 +107,6 @@ pub struct Tunn {
pub peer_static_public: x25519_dalek::PublicKey,

encrypt_tx_chan: Option<Sender<usize>>,
network_tx_chan: Option<Sender<&'static EncryptionTaskData>>,
}

type MessageType = u32;
Expand Down Expand Up @@ -274,7 +273,6 @@ impl Tunn {
Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT))
}),
encrypt_tx_chan: encrypt_tx_chan,
network_tx_chan: network_tx_chan,
};

Ok(tunn)
Expand Down Expand Up @@ -353,6 +351,7 @@ impl Tunn {
element: &'static mut EncryptionTaskData,
idx: usize,
endpoint: Arc<parking_lot::RwLock<Endpoint>>,
dst: &'a mut [u8],
) -> TunnResult<'a> {
let current = self.current;
if let Some(ref session) = self.sessions[current % N_SESSIONS] {
Expand Down Expand Up @@ -384,34 +383,7 @@ impl Tunn {
}

// Initiate a new handshake if none is in progress
self.initiate_handshake(endpoint, false);
TunnResult::Done
}

pub fn initiate_handshake<'a>(
&mut self,
endpoint: Arc<parking_lot::RwLock<Endpoint>>,
force_resend: bool,
) {
// TODO: Have to fix this. This can't be a hardcoded 0th iter
let (dst, _) = unsafe { TX_RING_BUFFER.get_next() };
{
dst.endpoint = endpoint;
let res = self.format_handshake_initiation(dst.data.as_mut_slice(), force_resend);
match res {
TunnResult::Done => return,
TunnResult::Err(e) => {
tracing::error!(message = "Handshake initiation error", error = ?e);
return;
}
TunnResult::WriteToNetwork(buf) => {
dst.res = NeptunResult::WriteToNetwork(buf.len())
}
_ => panic!("Unexpected result from handshake initiation"),
}
};
dst.is_element_free.store(false, Ordering::Relaxed);
let _ = self.network_tx_chan.as_ref().unwrap().send(dst);
self.format_handshake_initiation(dst, false)
}

/// Receives a UDP datagram from the network and parses it.
Expand Down

0 comments on commit 139cc26

Please sign in to comment.