Skip to content

Commit

Permalink
fix: tokio client task queue fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
filipton committed Aug 25, 2024
1 parent 4f00f8e commit 36af85f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
19 changes: 15 additions & 4 deletions src/client/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use anyhow::Result;
use anyhow::{anyhow, Result};
use clap::{command, Parser};
use rcgen::CertifiedKey;
use std::sync::Arc;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
time::Instant,
};
use utils::{
certs::{cert_from_str, key_from_str},
Expand All @@ -13,6 +14,8 @@ use utils::{
read_string_from_stream,
};

const MAX_REQUEST_TIME: u128 = 1000;

#[derive(Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
struct Args {
Expand Down Expand Up @@ -81,6 +84,7 @@ async fn connector(args: &Args) -> Result<()> {
let redirect_to_ssl = args.redirect_ssl && !ssl;
let domain = domain.to_string();
let acceptor = acceptor.clone();
let requested_time = Instant::now();

hello_packet[26..42].copy_from_slice(&buf[1..17]);
tokio::task::spawn(async move {
Expand All @@ -91,6 +95,7 @@ async fn connector(args: &Args) -> Result<()> {
redirect_to_ssl,
domain,
acceptor,
requested_time,
)
.await;

Expand All @@ -108,7 +113,12 @@ async fn spawn_tunnel(
redirect_to_ssl: bool,
domain: String,
acceptor: Arc<tokio_rustls::TlsAcceptor>,
request_time: Instant,
) -> Result<()> {
if request_time.elapsed().as_millis() > MAX_REQUEST_TIME {
return Err(anyhow!("Requested time exceeded max request time."));
}

let tunnel_stream = TcpStream::connect(proxy_addr).await?;
tunnel_stream.set_nodelay(true)?;
let mut tunnel_stream = acceptor.accept(tunnel_stream).await?;
Expand All @@ -118,6 +128,7 @@ async fn spawn_tunnel(
local_stream.set_nodelay(true)?;

if redirect_to_ssl {
// for example: "GET / HTTP1.1"
let mut buffer = [0u8; 1];
let mut parts = String::new();
loop {
Expand All @@ -132,12 +143,12 @@ async fn spawn_tunnel(
let path = parts[1];
let redirect = construct_http_redirect(&format!("https://{domain}{path}"));
tunnel_stream.write_all(redirect.as_bytes()).await?;
_ = tunnel_stream.shutdown().await;
} else {
_ = tokio::io::copy_bidirectional(&mut local_stream, &mut tunnel_stream).await;
_ = tunnel_stream.shutdown().await;
_ = local_stream.shutdown().await;
}

_ = tunnel_stream.shutdown().await;
_ = local_stream.shutdown().await;

Ok(())
}
9 changes: 5 additions & 4 deletions src/server/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ where

if let Ok((tunn, _)) = get_tunn_or_error(&state, &host, &mut stream).await {
let rng = state.consts.rng.secure_random;
let mut token = [0u8; 16];
rng.fill(&mut token).unwrap();
let generated_tunnel_id = u128::from_be_bytes(token);
let mut generated_tunnel_id = [0u8; 16];
rng.fill(&mut generated_tunnel_id).unwrap();
let generated_tunnel_id = u128::from_be_bytes(generated_tunnel_id);

let (tx, rx) = tokio::sync::oneshot::channel();
state.insert_tunnel_oneshot(generated_tunnel_id, tx).await;
Expand All @@ -205,8 +205,9 @@ where
tunnel.write_all(&in_buffer[..n]).await?; // relay the first packet
_ = tokio::io::copy_bidirectional(&mut stream, &mut tunnel).await;
_ = tunnel.shutdown().await;
_ = stream.shutdown().await;
}

_ = stream.shutdown().await;
Ok(())
}

Expand Down

0 comments on commit 36af85f

Please sign in to comment.