Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
siketyan committed Dec 31, 2023
1 parent a425ca7 commit bc081b2
Show file tree
Hide file tree
Showing 14 changed files with 182 additions and 107 deletions.
12 changes: 7 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ path = "src/lib.rs"

[features]
default = []
hyper = ["dep:tokio","dep:hyper", "dep:hyper-rustls", "dep:tokio-stream","dep:tokio-tungstenite", "dep:tokio-tungstenite", "dep:signal-hook", "dep:signal-hook-tokio"]
hyper = ["dep:tokio", "dep:http-body-util", "dep:hyper", "dep:hyper-rustls", "dep:hyper-util", "dep:tokio-stream","dep:tokio-tungstenite", "dep:tokio-tungstenite", "dep:signal-hook", "dep:signal-hook-tokio"]
axum = ["hyper", "dep:axum", "dep:tower"]

[dependencies]
Expand All @@ -39,20 +39,22 @@ hex = "0.4"
tracing = "0.1"
ring = "0.17"
lazy_static = "1.4"
http = "0.2"
http = "1.0"
async-trait = "0.1"
bytes = "1"
rand = "0.8"
async-recursion="1.0"
mime = "0.3"
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
url = { version = "2.5", features = ["serde"]}
hyper = { version ="0.14", features = ["http2","server", "client", "h2", "stream"], optional = true }
http-body-util = { version = "0.1", optional = true }
hyper = { version ="1.0", features = ["http2","server", "client"], optional = true }
hyper-util = { version = "0.1", features = ["client", "client-legacy", "server"], optional = true }
tokio = { version = "1", features = ["bytes","rt-multi-thread","signal","tracing"], optional = true }
tokio-stream = { version = "0.1", optional = true }
hyper-rustls = { version="0.24", features = ["rustls-native-certs", "http2"], optional = true }
tokio-tungstenite = { version = "0.21.0", features = ["rustls-tls-native-roots"], optional = true }
axum = { version = "0.6", optional = true }
axum = { version = "0.7", optional = true }
tower = { version = "0.4", optional = true }
serde_urlencoded = "0.7.1"

Expand All @@ -68,7 +70,7 @@ cargo-husky = { version = "1", default-features = false, features = ["run-for-al
cargo-audit = "0.18"
tracing-subscriber = { version ="0.3", features = ["env-filter"] }
hyper-proxy = "0.9"
hyper = { version ="0.14", features = ["full"] }
hyper = { version ="1.0", features = ["full"] }
tokio = { version = "1", features = ["full"] }

[package.metadata.release]
Expand Down
2 changes: 1 addition & 1 deletion src/axum_support/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::hyper_tokio::SlackClientHyperConnector;
use crate::listener::SlackClientEventsListenerEnvironment;
use hyper::client::connect::Connect;
use hyper_util::client::legacy::connect::Connect;
use std::sync::Arc;

mod slack_events_middleware;
Expand Down
9 changes: 4 additions & 5 deletions src/axum_support/slack_events_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ use crate::listener::SlackClientEventsListenerEnvironment;
use crate::prelude::hyper_ext::HyperExtensions;
use crate::signature_verifier::SlackEventSignatureVerifier;
use crate::{SlackClientHttpConnector, SlackSigningSecret};
use axum::body::BoxBody;
use axum::response::IntoResponse;
use axum::{body::Body, http::Request, response::Response};
use futures_util::future::BoxFuture;
use hyper::client::connect::Connect;
use hyper_util::client::legacy::connect::Connect;
use std::convert::Infallible;
use std::marker::PhantomData;
use std::sync::Arc;
Expand Down Expand Up @@ -104,7 +103,7 @@ where
);
Ok(Response::builder()
.status(http_status)
.body(BoxBody::default())
.body(Body::default())
.unwrap())
} else {
*verified_request.body_mut() = Body::from(verified_body);
Expand All @@ -126,7 +125,7 @@ where
);
Ok(Response::builder()
.status(http_status)
.body(BoxBody::default())
.body(Body::default())
.unwrap())
}
}
Expand All @@ -141,7 +140,7 @@ where
);
Ok(Response::builder()
.status(http_status)
.body(BoxBody::default())
.body(Body::default())
.unwrap())
}
}
Expand Down
24 changes: 15 additions & 9 deletions src/axum_support/slack_oauth_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use crate::axum_support::SlackEventsAxumListener;
use crate::hyper_tokio::hyper_ext::HyperExtensions;
use crate::listener::{SlackClientEventsListenerEnvironment, UserCallbackFunction};
use crate::prelude::SlackOAuthListenerConfig;
use axum::response::Response;
use axum::body::Body;
use axum::response::{IntoResponse, Response};
use futures_util::future::BoxFuture;
use futures_util::FutureExt;
use http::Request;
use hyper::client::connect::Connect;
use hyper::Body;
use http_body_util::BodyExt;
use hyper_util::client::legacy::connect::Connect;
use rvstruct::ValueStruct;
use std::future::Future;
use std::sync::Arc;
Expand All @@ -22,7 +23,7 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {
pub fn slack_oauth_install(
&self,
config: &SlackOAuthListenerConfig,
) -> impl Fn(Request<Body>) -> BoxFuture<'static, Response<Body>> + 'static + Send + Clone {
) -> impl Fn(Request<Body>) -> BoxFuture<'static, Response> + 'static + Send + Clone {
let environment = self.environment.clone();
let config = config.clone();
move |_| {
Expand All @@ -41,7 +42,7 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {
],
);
debug!("Redirecting to Slack OAuth authorize: {}", &full_uri);
HyperExtensions::hyper_redirect_to(full_uri.as_ref())
HyperExtensions::hyper_redirect_to(full_uri.as_ref()).map(|r| r.into_response())
}
.map(|res| Self::handle_error(environment, res))
.boxed()
Expand All @@ -66,7 +67,7 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {
let err_config = config.clone();

async move {
let params = HyperExtensions::parse_query_params(&req);
let params = HyperExtensions::parse_query_params(req.uri());
debug!("Received Slack OAuth callback: {:?}", &params);

match (params.get("code"), params.get("error")) {
Expand Down Expand Up @@ -105,6 +106,7 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {
)
.await;
HyperExtensions::hyper_redirect_to(&config.redirect_installed_url)
.map(|r| r.into_response())
}
Err(err) => {
error!("Slack OAuth error: {}", &err);
Expand All @@ -116,6 +118,7 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {
HyperExtensions::hyper_redirect_to(
&config.redirect_error_redirect_url,
)
.map(|r| r.into_response())
}
}
}
Expand All @@ -134,6 +137,7 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {
req.uri().query().map_or("".into(), |q| format!("?{}", &q))
);
HyperExtensions::hyper_redirect_to(&redirect_error_url)
.map(|r| r.into_response())
}
_ => {
error!("Slack OAuth cancelled with unknown reason");
Expand All @@ -146,6 +150,7 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {
environment.user_state.clone(),
);
HyperExtensions::hyper_redirect_to(&config.redirect_error_redirect_url)
.map(|r| r.into_response())
}
}
}
Expand All @@ -163,6 +168,7 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {
);
HyperExtensions::hyper_redirect_to(&err_config.redirect_error_redirect_url)
.unwrap()
.into_response()
}
})
.boxed()
Expand Down Expand Up @@ -195,8 +201,8 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {

fn handle_error(
environment: Arc<SlackClientEventsListenerEnvironment<SlackClientHyperConnector<H>>>,
result: AnyStdResult<Response<hyper::Body>>,
) -> Response<hyper::Body> {
result: AnyStdResult<Response>,
) -> Response {
match result {
Err(err) => {
let http_status = (environment.error_handler)(
Expand All @@ -206,7 +212,7 @@ impl<H: 'static + Send + Sync + Connect + Clone> SlackEventsAxumListener<H> {
);
Response::builder()
.status(http_status)
.body(hyper::Body::empty())
.body(Body::empty())
.unwrap()
}
Ok(result) => result,
Expand Down
27 changes: 15 additions & 12 deletions src/hyper_tokio/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use crate::models::{SlackClientId, SlackClientSecret};
use crate::*;
use async_recursion::async_recursion;
use futures::future::{BoxFuture, FutureExt};
use hyper::client::*;
use hyper::body::{Body, Incoming};
use hyper::http::StatusCode;
use hyper::{Body, Request};
use hyper::Request;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::*;
use hyper_util::rt::TokioExecutor;
use rvstruct::ValueStruct;

use crate::prelude::hyper_ext::HyperExtensions;
Expand All @@ -21,13 +23,14 @@ use url::Url;

#[derive(Clone, Debug)]
pub struct SlackClientHyperConnector<H: Send + Sync + Clone + connect::Connect> {
hyper_connector: Client<H>,
hyper_connector: Client<H, Incoming>,
tokio_rate_controller: Option<Arc<SlackTokioRateController>>,
}

pub type SlackClientHyperHttpsConnector = SlackClientHyperConnector<HttpsConnector<HttpConnector>>;
pub type SlackClientHyperHttpsConnector =
SlackClientHyperConnector<HttpsConnector<connect::HttpConnector>>;

impl SlackClientHyperConnector<HttpsConnector<HttpConnector>> {
impl SlackClientHyperConnector<HttpsConnector<connect::HttpConnector>> {
pub fn new() -> Self {
let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
Expand All @@ -38,18 +41,18 @@ impl SlackClientHyperConnector<HttpsConnector<HttpConnector>> {
}
}

impl From<HttpsConnector<HttpConnector>>
for SlackClientHyperConnector<HttpsConnector<HttpConnector>>
impl From<HttpsConnector<connect::HttpConnector>>
for SlackClientHyperConnector<HttpsConnector<connect::HttpConnector>>
{
fn from(https_connector: hyper_rustls::HttpsConnector<HttpConnector>) -> Self {
fn from(https_connector: HttpsConnector<connect::HttpConnector>) -> Self {
Self::with_connector(https_connector)
}
}

impl<H: 'static + Send + Sync + Clone + connect::Connect> SlackClientHyperConnector<H> {
pub fn with_connector(connector: H) -> Self {
Self {
hyper_connector: Client::builder().build::<_, hyper::Body>(connector),
hyper_connector: Client::builder(TokioExecutor::new()).build::<_, Incoming>(connector),
tokio_rate_controller: None,
}
}
Expand All @@ -65,7 +68,7 @@ impl<H: 'static + Send + Sync + Clone + connect::Connect> SlackClientHyperConnec

async fn send_http_request<'a, RS>(
&'a self,
request: Request<Body>,
request: Request<Incoming>,
context: SlackClientApiCallContext<'a>,
) -> ClientResult<RS>
where
Expand Down Expand Up @@ -185,7 +188,7 @@ impl<H: 'static + Send + Sync + Clone + connect::Connect> SlackClientHyperConnec
retried: usize,
) -> ClientResult<RS>
where
R: Fn() -> ClientResult<Request<Body>> + Send + Sync,
R: Fn() -> ClientResult<Request<Incoming>> + Send + Sync,
RS: for<'de> serde::de::Deserialize<'de> + Send,
{
match (
Expand Down Expand Up @@ -223,7 +226,7 @@ impl<H: 'static + Send + Sync + Clone + connect::Connect> SlackClientHyperConnec
context: SlackClientApiCallContext<'a>,
) -> ClientResult<RS>
where
R: Fn() -> ClientResult<Request<Body>> + Send + Sync,
R: Fn() -> ClientResult<Request<Incoming>> + Send + Sync,
RS: for<'de> serde::de::Deserialize<'de> + Send,
{
match result {
Expand Down
25 changes: 13 additions & 12 deletions src/hyper_tokio/hyper_ext.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
use crate::signature_verifier::*;
use crate::{AnyStdResult, SlackApiToken};
use base64::prelude::*;
use bytes::Buf;
use bytes::{Buf, Bytes};
use futures_util::TryFutureExt;
use http::request::Parts;
use http::{Request, Response, Uri};
use hyper::body::HttpBody;
use hyper::Body;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use hyper::body::{Body, Incoming};
use mime::Mime;
use rvstruct::ValueStruct;
use std::collections::HashMap;
use std::convert::Infallible;
use std::io::Read;
use url::Url;

pub struct HyperExtensions;

impl HyperExtensions {
pub fn parse_query_params(request: &Request<Body>) -> HashMap<String, String> {
request
.uri()
.query()
pub fn parse_query_params(uri: &Uri) -> HashMap<String, String> {
uri.query()
.map(|v| {
url::form_urlencoded::parse(v.as_bytes())
.into_owned()
Expand All @@ -30,11 +30,12 @@ impl HyperExtensions {

pub fn hyper_redirect_to(
url: &str,
) -> Result<Response<Body>, Box<dyn std::error::Error + Send + Sync>> {
) -> Result<Response<BoxBody<Bytes, Infallible>>, Box<dyn std::error::Error + Send + Sync>>
{
Response::builder()
.status(hyper::http::StatusCode::FOUND)
.header(hyper::header::LOCATION, url)
.body(Body::empty())
.body(Empty::new().boxed())
.map_err(|e| e.into())
}

Expand Down Expand Up @@ -75,10 +76,10 @@ impl HyperExtensions {

pub async fn http_body_to_string<T>(body: T) -> AnyStdResult<String>
where
T: HttpBody,
T: Body,
T::Error: std::error::Error + Sync + Send + 'static,
{
let http_body = hyper::body::aggregate(body).await?;
let http_body = body.collect().await?.aggregate();
let mut http_reader = http_body.reader();
let mut http_body_str = String::new();
http_reader.read_to_string(&mut http_body_str)?;
Expand All @@ -94,7 +95,7 @@ impl HyperExtensions {
}

pub async fn decode_signed_response(
req: Request<Body>,
req: Request<Incoming>,
signature_verifier: &SlackEventSignatureVerifier,
) -> AnyStdResult<(String, Parts)> {
match (
Expand Down
Loading

0 comments on commit bc081b2

Please sign in to comment.