From d3891decc0533914e5248e5f85e7ebe65f676097 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Sun, 4 Aug 2019 23:23:14 +0200 Subject: [PATCH] cleanup and fix examples --- README.md | 2 +- examples/client/Cargo.toml | 5 +-- examples/client/src/main.rs | 18 ++++----- examples/server/Cargo.toml | 4 +- examples/server/src/main.rs | 59 ++++++++++++++-------------- rustfmt.toml | 1 + src/common/test_stream.rs | 76 +++++++++++++++++++++++++------------ src/test_0rtt.rs | 25 ++++++------ tests/test.rs | 32 ++++++++-------- 9 files changed, 126 insertions(+), 96 deletions(-) create mode 100644 rustfmt.toml diff --git a/README.md b/README.md index d1e7a30..e0fbf37 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ See [examples/server](examples/server/src/main.rs). You can run it with: ```sh cd examples/server -cargo run -- 127.0.0.1 --cert mycert.der --key mykey.der +cargo run -- 127.0.0.1:8000 --cert mycert.der --key mykey.der ``` ### License & Origin diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index feec249..d3b67cc 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -5,9 +5,8 @@ authors = ["quininer "] edition = "2018" [dependencies] -futures = { package = "futures-preview", version = "0.3.0-alpha.16", features = ["io-compat"] } -romio = "0.3.0-alpha.8" +futures-preview = "0.3.0-alpha.17" +async-std = { path = "../../../async-std" } structopt = "0.2" tokio-rustls = { path = "../.." } webpki-roots = "0.16" -tokio-stdin-stdout = "0.1" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 6416db2..602475f 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -7,13 +7,11 @@ use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::BufReader; use structopt::StructOpt; -use romio::TcpStream; +use async_std::net::TcpStream; +use async_std::task; +use async_std::io as aio; use futures::prelude::*; -use futures::executor; -use futures::compat::{ AsyncRead01CompatExt, AsyncWrite01CompatExt }; use tokio_rustls::{ TlsConnector, rustls::ClientConfig, webpki::DNSNameRef }; -use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout }; - #[derive(StructOpt)] struct Options { @@ -56,9 +54,9 @@ fn main() -> io::Result<()> { } let connector = TlsConnector::from(Arc::new(config)); - let fut = async { + task::block_on(async { let stream = TcpStream::connect(&addr).await?; - let (mut stdin, mut stdout) = (tokio_stdin(0).compat(), tokio_stdout(0).compat()); + let (stdin, mut stdout) = (aio::stdin(), aio::stdout()); let domain = DNSNameRef::try_from_ascii_str(&domain) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; @@ -66,14 +64,12 @@ fn main() -> io::Result<()> { let mut stream = connector.connect(domain, stream).await?; stream.write_all(content.as_bytes()).await?; - let (mut reader, mut writer) = stream.split(); + let (reader, mut writer) = stream.split(); future::try_join( reader.copy_into(&mut stdout), stdin.copy_into(&mut writer) ).await?; Ok(()) - }; - - executor::block_on(fut) + }) } diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 9da4423..713b3cd 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -5,7 +5,7 @@ authors = ["quininer "] edition = "2018" [dependencies] -futures = { package = "futures-preview", version = "0.3.0-alpha.16" } -romio = "0.3.0-alpha.8" +futures-preview = "0.3.0-alpha.17" structopt = "0.2" tokio-rustls = { path = "../.." } +async-std = { path = "../../../async-std" } \ No newline at end of file diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index a2a3b13..9df842f 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -1,35 +1,35 @@ #![feature(async_await)] +use async_std::net::TcpListener; +use async_std::task; +use futures::executor; +use futures::prelude::*; +use futures::task::SpawnExt; use std::fs::File; -use std::sync::Arc; +use std::io::{self, BufReader}; use std::net::ToSocketAddrs; -use std::path::{ PathBuf, Path }; -use std::io::{ self, BufReader }; +use std::path::{Path, PathBuf}; +use std::sync::Arc; use structopt::StructOpt; -use futures::task::SpawnExt; -use futures::prelude::*; -use futures::executor; -use romio::TcpListener; -use tokio_rustls::rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; -use tokio_rustls::rustls::internal::pemfile::{ certs, rsa_private_keys }; +use tokio_rustls::rustls::internal::pemfile::{certs, rsa_private_keys}; +use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig}; use tokio_rustls::TlsAcceptor; - #[derive(StructOpt)] struct Options { addr: String, /// cert file - #[structopt(short="c", long="cert", parse(from_os_str))] + #[structopt(short = "c", long = "cert", parse(from_os_str))] cert: PathBuf, /// key file - #[structopt(short="k", long="key", parse(from_os_str))] + #[structopt(short = "k", long = "key", parse(from_os_str))] key: PathBuf, /// echo mode - #[structopt(short="e", long="echo-mode")] - echo: bool + #[structopt(short = "e", long = "echo-mode")] + echo: bool, } fn load_certs(path: &Path) -> io::Result> { @@ -42,11 +42,12 @@ fn load_keys(path: &Path) -> io::Result> { .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key")) } - fn main() -> io::Result<()> { let options = Options::from_args(); - let addr = options.addr.to_socket_addrs()? + let addr = options + .addr + .to_socket_addrs()? .next() .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?; let certs = load_certs(&options.cert)?; @@ -55,12 +56,13 @@ fn main() -> io::Result<()> { let mut pool = executor::ThreadPool::new()?; let mut config = ServerConfig::new(NoClientAuth::new()); - config.set_single_cert(certs, keys.remove(0)) + config + .set_single_cert(certs, keys.remove(0)) .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; let acceptor = TlsAcceptor::from(Arc::new(config)); - let fut = async { - let mut listener = TcpListener::bind(&addr)?; + task::block_on(async { + let listener = TcpListener::bind(&addr).await?; let mut incoming = listener.incoming(); while let Some(stream) = incoming.next().await { @@ -72,17 +74,19 @@ fn main() -> io::Result<()> { let mut stream = acceptor.accept(stream).await?; if flag_echo { - let (mut reader, mut writer) = stream.split(); + let (reader, mut writer) = stream.split(); let n = reader.copy_into(&mut writer).await?; println!("Echo: {} - {}", peer_addr, n); } else { - stream.write_all( - &b"HTTP/1.0 200 ok\r\n\ + stream + .write_all( + &b"HTTP/1.0 200 ok\r\n\ Connection: close\r\n\ Content-length: 12\r\n\ \r\n\ - Hello world!"[..] - ).await?; + Hello world!"[..], + ) + .await?; stream.flush().await?; println!("Hello: {}", peer_addr); } @@ -90,11 +94,10 @@ fn main() -> io::Result<()> { Ok(()) as io::Result<()> }; - pool.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); + pool.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))) + .unwrap(); } Ok(()) - }; - - executor::block_on(fut) + }) } diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..c51666e --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +edition = "2018" \ No newline at end of file diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 4d953c4..9c1bcc9 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -1,33 +1,37 @@ +use super::Stream; +use futures::executor; +use futures::io::{AsyncRead, AsyncWrite}; +use futures::prelude::*; +use futures::task::{noop_waker_ref, Context}; +use rustls::internal::pemfile::{certs, rsa_private_keys}; +use rustls::{ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession, Session}; +use std::io::{self, BufReader, Cursor, Read, Write}; use std::pin::Pin; -use std::task::Poll; use std::sync::Arc; -use futures::prelude::*; -use futures::task::{ Context, noop_waker_ref }; -use futures::executor; -use futures::io::{ AsyncRead, AsyncWrite }; -use std::io::{ self, Read, Write, BufReader, Cursor }; +use std::task::Poll; use webpki::DNSNameRef; -use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use rustls::{ - ServerConfig, ClientConfig, - ServerSession, ClientSession, - Session, NoClientAuth -}; -use super::Stream; - struct Good<'a>(&'a mut dyn Session); impl<'a> AsyncRead for Good<'a> { - fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + mut buf: &mut [u8], + ) -> Poll> { Poll::Ready(self.0.write_tls(buf.by_ref())) } } impl<'a> AsyncWrite for Good<'a> { - fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { let len = self.0.read_tls(buf.by_ref())?; - self.0.process_new_packets() + self.0 + .process_new_packets() .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; Poll::Ready(Ok(len)) } @@ -44,13 +48,21 @@ impl<'a> AsyncWrite for Good<'a> { struct Bad(bool); impl AsyncRead for Bad { - fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _: &mut [u8], + ) -> Poll> { Poll::Ready(Ok(0)) } } impl AsyncWrite for Bad { - fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { if self.0 { Poll::Pending } else { @@ -105,13 +117,22 @@ fn stream_bad() -> io::Result<()> { let mut bad = Bad(true); let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); - assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + assert_eq!( + future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, + 8 + ); + assert_eq!( + future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, + 8 + ); let r = future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer assert!(r < 1024); let mut cx = Context::from_waker(noop_waker_ref()); - assert!(stream.as_mut_pin().poll_write(&mut cx, &[0x01]).is_pending()); + assert!(stream + .as_mut_pin() + .poll_write(&mut cx, &[0x01]) + .is_pending()); Ok(()) as io::Result<()> }; @@ -154,7 +175,10 @@ fn stream_handshake_eof() -> io::Result<()> { let mut cx = Context::from_waker(noop_waker_ref()); let r = stream.complete_io(&mut cx); - assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); + assert_eq!( + r.map_err(|err| err.kind()), + Poll::Ready(Err(io::ErrorKind::UnexpectedEof)) + ); Ok(()) as io::Result<()> }; @@ -201,7 +225,11 @@ fn make_pair() -> (ServerSession, ClientSession) { (server, client) } -fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll> { +fn do_handshake( + client: &mut ClientSession, + server: &mut ServerSession, + cx: &mut Context<'_>, +) -> Poll> { let mut good = Good(server); let mut stream = Stream::new(&mut good, client); diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index 8c8db6c..315e99e 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -1,22 +1,21 @@ -use std::io; -use std::sync::Arc; -use std::net::ToSocketAddrs; +use crate::{client::TlsStream, TlsConnector}; use futures::executor; use futures::prelude::*; use romio::tcp::TcpStream; use rustls::ClientConfig; -use crate::{ TlsConnector, client::TlsStream }; - +use std::io; +use std::net::ToSocketAddrs; +use std::sync::Arc; -async fn get(config: Arc, domain: &str, rtt0: bool) - -> io::Result<(TlsStream, String)> -{ +async fn get( + config: Arc, + domain: &str, + rtt0: bool, +) -> io::Result<(TlsStream, String)> { let connector = TlsConnector::from(config).early_data(rtt0); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); - let addr = (domain, 443) - .to_socket_addrs()? - .next().unwrap(); + let addr = (domain, 443).to_socket_addrs()?.next().unwrap(); let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); let mut buf = Vec::new(); @@ -31,7 +30,9 @@ async fn get(config: Arc, domain: &str, rtt0: bool) #[test] fn test_0rtt() { let mut config = ClientConfig::new(); - config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); config.enable_early_data = true; let config = Arc::new(config); let domain = "mozilla-modern.badssl.com"; diff --git a/tests/test.rs b/tests/test.rs index 18075f9..d4d28dc 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,30 +1,31 @@ #![feature(async_await)] -use std::{ io, thread }; -use std::io::{ BufReader, Cursor }; -use std::sync::Arc; -use std::sync::mpsc::channel; -use std::net::SocketAddr; -use lazy_static::lazy_static; -use futures::prelude::*; +use async_std::net::{TcpListener, TcpStream}; use futures::executor; +use futures::prelude::*; use futures::task::SpawnExt; -use async_std::net::{ TcpListener, TcpStream }; -use rustls::{ ServerConfig, ClientConfig }; -use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use tokio_rustls::{ TlsConnector, TlsAcceptor }; +use lazy_static::lazy_static; +use rustls::internal::pemfile::{certs, rsa_private_keys}; +use rustls::{ClientConfig, ServerConfig}; +use std::io::{BufReader, Cursor}; +use std::net::SocketAddr; +use std::sync::mpsc::channel; +use std::sync::Arc; +use std::{io, thread}; +use tokio_rustls::{TlsAcceptor, TlsConnector}; const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); const RSA: &str = include_str!("end.rsa"); -lazy_static!{ +lazy_static! { static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(cert, keys.pop().unwrap()) + config + .set_single_cert(cert, keys.pop().unwrap()) .expect("invalid key or certificate"); let acceptor = TlsAcceptor::from(Arc::new(config)); @@ -48,8 +49,9 @@ lazy_static!{ reader.copy_into(&mut write).await?; Ok(()) as io::Result<()> } - .unwrap_or_else(|err| eprintln!("{:?}", err)) - ).unwrap(); + .unwrap_or_else(|err| eprintln!("{:?}", err)), + ) + .unwrap(); } Ok(()) as io::Result<()>