Skip to content

Commit

Permalink
Implement Clone for TcpStream (#689)
Browse files Browse the repository at this point in the history
* Implement Clone for TcpStream

* Update examples

* Remove accidentally added examples
  • Loading branch information
Stjepan Glavina authored Jan 28, 2020
1 parent 57974ae commit 1d87583
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 18 deletions.
5 changes: 3 additions & 2 deletions examples/tcp-echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ use async_std::task;
async fn process(stream: TcpStream) -> io::Result<()> {
println!("Accepted from: {}", stream.peer_addr()?);

let (reader, writer) = &mut (&stream, &stream);
io::copy(reader, writer).await?;
let mut reader = stream.clone();
let mut writer = stream;
io::copy(&mut reader, &mut writer).await?;

Ok(())
}
Expand Down
5 changes: 3 additions & 2 deletions examples/tcp-ipv4-and-6-echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ use async_std::task;
async fn process(stream: TcpStream) -> io::Result<()> {
println!("Accepted from: {}", stream.peer_addr()?);

let (reader, writer) = &mut (&stream, &stream);
io::copy(reader, writer).await?;
let mut reader = stream.clone();
let mut writer = stream;
io::copy(&mut reader, &mut writer).await?;

Ok(())
}
Expand Down
7 changes: 3 additions & 4 deletions src/net/tcp/listener.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;

use crate::future;
use crate::io;
Expand Down Expand Up @@ -75,9 +76,7 @@ impl TcpListener {
/// [`local_addr`]: #method.local_addr
pub async fn bind<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpListener> {
let mut last_err = None;
let addrs = addrs
.to_socket_addrs()
.await?;
let addrs = addrs.to_socket_addrs().await?;

for addr in addrs {
match mio::net::TcpListener::bind(&addr) {
Expand Down Expand Up @@ -121,7 +120,7 @@ impl TcpListener {

let mio_stream = mio::net::TcpStream::from_stream(io)?;
let stream = TcpStream {
watcher: Watcher::new(mio_stream),
watcher: Arc::new(Watcher::new(mio_stream)),
};
Ok((stream, addr))
}
Expand Down
26 changes: 16 additions & 10 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io::{IoSlice, IoSliceMut, Read as _, Write as _};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;

use crate::future;
use crate::io::{self, Read, Write};
Expand Down Expand Up @@ -44,9 +45,9 @@ use crate::task::{Context, Poll};
/// #
/// # Ok(()) }) }
/// ```
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct TcpStream {
pub(super) watcher: Watcher<mio::net::TcpStream>,
pub(super) watcher: Arc<Watcher<mio::net::TcpStream>>,
}

impl TcpStream {
Expand All @@ -71,9 +72,7 @@ impl TcpStream {
/// ```
pub async fn connect<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpStream> {
let mut last_err = None;
let addrs = addrs
.to_socket_addrs()
.await?;
let addrs = addrs.to_socket_addrs().await?;

for addr in addrs {
// mio's TcpStream::connect is non-blocking and may just be in progress
Expand All @@ -84,16 +83,20 @@ impl TcpStream {
Ok(s) => Watcher::new(s),
Err(e) => {
last_err = Some(e);
continue
continue;
}
};

future::poll_fn(|cx| watcher.poll_write_ready(cx)).await;

match watcher.get_ref().take_error() {
Ok(None) => return Ok(TcpStream { watcher }),
Ok(None) => {
return Ok(TcpStream {
watcher: Arc::new(watcher),
});
}
Ok(Some(e)) => last_err = Some(e),
Err(e) => last_err = Some(e)
Err(e) => last_err = Some(e),
}
}

Expand Down Expand Up @@ -369,7 +372,7 @@ impl From<std::net::TcpStream> for TcpStream {
fn from(stream: std::net::TcpStream) -> TcpStream {
let mio_stream = mio::net::TcpStream::from_stream(stream).unwrap();
TcpStream {
watcher: Watcher::new(mio_stream),
watcher: Arc::new(Watcher::new(mio_stream)),
}
}
}
Expand All @@ -391,7 +394,10 @@ cfg_unix! {

impl IntoRawFd for TcpStream {
fn into_raw_fd(self) -> RawFd {
self.watcher.into_inner().into_raw_fd()
// TODO(stjepang): This does not mean `RawFd` is now the sole owner of the file
// descriptor because it's possible that there are other clones of this `TcpStream`
// using it at the same time. We should probably document that behavior.
self.as_raw_fd()
}
}
}
Expand Down
22 changes: 22 additions & 0 deletions tests/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,25 @@ fn smoke_async_stream_to_std_listener() -> io::Result<()> {

Ok(())
}

#[test]
fn cloned_streams() -> io::Result<()> {
task::block_on(async {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;

let mut stream = TcpStream::connect(&addr).await?;
let mut cloned_stream = stream.clone();
let mut incoming = listener.incoming();
let mut write_stream = incoming.next().await.unwrap()?;
write_stream.write_all(b"Each your doing").await?;

let mut buf = [0; 15];
stream.read_exact(&mut buf[..8]).await?;
cloned_stream.read_exact(&mut buf[8..]).await?;

assert_eq!(&buf[..15], b"Each your doing");

Ok(())
})
}

0 comments on commit 1d87583

Please sign in to comment.