Skip to content

Commit

Permalink
migrate to async-std
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Aug 4, 2019
1 parent 9daf87a commit ac4ccfe
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 115 deletions.
5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@ appveyor = { repository = "quininer/tokio-rustls" }

[dependencies]
smallvec = "0.6"
tokio-io = { git = "https://github.com/tokio-rs/tokio" }
futures-core-preview = "0.3.0-alpha.17"
futures-preview = "0.3.0-alpha.17"
rustls = "0.15"
webpki = "0.19"

[features]
early-data = []

[dev-dependencies]
tokio = { git = "https://github.com/tokio-rs/tokio" }
lazy_static = "1"
webpki-roots = "0.16"
async-std = { path = "../async-std" }
55 changes: 31 additions & 24 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ impl<IO> AsyncRead for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.io.prepare_uninitialized_buffer(buf)
}

fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData => {
let this = self.get_mut();

let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
let (pos, data) = &mut this.early_data;

// complete handshake
Expand All @@ -96,7 +96,8 @@ where
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
let len =
futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}
Expand All @@ -109,8 +110,8 @@ where
}
TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => {
Expand All @@ -127,7 +128,7 @@ where
Poll::Ready(Ok(0))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending
Poll::Pending => Poll::Pending,
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
Expand All @@ -139,10 +140,14 @@ impl<IO> AsyncWrite for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

match this.state {
#[cfg(feature = "early-data")]
Expand All @@ -155,9 +160,10 @@ where
if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
return Poll::Pending,
Err(err) => return Poll::Ready(Err(err))
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Poll::Pending
}
Err(err) => return Poll::Ready(Err(err)),
};
data.extend_from_slice(&buf[..len]);
return Poll::Ready(Ok(len));
Expand All @@ -171,7 +177,8 @@ where
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
let len =
futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}
Expand All @@ -187,20 +194,20 @@ where

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.state.writeable() {
self.session.send_close_notify();
self.state.shutdown_write();
}

let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
stream.as_mut_pin().poll_shutdown(cx)
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_close(cx)
}
}
90 changes: 47 additions & 43 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use std::pin::Pin;
use std::task::{ Poll, Context };
use std::marker::Unpin;
use std::io::{ self, Read, Write };
use futures::io::{AsyncRead, AsyncWrite};
use rustls::Session;
use tokio_io::{ AsyncRead, AsyncWrite };
use futures_core as futures;

use std::io::{self, Read, Write};
use std::marker::Unpin;
use std::pin::Pin;
use std::task::{Context, Poll};

pub struct Stream<'a, IO, S> {
pub io: &'a mut IO,
pub session: &'a mut S,
pub eof: bool
pub eof: bool,
}

trait WriteTls<IO: AsyncWrite, S: Session> {
Expand All @@ -21,7 +19,7 @@ trait WriteTls<IO: AsyncWrite, S: Session> {
enum Focus {
Empty,
Readable,
Writable
Writable,
}

impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Expand Down Expand Up @@ -51,14 +49,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
fn complete_read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Reader<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>
cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_read(self.cx, buf) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}
Expand All @@ -68,30 +66,33 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
let n = match self.session.read_tls(&mut reader) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
Err(err) => return Poll::Ready(Err(err))
Err(err) => return Poll::Ready(Err(err)),
};

self.session.process_new_packets()
.map_err(|err| {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_tls(cx);
self.session.process_new_packets().map_err(|err| {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_tls(cx);

io::Error::new(io::ErrorKind::InvalidData, err)
})?;
io::Error::new(io::ErrorKind::InvalidData, err)
})?;

Poll::Ready(Ok(n))
}

fn complete_write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
match self.write_tls(cx) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result)
result => Poll::Ready(result),
}
}

fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll<io::Result<(usize, usize)>> {
fn complete_inner_io(
&mut self,
cx: &mut Context,
focus: Focus,
) -> Poll<io::Result<(usize, usize)>> {
let mut wrlen = 0;
let mut rdlen = 0;

Expand All @@ -104,9 +105,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Ok(n)) => wrlen += n,
Poll::Pending => {
write_would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}

Expand All @@ -115,7 +116,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Ok(0)) => self.eof = true,
Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => read_would_block = true,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}

Expand All @@ -129,23 +130,23 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
(true, true, _) => {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
return Poll::Ready(Err(err));
},
}
(_, false, true) => {
let would_block = match focus {
Focus::Empty => rdlen == 0 && wrlen == 0,
Focus::Readable => rdlen == 0,
Focus::Writable => wrlen == 0
Focus::Writable => wrlen == 0,
};

return if would_block {
Poll::Pending
} else {
Poll::Ready(Ok((rdlen, wrlen)))
};
},
}
(_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))),
(_, true, true) => return Poll::Pending,
(..) => ()
(..) => (),
}
}
}
Expand All @@ -157,21 +158,21 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls<IO, S> for Str

struct Writer<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>
cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_write(self.cx, buf) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}

fn flush(&mut self) -> io::Result<()> {
match Pin::new(&mut self.io).poll_flush(self.cx) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}
Expand All @@ -182,21 +183,25 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls<IO, S> for Str
}

impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();

while this.session.wants_read() {
match this.complete_inner_io(cx, Focus::Readable) {
Poll::Ready(Ok((0, _))) => break,
Poll::Ready(Ok(_)) => (),
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}

match this.session.read(buf) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result)
result => Poll::Ready(result),
}
}
}
Expand All @@ -207,16 +212,15 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'

let len = match this.session.write(buf) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
return Poll::Pending,
Err(err) => return Poll::Ready(Err(err))
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
Err(err) => return Poll::Ready(Err(err)),
};
while this.session.wants_write() {
match this.complete_inner_io(cx, Focus::Writable) {
Poll::Ready(Ok(_)) => (),
Poll::Pending if len != 0 => break,
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}

Expand All @@ -228,7 +232,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
Ok(0) => Poll::Pending,
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(err) => Poll::Ready(Err(err))
Err(err) => Poll::Ready(Err(err)),
}
}
}
Expand All @@ -243,13 +247,13 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
Pin::new(&mut this.io).poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();

while this.session.wants_write() {
futures::ready!(this.complete_inner_io(cx, Focus::Writable))?;
}
Pin::new(&mut this.io).poll_shutdown(cx)
Pin::new(&mut this.io).poll_close(cx)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/common/test_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use rustls::{
use super::Stream;


struct Good<'a>(&'a mut Session);
struct Good<'a>(&'a mut dyn Session);

impl<'a> AsyncRead for Good<'a> {
fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll<io::Result<usize>> {
Expand Down
Loading

0 comments on commit ac4ccfe

Please sign in to comment.