diff --git a/Cargo.toml b/Cargo.toml index 6d25f59..54186f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ authors = [ [features] default = ["transport_utp"] -transport_utp = ["libutp-rs"] +transport_utp = ["async-std-utp"] [dependencies] async-std = { version = "1.9.0", features = ["unstable"] } @@ -30,7 +30,7 @@ hex = "0.4.3" pretty-hash = "0.4.1" hyperswarm-dht = { git = "https://github.com/Frando/hyperswarm-dht.git", branch = "hyperspace" } colmeia-hyperswarm-mdns = { git = "https://github.com/bltavares/colmeia.git", rev = "e92ab71981356197a21592b7ce6854e209582985" } -libutp-rs = { git = "https://github.com/Frando/libutp-rs.git", branch = "feat/clone", optional = true } +async-std-utp = { version = "0.1.1", optional = true } [dev-dependencies] env_logger = "0.8.3" diff --git a/src/transport/combined.rs b/src/transport/combined.rs index 2b398aa..f90db6f 100644 --- a/src/transport/combined.rs +++ b/src/transport/combined.rs @@ -163,7 +163,7 @@ impl CombinedStream { match self { Self::Tcp(stream) => stream.peer_addr().unwrap(), #[cfg(feature = "transport_utp")] - Self::Utp(stream) => stream.peer_addr(), + Self::Utp(stream) => stream.peer_addr().unwrap(), } } diff --git a/src/transport/utp.rs b/src/transport/utp.rs index fc93ca4..20eefa1 100644 --- a/src/transport/utp.rs +++ b/src/transport/utp.rs @@ -1,7 +1,6 @@ -use async_compat::Compat; -use futures::stream::FuturesUnordered; -use futures_lite::{AsyncRead, AsyncWrite, Stream}; -use libutp_rs::{Connect as ConnectFut, UtpContext, UtpListener, UtpSocket}; +use async_std_utp::{UtpListener, UtpSocket}; +use futures::FutureExt; +use futures::{future::BoxFuture, stream::FuturesUnordered, Stream}; use std::fmt; use std::io; use std::net::{SocketAddr, ToSocketAddrs}; @@ -10,12 +9,16 @@ use std::task::{Context, Poll}; use super::{Connection, Transport}; +pub(crate) use async_std_utp::UtpStream; + const PROTOCOL: &'static str = "utp"; +type ConnectFut = BoxFuture<'static, io::Result>; + pub struct UtpTransport { - context: UtpContext, pending_connects: FuturesUnordered, - incoming: UtpListener, + context: UtpListener, + incoming: Option>>, } impl fmt::Debug for UtpTransport { @@ -30,12 +33,11 @@ impl UtpTransport { A: ToSocketAddrs + Send, { let addr = local_addr.to_socket_addrs()?.next().unwrap(); - let context = UtpContext::bind(addr)?; - let incoming = context.listener(); + let context = UtpListener::bind(addr).await?; Ok(Self { context, - incoming, pending_connects: FuturesUnordered::new(), + incoming: None, }) } } @@ -43,7 +45,11 @@ impl UtpTransport { impl Transport for UtpTransport { type Connection = UtpStream; fn connect(&mut self, peer_addr: SocketAddr) { - let fut = self.context.connect(peer_addr); + let fut = async move { + let socket = UtpSocket::connect(peer_addr).await?; + Ok(UtpStream::from(socket)) + } + .boxed(); self.pending_connects.push(fut); } } @@ -51,8 +57,14 @@ impl Transport for UtpTransport { impl Stream for UtpTransport { type Item = io::Result::Connection>>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let incoming = Pin::new(&mut self.incoming).poll_next(cx); + if self.incoming.is_none() { + let contex = self.context.clone(); + self.incoming = async move { contex.accept().await }.boxed().into(); + } + let incoming = self.incoming.as_mut().unwrap().poll_unpin(cx); + let incoming = incoming.map(|poll| Some(poll.map(|(socket, _)| UtpStream::from(socket)))); if let Some(conn) = into_connection(incoming, false) { + self.incoming = None; return Poll::Ready(Some(conn)); } @@ -65,7 +77,7 @@ impl Stream for UtpTransport { } fn into_connection( - poll: Poll>>, + poll: Poll>>, is_initiator: bool, ) -> Option>> { match poll { @@ -73,68 +85,9 @@ fn into_connection( Poll::Ready(None) => None, Poll::Ready(Some(Err(e))) => Some(Err(e)), Poll::Ready(Some(Ok(stream))) => { - let stream = UtpStream::new(stream); - let peer_addr = stream.peer_addr(); + let peer_addr = stream.peer_addr().unwrap(); let conn = Connection::new(stream, peer_addr, is_initiator, PROTOCOL.into()); Some(Ok(conn)) } } } - -pub struct UtpStream { - inner: Compat, -} - -impl fmt::Debug for UtpStream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("UtpStream").finish() - } -} - -impl UtpStream { - pub fn new(socket: UtpSocket) -> Self { - Self { - inner: Compat::new(socket), - } - } - - pub fn peer_addr(&self) -> SocketAddr { - self.inner.get_ref().peer_addr() - } -} - -impl AsyncRead for UtpStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - Pin::new(&mut self.inner).poll_read(cx, buf) - } -} - -impl AsyncWrite for UtpStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_close(cx) - } -} - -impl Clone for UtpStream { - fn clone(&self) -> Self { - Self { - inner: Compat::new(self.inner.get_ref().clone()), - } - } -}