Skip to content

Commit

Permalink
better peer id
Browse files Browse the repository at this point in the history
  • Loading branch information
icewind1991 committed Nov 29, 2024
1 parent e3ca217 commit 659cc0a
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 32 deletions.
59 changes: 59 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ main_error = "0.1.2"
futures-channel = "0.3.31"
log = "0.4.22"
futures-util = "0.3.31"
real-ip = "0.1.0"

[dev-dependencies]
maplit = "1"
Expand Down
88 changes: 65 additions & 23 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod session;

use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};

use crate::session::Session;
use dashmap::DashMap;
Expand All @@ -9,16 +10,19 @@ use futures_util::future::select;
use futures_util::StreamExt;
use futures_util::TryStreamExt;
use main_error::MainResult;
use std::net::{Ipv4Addr, SocketAddr};
use real_ip::{real_ip, IpNet};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::handshake::server::{ErrorResponse, Request, Response};
use tokio_tungstenite::tungstenite::Message;
use tracing::{debug, error, info, instrument, warn};
use tracing::{debug, error, info, warn};

type Tx = Sender<Message>;
type PeerMap = DashMap<SocketAddr, Tx>;
type PeerMap = DashMap<PeerId, Tx>;
type Sessions = DashMap<String, Session>;

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
Expand All @@ -32,28 +36,43 @@ pub enum SyncCommand<'a> {
Clients { session: &'a str, count: usize },
}

#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
pub struct PeerId(IpAddr, u64);

impl Display for PeerId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}#{}", self.0, self.1)
}
}

pub struct Server {
id_counter: AtomicU64,
peers: PeerMap,
sessions: Sessions,
}

impl Server {
fn new() -> Self {
Server {
id_counter: AtomicU64::default(),
peers: PeerMap::with_capacity(128),
sessions: Sessions::with_capacity(64),
}
}

fn send_text<S: Into<String>>(&self, peer: &SocketAddr, text: S) {
fn next_peer_id(&self) -> u64 {
self.id_counter.fetch_add(1, Ordering::Relaxed)
}

fn send_text<S: Into<String>>(&self, peer: &PeerId, text: S) {
if let Some(mut tx) = self.peers.get_mut(peer) {
if let Err(e) = tx.try_send(Message::Text(text.into())) {
error!(%peer, ?e, "failed to send message to client")
}
}
}

pub fn send_command(&self, peer: &SocketAddr, command: &SyncCommand) {
pub fn send_command(&self, peer: &PeerId, command: &SyncCommand) {
self.send_text(peer, serde_json::to_string(command).unwrap())
}

Expand All @@ -64,7 +83,7 @@ impl Server {
}
}

fn handle_command(&self, command: SyncCommand, sender: SocketAddr) {
fn handle_command(&self, command: SyncCommand, sender: PeerId) {
match &command {
SyncCommand::Create { session, token } => {
self.sessions
Expand Down Expand Up @@ -111,13 +130,17 @@ impl Server {
}
}

fn handle_disconnect(&self, peer: &SocketAddr) {
fn handle_disconnect(&self, peer: &PeerId) {
self.peers.remove(peer);
for mut session in self.sessions.iter_mut() {
session.remove_client(peer);
self.send_command(&session.owner, &SyncCommand::Clients {
session: &session.token,
count: session.clients().count(),
})
self.send_command(
&session.owner,
&SyncCommand::Clients {
session: &session.token,
count: session.clients().count(),
},
)
}
}

Expand All @@ -131,30 +154,45 @@ impl Server {
});
}

#[instrument(skip(self, raw_stream))]
async fn handle_connection(&self, raw_stream: TcpStream, addr: SocketAddr) {
debug!("incoming connection");

let ws_stream = tokio_tungstenite::accept_async(raw_stream)
.await
.expect("Error during the websocket handshake occurred");
info!("connection established");
let mut remote_ip = addr.ip();

let ws_stream_res =
tokio_tungstenite::accept_hdr_async(raw_stream, |req: &Request, response: Response| {
if let Some(ip) = real_ip(req.headers(), addr.ip(), TRUSTED_PROXIES) {
remote_ip = ip;
}
Ok::<_, ErrorResponse>(response)
})
.await;
let peer_id = PeerId(remote_ip, self.next_peer_id());
let ws_stream = match ws_stream_res {
Ok(ws_stream) => ws_stream,
Err(error) => {
error!(?error, %peer_id, "error while performing websocket handshake");
return;
}
};

info!(peer = %peer_id, "connection established");

// Insert the write part of this peer to the peer map.
let (tx, rx) = channel(16);
self.peers.insert(addr, tx);
self.peers.insert(peer_id, tx);

let (outgoing, incoming) = ws_stream.split();

let handle_messages = incoming.try_for_each(|msg| async move {
if let Ok(message) = msg.to_text() {
match serde_json::from_str(message) {
Ok(command) => {
debug!(sender = %addr, message = ?command, "Received a message");
self.handle_command(command, addr);
debug!(sender = %peer_id, message = ?command, "Received a message");
self.handle_command(command, peer_id);
}
Err(e) => {
warn!(sender = %addr, message, error = %e, "Error while decoding message");
warn!(sender = %peer_id, message, error = %e, "Error while decoding message");
}
}
} else {
Expand All @@ -169,9 +207,8 @@ impl Server {
let receive_from_others = pin!(receive_from_others);
select(handle_messages, receive_from_others).await;

info!(%addr, "disconnected");
self.peers.remove(&addr);
self.handle_disconnect(&addr);
info!(%peer_id, "disconnected");
self.handle_disconnect(&peer_id);
}
}

Expand Down Expand Up @@ -203,3 +240,8 @@ async fn main() -> MainResult {

Ok(())
}

const TRUSTED_PROXIES: &[IpNet] = &[IpNet::new_assert(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 0)),
8,
)];
17 changes: 8 additions & 9 deletions src/session.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use crate::SyncCommand;
use std::net::SocketAddr;
use crate::{PeerId, SyncCommand};
use std::time::{Duration, Instant};

#[derive(Debug)]
pub struct Session {
pub owner: SocketAddr,
pub owner: PeerId,
owner_token: String,
clients: Vec<SocketAddr>,
clients: Vec<PeerId>,
tick: u64,
playing: bool,
owner_left: Option<Instant>,
Expand All @@ -20,7 +19,7 @@ impl PartialEq for Session {
}

impl Session {
pub fn new(owner: SocketAddr, token: String, owner_token: String) -> Self {
pub fn new(owner: PeerId, token: String, owner_token: String) -> Self {
Session {
owner,
owner_token,
Expand All @@ -32,11 +31,11 @@ impl Session {
}
}

pub fn join(&mut self, client: SocketAddr) {
pub fn join(&mut self, client: PeerId) {
self.clients.push(client);
}

pub fn set_owner(&mut self, owner: SocketAddr, owner_token: &str) -> bool {
pub fn set_owner(&mut self, owner: PeerId, owner_token: &str) -> bool {
if owner_token == self.owner_token {
self.owner = owner;
self.owner_left = None;
Expand All @@ -62,11 +61,11 @@ impl Session {
.into_iter()
}

pub fn clients(&self) -> impl Iterator<Item = &SocketAddr> {
pub fn clients(&self) -> impl Iterator<Item = &PeerId> {
self.clients.iter()
}

pub fn remove_client(&mut self, peer: &SocketAddr) {
pub fn remove_client(&mut self, peer: &PeerId) {
self.clients.retain(|client| client != peer)
}

Expand Down

0 comments on commit 659cc0a

Please sign in to comment.