From 918439347b67d3c55e8db661707f226e6467597e Mon Sep 17 00:00:00 2001 From: 4t145 Date: Thu, 4 Jan 2024 18:24:31 +0800 Subject: [PATCH] update --- Cargo.toml | 7 +- kernel-common/src/inner_model/gateway.rs | 22 +- kernel/Cargo.toml | 3 +- kernel/src/config/config_by_k8s.rs | 2 +- kernel/src/functions/http_client.rs | 186 ++++++------ kernel/src/functions/http_route.rs | 23 +- kernel/src/functions/server.rs | 286 ++++++------------ kernel/src/functions/websocket.rs | 2 +- kernel/src/instance.rs | 4 +- kernel/src/plugins/context.rs | 34 ++- kernel/src/plugins/filters/retry.rs | 2 +- kernel/src/plugins/filters/status.rs | 4 +- .../plugins/filters/status/status_plugin.rs | 4 +- 13 files changed, 267 insertions(+), 312 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 64090ecd..8dd73b40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,11 +39,14 @@ tardis = { version = "=0.1.0-rc.7" } # Http http = { version = "0.2" } -rustls = { version = "0.21.0" } +rustls = { version = "0.22.1" } hyper = { version = "1", features = ["full"] } +hyper-util = { version = "0.1.2", features = ["server-auto", "tokio", "client-legacy", "client"] } +http-body-util = { version = "0.1" } hyper-rustls = { version = "0.24" } +hyper-tls = { version = "0.6.0"} rustls-pemfile = { version = "1" } -tokio-rustls = { version = "0.24", default-features = false } +tokio-rustls = { version = "0.25", default-features = false } # K8s kube = { version = "0.85", features = ["runtime", "derive"] } diff --git a/kernel-common/src/inner_model/gateway.rs b/kernel-common/src/inner_model/gateway.rs index 3a29e63b..507f7709 100644 --- a/kernel-common/src/inner_model/gateway.rs +++ b/kernel-common/src/inner_model/gateway.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, str::FromStr}; +use std::{fmt::Display, str::FromStr, net::IpAddr}; use super::plugin_filter::SgRouteFilter; use serde::{Deserialize, Serialize}; @@ -46,7 +46,7 @@ pub struct SgListener { /// Name is the name of the Listener. This name MUST be unique within a Gateway. pub name: String, /// Ip bound to the Listener. Default is 0.0.0.0 - pub ip: Option, + pub ip: Option, /// Port is the network port. Multiple listeners may use the same port, subject /// to the Listener compatibility rules. pub port: u16, @@ -77,17 +77,23 @@ pub enum SgProtocol { Wss, } -impl Display for SgProtocol { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl SgProtocol { + pub const fn as_str(&self) -> &'static str { match self { - SgProtocol::Http => write!(f, "http"), - SgProtocol::Https => write!(f, "https"), - SgProtocol::Ws => write!(f, "ws"), - SgProtocol::Wss => write!(f, "wss"), + SgProtocol::Http => "http", + SgProtocol::Https => "https", + SgProtocol::Ws => "ws", + SgProtocol::Wss => "wss", } } } +impl Display for SgProtocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } +} + /// GatewayTLSConfig describes a TLS configuration. #[derive(Debug, Serialize, Deserialize, Clone)] pub struct SgTlsConfig { diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index fc697836..d107c087 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -32,10 +32,11 @@ urlencoding.workspace = true async-compression.workspace = true http-body-util.workspace = true hyper-util.workspace = true +hyper-tls.workspace = true kernel-common = { path = "../kernel-common" } tardis = { workspace = true, features = ["future", "crypto", "tls"] } http.workspace = true -rustls = { workspace = true, features = ["dangerous_configuration"] } +rustls = { workspace = true } hyper.workspace = true hyper-rustls.workspace = true rustls-pemfile.workspace = true diff --git a/kernel/src/config/config_by_k8s.rs b/kernel/src/config/config_by_k8s.rs index b3770975..ea71e5e8 100644 --- a/kernel/src/config/config_by_k8s.rs +++ b/kernel/src/config/config_by_k8s.rs @@ -680,7 +680,7 @@ async fn process_http_route_config(mut http_route_objs: Vec) -> }, http_route_obj.spec.inner.parent_refs.as_ref().ok_or_else(|| TardisError::format_error("[SG.Config] HttpRoute [spec.parentRefs] is required", ""))?[0].name ); - let priority=http_route_obj.annotations().get(kernel_common::constants::ANNOTATION_RESOURCE_PRIORITY).and_then(|a| a.parse::().ok()).unwrap_or(0); + let priority = http_route_obj.annotations().get(kernel_common::constants::ANNOTATION_RESOURCE_PRIORITY).and_then(|a| a.parse::().ok()).unwrap_or(0); let http_route_config = SgHttpRoute { name: get_k8s_obj_unique(&http_route_obj), gateway_name: rel_gateway_name, diff --git a/kernel/src/functions/http_client.rs b/kernel/src/functions/http_client.rs index 67972cd5..9749145f 100644 --- a/kernel/src/functions/http_client.rs +++ b/kernel/src/functions/http_client.rs @@ -3,33 +3,37 @@ use std::{ time::Duration, }; -use crate::plugins::context::SgRoutePluginContext; -use http::{HeaderMap, HeaderValue, Method, Request, Response, StatusCode}; -use hyper::Error; +use crate::plugins::context::{SgRouteFilterRequestAction, SgRoutePluginContext}; +use http_body_util::Empty; +use hyper::{body::Incoming, Error}; +use hyper::{header::HeaderValue, HeaderMap, Method, Request, Response, StatusCode, Uri}; use hyper_rustls::{ConfigBuilderExt, HttpsConnector}; +use hyper_util::client::legacy::connect::HttpConnector; use kernel_common::inner_model::gateway::SgProtocol; +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerifier}; use tardis::{ basic::{error::TardisError, result::TardisResult}, log, - tokio::time::timeout, + tokio::{time::timeout, self}, }; -const DEFAULT_TIMEOUT_MS: u64 = 5000; +type Client = hyper_util::client::legacy::Client, ()>; +const DEFAULT_TIMEOUT: Duration = Duration::from_millis(5000); -static DEFAULT_CLIENT: OnceLock>> = OnceLock::new(); +static DEFAULT_CLIENT: OnceLock = OnceLock::new(); -pub fn init() -> TardisResult<&'static Client>> { +pub fn init() -> TardisResult<&'static Client> { if DEFAULT_CLIENT.get().is_none() { let _ = DEFAULT_CLIENT.set(do_init(false)?); } Ok(default_client()) } -pub fn get_ignore_validation_clint() -> TardisResult>> { +pub fn get_ignore_validation_clint() -> TardisResult { do_init(true) } -fn do_init(ignore_validation: bool) -> TardisResult>> { +fn do_init(ignore_validation: bool) -> TardisResult { fn get_tls_config(ignore: bool) -> rustls::ClientConfig { if ignore { get_rustls_config_dangerous() @@ -37,9 +41,9 @@ fn do_init(ignore_validation: bool) -> TardisResult rustls::ClientConfig { config } - +#[derive(Debug)] pub struct NoCertificateVerification {} -impl rustls::client::ServerCertVerifier for NoCertificateVerification { +impl ServerCertVerifier for NoCertificateVerification { + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &rustls::pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &rustls::pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + Ok(HandshakeSignatureValid::assertion()) + } + fn verify_server_cert( &self, - _end_entity: &rustls::Certificate, - _intermediates: &[rustls::Certificate], - _server_name: &rustls::ServerName, - _scts: &mut dyn Iterator, - _ocsp: &[u8], - _now: std::time::SystemTime, - ) -> Result { - Ok(rustls::client::ServerCertVerified::assertion()) + end_entity: &rustls::pki_types::CertificateDer<'_>, + intermediates: &[rustls::pki_types::CertificateDer<'_>], + server_name: &rustls::pki_types::ServerName<'_>, + ocsp_response: &[u8], + now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) } } #[inline] -fn default_client() -> &'static Client> { +fn default_client() -> &'static Client { DEFAULT_CLIENT.get().expect("DEFAULT_CLIENT not initialized") } pub struct RequestConfig { - timeout: Option, - + pub timeout: Option, } - -pub async fn request( - client: &Client>, - rule_timeout_ms: Option, - redirect: bool, - mut ctx: SgRoutePluginContext, -) -> TardisResult { - if redirect { - ctx = do_request(client, &ctx.request.get_uri().to_string(), rule_timeout_ms, ctx).await?; +impl SgRoutePluginContext { + pub async fn request(&mut self, client: &Client, mut config: RequestConfig) -> TardisResult<()> { + if SgRouteFilterRequestAction::Redirect == self.get_action() { + self.do_request(client, &config).await?; + } + if let Some(backend) = self.get_chosen_backend_mut() { + let mut base_uri = backend.build_base_uri(); + if let Some(prq) = self.request.get_uri().path_and_query() { + base_uri.path_and_query(prq) + } + let uri = base_uri.build()?; + config.timeout = backend.timeout_ms.map(Duration::from_millis).or(config.timeout); + *self.request.get_uri_mut() = uri; + self.do_request(client, &config) + } + Ok(()) } - if let Some(backend) = ctx.get_chose_backend() { - let scheme = backend.protocol.as_ref().unwrap_or(&SgProtocol::Http); - let host = format!("{}{}", backend.name_or_host, backend.namespace.as_ref().map(|n| format!(".{n}")).unwrap_or("".to_string())); - let port = if (backend.port == 0 || backend.port == 80) && scheme == &SgProtocol::Http || (backend.port == 0 || backend.port == 443) && scheme == &SgProtocol::Https { - "".to_string() - } else { - format!(":{}", backend.port) + pub async fn do_request(&mut self, client: &Client, config: &RequestConfig) -> TardisResult<()> { + let ctx = match raw_request( + Some(client), + self.request.get_method().clone(), + self.request.get_uri(), + self.request.take_body(), + self.request.get_headers(), + config.timeout, + ) + .await + { + Ok(response) => self.resp(response.status(), response.headers().clone(), response.into_body()), + Err(e) => self.resp_from_error(e), }; - let url = format!("{}://{}{}{}", scheme, host, port, ctx.request.get_uri().path_and_query().map(|p| p.as_str()).unwrap_or("")); - let timeout_ms = if let Some(timeout_ms) = backend.timeout_ms { Some(timeout_ms) } else { rule_timeout_ms }; - ctx = do_request(client, &url, timeout_ms, ctx).await?; - ctx.set_chose_backend(backend); + Ok(ctx) } - Ok(ctx) } -async fn do_request(client: &Client>, url: &str, timeout_ms: Option, mut ctx: SgRoutePluginContext) -> TardisResult { - let ctx = match raw_request( - Some(client), - ctx.request.get_method().clone(), - url, - ctx.request.take_body(), - ctx.request.get_headers(), - timeout_ms, - ) - .await - { - Ok(response) => ctx.resp(response.status(), response.headers().clone(), response.into_body()), - Err(e) => ctx.resp_from_error(e), - }; - Ok(ctx) -} - -pub async fn raw_request( - client: Option<&Client>>, +pub async fn raw_request( + client: Option<&Client>, method: Method, - url: &str, - body: Body, + url: &Uri, + body: Incoming, headers: &HeaderMap, - timeout_ms: Option, -) -> TardisResult> { - let timeout_ms = timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS); + timeout: Option, +) -> TardisResult> { + let timeout_ms = timeout.unwrap_or(DEFAULT_TIMEOUT); let method_str = method.to_string(); let url_str = url.to_string(); @@ -152,10 +168,10 @@ pub async fn raw_request( req = req.uri(url); let req = req.body(body).map_err(|error| TardisError::internal_error(&format!("[SG.Route] Build request method {method_str} url {url_str} error:{error}"), ""))?; let req = if let Some(client) = client { client.request(req) } else { init()?.request(req) }; - let response = match timeout(Duration::from_millis(timeout_ms), req).await { + let response = match tokio::time::timeout(Duration::from_millis(timeout_ms), req).await { Ok(response) => response.map_err(|error: Error| TardisError::custom("502", &format!("[SG.Client] Request method {method_str} url {url_str} error: {error}"), "")), Err(_) => { - Response::builder().status(StatusCode::GATEWAY_TIMEOUT).body(Body::empty()).map_err(|e| TardisError::internal_error(&format!("[SG.Client] timeout error: {e}"), "")) + Response::builder().status(StatusCode::GATEWAY_TIMEOUT).body(Empty::default()).map_err(|e| TardisError::internal_error(&format!("[SG.Client] timeout error: {e}"), "")) } }?; Ok(response) @@ -167,15 +183,13 @@ mod tests { use hyper::body::Body; use tardis::{basic::result::TardisResult, tokio}; - use crate::plugins::context::AvailableBackendInst; - use crate::{ - functions::http_client::{init, request}, - plugins::context::SgRoutePluginContext, - }; - use hyper::{client::HttpConnector, Client}; + use crate::plugins::context::{AvailableBackendInst, SgRouteFilterRequestAction}; + use crate::{functions::http_client::init, plugins::context::SgRoutePluginContext}; use hyper_rustls::HttpsConnector; use kernel_common::inner_model::gateway::SgProtocol; + use super::Client; + #[tokio::test] async fn test_request() -> TardisResult<()> { let client = init().unwrap(); @@ -340,19 +354,17 @@ mod tests { // Because this unit test depends on the external url, // it may be due to the failure of the external url, so add retry - async fn retry_test_request( - client: &Client>, - rule_timeout_ms: Option, - redirect: bool, - mut ctx: SgRoutePluginContext, - ) -> TardisResult { + async fn retry_test_request(client: &Client, rule_timeout_ms: Option, redirect: bool, mut ctx: SgRoutePluginContext) -> TardisResult { + if redirect { + ctx.set_action(SgRouteFilterRequestAction::Redirect); + } let clone_body = ctx.request.dump_body().await?; let mut clone_ctx = ctx.clone(); clone_ctx.request.set_body(clone_body); - let mut result = request(client, rule_timeout_ms, redirect, ctx).await?; - if !result.response.get_status_code().is_success() { - result = request(client, rule_timeout_ms, redirect, clone_ctx).await?; + ctx.request(client, rule_timeout_ms).await?; + if !ctx.response.get_status_code().is_success() { + clone_ctx.request(client, rule_timeout_ms).await?; } - Ok(result) + Ok(clone_ctx) } } diff --git a/kernel/src/functions/http_route.rs b/kernel/src/functions/http_route.rs index 27b7197f..96942e80 100644 --- a/kernel/src/functions/http_route.rs +++ b/kernel/src/functions/http_route.rs @@ -1,5 +1,7 @@ +use std::time::Duration; use std::{collections::HashMap, net::SocketAddr}; +use crate::functions::http_client::RequestConfig; use crate::instance::{SgBackendInst, SgGatewayInst, SgHttpHeaderMatchInst, SgHttpQueryMatchInst}; use crate::{ instance::{SgHttpPathMatchInst, SgHttpRouteInst, SgHttpRouteMatchInst, SgHttpRouteRuleInst}, @@ -9,7 +11,7 @@ use crate::{ }, }; use http::{header::UPGRADE, HeaderValue, Request, Response}; -use hyper::{Body, StatusCode}; +use hyper::{body::Incoming, StatusCode}; use crate::plugins::context::AvailableBackendInst; use itertools::Itertools; @@ -250,7 +252,7 @@ async fn get(name: &str) -> TardisResult> { } } -pub async fn process(gateway_name: Arc, req_scheme: &str, (remote_addr, local_addr): (SocketAddr, SocketAddr), mut request: Request) -> TardisResult> { +pub async fn process(gateway_name: Arc, req_scheme: &str, (remote_addr, local_addr): (SocketAddr, SocketAddr), mut request: Request) -> TardisResult> { if request.uri().host().is_none() && request.headers().contains_key("Host") { *request.uri_mut() = format!( "{}://{}{}", @@ -381,18 +383,17 @@ pub async fn process(gateway_name: Arc, req_scheme: &str, (remote_addr, let mut ctx = if ctx.get_action() == &SgRouteFilterRequestAction::Response { ctx } else { - let rule_timeout = if let Some(matched_rule_inst) = matched_rule_inst { + let timeout = matched_rule_inst.map(|i| i.timeout_ms).flatten().map(Duration::from_millis); + if let Some(matched_rule_inst) = matched_rule_inst { matched_rule_inst.timeout_ms } else { None }; - match backend { Some(b) => log::debug!("[SG.Request] matched backend: {}", b), None => log::info!("[SG.Request] matched no backend"), } - - http_client::request(&gateway_inst.client, rule_timeout, ctx.get_action() == &SgRouteFilterRequestAction::Redirect, ctx).await? + ctx.request(&gateway_inst.client, RequestConfig { timeout }).await?; }; if log::level_enabled!(log::Level::TRACE) { @@ -420,7 +421,7 @@ pub async fn process(gateway_name: Arc, req_scheme: &str, (remote_addr, process_response_headers(ctx).await?.build_response().await } -fn process_request_headers(request: &mut Request, remote_addr: SocketAddr) -> TardisResult<()> { +fn process_request_headers(request: &mut Request, remote_addr: SocketAddr) -> TardisResult<()> { const X_FORWARDED_FOR: &str = "X-Forwarded-For"; let real_ip = remote_addr.ip().to_string(); let forwarded_for = match request.headers().get(X_FORWARDED_FOR) { @@ -472,7 +473,7 @@ async fn process_response_headers(mut ctx: SgRoutePluginContext) -> TardisResult /// 1. Exact domain match: "example.com" -> Handles exact hostname "example.com" /// 2. Wildcard domain match: "*.example.com" -> Handles any subdomain of "example.com" /// 3. Unspecified domain match: "*" -> Handles any hostname not matched by the above rules -fn match_route_process<'a>(req: &Request, routes: &'a [SgHttpRouteInst]) -> (Option<&'a SgHttpRouteInst>, Option<&'a SgHttpRouteRuleInst>, Option<&'a SgHttpRouteMatchInst>) { +fn match_route_process<'a>(req: &Request, routes: &'a [SgHttpRouteInst]) -> (Option<&'a SgHttpRouteInst>, Option<&'a SgHttpRouteRuleInst>, Option<&'a SgHttpRouteMatchInst>) { let (highest, second, lowest) = match_route_insts_with_hostname_priority(req.uri().host(), routes); let matched_hostname_route_priorities = [highest, second, lowest]; @@ -585,7 +586,7 @@ fn match_route_insts_with_hostname_priority<'a>( } } -fn match_rule_inst<'a>(req: &Request, rule_matches: Option<&'a Vec>) -> (bool, Option<&'a SgHttpRouteMatchInst>) { +fn match_rule_inst<'a>(req: &Request, rule_matches: Option<&'a Vec>) -> (bool, Option<&'a SgHttpRouteMatchInst>) { if let Some(matches) = rule_matches { for rule_match in matches { if let Some(method) = &rule_match.method { @@ -696,7 +697,7 @@ fn match_listeners_hostname_and_port(hostname: Option<&str>, port: u16, listener async fn process_req_filters_http( gateway_name: String, remote_addr: SocketAddr, - request: Request, + request: Request, backend_filters: Option<&[(String, BoxSgPluginFilter)]>, rule_filters: Option<&[(String, BoxSgPluginFilter)]>, route_filters: &[(String, BoxSgPluginFilter)], @@ -723,7 +724,7 @@ async fn process_req_filters_http( async fn process_req_filters_ws( gateway_name: String, remote_addr: SocketAddr, - request: &Request, + request: &Request, backend_filters: Option<&[(String, BoxSgPluginFilter)]>, rule_filters: Option<&[(String, BoxSgPluginFilter)]>, route_filters: &[(String, BoxSgPluginFilter)], diff --git a/kernel/src/functions/server.rs b/kernel/src/functions/server.rs index 4e80817d..aa5990f9 100644 --- a/kernel/src/functions/server.rs +++ b/kernel/src/functions/server.rs @@ -7,14 +7,13 @@ use std::{ use core::task::{Context, Poll}; use http::{HeaderValue, Request, Response, StatusCode}; -use hyper::server::conn::{AddrIncoming, AddrStream}; -use hyper::service::{make_service_fn, service_fn}; -use hyper::Server; -use hyper::{server::accept::Accept, Body}; +use http_body_util::StreamBody; +use hyper::{body::{Incoming, Body}, service::service_fn}; +use hyper_util::rt::{TokioExecutor, TokioIo}; use kernel_common::inner_model::gateway::{SgGateway, SgProtocol, SgTlsMode}; use lazy_static::lazy_static; -use rustls::{PrivateKey, ServerConfig}; +use rustls::{ServerConfig, pki_types::{PrivateKeyDer, PrivateSec1KeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer}}; use serde_json::json; use std::pin::Pin; use std::sync::Arc; @@ -43,6 +42,29 @@ lazy_static! { static ref START_JOIN_HANDLE: Arc>>> = <_>::default(); } +pub async fn serve(shutdown_rx: tokio::sync::watch::Receiver<()>, listener: tokio::net::TcpListener, handler: impl Fn(tokio::net::TcpStream, SocketAddr) + Send + Sync + 'static) { + loop { + let conn = tokio::select! { + _ = shutdown_rx.changed() => { + break; + } + next = listener.accept() => { + next + } + }; + match conn { + Ok((socket, addr)) => { + log::debug!("[SG.Server] Accepting from: {}", addr); + handler(socket, addr); + } + Err(e) => { + log::error!("[SG.Server] Error: {}", e); + break; + } + } + } +} + pub async fn init(gateway_conf: &SgGateway) -> TardisResult> { if gateway_conf.listeners.is_empty() { return Err(TardisError::bad_request("[SG.Server] Missing Listeners", "")); @@ -71,100 +93,81 @@ pub async fn init(gateway_conf: &SgGateway) -> TardisResult> { let gateway_name = Arc::new(gateway_conf.name.to_string()); let mut server_insts: Vec = Vec::new(); for listener in &gateway_conf.listeners { - let ip = listener.ip.as_deref().unwrap_or("0.0.0.0"); - let addr = if ip.contains('.') { - let ip: Ipv4Addr = ip.parse().map_err(|_| TardisError::bad_request(&format!("[SG.Server] IP {ip} is not legal"), ""))?; - SocketAddr::new(std::net::IpAddr::V4(ip), listener.port) - } else { - let ip: Ipv6Addr = ip.parse().map_err(|_| TardisError::bad_request(&format!("[SG.Server] IP {ip} is not legal"), ""))?; - SocketAddr::new(std::net::IpAddr::V6(ip), listener.port) - }; - let mut shutdown_rx = shutdown_tx.subscribe(); + let ip = listener.ip.unwrap_or(std::net::Ipv4Addr::UNSPECIFIED.into()); + let addr = SocketAddr::new(ip, listener.port); let gateway_name = gateway_name.clone(); let protocol = listener.protocol.to_string(); - if let Some(tls) = &listener.tls { - log::debug!("[SG.Server] Tls is init...mode:{:?}", tls.mode); - if SgTlsMode::Terminate == tls.mode { - let tls_cfg = { - let certs = rustls_pemfile::certs(&mut tls.tls.cert.as_bytes()) - .map_err(|error| TardisError::bad_request(&format!("[SG.Server] Tls certificates not legal: {error}"), ""))?; - let certs = certs.into_iter().map(rustls::Certificate).collect::>(); - let key = rustls_pemfile::read_all(&mut tls.tls.key.as_bytes()) - .map_err(|error| TardisError::bad_request(&format!("[SG.Server] Tls private keys not legal: {error}"), ""))?; - if key.is_empty() { - return Err(TardisError::bad_request("[SG.Server] not found Tls private key", "")); - } - let mut selected_key = None; - for k in key { - selected_key = match k { - rustls_pemfile::Item::X509Certificate(_) => continue, - rustls_pemfile::Item::RSAKey(k) => Some(k), - rustls_pemfile::Item::PKCS8Key(k) => Some(k), - rustls_pemfile::Item::ECKey(k) => Some(k), - _ => continue, - }; - if selected_key.is_some() { - break; - } - } - if let Some(selected_key) = selected_key { - let key = PrivateKey(selected_key); - let mut cfg = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certs, key) - .map_err(|error| TardisError::bad_request(&format!("[SG.Server] Tls not legal: {error}"), ""))?; - cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; - sync::Arc::new(cfg) - } else { - return Err(TardisError::not_implemented("[SG.Server] Tls encoding not supported ", "")); - } - }; + let tcp_listener = tokio::net::TcpListener::bind(&addr).await.map_err(|error| TardisError::bad_request(&format!("[SG.Server] Bind address error: {error}"), ""))?; + let server = hyper_util::server::conn::auto::Builder::new(TokioExecutor::default()); + - let incoming = AddrIncoming::bind(&addr).map_err(|error| TardisError::bad_request(&format!("[SG.Server] Bind address error: {error}"), ""))?; - let server = Server::builder(TlsAcceptor::new(tls_cfg, incoming)).serve(make_service_fn(move |client: &TlsStream| { - let protocol = Arc::new(protocol.clone()); - let remote_and_local_addr = match &client.state { - State::Handshaking(addr) => ( - addr.get_ref().expect("[SG.server.init] can't get addr").remote_addr(), - addr.get_ref().expect("[SG.server.init] can't get addr").local_addr(), - ), - State::Streaming(addr) => (addr.get_ref().0.remote_addr(), addr.get_ref().0.local_addr()), + + let server = 'create_server: { + if let Some(tls_cfg) = &listener.tls { + log::debug!("[SG.Server] Tls is init...mode:{:?}", tls_cfg.mode); + if SgTlsMode::Terminate == tls_cfg.mode { + let tls_cfg = { + let certs = rustls_pemfile::certs(&mut tls_cfg.tls.cert.as_bytes()) + .map_err(|error| TardisError::bad_request(&format!("[SG.Server] Tls certificates not legal: {error}"), ""))?; + let certs = certs.into_iter().map(rustls::HandshakeType::Certificate).collect::>(); + let key = rustls_pemfile::read_all(&mut tls_cfg.tls.key.as_bytes()) + .map_err(|error| TardisError::bad_request(&format!("[SG.Server] Tls private keys not legal: {error}"), ""))?; + if key.is_empty() { + return Err(TardisError::bad_request("[SG.Server] not found Tls private key", "")); + } + let mut selected_key = key.into_iter().find_map(|k| { + match k { + rustls_pemfile::Item::RSAKey(k) => Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(k))), + rustls_pemfile::Item::PKCS8Key(k) => Some(PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(k))), + rustls_pemfile::Item::ECKey(k) => Some(PrivateKeyDer::Sec1(PrivateSec1KeyDer::from(k))), + _ => None, + } + }); + if let Some(key_der) = selected_key { + let mut cfg = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key_der) + .map_err(|error| TardisError::bad_request(&format!("[SG.Server] Tls not legal: {error}"), ""))?; + cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + sync::Arc::new(cfg) + } else { + return Err(TardisError::not_implemented("[SG.Server] Tls encoding not supported ", "")); + } }; - let gateway_name = gateway_name.clone(); - async move { Ok::<_, Infallible>(service_fn(move |req| process(gateway_name.clone(), protocol.clone(), remote_and_local_addr, req))) } - })); - let server = server.with_graceful_shutdown(async move { - shutdown_rx.changed().await.ok(); - }); - server_insts.push(SgServerInst { addr, server: server.boxed() }); - } else { - let server = Server::bind(&addr).serve(make_service_fn(move |client: &AddrStream| { - let protocol = Arc::new(protocol.clone()); - let remote_addr = client.remote_addr(); - let local_addr = client.local_addr(); - let gateway_name = gateway_name.clone(); - async move { Ok::<_, Infallible>(service_fn(move |req| process(gateway_name.clone(), protocol.clone(), (remote_addr, local_addr), req))) } - })); - let server = server.with_graceful_shutdown(async move { - shutdown_rx.changed().await.ok(); - }); - server_insts.push(SgServerInst { addr, server: server.boxed() }); + let tls_acceptor = tokio_rustls::TlsAcceptor::from(tls_cfg); + break 'create_server serve(shutdown_rx, tcp_listener, move |stream, peer_addr| { + tokio::spawn(async move { + let stream = tls_acceptor.accept(stream).await; + match stream { + Ok(s) => { + server.serve_connection( + TokioIo::new(stream), + service_fn(move |request: Request| process(gateway_name.clone(), protocol.clone(), (peer_addr, addr), request)), + ); + }, + Err(e) => { + log::error!("[SG.Server] Tls handshake error: {}", e); + }, + } + }) + }).boxed() + } } - } else { - let server = Server::bind(&addr).serve(make_service_fn(move |client: &AddrStream| { - let protocol = Arc::new(protocol.clone()); - let remote_and_local_addr = (client.remote_addr(), client.local_addr()); - let gateway_name = gateway_name.clone(); - async move { Ok::<_, Infallible>(service_fn(move |req| process(gateway_name.clone(), protocol.clone(), remote_and_local_addr, req))) } - })); - let server = server.with_graceful_shutdown(async move { - shutdown_rx.changed().await.ok(); - }); - server_insts.push(SgServerInst { addr, server: server.boxed() }); - } + serve(shutdown_rx, tcp_listener, move |stream, peer_addr| { + tokio::spawn(async move { + server.serve_connection( + TokioIo::new(stream), + service_fn(move |request: Request| process(gateway_name.clone(), protocol.clone(), (peer_addr, addr), request)), + ); + }) + }).boxed() + }; + + + server_insts.push(SgServerInst { addr, server }); + } let mut shutdown = SHUTDOWN_TX.lock().await; @@ -173,12 +176,12 @@ pub async fn init(gateway_conf: &SgGateway) -> TardisResult> { Ok(server_insts) } -async fn process( +async fn process( gateway_name: Arc, req_scheme: Arc, (remote_addr, local_addr): (SocketAddr, SocketAddr), - request: Request, -) -> Result, hyper::Error> { + request: Request, +) -> Result, hyper::Error> { let method = request.method().to_string().clone(); let uri = request.uri().to_string().clone(); let response = http_route::process(gateway_name, req_scheme.as_str(), (remote_addr, local_addr), request).await; @@ -208,7 +211,7 @@ async fn process( result } -fn into_http_error(error: TardisError) -> Result, hyper::Error> { +fn into_http_error(error: TardisError) -> Result, hyper::Error> { let status_code = match error.code.parse::() { Ok(code) => match StatusCode::from_u16(code) { Ok(status_code) => status_code, @@ -277,97 +280,6 @@ pub async fn shutdown(gateway_name: &str) -> TardisResult<()> { Ok(()) } -struct TlsAcceptor { - config: Arc, - incoming: AddrIncoming, -} - -impl TlsAcceptor { - pub fn new(config: Arc, incoming: AddrIncoming) -> TlsAcceptor { - TlsAcceptor { config, incoming } - } -} - -impl Accept for TlsAcceptor { - type Conn = TlsStream; - type Error = io::Error; - - fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - let pin = self.get_mut(); - match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { - Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), - Some(Err(e)) => Poll::Ready(Some(Err(e))), - None => Poll::Ready(None), - } - } -} - -enum State { - Handshaking(tokio_rustls::Accept), - Streaming(tokio_rustls::server::TlsStream), -} - -struct TlsStream { - state: State, -} - -impl TlsStream { - fn new(stream: AddrStream, config: Arc) -> TlsStream { - let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); - TlsStream { - state: State::Handshaking(accept), - } - } -} - -impl AsyncRead for TlsStream { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll> { - let pin = self.get_mut(); - match pin.state { - State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_read(cx, buf); - pin.state = State::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - }, - State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), - } - } -} - -impl AsyncWrite for TlsStream { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - let pin = self.get_mut(); - match pin.state { - State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_write(cx, buf); - pin.state = State::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - }, - State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.state { - State::Handshaking(_) => Poll::Ready(Ok(())), - State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.state { - State::Handshaking(_) => Poll::Ready(Ok(())), - State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), - } - } -} - pub struct SgServerInst { pub addr: SocketAddr, pub server: Pin> + std::marker::Send>>, diff --git a/kernel/src/functions/websocket.rs b/kernel/src/functions/websocket.rs index ddd8ce19..8e8e57db 100644 --- a/kernel/src/functions/websocket.rs +++ b/kernel/src/functions/websocket.rs @@ -17,7 +17,7 @@ use tardis::{log, tokio, TardisFuns}; use crate::instance::SgBackendInst; -pub async fn process(gateway_name: Arc, remote_addr: SocketAddr, backend: &SgBackendInst, mut request: Request) -> TardisResult> { +pub async fn process(gateway_name: Arc, remote_addr: SocketAddr, backend: &SgBackendInst, mut request: Request) -> TardisResult> { let have_upgrade = request .headers() .get(CONNECTION) diff --git a/kernel/src/instance.rs b/kernel/src/instance.rs index d2492d13..c995e80c 100644 --- a/kernel/src/instance.rs +++ b/kernel/src/instance.rs @@ -2,6 +2,8 @@ use crate::plugins::filters::BoxSgPluginFilter; use http::Method; use hyper_rustls::HttpsConnector; +use hyper_util::client::legacy::Client; +use hyper_util::client::legacy::connect::HttpConnector; use kernel_common::inner_model::gateway::{SgListener, SgProtocol}; use kernel_common::inner_model::http_route::{SgHttpHeaderMatchType, SgHttpPathMatchType, SgHttpQueryMatchType}; use std::{fmt, vec::Vec}; @@ -10,7 +12,7 @@ use tardis::regex::Regex; pub(crate) struct SgGatewayInst { pub filters: Vec<(String, BoxSgPluginFilter)>, pub routes: Vec, - pub client: Client>, + pub client: Client, ()>, pub listeners: Vec, } diff --git a/kernel/src/plugins/context.rs b/kernel/src/plugins/context.rs index 49797ef9..3b1811d5 100644 --- a/kernel/src/plugins/context.rs +++ b/kernel/src/plugins/context.rs @@ -1,3 +1,4 @@ +use http::uri::Builder; use http::{HeaderMap, HeaderName, HeaderValue, Method, Response, StatusCode, Uri, Version}; use http_body_util::{BodyExt, Collected}; use hyper::body::{Body, Incoming}; @@ -59,15 +60,21 @@ impl AvailableBackendInst { } } - pub fn get_base_url(&self) -> String { + pub fn build_base_uri(&self) -> Builder { let scheme = self.protocol.as_ref().unwrap_or(&SgProtocol::Http); - let host = format!("{}{}", self.name_or_host, self.namespace.as_ref().map(|n| format!(".{n}")).unwrap_or("".to_string())); let port = if (self.port == 0 || self.port == 80) && scheme == &SgProtocol::Http || (self.port == 0 || self.port == 443) && scheme == &SgProtocol::Https { - "".to_string() + None } else { - format!(":{}", self.port) + Some(self.port) }; - format!("{}://{}{}", scheme, host, port) + let mut auth = String::from(self.name_or_host); + if let Some(ref namespace) = self.namespace { + write!(auth, ".{namespace}"); + } + if let Some(port) = port { + write!(auth, ":{port}"); + } + Uri::builder().scheme(scheme.as_str()).authority(auth) } } @@ -160,7 +167,7 @@ pub struct SgCtxRequest { } impl SgCtxRequest { - pub fn new(method: Method, uri: Uri, version: Version, headers: HeaderMap, body: Body, remote_addr: SocketAddr) -> Self { + pub fn new(method: Method, uri: Uri, version: Version, headers: HeaderMap, body: Incoming, remote_addr: SocketAddr) -> Self { Self { method: MaybeModified::new(method), uri: MaybeModified::new(uri), @@ -201,6 +208,11 @@ impl SgCtxRequest { self.uri.get_raw() } + #[inline] + pub fn get_uri_mut(&self) -> &mut Uri { + self.uri.get_mut() + } + #[inline] pub fn get_version(&self) -> &Version { &self.version @@ -493,7 +505,7 @@ impl SgRoutePluginContext { uri: Uri, version: Version, headers: HeaderMap, - body: Body, + body: Incoming, remote_addr: SocketAddr, gateway_name: String, chose_route_rule: Option, @@ -537,7 +549,7 @@ impl SgRoutePluginContext { } /// The following two methods can only be used to fill in the context [resp] [resp_from_error] - pub fn resp(mut self, status_code: StatusCode, headers: HeaderMap, body: Body) -> Self { + pub fn resp(mut self, status_code: StatusCode, headers: HeaderMap, body: Incoming) -> Self { self.response.status_code.reset(status_code); self.response.headers.reset(headers); self.response.body = body; @@ -564,7 +576,7 @@ impl SgRoutePluginContext { } /// build response from Context - pub async fn build_response(&mut self) -> TardisResult> { + pub async fn build_response(&mut self) -> TardisResult> { if let Some(err) = &self.response.resp_err { return Err(err.clone()); } @@ -621,6 +633,10 @@ impl SgRoutePluginContext { self.chosen_backend.clone() } + pub fn get_chosen_backend_mut(&mut self) -> Option<&mut AvailableBackendInst> { + self.chosen_backend.as_mut() + } + pub fn get_chose_backend_name(&self) -> Option { self.get_chose_backend().map(|b| b.name_or_host) } diff --git a/kernel/src/plugins/filters/retry.rs b/kernel/src/plugins/filters/retry.rs index e4e21b2a..629a0764 100644 --- a/kernel/src/plugins/filters/retry.rs +++ b/kernel/src/plugins/filters/retry.rs @@ -138,7 +138,7 @@ fn choose_backend_url(ctx: &mut SgRoutePluginContext) -> String { } else { available_backend.get(0) }; - backend.map(|backend| backend.get_base_url()).unwrap_or_else(|| "".to_string()) + backend.map(|backend| backend.build_base_uri()).unwrap_or_else(|| "".to_string()) } else { ctx.request.get_uri().to_string() } diff --git a/kernel/src/plugins/filters/status.rs b/kernel/src/plugins/filters/status.rs index 68db82fb..40ca632f 100644 --- a/kernel/src/plugins/filters/status.rs +++ b/kernel/src/plugins/filters/status.rs @@ -131,7 +131,9 @@ impl SgPluginFilter for SgFilterStatus { service_fn(move |request: Request<()>| status_plugin::create_status_html(request, gateway_name.clone(), cache_key.clone(), title.clone())), ) } - Err(_) => todo!(), + Err(e) => { + log::error!("[SG.Filter.Status] accept error: {e}"); + }, } } }; diff --git a/kernel/src/plugins/filters/status/status_plugin.rs b/kernel/src/plugins/filters/status/status_plugin.rs index a768623f..aba92a99 100644 --- a/kernel/src/plugins/filters/status/status_plugin.rs +++ b/kernel/src/plugins/filters/status/status_plugin.rs @@ -44,7 +44,7 @@ pub(crate) async fn create_status_html( _gateway_name: Arc>, _cache_key: Arc>, title: Arc>, -) -> Result, hyper::Error> { +) -> Result, hyper::Error> { let keys; #[cfg(feature = "cache")] { @@ -87,7 +87,7 @@ pub(crate) async fn create_status_html( let title = &title.lock().await; let html = STATUS_TEMPLATE.replace("{title}", title).replace("{status}", &service_html); - Ok(Response::new(Body::from(html))) + Ok(Response::new(html)) } #[cfg(feature = "cache")]