Skip to content

Commit

Permalink
fix: Check for ErrorKind::WouldBlock in LazyConfigAcceptor (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr authored Mar 8, 2024
1 parent 330d287 commit d26502c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
27 changes: 22 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use std::sync::Arc;
use std::task::{Context, Poll};

pub use rustls;
use rustls::server::AcceptedAlert;
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

Expand Down Expand Up @@ -195,6 +196,7 @@ impl TlsAcceptor {
pub struct LazyConfigAcceptor<IO> {
acceptor: rustls::server::Acceptor,
io: Option<IO>,
alert: Option<(rustls::Error, AcceptedAlert)>,
}

impl<IO> LazyConfigAcceptor<IO>
Expand All @@ -206,6 +208,7 @@ where
Self {
acceptor,
io: Some(io),
alert: None,
}
}

Expand Down Expand Up @@ -274,6 +277,22 @@ where
}
};

if let Some((err, mut alert)) = this.alert.take() {
match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
this.alert = Some((err, alert));
return Poll::Pending;
}
Ok(0) | Err(_) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
}
Ok(_) => {
this.alert = Some((err, alert));
continue;
}
};
}

let mut reader = common::SyncReadAdapter { io, cx };
match this.acceptor.read_tls(&mut reader) {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
Expand All @@ -287,11 +306,9 @@ where
let io = this.io.take().unwrap();
return Poll::Ready(Ok(StartHandshake { accepted, io }));
}
Ok(None) => continue,
Err((err, mut alert)) => {
let mut writer = common::SyncWriteAdapter { io, cx };
let _ = alert.write(&mut writer); // best effort
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)));
Ok(None) => {}
Err((err, alert)) => {
this.alert = Some((err, alert));
}
}
}
Expand Down
32 changes: 32 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,5 +290,37 @@ async fn acceptor_alert() {
assert_eq!(received, [0x15, 0x03, 0x03, 0x00, 0x02, 0x02, 0x46]);
}

#[tokio::test]
async fn lazy_config_acceptor_alert() {
// Intentionally small so that we have to call alert.write several times
let (mut cstream, sstream) = tokio::io::duplex(2);

let (tx, rx) = oneshot::channel();

tokio::spawn(async move {
// This is write instead of write_all because of the short duplex size, which is necessarily
// symmetrical. We never finish writing because the LazyConfigAcceptor returns an error
let _ = cstream.write(b"not tls").await;
let mut buf = Vec::new();
cstream.read_to_end(&mut buf).await.unwrap();
tx.send(buf).unwrap();
});

let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::default(), sstream);

let Ok(accept_result) = time::timeout(Duration::from_secs(3), acceptor).await else {
panic!("timeout");
};

assert!(accept_result.is_err());

let Ok(Ok(received)) = time::timeout(Duration::from_secs(3), rx).await else {
panic!("failed to receive");
};

let fatal_alert_decode_error = b"\x15\x03\x03\x00\x02\x02\x32";
assert_eq!(received, fatal_alert_decode_error)
}

// Include `utils` module
include!("utils.rs");

0 comments on commit d26502c

Please sign in to comment.