From 26f1d043a373609f8a0111648ca57e29c2ec1777 Mon Sep 17 00:00:00 2001 From: sugyan Date: Thu, 16 Jan 2025 00:37:13 +0900 Subject: [PATCH] Update --- atrium-api/src/agent.rs | 22 +- atrium-oauth/identity/src/did.rs | 3 +- atrium-oauth/oauth-client/Cargo.toml | 4 +- .../oauth-client/examples/generate_key.rs | 28 ++ .../oauth-client/src/http_client/dpop.rs | 11 + .../oauth-client/src/oauth_session.rs | 448 +++++++++++++++--- .../oauth-client/src/oauth_session/inner.rs | 96 ++++ .../oauth-client/src/oauth_session/store.rs | 34 ++ atrium-oauth/oauth-client/src/server_agent.rs | 15 + atrium-oauth/oauth-client/src/types.rs | 2 +- .../oauth-client/src/types/client_metadata.rs | 2 +- .../oauth-client/src/types/metadata.rs | 2 +- 12 files changed, 588 insertions(+), 79 deletions(-) create mode 100644 atrium-oauth/oauth-client/examples/generate_key.rs create mode 100644 atrium-oauth/oauth-client/src/oauth_session/inner.rs create mode 100644 atrium-oauth/oauth-client/src/oauth_session/store.rs diff --git a/atrium-api/src/agent.rs b/atrium-api/src/agent.rs index 8f261c27..83679b57 100644 --- a/atrium-api/src/agent.rs +++ b/atrium-api/src/agent.rs @@ -25,8 +25,11 @@ pub trait AuthorizationProvider { } pub trait Configure { + /// Set the current endpoint. fn configure_endpoint(&self, endpoint: String); + /// Configures the moderation services to be applied on requests. fn configure_labelers_header(&self, labeler_dids: Option>); + /// Configures the atproto-proxy header to be applied on requests. fn configure_proxy_header(&self, did: Did, service_type: impl AsRef); } @@ -94,15 +97,12 @@ impl Configure for Agent where M: Configure + SessionManager + Send + Sync, { - /// Set the current endpoint. fn configure_endpoint(&self, endpoint: String) { self.session_manager.configure_endpoint(endpoint); } - /// Configures the moderation services to be applied on requests. fn configure_labelers_header(&self, labeler_dids: Option>) { self.session_manager.configure_labelers_header(labeler_dids); } - /// Configures the atproto-proxy header to be applied on requests. fn configure_proxy_header(&self, did: Did, service_type: impl AsRef) { self.session_manager.configure_proxy_header(did, service_type); } @@ -192,7 +192,7 @@ where impl XrpcClient for WrapperClient where S: Store<(), U> + AuthorizationProvider + Send + Sync, - T: XrpcClient + Send + Sync, + T: HttpClient + Send + Sync, U: Clone + Send + Sync, { fn base_uri(&self) -> String { @@ -499,7 +499,7 @@ mod tests { // labeler service { agent.configure_proxy_header( - Did::new(String::from("did:fake:service.test")).expect("did should be valid"), + Did::new(String::from("did:fake:service.test"))?, AtprotoServiceType::AtprotoLabeler, ); call_service(&agent.api).await?; @@ -514,7 +514,7 @@ mod tests { // custom service { agent.configure_proxy_header( - Did::new(String::from("did:fake:service.test")).expect("did should be valid"), + Did::new(String::from("did:fake:service.test"))?, "custom_service", ); call_service(&agent.api).await?; @@ -528,10 +528,12 @@ mod tests { } // api_with_proxy { - call_service(&agent.api_with_proxy( - Did::new(String::from("did:fake:service.test")).expect("did should be valid"), - "temp_service", - )) + call_service( + &agent.api_with_proxy( + Did::new(String::from("did:fake:service.test"))?, + "temp_service", + ), + ) .await?; assert_eq!( data.lock().await.as_ref().expect("data should be recorded").headers, diff --git a/atrium-oauth/identity/src/did.rs b/atrium-oauth/identity/src/did.rs index 0b731cb1..9e873904 100644 --- a/atrium-oauth/identity/src/did.rs +++ b/atrium-oauth/identity/src/did.rs @@ -2,10 +2,9 @@ mod common_resolver; mod plc_resolver; mod web_resolver; -use crate::Error; - pub use self::common_resolver::{CommonDidResolver, CommonDidResolverConfig}; pub use self::plc_resolver::DEFAULT_PLC_DIRECTORY_URL; +use crate::Error; use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; use atrium_common::resolver::Resolver; diff --git a/atrium-oauth/oauth-client/Cargo.toml b/atrium-oauth/oauth-client/Cargo.toml index 68707935..fb5f3ac9 100644 --- a/atrium-oauth/oauth-client/Cargo.toml +++ b/atrium-oauth/oauth-client/Cargo.toml @@ -14,7 +14,7 @@ keywords = ["atproto", "bluesky", "oauth"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -atrium-api.workspace = true +atrium-api = { workspace = true, features = ["agent"] } atrium-common.workspace = true atrium-identity.workspace = true atrium-xrpc.workspace = true @@ -37,7 +37,7 @@ trait-variant.workspace = true [dev-dependencies] atrium-api = { workspace = true, features = ["bluesky"] } hickory-resolver.workspace = true -p256 = { workspace = true, features = ["pem"] } +p256 = { workspace = true, features = ["pem", "std"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } [features] diff --git a/atrium-oauth/oauth-client/examples/generate_key.rs b/atrium-oauth/oauth-client/examples/generate_key.rs new file mode 100644 index 00000000..08228e5c --- /dev/null +++ b/atrium-oauth/oauth-client/examples/generate_key.rs @@ -0,0 +1,28 @@ +use elliptic_curve::pkcs8::EncodePrivateKey; +use elliptic_curve::SecretKey; +use jose_jwa::{Algorithm, Signing}; +use jose_jwk::{Class, Jwk, JwkSet, Key, Parameters}; +use p256::NistP256; +use rand::rngs::ThreadRng; + +fn main() -> Result<(), Box> { + let secret_key = SecretKey::::random(&mut ThreadRng::default()); + let key = Key::from(&secret_key.public_key().into()); + let jwks = JwkSet { + keys: vec![Jwk { + key, + prm: Parameters { + alg: Some(Algorithm::Signing(Signing::Es256)), + kid: Some(String::from("kid01")), + cls: Some(Class::Signing), + ..Default::default() + }, + }], + }; + println!("SECRET KEY:"); + println!("{}", secret_key.to_pkcs8_pem(Default::default())?.as_str()); + + println!("JWKS:"); + println!("{}", serde_json::to_string_pretty(&jwks)?); + Ok(()) +} diff --git a/atrium-oauth/oauth-client/src/http_client/dpop.rs b/atrium-oauth/oauth-client/src/http_client/dpop.rs index 29999662..dbe65f6f 100644 --- a/atrium-oauth/oauth-client/src/http_client/dpop.rs +++ b/atrium-oauth/oauth-client/src/http_client/dpop.rs @@ -187,3 +187,14 @@ where Ok(response) } } + +impl Clone for DpopClient { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + key: self.key.clone(), + nonces: self.nonces.clone(), + is_auth_server: self.is_auth_server, + } + } +} diff --git a/atrium-oauth/oauth-client/src/oauth_session.rs b/atrium-oauth/oauth-client/src/oauth_session.rs index 5e791413..c92ab6c4 100644 --- a/atrium-oauth/oauth-client/src/oauth_session.rs +++ b/atrium-oauth/oauth-client/src/oauth_session.rs @@ -1,17 +1,20 @@ -use crate::{ - http_client::dpop::Error, - server_agent::OAuthServerAgent, - {DpopClient, TokenSet}, +mod inner; +mod store; + +use crate::{http_client::dpop, server_agent::OAuthServerAgent, DpopClient, TokenSet}; +use atrium_api::{ + agent::{CloneWithProxy, Configure, InnerStore, SessionManager}, + types::string::Did, }; -use atrium_api::{agent::SessionManager, types::string::Did}; use atrium_common::store::{memory::MemoryStore, Store}; use atrium_xrpc::{ http::{Request, Response}, - types::AuthorizationToken, - HttpClient, XrpcClient, + Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, }; use jose_jwk::Key; -use std::sync::Arc; +use serde::{de::DeserializeOwned, Serialize}; +use std::{fmt::Debug, sync::Arc}; +use store::MemorySessionStore; pub struct OAuthSession> where @@ -19,98 +22,419 @@ where S: Store, { server_agent: OAuthServerAgent, - dpop_client: DpopClient, + store: Arc>, + inner: inner::Client>, token_set: TokenSet, // TODO: replace with a session store? } impl OAuthSession where - T: HttpClient + Send + Sync + 'static, + T: HttpClient + Send + Sync, { pub(crate) fn new( server_agent: OAuthServerAgent, dpop_key: Key, http_client: Arc, token_set: TokenSet, - ) -> Result { - let dpop_client = DpopClient::new( - dpop_key, - http_client.clone(), - false, - &server_agent.server_metadata.token_endpoint_auth_signing_alg_values_supported, - )?; - Ok(Self { server_agent, dpop_client, token_set }) - } - pub fn dpop_key(&self) -> Key { - self.dpop_client.key.clone() - } - pub fn token_set(&self) -> TokenSet { - self.token_set.clone() + ) -> Result { + let store = Arc::new(InnerStore::new(MemorySessionStore::default(), token_set.aud.clone())); + let inner = inner::Client::new( + Arc::clone(&store), + DpopClient::new( + dpop_key, + http_client.clone(), + false, + &server_agent.server_metadata.token_endpoint_auth_signing_alg_values_supported, + )?, + ); + Ok(Self { server_agent, store, inner, token_set }) } } impl HttpClient for OAuthSession where - T: HttpClient + Send + Sync + 'static, - D: Send + Sync + 'static, - H: Send + Sync + 'static, + T: HttpClient + Send + Sync, + D: Send + Sync, + H: Send + Sync, S: Store + Send + Sync + 'static, - S::Error: std::error::Error + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync, { async fn send_http( &self, request: Request>, ) -> Result>, Box> { - self.dpop_client.send_http(request).await + self.inner.send_http(request).await } } impl XrpcClient for OAuthSession where - T: HttpClient + Send + Sync + 'static, - D: Send + Sync + 'static, - H: Send + Sync + 'static, + T: HttpClient + Send + Sync, + D: Send + Sync, + H: Send + Sync, S: Store + Send + Sync + 'static, - S::Error: std::error::Error + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync, { fn base_uri(&self) -> String { - self.token_set.aud.clone() - } - async fn authorization_token(&self, _is_refresh: bool) -> Option { - Some(AuthorizationToken::Dpop(self.token_set.access_token.clone())) - } - // async fn atproto_proxy_header(&self) -> Option { - // todo!() - // } - // async fn atproto_accept_labelers_header(&self) -> Option> { - // todo!() - // } - // async fn send_xrpc( - // &self, - // request: &XrpcRequest, - // ) -> Result, Error> - // where - // P: Serialize + Send + Sync, - // I: Serialize + Send + Sync, - // O: DeserializeOwned + Send + Sync, - // E: DeserializeOwned + Send + Sync + Debug, - // { - // todo!() - // } + self.inner.base_uri() + } + async fn send_xrpc( + &self, + request: &XrpcRequest, + ) -> Result, Error> + where + P: Serialize + Send + Sync, + I: Serialize + Send + Sync, + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, + { + self.inner.send_xrpc(request).await + } } impl SessionManager for OAuthSession where - T: HttpClient + Send + Sync + 'static, - D: Send + Sync + 'static, - H: Send + Sync + 'static, + T: HttpClient + Send + Sync, + D: Send + Sync, + H: Send + Sync, S: Store + Send + Sync + 'static, - S::Error: std::error::Error + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync, { async fn did(&self) -> Option { - todo!() + Some(self.token_set.sub.clone()) + } +} + +impl Configure for OAuthSession +where + T: HttpClient + Send + Sync, + S: Store + Send + Sync + 'static, +{ + fn configure_endpoint(&self, endpoint: String) { + self.inner.configure_endpoint(endpoint); + } + fn configure_labelers_header(&self, labeler_dids: Option>) { + self.inner.configure_labelers_header(labeler_dids); + } + fn configure_proxy_header(&self, did: Did, service_type: impl AsRef) { + self.inner.configure_proxy_header(did, service_type); + } +} + +impl CloneWithProxy for OAuthSession +where + T: HttpClient + Send + Sync, + S: Store + Send + Sync + 'static, +{ + fn clone_with_proxy(&self, did: Did, service_type: impl AsRef) -> Self { + Self { + server_agent: self.server_agent.clone(), + store: self.store.clone(), + inner: self.inner.clone_with_proxy(did, service_type), + token_set: self.token_set.clone(), + } } } #[cfg(test)] -mod tests {} +mod tests { + use super::*; + use crate::{ + jose::jwt::Claims, resolver::OAuthResolver, types::OAuthTokenType, OAuthResolverConfig, + }; + use atrium_api::{ + agent::{Agent, AtprotoServiceType}, + client::Service, + did_doc::DidDocument, + types::string::Handle, + xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue}, + }; + use atrium_common::resolver::Resolver; + use atrium_identity::{did::DidResolver, handle::HandleResolver}; + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; + use tokio::sync::Mutex; + + #[derive(Default)] + struct RecordData { + host: Option, + headers: HeaderMap, + } + + struct MockHttpClient { + data: Arc>>, + } + + impl HttpClient for MockHttpClient { + async fn send_http( + &self, + request: Request>, + ) -> Result>, Box> { + let mut headers = request.headers().clone(); + let dpop_jwt = headers.remove("dpop").expect("dpop header should be present"); + let payload = dpop_jwt + .to_str() + .expect("dpop header should be valid") + .split('.') + .nth(1) + .expect("dpop header should have 2 parts"); + let claims = URL_SAFE_NO_PAD + .decode(payload) + .ok() + .and_then(|value| serde_json::from_slice::(&value).ok()) + .expect("dpop payload should be valid"); + assert!(claims.registered.iat.is_some()); + assert!(claims.registered.jti.is_some()); + assert_eq!(claims.public.htm, Some(request.method().to_string())); + assert_eq!(claims.public.htu, Some(request.uri().to_string())); + + self.data + .lock() + .await + .replace(RecordData { host: request.uri().host().map(String::from), headers }); + let output = atrium_api::com::atproto::server::get_service_auth::OutputData { + token: String::from("fake_token"), + }; + Response::builder() + .header(CONTENT_TYPE, "application/json") + .body(serde_json::to_vec(&output)?) + .map_err(|e| e.into()) + } + } + + struct NoopDidResolver; + + impl Resolver for NoopDidResolver { + type Input = Did; + type Output = DidDocument; + type Error = atrium_identity::Error; + async fn resolve(&self, _: &Self::Input) -> Result { + unimplemented!() + } + } + + impl DidResolver for NoopDidResolver {} + + struct NoopHandleResolver; + + impl Resolver for NoopHandleResolver { + type Input = Handle; + type Output = Did; + type Error = atrium_identity::Error; + async fn resolve(&self, _: &Self::Input) -> Result { + unimplemented!() + } + } + + impl HandleResolver for NoopHandleResolver {} + + fn oauth_agent( + data: Arc>>, + ) -> Agent { + let dpop_key = serde_json::from_str::( + r#"{ + "kty": "EC", + "crv": "P-256", + "x": "NIRNgPVAwnVNzN5g2Ik2IMghWcjnBOGo9B-lKXSSXFs", + "y": "iWF-Of43XoSTZxcadO9KWdPTjiCoviSztYw7aMtZZMc", + "d": "9MuCYfKK4hf95p_VRj6cxKJwORTgvEU3vynfmSgFH2M" + }"#, + ) + .expect("key should be valid"); + let http_client = Arc::new(MockHttpClient { data }); + let resolver = Arc::new(OAuthResolver::new( + OAuthResolverConfig { + did_resolver: NoopDidResolver, + handle_resolver: NoopHandleResolver, + authorization_server_metadata: Default::default(), + protected_resource_metadata: Default::default(), + }, + Arc::clone(&http_client), + )); + let keyset = None; + let server_agent = OAuthServerAgent::new( + dpop_key.clone(), + Default::default(), + Default::default(), + resolver, + Arc::clone(&http_client), + keyset, + ) + .expect("failed to create server agent"); + let token_set = TokenSet { + iss: String::from("https://iss.example.com"), + sub: Did::new(String::from("did:fake:sub.test")).expect("did should be valid"), + aud: String::from("https://aud.example.com"), + scope: None, + refresh_token: None, + access_token: String::from("access_token"), + token_type: OAuthTokenType::DPoP, + expires_at: None, + }; + let oauth_session = OAuthSession::new(server_agent, dpop_key, http_client, token_set) + .expect("failed to create oauth session"); + Agent::new(oauth_session) + } + + async fn call_service( + service: &Service, + ) -> Result<(), Error> { + let output = service + .com + .atproto + .server + .get_service_auth( + atrium_api::com::atproto::server::get_service_auth::ParametersData { + aud: Did::new(String::from("did:fake:handle.test")) + .expect("did should be valid"), + exp: None, + lxm: None, + } + .into(), + ) + .await?; + assert_eq!(output.token, "fake_token"); + Ok(()) + } + + #[tokio::test] + async fn test_new() -> Result<(), Box> { + let agent = oauth_agent(Arc::new(Mutex::new(Default::default()))); + assert_eq!(agent.did().await.as_deref(), Some("did:fake:sub.test")); + Ok(()) + } + + #[tokio::test] + async fn test_configure_endpoint() -> Result<(), Box> { + let data = Arc::new(Mutex::new(Default::default())); + let agent = oauth_agent(Arc::clone(&data)); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").host.as_deref(), + Some("aud.example.com") + ); + agent.configure_endpoint(String::from("https://pds.example.com")); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").host.as_deref(), + Some("pds.example.com") + ); + Ok(()) + } + + #[tokio::test] + async fn test_configure_labelers_header() -> Result<(), Box> { + let data = Arc::new(Mutex::new(Default::default())); + let agent = oauth_agent(Arc::clone(&data)); + // not configured + { + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::new() + ); + } + // configured 1 + { + agent.configure_labelers_header(Some(vec![( + Did::new(String::from("did:fake:labeler.test"))?, + false, + )])); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-accept-labelers"), + HeaderValue::from_static("did:fake:labeler.test"), + )]) + ); + } + // configured 2 + { + agent.configure_labelers_header(Some(vec![ + (Did::new(String::from("did:fake:labeler.test_redact"))?, true), + (Did::new(String::from("did:fake:labeler.test"))?, false), + ])); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-accept-labelers"), + HeaderValue::from_static( + "did:fake:labeler.test_redact;redact, did:fake:labeler.test" + ), + )]) + ); + } + Ok(()) + } + + #[tokio::test] + async fn test_configure_proxy_header() -> Result<(), Box> { + let data = Arc::new(Mutex::new(Default::default())); + let agent = oauth_agent(data.clone()); + // not configured + { + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::new() + ); + } + // labeler service + { + agent.configure_proxy_header( + Did::new(String::from("did:fake:service.test"))?, + AtprotoServiceType::AtprotoLabeler, + ); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:fake:service.test#atproto_labeler"), + )]) + ); + } + // custom service + { + agent.configure_proxy_header( + Did::new(String::from("did:fake:service.test"))?, + "custom_service", + ); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:fake:service.test#custom_service"), + )]) + ); + } + // api_with_proxy + { + call_service( + &agent.api_with_proxy( + Did::new(String::from("did:fake:service.test"))?, + "temp_service", + ), + ) + .await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:fake:service.test#temp_service"), + )]) + ); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:fake:service.test#custom_service"), + )]) + ); + } + Ok(()) + } +} diff --git a/atrium-oauth/oauth-client/src/oauth_session/inner.rs b/atrium-oauth/oauth-client/src/oauth_session/inner.rs new file mode 100644 index 00000000..3d6a991b --- /dev/null +++ b/atrium-oauth/oauth-client/src/oauth_session/inner.rs @@ -0,0 +1,96 @@ +use super::store::OAuthSessionStore; +use atrium_api::{ + agent::{CloneWithProxy, Configure, InnerStore, WrapperClient}, + types::string::Did, +}; +use atrium_xrpc::{ + http::{Request, Response}, + Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, +}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{fmt::Debug, sync::Arc}; + +pub struct Client { + inner: WrapperClient, +} + +impl Client { + pub fn new(store: Arc>, xrpc: T) -> Self { + Self { inner: WrapperClient::new(Arc::clone(&store), xrpc) } + } + async fn refresh_token(&self) {} + // https://datatracker.ietf.org/doc/html/rfc6750#section-3 + // fn is_invalid_token_response(result: &Result, Error>) -> bool + // where + // O: DeserializeOwned + Send + Sync, + // E: DeserializeOwned + Send + Sync + Debug, + // { + // todo!() + // } +} + +impl HttpClient for Client +where + S: OAuthSessionStore + Send + Sync, + T: HttpClient + Send + Sync, +{ + async fn send_http( + &self, + request: Request>, + ) -> Result>, Box> { + self.inner.send_http(request).await + } +} + +impl XrpcClient for Client +where + S: OAuthSessionStore + Send + Sync, + T: HttpClient + Send + Sync, +{ + fn base_uri(&self) -> String { + self.inner.base_uri() + } + async fn send_xrpc( + &self, + request: &XrpcRequest, + ) -> Result, Error> + where + P: Serialize + Send + Sync, + I: Serialize + Send + Sync, + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, + { + // let result = self.inner.send_xrpc(request).await; + // // handle session-refreshes as needed + // if Self::is_invalid_token_response(&result) { + // self.refresh_token().await; + // self.inner.send_xrpc(request).await + // } else { + // result + // } + self.inner.send_xrpc(request).await + } +} + +impl Configure for Client { + fn configure_endpoint(&self, endpoint: String) { + self.inner.configure_endpoint(endpoint) + } + /// Configures the moderation services to be applied on requests. + fn configure_labelers_header(&self, labeler_dids: Option>) { + self.inner.configure_labelers_header(labeler_dids) + } + /// Configures the atproto-proxy header to be applied on requests. + fn configure_proxy_header(&self, did: Did, service_type: impl AsRef) { + self.inner.configure_proxy_header(did, service_type) + } +} + +impl CloneWithProxy for Client +where + WrapperClient: CloneWithProxy, +{ + fn clone_with_proxy(&self, did: Did, service_type: impl AsRef) -> Self { + Self { inner: self.inner.clone_with_proxy(did, service_type) } + } +} diff --git a/atrium-oauth/oauth-client/src/oauth_session/store.rs b/atrium-oauth/oauth-client/src/oauth_session/store.rs new file mode 100644 index 00000000..7a7f2312 --- /dev/null +++ b/atrium-oauth/oauth-client/src/oauth_session/store.rs @@ -0,0 +1,34 @@ +use atrium_api::agent::AuthorizationProvider; +use atrium_common::store::{self, memory::MemoryStore, Store}; +use atrium_xrpc::types::AuthorizationToken; + +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] +pub trait OAuthSessionStore: store::Store<(), String> + AuthorizationProvider {} + +#[derive(Default)] +pub struct MemorySessionStore(MemoryStore<(), String>); + +impl OAuthSessionStore for MemorySessionStore {} + +impl Store<(), String> for MemorySessionStore { + type Error = store::memory::Error; + + async fn get(&self, key: &()) -> Result, Self::Error> { + todo!() + } + async fn set(&self, key: (), value: String) -> Result<(), Self::Error> { + todo!() + } + async fn del(&self, key: &()) -> Result<(), Self::Error> { + todo!() + } + async fn clear(&self) -> Result<(), Self::Error> { + todo!() + } +} + +impl AuthorizationProvider for MemorySessionStore { + async fn authorization_token(&self, _: bool) -> Option { + self.0.get(&()).await.ok().flatten().map(AuthorizationToken::Dpop) + } +} diff --git a/atrium-oauth/oauth-client/src/server_agent.rs b/atrium-oauth/oauth-client/src/server_agent.rs index b7ac660d..6045358c 100644 --- a/atrium-oauth/oauth-client/src/server_agent.rs +++ b/atrium-oauth/oauth-client/src/server_agent.rs @@ -320,3 +320,18 @@ where Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set)?) } } + +impl Clone for OAuthServerAgent +where + T: HttpClient + Send + Sync + 'static, +{ + fn clone(&self) -> Self { + Self { + server_metadata: self.server_metadata.clone(), + client_metadata: self.client_metadata.clone(), + dpop_client: self.dpop_client.clone(), + resolver: Arc::clone(&self.resolver), + keyset: self.keyset.clone(), + } + } +} diff --git a/atrium-oauth/oauth-client/src/types.rs b/atrium-oauth/oauth-client/src/types.rs index a5712674..b0cf0af3 100644 --- a/atrium-oauth/oauth-client/src/types.rs +++ b/atrium-oauth/oauth-client/src/types.rs @@ -12,7 +12,7 @@ pub use request::{ PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType, TokenRequestParameters, }; -pub use response::{OAuthPusehedAuthorizationRequestResponse, OAuthTokenResponse}; +pub use response::{OAuthPusehedAuthorizationRequestResponse, OAuthTokenResponse, OAuthTokenType}; use serde::Deserialize; pub use token::TokenSet; diff --git a/atrium-oauth/oauth-client/src/types/client_metadata.rs b/atrium-oauth/oauth-client/src/types/client_metadata.rs index 04f2f2bf..b30a23f2 100644 --- a/atrium-oauth/oauth-client/src/types/client_metadata.rs +++ b/atrium-oauth/oauth-client/src/types/client_metadata.rs @@ -2,7 +2,7 @@ use crate::keyset::Keyset; use jose_jwk::JwkSet; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct OAuthClientMetadata { pub client_id: String, #[serde(skip_serializing_if = "Option::is_none")] diff --git a/atrium-oauth/oauth-client/src/types/metadata.rs b/atrium-oauth/oauth-client/src/types/metadata.rs index 0e40c649..16a9e723 100644 --- a/atrium-oauth/oauth-client/src/types/metadata.rs +++ b/atrium-oauth/oauth-client/src/types/metadata.rs @@ -1,7 +1,7 @@ use atrium_api::types::string::Language; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct OAuthAuthorizationServerMetadata { // https://datatracker.ietf.org/doc/html/rfc8414#section-2 pub issuer: String,