Skip to content

Commit

Permalink
cleanup and fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Aug 4, 2019
1 parent ac4ccfe commit d3891de
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 96 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ See [examples/server](examples/server/src/main.rs). You can run it with:

```sh
cd examples/server
cargo run -- 127.0.0.1 --cert mycert.der --key mykey.der
cargo run -- 127.0.0.1:8000 --cert mycert.der --key mykey.der
```

### License & Origin
Expand Down
5 changes: 2 additions & 3 deletions examples/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ authors = ["quininer <[email protected]>"]
edition = "2018"

[dependencies]
futures = { package = "futures-preview", version = "0.3.0-alpha.16", features = ["io-compat"] }
romio = "0.3.0-alpha.8"
futures-preview = "0.3.0-alpha.17"
async-std = { path = "../../../async-std" }
structopt = "0.2"
tokio-rustls = { path = "../.." }
webpki-roots = "0.16"
tokio-stdin-stdout = "0.1"
18 changes: 7 additions & 11 deletions examples/client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ use std::sync::Arc;
use std::net::ToSocketAddrs;
use std::io::BufReader;
use structopt::StructOpt;
use romio::TcpStream;
use async_std::net::TcpStream;
use async_std::task;
use async_std::io as aio;
use futures::prelude::*;
use futures::executor;
use futures::compat::{ AsyncRead01CompatExt, AsyncWrite01CompatExt };
use tokio_rustls::{ TlsConnector, rustls::ClientConfig, webpki::DNSNameRef };
use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout };


#[derive(StructOpt)]
struct Options {
Expand Down Expand Up @@ -56,24 +54,22 @@ fn main() -> io::Result<()> {
}
let connector = TlsConnector::from(Arc::new(config));

let fut = async {
task::block_on(async {
let stream = TcpStream::connect(&addr).await?;
let (mut stdin, mut stdout) = (tokio_stdin(0).compat(), tokio_stdout(0).compat());
let (stdin, mut stdout) = (aio::stdin(), aio::stdout());

let domain = DNSNameRef::try_from_ascii_str(&domain)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;

let mut stream = connector.connect(domain, stream).await?;
stream.write_all(content.as_bytes()).await?;

let (mut reader, mut writer) = stream.split();
let (reader, mut writer) = stream.split();
future::try_join(
reader.copy_into(&mut stdout),
stdin.copy_into(&mut writer)
).await?;

Ok(())
};

executor::block_on(fut)
})
}
4 changes: 2 additions & 2 deletions examples/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = ["quininer <[email protected]>"]
edition = "2018"

[dependencies]
futures = { package = "futures-preview", version = "0.3.0-alpha.16" }
romio = "0.3.0-alpha.8"
futures-preview = "0.3.0-alpha.17"
structopt = "0.2"
tokio-rustls = { path = "../.." }
async-std = { path = "../../../async-std" }
59 changes: 31 additions & 28 deletions examples/server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
#![feature(async_await)]

use async_std::net::TcpListener;
use async_std::task;
use futures::executor;
use futures::prelude::*;
use futures::task::SpawnExt;
use std::fs::File;
use std::sync::Arc;
use std::io::{self, BufReader};
use std::net::ToSocketAddrs;
use std::path::{ PathBuf, Path };
use std::io::{ self, BufReader };
use std::path::{Path, PathBuf};
use std::sync::Arc;
use structopt::StructOpt;
use futures::task::SpawnExt;
use futures::prelude::*;
use futures::executor;
use romio::TcpListener;
use tokio_rustls::rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig };
use tokio_rustls::rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio_rustls::rustls::internal::pemfile::{certs, rsa_private_keys};
use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
use tokio_rustls::TlsAcceptor;


#[derive(StructOpt)]
struct Options {
addr: String,

/// cert file
#[structopt(short="c", long="cert", parse(from_os_str))]
#[structopt(short = "c", long = "cert", parse(from_os_str))]
cert: PathBuf,

/// key file
#[structopt(short="k", long="key", parse(from_os_str))]
#[structopt(short = "k", long = "key", parse(from_os_str))]
key: PathBuf,

/// echo mode
#[structopt(short="e", long="echo-mode")]
echo: bool
#[structopt(short = "e", long = "echo-mode")]
echo: bool,
}

fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
Expand All @@ -42,11 +42,12 @@ fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
}


fn main() -> io::Result<()> {
let options = Options::from_args();

let addr = options.addr.to_socket_addrs()?
let addr = options
.addr
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?;
let certs = load_certs(&options.cert)?;
Expand All @@ -55,12 +56,13 @@ fn main() -> io::Result<()> {

let mut pool = executor::ThreadPool::new()?;
let mut config = ServerConfig::new(NoClientAuth::new());
config.set_single_cert(certs, keys.remove(0))
config
.set_single_cert(certs, keys.remove(0))
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
let acceptor = TlsAcceptor::from(Arc::new(config));

let fut = async {
let mut listener = TcpListener::bind(&addr)?;
task::block_on(async {
let listener = TcpListener::bind(&addr).await?;
let mut incoming = listener.incoming();

while let Some(stream) = incoming.next().await {
Expand All @@ -72,29 +74,30 @@ fn main() -> io::Result<()> {
let mut stream = acceptor.accept(stream).await?;

if flag_echo {
let (mut reader, mut writer) = stream.split();
let (reader, mut writer) = stream.split();
let n = reader.copy_into(&mut writer).await?;
println!("Echo: {} - {}", peer_addr, n);
} else {
stream.write_all(
&b"HTTP/1.0 200 ok\r\n\
stream
.write_all(
&b"HTTP/1.0 200 ok\r\n\
Connection: close\r\n\
Content-length: 12\r\n\
\r\n\
Hello world!"[..]
).await?;
Hello world!"[..],
)
.await?;
stream.flush().await?;
println!("Hello: {}", peer_addr);
}

Ok(()) as io::Result<()>
};

pool.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap();
pool.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err)))
.unwrap();
}

Ok(())
};

executor::block_on(fut)
})
}
1 change: 1 addition & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
edition = "2018"
76 changes: 52 additions & 24 deletions src/common/test_stream.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
use super::Stream;
use futures::executor;
use futures::io::{AsyncRead, AsyncWrite};
use futures::prelude::*;
use futures::task::{noop_waker_ref, Context};
use rustls::internal::pemfile::{certs, rsa_private_keys};
use rustls::{ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession, Session};
use std::io::{self, BufReader, Cursor, Read, Write};
use std::pin::Pin;
use std::task::Poll;
use std::sync::Arc;
use futures::prelude::*;
use futures::task::{ Context, noop_waker_ref };
use futures::executor;
use futures::io::{ AsyncRead, AsyncWrite };
use std::io::{ self, Read, Write, BufReader, Cursor };
use std::task::Poll;
use webpki::DNSNameRef;
use rustls::internal::pemfile::{ certs, rsa_private_keys };
use rustls::{
ServerConfig, ClientConfig,
ServerSession, ClientSession,
Session, NoClientAuth
};
use super::Stream;


struct Good<'a>(&'a mut dyn Session);

impl<'a> AsyncRead for Good<'a> {
fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.0.write_tls(buf.by_ref()))
}
}

impl<'a> AsyncWrite for Good<'a> {
fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
mut buf: &[u8],
) -> Poll<io::Result<usize>> {
let len = self.0.read_tls(buf.by_ref())?;
self.0.process_new_packets()
self.0
.process_new_packets()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Poll::Ready(Ok(len))
}
Expand All @@ -44,13 +48,21 @@ impl<'a> AsyncWrite for Good<'a> {
struct Bad(bool);

impl AsyncRead for Bad {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(0))
}
}

impl AsyncWrite for Bad {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.0 {
Poll::Pending
} else {
Expand Down Expand Up @@ -105,13 +117,22 @@ fn stream_bad() -> io::Result<()> {

let mut bad = Bad(true);
let mut stream = Stream::new(&mut bad, &mut client);
assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
assert_eq!(
future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?,
8
);
assert_eq!(
future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?,
8
);
let r = future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer
assert!(r < 1024);

let mut cx = Context::from_waker(noop_waker_ref());
assert!(stream.as_mut_pin().poll_write(&mut cx, &[0x01]).is_pending());
assert!(stream
.as_mut_pin()
.poll_write(&mut cx, &[0x01])
.is_pending());

Ok(()) as io::Result<()>
};
Expand Down Expand Up @@ -154,7 +175,10 @@ fn stream_handshake_eof() -> io::Result<()> {

let mut cx = Context::from_waker(noop_waker_ref());
let r = stream.complete_io(&mut cx);
assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof)));
assert_eq!(
r.map_err(|err| err.kind()),
Poll::Ready(Err(io::ErrorKind::UnexpectedEof))
);

Ok(()) as io::Result<()>
};
Expand Down Expand Up @@ -201,7 +225,11 @@ fn make_pair() -> (ServerSession, ClientSession) {
(server, client)
}

fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn do_handshake(
client: &mut ClientSession,
server: &mut ServerSession,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let mut good = Good(server);
let mut stream = Stream::new(&mut good, client);

Expand Down
25 changes: 13 additions & 12 deletions src/test_0rtt.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
use std::io;
use std::sync::Arc;
use std::net::ToSocketAddrs;
use crate::{client::TlsStream, TlsConnector};
use futures::executor;
use futures::prelude::*;
use romio::tcp::TcpStream;
use rustls::ClientConfig;
use crate::{ TlsConnector, client::TlsStream };

use std::io;
use std::net::ToSocketAddrs;
use std::sync::Arc;

async fn get(config: Arc<ClientConfig>, domain: &str, rtt0: bool)
-> io::Result<(TlsStream<TcpStream>, String)>
{
async fn get(
config: Arc<ClientConfig>,
domain: &str,
rtt0: bool,
) -> io::Result<(TlsStream<TcpStream>, String)> {
let connector = TlsConnector::from(config).early_data(rtt0);
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);

let addr = (domain, 443)
.to_socket_addrs()?
.next().unwrap();
let addr = (domain, 443).to_socket_addrs()?.next().unwrap();
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
let mut buf = Vec::new();

Expand All @@ -31,7 +30,9 @@ async fn get(config: Arc<ClientConfig>, domain: &str, rtt0: bool)
#[test]
fn test_0rtt() {
let mut config = ClientConfig::new();
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
config.enable_early_data = true;
let config = Arc::new(config);
let domain = "mozilla-modern.badssl.com";
Expand Down
Loading

0 comments on commit d3891de

Please sign in to comment.