Skip to content

Commit

Permalink
change Resolver type signature
Browse files Browse the repository at this point in the history
  • Loading branch information
avdb13 committed Nov 15, 2024
1 parent ef343d2 commit c6d51fe
Show file tree
Hide file tree
Showing 18 changed files with 57 additions and 74 deletions.
45 changes: 19 additions & 26 deletions atrium-common/src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub trait Resolver {
fn resolve(
&self,
input: &Self::Input,
) -> impl Future<Output = core::result::Result<Option<Self::Output>, Self::Error>>;
) -> impl Future<Output = core::result::Result<Self::Output, Self::Error>>;
}

#[cfg(test)]
Expand Down Expand Up @@ -52,10 +52,17 @@ mod tests {
type Output = String;
type Error = Error;

async fn resolve(&self, input: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(
&self,
input: &Self::Input,
) -> core::result::Result<Self::Output, Self::Error> {
sleep(Duration::from_millis(10)).await;
*self.counts.write().await.entry(input.clone()).or_default() += 1;
Ok(self.data.get(input).cloned())
if let Some(value) = self.data.get(input) {
Ok(value.clone())
} else {
Err(Error::NotFound)
}
}
}

Expand Down Expand Up @@ -87,12 +94,8 @@ mod tests {
] {
let result = resolver.resolve(&input.to_string()).await;
match expected {
Some(value) => {
assert_eq!(result.expect("failed to resolve").as_deref(), Some(value))
}
None => {
assert_eq!(result.expect("failed to resolve").as_deref(), None)
}
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
None => assert!(result.is_err()),
}
}
assert_eq!(
Expand All @@ -119,12 +122,8 @@ mod tests {
] {
let result = resolver.resolve(&input.to_string()).await;
match expected {
Some(value) => {
assert_eq!(result.expect("failed to resolve").as_deref(), Some(value))
}
None => {
assert_eq!(result.expect("failed to resolve").as_deref(), None)
}
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
None => assert!(result.is_err()),
}
}
assert_eq!(
Expand Down Expand Up @@ -152,12 +151,8 @@ mod tests {
] {
let result = resolver.resolve(&input.to_string()).await;
match expected {
Some(value) => {
assert_eq!(result.expect("failed to resolve").as_deref(), Some(value))
}
None => {
assert_eq!(result.expect("failed to resolve").as_deref(), None)
}
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
None => assert!(result.is_err()),
}
}
assert_eq!(
Expand All @@ -178,12 +173,12 @@ mod tests {
}));
for _ in 0..10 {
let result = resolver.resolve(&String::from("k1")).await;
assert_eq!(result.expect("failed to resolve").as_deref(), Some("v1"));
assert_eq!(result.expect("failed to resolve"), "v1");
}
sleep(Duration::from_millis(10)).await;
for _ in 0..10 {
let result = resolver.resolve(&String::from("k1")).await;
assert_eq!(result.expect("failed to resolve").as_deref(), Some("v1"));
assert_eq!(result.expect("failed to resolve"), "v1");
}
assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect());
}
Expand Down Expand Up @@ -212,9 +207,7 @@ mod tests {
Some(value) => {
assert_eq!(result.expect("failed to resolve").as_deref(), Some(value))
}
None => {
assert_eq!(result.expect("failed to resolve").as_deref(), None)
}
None => assert!(result.is_err()),
}
}
assert_eq!(
Expand Down
17 changes: 6 additions & 11 deletions atrium-common/src/resolver/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,12 @@ where
type Output = R::Output;
type Error = R::Error;

async fn resolve(&self, input: &Self::Input) -> Result<Option<Self::Output>, Self::Error> {
match self.cache.get(input).await {
Some(cached) => Ok(Some(cached)),
None => {
let result = self.inner.resolve(input).await?;

if let Some(result) = result.as_ref().cloned() {
self.cache.set(input.clone(), result.clone()).await;
}
Ok(result)
}
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
if let Some(output) = self.cache.get(input).await {
return Ok(output);
}
let output = self.inner.resolve(input).await?;
self.cache.set(input.clone(), output.clone()).await;
Ok(output)
}
}
2 changes: 2 additions & 0 deletions atrium-common/src/resolver/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub type Result<T> = core::result::Result<T, Error>;

#[derive(Error, Debug)]
pub enum Error {
#[error("resource not found")]
NotFound,
#[error("dns resolver error: {0}")]
DnsResolver(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(transparent)]
Expand Down
8 changes: 4 additions & 4 deletions atrium-common/src/resolver/throttled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ where
R::Output: Clone + Send + Sync + 'static,
{
type Input = R::Input;
type Output = R::Output;
type Output = Option<R::Output>;
type Error = R::Error;

async fn resolve(&self, input: &Self::Input) -> Result<Option<Self::Output>, Self::Error> {
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
match self.pending.entry(input.clone()) {
Entry::Occupied(occupied) => {
let tx = occupied.get().lock().await.clone();
Expand All @@ -34,9 +34,9 @@ where
let (tx, _) = channel(1);
vacant.insert(Arc::new(Mutex::new(tx.clone())));
let result = self.inner.resolve(input).await;
let _ = tx.send(result.as_ref().cloned().transpose().and_then(Result::ok));
tx.send(result.as_ref().ok().cloned()).ok();
self.pending.remove(input);
result
result.map(Some)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion atrium-common/src/types/cached/impl/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ where
};
Self { inner: Arc::new(Mutex::new(store)), expiration: config.time_to_live }
}
async fn get(&self, key: &Self::Input) -> Option<Self::Output> {
async fn get(&self, key: &Self::Input) -> Self::Output {
let mut cache = self.inner.lock().await;
if let Some(ValueWithInstant { value, instant }) = cache.get(key) {
if let Some(expiration) = self.expiration {
Expand Down
2 changes: 1 addition & 1 deletion atrium-oauth/identity/src/did/common_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ where
type Output = DidDocument;
type Error = Error;

async fn resolve(&self, did: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, did: &Self::Input) -> Result<Self::Output> {
match did.strip_prefix("did:").and_then(|s| s.split_once(':').map(|(method, _)| method)) {
Some("plc") => self.plc_resolver.resolve(did).await,
Some("web") => self.web_resolver.resolve(did).await,
Expand Down
2 changes: 1 addition & 1 deletion atrium-oauth/identity/src/did/plc_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where
type Output = DidDocument;
type Error = Error;

async fn resolve(&self, did: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, did: &Self::Input) -> Result<Self::Output> {
let uri = Builder::from(self.plc_directory_url.parse::<Uri>()?)
.path_and_query(format!("/{}", did.as_str()))
.build()?;
Expand Down
2 changes: 1 addition & 1 deletion atrium-oauth/identity/src/did/web_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ where
type Output = DidDocument;
type Error = Error;

async fn resolve(&self, did: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, did: &Self::Input) -> Result<Self::Output> {
let document_url = format!(
"https://{}/.well-known/did.json",
did.as_str()
Expand Down
1 change: 1 addition & 0 deletions atrium-oauth/identity/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ impl From<resolver::Error> for Error {
resolver::Error::SerdeJson(error) => Error::SerdeJson(error),
resolver::Error::SerdeHtmlForm(error) => Error::SerdeHtmlForm(error),
resolver::Error::Uri(error) => Error::Uri(error),
resolver::Error::NotFound => Error::NotFound,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions atrium-oauth/identity/src/handle/appview_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ where
type Output = Did;
type Error = Error;

async fn resolve(&self, handle: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, handle: &Self::Input) -> Result<Self::Output> {
let uri = Builder::from(self.service_url.parse::<Uri>()?)
.path_and_query(format!(
"/xrpc/com.atproto.identity.resolveHandle?{}",
Expand All @@ -49,7 +49,7 @@ where
.await
.map_err(Error::HttpClient)?;
if res.status().is_success() {
Ok(Some(serde_json::from_slice::<resolve_handle::OutputData>(res.body())?.did))
Ok(serde_json::from_slice::<resolve_handle::OutputData>(res.body())?.did)
} else {
Err(Error::HttpStatus(res.status()))
}
Expand Down
2 changes: 1 addition & 1 deletion atrium-oauth/identity/src/handle/atproto_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ where
type Output = Did;
type Error = Error;

async fn resolve(&self, handle: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, handle: &Self::Input) -> Result<Self::Output> {
let d_fut = self.dns.resolve(handle);
let h_fut = self.http.resolve(handle);
if let Ok(did) = d_fut.await {
Expand Down
4 changes: 2 additions & 2 deletions atrium-oauth/identity/src/handle/dns_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ where
type Output = Did;
type Error = Error;

async fn resolve(&self, handle: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, handle: &Self::Input) -> Result<Self::Output> {
for result in self
.dns_txt_resolver
.resolve(&format!("{SUBDOMAIN}.{}", handle.as_ref()))
.await
.map_err(Error::DnsResolver)?
{
if let Some(did) = result.strip_prefix(PREFIX) {
return Some(did.parse::<Did>().map_err(|e| Error::Did(e.to_string()))).transpose();
return did.parse::<Did>().map_err(|e| Error::Did(e.to_string()));
}
}
Err(Error::NotFound)
Expand Down
4 changes: 2 additions & 2 deletions atrium-oauth/identity/src/handle/well_known_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ where
type Output = Did;
type Error = Error;

async fn resolve(&self, handle: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, handle: &Self::Input) -> Result<Self::Output> {
let url = format!("https://{}{WELL_KNWON_PATH}", handle.as_str());
// TODO: no-cache?
let res = self
Expand All @@ -41,7 +41,7 @@ where
.map_err(Error::HttpClient)?;
if res.status().is_success() {
let text = String::from_utf8_lossy(res.body()).to_string();
Some(text.parse::<Did>().map_err(|e| Error::Did(e.to_string()))).transpose()
text.parse::<Did>().map_err(|e| Error::Did(e.to_string()))
} else {
Err(Error::HttpStatus(res.status()))
}
Expand Down
15 changes: 5 additions & 10 deletions atrium-oauth/identity/src/identity_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,13 @@ where
type Output = ResolvedIdentity;
type Error = Error;

async fn resolve(&self, input: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
let document =
match input.parse::<AtIdentifier>().map_err(|e| Error::AtIdentifier(e.to_string()))? {
AtIdentifier::Did(did) => {
let result = self.did_resolver.resolve(&did).await?;
result.ok_or_else(|| Error::NotFound)?
}
AtIdentifier::Did(did) => self.did_resolver.resolve(&did).await?,
AtIdentifier::Handle(handle) => {
let result = self.handle_resolver.resolve(&handle).await?;
let did = result.ok_or_else(|| Error::NotFound)?;
let result = self.did_resolver.resolve(&did).await?;
let document = result.ok_or_else(|| Error::NotFound)?;
let did = self.handle_resolver.resolve(&handle).await?;
let document = self.did_resolver.resolve(&did).await?;
if let Some(aka) = &document.also_known_as {
if !aka.contains(&format!("at://{}", handle.as_str())) {
return Err(Error::DidDocument(format!(
Expand All @@ -67,6 +62,6 @@ where
document.id
)));
};
Ok(Some(ResolvedIdentity { did: document.id, pds: service }))
Ok(ResolvedIdentity { did: document.id, pds: service })
}
}
4 changes: 1 addition & 3 deletions atrium-oauth/oauth-client/src/oauth_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ where
} else {
self.client_metadata.redirect_uris[0].clone()
};
let result = self.resolver.resolve(input.as_ref()).await?;
let (metadata, identity) =
result.ok_or_else(|| Error::Identity(atrium_identity::Error::NotFound))?;
let (metadata, identity) = self.resolver.resolve(input.as_ref()).await?;
let Some(dpop_key) = Self::generate_dpop_key(&metadata) else {
return Err(Error::Authorize("none of the algorithms worked".into()));
};
Expand Down
9 changes: 4 additions & 5 deletions atrium-oauth/oauth-client/src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ where
&self,
input: &str,
) -> Result<(OAuthAuthorizationServerMetadata, ResolvedIdentity)> {
let result = self.identity_resolver.resolve(input).await;
let identity = result.and_then(|result| result.ok_or_else(|| Error::NotFound))?;
let identity = self.identity_resolver.resolve(input).await?;
let metadata = self.get_resource_server_metadata(&identity.pds).await?;
Ok((metadata, identity))
}
Expand Down Expand Up @@ -193,15 +192,15 @@ where
type Output = (OAuthAuthorizationServerMetadata, Option<ResolvedIdentity>);
type Error = Error;

async fn resolve(&self, input: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
// Allow using an entryway, or PDS url, directly as login input (e.g.
// when the user forgot their handle, or when the handle does not
// resolve to a DID)
Ok(if input.starts_with("https://") {
Some((self.resolve_from_service(input.as_ref()).await?, None))
(self.resolve_from_service(input.as_ref()).await?, None)
} else {
let (metadata, identity) = self.resolve_from_identity(input).await?;
Some((metadata, Some(identity)))
(metadata, Some(identity))
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ where
type Output = OAuthAuthorizationServerMetadata;
type Error = Error;

async fn resolve(&self, issuer: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, issuer: &Self::Input) -> Result<Self::Output> {
let uri = Builder::from(issuer.parse::<Uri>()?)
.path_and_query("/.well-known/oauth-authorization-server")
.build()?;
Expand All @@ -38,7 +38,7 @@ where
let metadata = serde_json::from_slice::<OAuthAuthorizationServerMetadata>(res.body())?;
// https://datatracker.ietf.org/doc/html/rfc8414#section-3.3
if &metadata.issuer == issuer {
Ok(Some(metadata))
Ok(metadata)
} else {
Err(Error::AuthorizationServerMetadata(format!(
"invalid issuer: {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ where
type Output = OAuthProtectedResourceMetadata;
type Error = Error;

async fn resolve(&self, resource: &Self::Input) -> Result<Option<Self::Output>> {
async fn resolve(&self, resource: &Self::Input) -> Result<Self::Output> {
let uri = Builder::from(resource.parse::<Uri>()?)
.path_and_query("/.well-known/oauth-protected-resource")
.build()?;
Expand All @@ -38,7 +38,7 @@ where
let metadata = serde_json::from_slice::<OAuthProtectedResourceMetadata>(res.body())?;
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-08#section-3.3
if &metadata.resource == resource {
Ok(Some(metadata))
Ok(metadata)
} else {
Err(Error::ProtectedResourceMetadata(format!(
"invalid resource: {}",
Expand Down

0 comments on commit c6d51fe

Please sign in to comment.