diff --git a/src/client/main.rs b/src/client/main.rs index 1e2c8ea..7e219a8 100644 --- a/src/client/main.rs +++ b/src/client/main.rs @@ -25,8 +25,8 @@ struct Args { #[arg(short, long, value_parser = parse_socketaddr, default_value = "127.0.0.1:80", env = "ADDR")] addr: SocketAddr, - #[arg(long, value_parser = parse_socketaddr, default_value = "127.0.0.1:80", env = "NON_SSL_ADDR")] - non_ssl_addr: Option, + #[arg(long, value_parser = parse_socketaddr, env = "SSL_ADDR")] + ssl_addr: Option, #[arg(long, env = "HASH")] hash: u64, @@ -36,9 +36,14 @@ struct Args { #[arg(short, long, action, env = "REDIRECT_SSL")] redirect_ssl: bool, +} - #[arg(long, action, env = "OWN_SSL")] - own_ssl: bool, +#[derive(Debug)] +struct TunnelSettings { + proxy_addr: SocketAddr, + ssl_addr: SocketAddr, + nonssl_addr: SocketAddr, + redirect_ssl: bool, } #[tokio::main] @@ -72,7 +77,8 @@ async fn connector(args: &Args) -> Result<()> { let stream = TcpStream::connect(&args.proxy_addr).await?; let mut stream = acceptor.accept(stream).await?; - let mut hello_packet = generate_hello_packet(0, &args.token, &args.hash, args.own_ssl)?; + let mut hello_packet = + generate_hello_packet(0, &args.token, &args.hash, args.ssl_addr.is_some())?; stream.write_all(&hello_packet).await?; let nonssl_port = stream.read_u16().await?; @@ -94,23 +100,21 @@ async fn connector(args: &Args) -> Result<()> { stream.read_exact(&mut buf).await?; let ssl = first_byte == 0x01; - let redirect_ssl = args.redirect_ssl; - - let addr = args.addr; - let non_ssl = args.non_ssl_addr.clone(); - let proxy_addr = args.proxy_addr; let domain = domain.to_string(); let acceptor = acceptor.clone(); let requested_time = Instant::now(); + let settings = TunnelSettings { + proxy_addr: args.proxy_addr, + ssl_addr: args.ssl_addr.unwrap_or(args.addr), + nonssl_addr: args.addr, + redirect_ssl: args.redirect_ssl, + }; hello_packet[26..42].copy_from_slice(&buf[0..16]); tokio::task::spawn(async move { let res = spawn_tunnel( hello_packet, - addr, - non_ssl, - proxy_addr, - redirect_ssl, + settings, ssl, ssl_port, domain, @@ -128,11 +132,8 @@ async fn connector(args: &Args) -> Result<()> { async fn spawn_tunnel( hello_packet: [u8; 80], - local_ssl_addr: SocketAddr, - local_non_ssl_addr: Option, - proxy_addr: SocketAddr, + settings: TunnelSettings, ssl: bool, - redirect_to_ssl: bool, ssl_port: u16, domain: String, acceptor: Arc, @@ -142,20 +143,19 @@ async fn spawn_tunnel( return Err(anyhow!("Requested time exceeded max request time.")); } - let tunnel_stream = TcpStream::connect(proxy_addr).await?; + let tunnel_stream = TcpStream::connect(settings.proxy_addr).await?; tunnel_stream.set_nodelay(true)?; let mut tunnel_stream = acceptor.accept(tunnel_stream).await?; tunnel_stream.write_all(&hello_packet).await?; - let local_addr = match (ssl, local_non_ssl_addr.is_some()) { - (true, _) => local_ssl_addr, - (false, true) => local_non_ssl_addr.unwrap(), - (false, false) => local_ssl_addr, + let local_addr = match ssl { + true => settings.ssl_addr, + false => settings.nonssl_addr, }; let mut local_stream = TcpStream::connect(local_addr).await?; local_stream.set_nodelay(true)?; - let redirect_to_ssl = redirect_to_ssl && !ssl && local_non_ssl_addr.is_none(); + let redirect_to_ssl = settings.redirect_ssl && !ssl; if redirect_to_ssl { // for example: "GET / HTTP1.1" let mut buffer = [0u8; 1];