Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
4t145 committed Jan 5, 2024
1 parent 9184393 commit d439cb2
Show file tree
Hide file tree
Showing 16 changed files with 150 additions and 96 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ rustls = { version = "0.22.1" }
hyper = { version = "1", features = ["full"] }
hyper-util = { version = "0.1.2", features = ["server-auto", "tokio", "client-legacy", "client"] }
http-body-util = { version = "0.1" }
hyper-rustls = { version = "0.24" }
hyper-rustls = { version = "0.25" }
hyper-tls = { version = "0.6.0"}
rustls-pemfile = { version = "1" }
tokio-rustls = { version = "0.25", default-features = false }
Expand Down
30 changes: 19 additions & 11 deletions kernel/src/functions/http_client.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use std::{
fmt::Debug,
sync::{Arc, OnceLock},
time::Duration,
};

use crate::plugins::context::{SgRouteFilterRequestAction, SgRoutePluginContext};
use http_body_util::Empty;
use hyper::{body::Incoming, Error};
use http_body_util::{combinators::BoxBody, Empty};
use hyper::{
body::{Body, Incoming},
Error,
};
use hyper::{header::HeaderValue, HeaderMap, Method, Request, Response, StatusCode, Uri};
use hyper_rustls::{ConfigBuilderExt, HttpsConnector};
use hyper_util::client::legacy::connect::HttpConnector;
Expand All @@ -14,10 +18,10 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerifier};
use tardis::{
basic::{error::TardisError, result::TardisResult},
log,
tokio::{time::timeout, self},
tokio::{self, time::timeout},
};

type Client = hyper_util::client::legacy::Client<HttpsConnector<HttpConnector>, ()>;
type Client = hyper_util::client::legacy::Client<HttpsConnector<HttpConnector>, BoxBody<hyper::body::Bytes, TardisError>>;
const DEFAULT_TIMEOUT: Duration = Duration::from_millis(5000);

static DEFAULT_CLIENT: OnceLock<Client> = OnceLock::new();
Expand Down Expand Up @@ -139,22 +143,22 @@ impl SgRoutePluginContext {
}
}

pub async fn raw_request<B>(
pub async fn raw_request<B: Debug + Body<Data = hyper::body::Bytes, Error = TardisError>>(
client: Option<&Client>,
method: Method,
url: &Uri,
body: Incoming,
body: B,
headers: &HeaderMap<HeaderValue>,
timeout: Option<Duration>,
) -> TardisResult<Response<Incoming>> {
let timeout_ms = timeout.unwrap_or(DEFAULT_TIMEOUT);
let timeout = timeout.unwrap_or(DEFAULT_TIMEOUT);
let method_str = method.to_string();
let url_str = url.to_string();

if log::level_enabled!(log::Level::TRACE) {
log::trace!("[SG.Client] Request method {method_str} url {url_str} header {headers:?} {body:?}, timeout {timeout_ms} ms",);
log::trace!("[SG.Client] Request method {method_str} url {url_str} header {headers:?} {body:?}, timeout {timeout:?} ms",);
} else if log::level_enabled!(log::Level::DEBUG) {
log::debug!("[SG.Client] Request method {method_str} url {url_str} header {headers:?}, timeout {timeout_ms} ms",);
log::debug!("[SG.Client] Request method {method_str} url {url_str} header {headers:?}, timeout {timeout:?} ms",);
}

let mut req = Request::builder();
Expand All @@ -167,8 +171,12 @@ pub async fn raw_request<B>(
}
req = req.uri(url);
let req = req.body(body).map_err(|error| TardisError::internal_error(&format!("[SG.Route] Build request method {method_str} url {url_str} error:{error}"), ""))?;
let req = if let Some(client) = client { client.request(req) } else { init()?.request(req) };
let response = match tokio::time::timeout(Duration::from_millis(timeout_ms), req).await {
let req = if let Some(client) = client {
client.request(req)
} else {
init()?.request(req)
};
let response = match tokio::time::timeout(timeout, req).await {
Ok(response) => response.map_err(|error: Error| TardisError::custom("502", &format!("[SG.Client] Request method {method_str} url {url_str} error: {error}"), "")),
Err(_) => {
Response::builder().status(StatusCode::GATEWAY_TIMEOUT).body(Empty::default()).map_err(|e| TardisError::internal_error(&format!("[SG.Client] timeout error: {e}"), ""))
Expand Down
3 changes: 1 addition & 2 deletions kernel/src/functions/http_route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ use crate::{
filters::{self, BoxSgPluginFilter, SgPluginFilterInitDto},
},
};
use http::{header::UPGRADE, HeaderValue, Request, Response};
use hyper::{body::Incoming, StatusCode};
use hyper::{body::Incoming, StatusCode, Request, Response, header::{HeaderValue, UPGRADE}};

use crate::plugins::context::AvailableBackendInst;
use itertools::Itertools;
Expand Down
2 changes: 1 addition & 1 deletion kernel/src/helpers/url_helper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::str::FromStr;

use http::Uri;
use hyper::Uri;
use tardis::{
basic::{error::TardisError, result::TardisResult},
url::Url,
Expand Down
2 changes: 1 addition & 1 deletion kernel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub mod functions;
pub mod helpers;
pub mod instance;
pub mod plugins;

mod utils;
#[inline]
pub async fn startup_k8s(namespace: Option<String>) -> TardisResult<()> {
k8s_client::inst(
Expand Down
70 changes: 46 additions & 24 deletions kernel/src/plugins/context.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use http::uri::Builder;
use http::{HeaderMap, HeaderName, HeaderValue, Method, Response, StatusCode, Uri, Version};
use http_body_util::{BodyExt, Collected};
use hyper::body::{Body, Incoming};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Collected, Empty, Full, BodyStream};
use hyper::body::{Body, Bytes, Incoming};
use hyper::header::{HeaderName, HeaderValue};
use hyper::{HeaderMap, Method, Response, StatusCode, Uri, Version, http};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::ops::{Deref, DerefMut};
use std::time::Duration;
use tardis::basic::error::TardisError;
use tardis::basic::result::TardisResult;

use tardis::TardisFuns;

use kernel_common::inner_model::gateway::SgProtocol;
Expand Down Expand Up @@ -60,7 +62,7 @@ impl AvailableBackendInst {
}
}

pub fn build_base_uri(&self) -> Builder {
pub fn build_base_uri(&self) -> hyper::http::uri::Builder {
let scheme = self.protocol.as_ref().unwrap_or(&SgProtocol::Http);
let port = if (self.port == 0 || self.port == 80) && scheme == &SgProtocol::Http || (self.port == 0 || self.port == 443) && scheme == &SgProtocol::Https {
None
Expand Down Expand Up @@ -163,7 +165,7 @@ pub struct SgCtxRequest {
pub version: MaybeModified<Version>,
pub headers: MaybeModified<HeaderMap<HeaderValue>>,
pub remote_addr: SocketAddr,
pub body: Incoming,
pub body: BoxBody<Bytes, TardisError>,
}

impl SgCtxRequest {
Expand Down Expand Up @@ -267,16 +269,22 @@ impl SgCtxRequest {
&self.remote_addr
}

pub fn take_body(&mut self) -> Incoming {
std::mem::take(&mut self.body)
pub fn take_body(&mut self) -> BoxBody<Bytes, TardisError> {
self.replace_body(Empty::default())
}

pub fn replace_body(&mut self, body: impl Into<Incoming>) -> Incoming {
std::mem::replace(&mut self.body, body.into())
pub fn replace_body<B>(&mut self, body: B) -> BoxBody<Bytes, TardisError>
where
B: Body<Data = Bytes, Error = TardisError>,
{
std::mem::replace(&mut self.body, body.boxed())
}

#[inline]
pub fn set_body(&mut self, body: impl Into<Incoming>) {
pub fn set_body<B>(&mut self, body: B)
where
B: Body<Data = Bytes, Error = TardisError>,
{
let _ = self.replace_body(body);
}

Expand All @@ -294,7 +302,8 @@ impl SgCtxRequest {
/// this method will read all of the body and clone it, and it's body will become an once stream which holds the whole body.
pub async fn dump_body(&mut self) -> TardisResult<hyper::body::Bytes> {
let bytes = self.take_body_into_bytes().await?;
self.set_body(bytes.clone());
let body = Full::new(bytes.clone()).map_err(crate::utils::never::<Infallible, TardisError>);
self.set_body(BoxBody::new(body));
Ok(bytes)
}
}
Expand Down Expand Up @@ -401,36 +410,45 @@ impl SgCtxResponse {
Ok(())
}

#[inline]
pub fn take_body(&mut self) -> Incoming {
std::mem::take(&mut self.body)
pub fn take_body(&mut self) -> BoxBody<Bytes, TardisError> {
self.replace_body(Empty::default())
}

#[inline]
pub fn replace_body(&mut self, body: impl Into<Incoming>) -> Incoming {
std::mem::replace(&mut self.body, body.into())
pub fn take_body_stream(&mut self) -> BodyStream<BoxBody<Bytes, TardisError>> {
BodyStream::new(self.take_body())
}

pub fn replace_body<B>(&mut self, body: B) -> BoxBody<Bytes, TardisError>
where
B: Body<Data = Bytes, Error = TardisError>,
{
std::mem::replace(&mut self.body, body.boxed())
}

#[inline]
pub fn set_body(&mut self, body: impl Into<Incoming>) {
pub fn set_body<B>(&mut self, body: B)
where
B: Body<Data = Bytes, Error = TardisError>,
{
let _ = self.replace_body(body);
}

/// it's a shortcut for [take_body](SgCtxResponse) + [hyper::body::to_bytes]
/// it's a shortcut for [take_body](SgCtxRequest) + [hyper::body::to_bytes]
pub async fn take_body_into_bytes(&mut self) -> TardisResult<hyper::body::Bytes> {
self.take_body().collect().await.map(Collected::to_bytes).map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to collect body into bytes: {e}"), ""))
}

/// it's a shortcut for [take_body](SgCtxResponse) + [hyper::body::aggregate]
/// it's a shortcut for [`take_body`](SgCtxRequest) + [hyper::body::aggregate]
pub async fn take_body_into_buf(&mut self) -> TardisResult<impl hyper::body::Buf> {
self.take_body().collect().await.map(Collected::aggregate).map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to aggregate body: {e}"), ""))
}

/// # Performance
/// This method will read **all** of the body and **clone** it, and it's body will become an once stream which holds the whole body.
/// this method will read all of the body and clone it, and it's body will become an once stream which holds the whole body.
pub async fn dump_body(&mut self) -> TardisResult<hyper::body::Bytes> {
let bytes = self.take_body_into_bytes().await?;
self.set_body(bytes.clone());
let body = Full::new(bytes.clone()).map_err(crate::utils::never::<Infallible, TardisError>);
self.set_body(BoxBody::new(body));
Ok(bytes)
}
}
Expand Down Expand Up @@ -653,6 +671,10 @@ impl SgRoutePluginContext {
}
}

pub fn get_timeout(&self) -> Option<Duration> {
self.get_timeout_ms().map(Duration::from_millis)
}

pub fn get_rule_matched(&self) -> Option<SgHttpRouteMatchInst> {
self.chosen_route_rule.as_ref().and_then(|r| r.matched_match.clone())
}
Expand Down
11 changes: 6 additions & 5 deletions kernel/src/plugins/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod status;
use async_trait::async_trait;

use core::fmt;
use std::future::Future;
use serde_json::Value;
use std::collections::HashMap;

Expand Down Expand Up @@ -146,9 +147,9 @@ pub trait SgPluginFilter: Send + Sync + 'static {
}
}

async fn init(&mut self, init_dto: &SgPluginFilterInitDto) -> TardisResult<()>;
fn init(&mut self, init_dto: &SgPluginFilterInitDto) -> impl Future<Output = TardisResult<()>> + Send;

async fn destroy(&self) -> TardisResult<()>;
fn destroy(&self) -> impl Future<Output = TardisResult<()>> + Send;

/// Request Filtering:
///
Expand All @@ -158,7 +159,7 @@ pub trait SgPluginFilter: Send + Sync + 'static {
/// instance.
/// - `ctx`: A mutable context object that holds information about the
/// request and allows for modifications.
async fn req_filter(&self, id: &str, ctx: SgRoutePluginContext) -> TardisResult<(bool, SgRoutePluginContext)>;
fn req_filter(&self, id: &str, ctx: SgRoutePluginContext) -> impl Future<Output = TardisResult<(bool, SgRoutePluginContext)>> + Send;

/// Response Filtering:
///
Expand All @@ -168,7 +169,7 @@ pub trait SgPluginFilter: Send + Sync + 'static {
/// instance.
/// - `ctx`: A mutable context object that holds information about the
/// request and allows for modifications.
async fn resp_filter(&self, id: &str, ctx: SgRoutePluginContext) -> TardisResult<(bool, SgRoutePluginContext)>;
fn resp_filter(&self, id: &str, ctx: SgRoutePluginContext) -> impl Future<Output = TardisResult<(bool, SgRoutePluginContext)>> + Send;

fn boxed(self) -> BoxSgPluginFilter
where
Expand All @@ -178,7 +179,7 @@ pub trait SgPluginFilter: Send + Sync + 'static {
}
}

pub fn http_common_modify_path(uri: &http::Uri, modify_path: &Option<SgHttpPathModifier>, matched_match_inst: Option<&SgHttpRouteMatchInst>) -> TardisResult<Option<http::Uri>> {
pub fn http_common_modify_path(uri: &hyper::Uri, modify_path: &Option<SgHttpPathModifier>, matched_match_inst: Option<&SgHttpRouteMatchInst>) -> TardisResult<Option<hyper::Uri>> {
if let Some(modify_path) = &modify_path {
let mut uri = Url::parse(&uri.to_string())?;
match modify_path.kind {
Expand Down
38 changes: 22 additions & 16 deletions kernel/src/plugins/filters/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ use std::{cmp::Ordering, pin::Pin};

use crate::def_filter;
use async_compression::tokio::bufread::{BrotliDecoder, BrotliEncoder, DeflateDecoder, DeflateEncoder, GzipDecoder, GzipEncoder};
use http::{header, HeaderValue};
use hyper::body::Body;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, BodyStream, StreamBody};
use hyper::body::{Body, Frame};
use hyper::http::{header, HeaderValue};
use serde::{Deserialize, Serialize};
use tardis::basic::error::TardisError;
use tardis::futures_util::{Stream, StreamExt};
use tardis::{basic::result::TardisResult, futures_util::TryStreamExt, tokio::io::BufReader};
use tokio_util::io::{ReaderStream, StreamReader};

Expand Down Expand Up @@ -66,7 +70,6 @@ impl CompressionType {
}
}


impl SgPluginFilter for SgFilterCompression {
fn accept(&self) -> super::SgPluginFilterAccept {
super::SgPluginFilterAccept {
Expand Down Expand Up @@ -97,6 +100,7 @@ impl SgPluginFilter for SgFilterCompression {
fn convert_error(err: hyper::Error) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::Other, err)
}
let read_stream_mapper = |b: std::io::Result<hyper::body::Bytes>| b.map(Frame::data).map_err(TardisError::from);
if desired_response_encoding == resp_encode_type {
return Ok((true, ctx));
}
Expand All @@ -107,10 +111,12 @@ impl SgPluginFilter for SgFilterCompression {
CompressionType::Br => ctx.response.set_header(header::CONTENT_ENCODING, CompressionType::Br.into())?,
}
}
let mut body = ctx.response.take_body();
body = if let Some(resp_encode_type) = resp_encode_type {
if let Some(resp_encode_type) = resp_encode_type {
let s = StreamExt::map(ctx.response.take_body_stream(), |maybe_frame| {
maybe_frame.map_err(std::io::Error::other).map(|f| f.into_data().unwrap_or_default())
});
ctx.response.remove_header(header::CONTENT_LENGTH)?;
let bytes_reader = StreamReader::new(body.map_err(convert_error));
let bytes_reader = StreamReader::new(s);
let mut read_stream: Pin<Box<dyn tardis::tokio::io::AsyncRead + Send>> = match resp_encode_type {
CompressionType::Gzip => Box::pin(GzipDecoder::new(bytes_reader)),
CompressionType::Deflate => Box::pin(DeflateDecoder::new(bytes_reader)),
Expand All @@ -123,22 +129,22 @@ impl SgPluginFilter for SgFilterCompression {
CompressionType::Br => Box::pin(BrotliEncoder::new(BufReader::new(read_stream))),
};
}
Body::wrap_stream(ReaderStream::new(read_stream))
ctx.response.set_body(StreamBody::new(ReaderStream::new(read_stream).map(read_stream_mapper)));
} else if let Some(desired_response_encoding) = desired_response_encoding {
let s = StreamExt::map(ctx.response.take_body_stream(), |maybe_frame| {
maybe_frame.map_err(std::io::Error::other).map(|f| f.into_data().unwrap_or_default())
});
ctx.response.remove_header(header::CONTENT_LENGTH)?;
let bytes_reader = StreamReader::new(body.map_err(convert_error));
ctx.response.get_headers_mut().insert(hyper::header::TRANSFER_ENCODING, hyper::header::HeaderValue::from_static("chunked"));
let bytes_reader = StreamReader::new(s);
match desired_response_encoding {
CompressionType::Gzip => Body::wrap_stream(ReaderStream::new(GzipEncoder::new(bytes_reader))),
CompressionType::Deflate => Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(bytes_reader))),
CompressionType::Br => Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(bytes_reader))),
CompressionType::Gzip => ctx.response.set_body(StreamBody::new(ReaderStream::new(GzipEncoder::new(bytes_reader)).map(read_stream_mapper))),
CompressionType::Deflate => ctx.response.set_body(StreamBody::new(ReaderStream::new(DeflateEncoder::new(bytes_reader)).map(read_stream_mapper))),
CompressionType::Br => ctx.response.set_body(StreamBody::new(ReaderStream::new(BrotliEncoder::new(bytes_reader)).map(read_stream_mapper))),
}
} else {
body
// ctx.response.take_body()
};

ctx.response.set_body(body);
// let body = ctx.response.dump_body().await?;
// dbg!(body);
Ok((true, ctx))
}
}
Expand Down
3 changes: 1 addition & 2 deletions kernel/src/plugins/filters/header_modifier.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use async_trait::async_trait;
use http::HeaderName;
use hyper::http::header::HeaderName;

use super::{SgPluginFilter, SgPluginFilterAccept, SgPluginFilterInitDto, SgPluginFilterKind, SgRoutePluginContext};
use crate::def_filter;
Expand All @@ -10,7 +10,6 @@ use tardis::basic::result::TardisResult;

def_filter!(SG_FILTER_HEADER_MODIFIER_CODE, SgFilterHeaderModifierDef, SgFilterHeaderModifier);


impl SgPluginFilter for SgFilterHeaderModifier {
fn accept(&self) -> SgPluginFilterAccept {
SgPluginFilterAccept {
Expand Down
Loading

0 comments on commit d439cb2

Please sign in to comment.