Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
4t145 committed Jan 3, 2024
1 parent f0c4a8d commit c263353
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 45 deletions.
3 changes: 2 additions & 1 deletion kernel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ async-trait.workspace = true
itertools.workspace = true
urlencoding.workspace = true
async-compression.workspace = true

http-body-util.workspace = true
hyper-util.workspace = true
kernel-common = { path = "../kernel-common" }
tardis = { workspace = true, features = ["future", "crypto", "tls"] }
http.workspace = true
Expand Down
8 changes: 7 additions & 1 deletion kernel/src/functions/http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{

use crate::plugins::context::SgRoutePluginContext;
use http::{HeaderMap, HeaderValue, Method, Request, Response, StatusCode};
use hyper::{client::HttpConnector, Body, Client, Error};
use hyper::Error;
use hyper_rustls::{ConfigBuilderExt, HttpsConnector};
use kernel_common::inner_model::gateway::SgProtocol;
use tardis::{
Expand Down Expand Up @@ -75,6 +75,12 @@ fn default_client() -> &'static Client<HttpsConnector<HttpConnector>> {
DEFAULT_CLIENT.get().expect("DEFAULT_CLIENT not initialized")
}

pub struct RequestConfig {
timeout: Option<Duration>,

}


pub async fn request(
client: &Client<HttpsConnector<HttpConnector>>,
rule_timeout_ms: Option<u64>,
Expand Down
7 changes: 4 additions & 3 deletions kernel/src/functions/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::net::SocketAddr;
use http::header::{CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE};
use hyper::header::HeaderValue;
use hyper::{self};
use hyper::{Body, Request, Response, StatusCode};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use kernel_common::inner_model::gateway::SgProtocol;
use std::sync::Arc;
use tardis::basic::{error::TardisError, result::TardisResult};
Expand All @@ -16,7 +17,7 @@ use tardis::{log, tokio, TardisFuns};

use crate::instance::SgBackendInst;

pub async fn process(gateway_name: Arc<String>, remote_addr: SocketAddr, backend: &SgBackendInst, mut request: Request<Body>) -> TardisResult<Response<Body>> {
pub async fn process(gateway_name: Arc<String>, remote_addr: SocketAddr, backend: &SgBackendInst, mut request: Request<Incoming>) -> TardisResult<Response<Body>> {
let have_upgrade = request
.headers()
.get(CONNECTION)
Expand Down Expand Up @@ -144,7 +145,7 @@ pub async fn process(gateway_name: Arc<String>, remote_addr: SocketAddr, backend
});
let accept_key = TardisFuns::crypto.base64.encode_raw(TardisFuns::crypto.digest.digest_bytes::<algorithm::Sha1>(format!("{request_key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))?);

let mut response = Response::new(Body::empty());
let mut response = Response::new(http_body_util::Empty::new());
*response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;

response.headers_mut().insert(UPGRADE, HeaderValue::from_static("websocket"));
Expand Down
2 changes: 0 additions & 2 deletions kernel/src/instance.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use crate::plugins::filters::BoxSgPluginFilter;

use http::Method;
use hyper::{client::HttpConnector, Client};
use hyper_rustls::HttpsConnector;

use kernel_common::inner_model::gateway::{SgListener, SgProtocol};
use kernel_common::inner_model::http_route::{SgHttpHeaderMatchType, SgHttpPathMatchType, SgHttpQueryMatchType};
use std::{fmt, vec::Vec};
Expand Down
31 changes: 14 additions & 17 deletions kernel/src/plugins/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use http::{HeaderMap, HeaderName, HeaderValue, Method, Response, StatusCode, Uri, Version};
use hyper::body::Body;
use http_body_util::{BodyExt, Collected};
use hyper::body::{Body, Incoming};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
Expand Down Expand Up @@ -153,9 +154,9 @@ pub struct SgCtxRequest {
pub method: MaybeModified<Method>,
pub uri: MaybeModified<Uri>,
pub version: MaybeModified<Version>,
pub body: Body,
pub headers: MaybeModified<HeaderMap<HeaderValue>>,
pub remote_addr: SocketAddr,
pub body: Incoming,
}

impl SgCtxRequest {
Expand Down Expand Up @@ -254,29 +255,27 @@ impl SgCtxRequest {
&self.remote_addr
}

pub fn take_body(&mut self) -> Body {
pub fn take_body(&mut self) -> Incoming {
std::mem::take(&mut self.body)
}

pub fn replace_body(&mut self, body: impl Into<Body>) -> Body {
pub fn replace_body(&mut self, body: impl Into<Incoming>) -> Incoming {
std::mem::replace(&mut self.body, body.into())
}

#[inline]
pub fn set_body(&mut self, body: impl Into<Body>) {
pub fn set_body(&mut self, body: impl Into<Incoming>) {
let _ = self.replace_body(body);
}

/// it's a shortcut for [take_body](SgCtxRequest) + [hyper::body::to_bytes]
pub async fn take_body_into_bytes(&mut self) -> TardisResult<hyper::body::Bytes> {
let bytes = hyper::body::to_bytes(self.take_body()).await.map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to collect body into bytes: {e}"), ""))?;
Ok(bytes)
self.take_body().collect().await.map(Collected::to_bytes).map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to collect body into bytes: {e}"), ""))
}

/// it's a shortcut for [`take_body`](SgCtxRequest) + [hyper::body::aggregate]
pub async fn take_body_into_buf(&mut self) -> TardisResult<impl hyper::body::Buf> {
let buf = hyper::body::aggregate(self.take_body()).await.map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to aggregate body: {e}"), ""))?;
Ok(buf)
self.take_body().collect().await.map(Collected::aggregate).map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to aggregate body: {e}"), ""))
}

/// # Performance
Expand Down Expand Up @@ -306,7 +305,7 @@ impl Clone for SgCtxRequest {
pub struct SgCtxResponse {
pub status_code: MaybeModified<StatusCode>,
pub headers: MaybeModified<HeaderMap<HeaderValue>>,
pub body: Body,
pub body: Incoming,
resp_err: Option<TardisError>,
}

Expand Down Expand Up @@ -391,30 +390,28 @@ impl SgCtxResponse {
}

#[inline]
pub fn take_body(&mut self) -> Body {
pub fn take_body(&mut self) -> Incoming {
std::mem::take(&mut self.body)
}

#[inline]
pub fn replace_body(&mut self, body: impl Into<Body>) -> Body {
pub fn replace_body(&mut self, body: impl Into<Incoming>) -> Incoming {
std::mem::replace(&mut self.body, body.into())
}

#[inline]
pub fn set_body(&mut self, body: impl Into<Body>) {
pub fn set_body(&mut self, body: impl Into<Incoming>) {
let _ = self.replace_body(body);
}

/// it's a shortcut for [take_body](SgCtxResponse) + [hyper::body::to_bytes]
pub async fn take_body_into_bytes(&mut self) -> TardisResult<hyper::body::Bytes> {
let bytes = hyper::body::to_bytes(self.take_body()).await.map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to collect body into bytes: {e}"), ""))?;
Ok(bytes)
self.take_body().collect().await.map(Collected::to_bytes).map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to collect body into bytes: {e}"), ""))
}

/// it's a shortcut for [take_body](SgCtxResponse) + [hyper::body::aggregate]
pub async fn take_body_into_buf(&mut self) -> TardisResult<impl hyper::body::Buf> {
let buf = hyper::body::aggregate(self.take_body()).await.map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to aggregate body: {e}"), ""))?;
Ok(buf)
self.take_body().collect().await.map(Collected::aggregate).map_err(|e| TardisError::format_error(&format!("[SG.Filter] fail to aggregate body: {e}"), ""))
}

/// # Performance
Expand Down
49 changes: 29 additions & 20 deletions kernel/src/plugins/filters/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::{collections::HashMap, sync::Arc};

use async_trait::async_trait;
use http::Request;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Server};
use hyper::server::conn::{http1, http2};
use hyper::service::service_fn;
use hyper_util::rt::{TokioExecutor, TokioIo};

use serde::{Deserialize, Serialize};
use tardis::chrono::{Duration, Utc};
Expand All @@ -14,6 +15,7 @@ use tardis::{
log,
tokio::{
self,
net::TcpListener,
sync::{watch::Sender, Mutex},
},
};
Expand Down Expand Up @@ -76,7 +78,6 @@ impl Default for SgFilterStatus {
}
}


impl SgPluginFilter for SgFilterStatus {
fn accept(&self) -> super::SgPluginFilterAccept {
super::SgPluginFilterAccept {
Expand Down Expand Up @@ -105,29 +106,37 @@ impl SgPluginFilter for SgFilterStatus {
let title = Arc::new(Mutex::new(self.title.clone()));
let gateway_name = Arc::new(Mutex::new(init_dto.gateway_name.clone()));
let cache_key = Arc::new(Mutex::new(get_cache_key(self, &init_dto.gateway_name)));
let make_svc = make_service_fn(move |_conn| {
let listener = TcpListener::bind(&addr).await.map_err(|e| TardisError::conflict(&format!("[SG.Filter.Status] bind error: {e}"), ""))?;
let server = hyper_util::server::conn::auto::Builder::new(TokioExecutor::default());

let task = async move {
log::info!("[SG.Filter.Status] Server started: {addr}");
let title = title.clone();
let gateway_name = gateway_name.clone();
let cache_key = cache_key.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |request: Request<Body>| {
status_plugin::create_status_html(request, gateway_name.clone(), cache_key.clone(), title.clone())
}))
loop {
let conn = tokio::select! {
_ = shutdown_rx.changed() => {
break;
}
next = listener.accept() => {
next
}
};
match conn {
Ok((stream, socket)) => {
let io = TokioIo::new(stream);
server.serve_connection(
io,
service_fn(move |request: Request<()>| status_plugin::create_status_html(request, gateway_name.clone(), cache_key.clone(), title.clone())),
)
}
Err(_) => todo!(),
}
}
});

let server = match Server::try_bind(&addr) {
Ok(server) => server.serve(make_svc),
Err(e) => return Err(TardisError::conflict(&format!("[SG.Filter.Status] bind error: {e}"), "")),
};

let join = tokio::spawn(async move {
log::info!("[SG.Filter.Status] Server started: {addr}");
let server = server.with_graceful_shutdown(async move {
shutdown_rx.changed().await.ok();
});
server.await
});
let join = tokio::spawn(task);
(*shutdown).insert(self.port, (shutdown_tx, join));

#[cfg(feature = "cache")]
Expand Down
2 changes: 1 addition & 1 deletion kernel/src/plugins/filters/status/status_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl Status {
}

pub(crate) async fn create_status_html(
_: Request<Body>,
_: Request<()>,
_gateway_name: Arc<Mutex<String>>,
_cache_key: Arc<Mutex<String>>,
title: Arc<Mutex<String>>,
Expand Down

0 comments on commit c263353

Please sign in to comment.