diff --git a/codecov.yml b/codecov.yml index 5590299..93e5950 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,9 +1,7 @@ coverage: status: - project: - default: - target: auto - threshold: 10% + project: off + patch: off github_checks: annotations: false diff --git a/server/src/auto_start.rs b/server/src/auto_start.rs index e814545..4b36e3b 100644 --- a/server/src/auto_start.rs +++ b/server/src/auto_start.rs @@ -1,6 +1,4 @@ use std::collections::HashMap; -use std::ops::DerefMut; -use std::sync::Arc; use anyhow::Context; use diesel_async::{AsyncPgConnection, RunQueryDsl}; @@ -9,20 +7,14 @@ use scuffle_context::ContextFutExt; use crate::database::ci_run::CiRun; use crate::database::pr::Pr; +use crate::database::DatabaseConnection; use crate::github::merge_workflow::GitHubMergeWorkflow; use crate::github::repo::GitHubRepoClient; +use crate::BrawlState; pub struct AutoStartSvc; -pub trait AutoStartConfig: Send + Sync + 'static { - type RepoClient: GitHubRepoClient; - - fn repo_client(&self, repo_id: RepositoryId) -> Option>; - - fn database( - &self, - ) -> impl std::future::Future + Send>> + Send; - +pub trait AutoStartConfig: BrawlState { fn interval(&self) -> std::time::Duration; } @@ -36,8 +28,8 @@ impl scuffle_bootstrap::Service for AutoStartSvc { while tokio::time::sleep(global.interval()).with_context(&ctx).await.is_some() { let mut conn = global.database().await?; - let runs = get_pending_runs(&mut conn).await?; - process_pending_runs(&mut conn, runs, global.as_ref()).await?; + let runs = get_pending_runs(conn.get()).await?; + process_pending_runs(conn.get(), runs, global.as_ref()).await?; } Ok(()) @@ -79,12 +71,12 @@ async fn process_pending_runs( global: &impl AutoStartConfig, ) -> anyhow::Result<()> { for run in runs { - let Some(repo_client) = global.repo_client(RepositoryId(run.github_repo_id as u64)) else { + let Some(repo_client) = global.get_repo(None, RepositoryId(run.github_repo_id as u64)).await else { tracing::error!("no installation client found for repo {}", run.github_repo_id); continue; }; - if let Err(e) = handle_run(conn, &run, repo_client.as_ref()).await { + if let Err(e) = handle_run(conn, &run, &repo_client).await { tracing::error!( "error handling run (repo id: {}, pr number: {}, run id: {}): {}", run.github_repo_id, @@ -128,15 +120,18 @@ async fn handle_run( #[cfg(test)] #[cfg_attr(coverage_nightly, coverage(off))] mod tests { + use std::sync::Arc; + use chrono::Utc; - use octocrab::models::UserId; + use octocrab::models::{InstallationId, UserId}; use super::*; use crate::database::ci_run::Base; use crate::database::enums::GithubCiRunStatus; - use crate::database::get_test_connection; + use crate::database::{get_test_connection, DatabaseConnection}; use crate::github::models::PullRequest; use crate::github::repo::test_utils::MockRepoClient; + use crate::github::repo::RepoClientRef; fn run() -> CiRun<'static> { CiRun { @@ -453,26 +448,23 @@ mod tests { struct AutoStartCfg(Arc>); - impl AutoStartConfig for AutoStartCfg { - type RepoClient = MockRepoClient; - - fn repo_client(&self, _: RepositoryId) -> Option> { - Some(self.0.clone()) + impl BrawlState for AutoStartCfg { + async fn get_repo(&self, _: Option, _: RepositoryId) -> Option { + Some(RepoClientRef::new(self.0.clone())) } - fn database( - &self, - ) -> impl std::future::Future + Send>> + Send - { + fn database(&self) -> impl std::future::Future> + Send { fn get_conn() -> diesel_async::pooled_connection::bb8::PooledConnection<'static, AsyncPgConnection> { unreachable!() } std::future::ready(Ok(get_conn())) } + } + impl AutoStartConfig for AutoStartCfg { fn interval(&self) -> std::time::Duration { - todo!() + unreachable!("this method should not be called") } } diff --git a/server/src/bin/server.rs b/server/src/bin/server.rs index 77e8b63..0133cfd 100644 --- a/server/src/bin/server.rs +++ b/server/src/bin/server.rs @@ -2,7 +2,6 @@ #![cfg_attr(all(coverage_nightly, test), coverage(off))] use std::net::SocketAddr; -use std::ops::DerefMut; use std::sync::Arc; use anyhow::Context; @@ -16,8 +15,9 @@ use scuffle_bootstrap_telemetry::opentelemetry_sdk::metrics::SdkMeterProvider; use scuffle_bootstrap_telemetry::opentelemetry_sdk::Resource; use scuffle_bootstrap_telemetry::prometheus_client::registry::Registry; use scuffle_brawl::database::schema::health_check; +use scuffle_brawl::database::DatabaseConnection; use scuffle_brawl::github::models::Installation; -use scuffle_brawl::github::repo::{GitHubRepoClient, RepoClient}; +use scuffle_brawl::github::repo::GitHubRepoClient; use scuffle_brawl::github::GitHubService; use scuffle_metrics::opentelemetry::KeyValue; use tracing_subscriber::layer::SubscriberExt; @@ -171,14 +171,6 @@ impl scuffle_brawl::webhook::WebhookConfig for Global { Some(self.config.github.webhook_bind) } - fn get_repo( - &self, - installation_id: InstallationId, - repo_id: RepositoryId, - ) -> Option> { - self.github_service.get_client(installation_id)?.get_repo_client(repo_id) - } - async fn add_repo(&self, installation_id: InstallationId, repo_id: RepositoryId) -> anyhow::Result<()> { let client = self .github_service @@ -202,33 +194,36 @@ impl scuffle_brawl::webhook::WebhookConfig for Global { Ok(()) } - fn delete_installation(&self, installation_id: InstallationId) -> anyhow::Result<()> { + async fn delete_installation(&self, installation_id: InstallationId) -> anyhow::Result<()> { self.github_service.delete_installation(installation_id); Ok(()) } - - async fn database(&self) -> anyhow::Result + Send> { - let conn = self.database.get().await.context("get database connection")?; - Ok(conn) - } } impl scuffle_brawl::auto_start::AutoStartConfig for Global { - type RepoClient = RepoClient; - fn interval(&self) -> std::time::Duration { std::time::Duration::from_secs(self.config.interval_seconds) } +} - async fn database(&self) -> anyhow::Result + Send> { - let conn = self.database.get().await.context("get database connection")?; - Ok(conn) +impl scuffle_brawl::BrawlState for Global { + async fn get_repo( + &self, + installation_id: Option, + repo_id: RepositoryId, + ) -> Option { + let installation = if let Some(installation_id) = installation_id { + self.github_service.get_client(installation_id) + } else { + self.github_service.get_client_by_repo(repo_id) + }?; + + installation.get_repo_client(repo_id).await } - fn repo_client(&self, repo_id: octocrab::models::RepositoryId) -> Option> { - let client = self.github_service.get_client_by_repo(repo_id)?; - let repo = client.get_repo_client(repo_id)?; - Some(repo) + async fn database(&self) -> anyhow::Result { + let conn = self.database.get().await.context("get database connection")?; + Ok(conn) } } diff --git a/server/src/command/mod.rs b/server/src/command/mod.rs index 2677962..75d91ab 100644 --- a/server/src/command/mod.rs +++ b/server/src/command/mod.rs @@ -1,5 +1,6 @@ use std::str::FromStr; +use anyhow::Context; use auto_try::AutoTryCommand; use diesel_async::AsyncPgConnection; use dry_run::DryRunCommand; @@ -38,12 +39,12 @@ impl BrawlCommand { context: BrawlCommandContext<'_, R>, ) -> anyhow::Result<()> { match self { - BrawlCommand::DryRun(command) => dry_run::handle(conn, context, command).await, - BrawlCommand::Merge(command) => merge::handle(conn, context, command).await, - BrawlCommand::Retry => retry::handle(conn, context).await, - BrawlCommand::Cancel => cancel::handle(conn, context).await, - BrawlCommand::Ping => ping::handle(conn, context).await, - BrawlCommand::AutoTry(command) => auto_try::handle(conn, context, command).await, + BrawlCommand::DryRun(command) => dry_run::handle(conn, context, command).await.context("dry run"), + BrawlCommand::Merge(command) => merge::handle(conn, context, command).await.context("merge"), + BrawlCommand::Retry => retry::handle(conn, context).await.context("retry"), + BrawlCommand::Cancel => cancel::handle(conn, context).await.context("cancel"), + BrawlCommand::Ping => ping::handle(conn, context).await.context("ping"), + BrawlCommand::AutoTry(command) => auto_try::handle(conn, context, command).await.context("auto try"), } } } diff --git a/server/src/database/mod.rs b/server/src/database/mod.rs index 13a4f62..2b8637e 100644 --- a/server/src/database/mod.rs +++ b/server/src/database/mod.rs @@ -22,6 +22,9 @@ macro_rules! test_query { }; } +use std::marker::PhantomData; + +use diesel_async::AsyncPgConnection; #[cfg(test)] use test_query; @@ -51,3 +54,40 @@ pub async fn get_test_connection() -> diesel_async::AsyncPgConnection { conn } + +pub trait DatabaseConnection { + fn get(&mut self) -> &mut AsyncPgConnection; +} + +impl DatabaseConnection for AsyncPgConnection { + fn get(&mut self) -> &mut AsyncPgConnection { + self + } +} + +impl DatabaseConnection for diesel_async::pooled_connection::bb8::PooledConnection<'_, AsyncPgConnection> { + fn get(&mut self) -> &mut AsyncPgConnection { + self + } +} + +impl DatabaseConnection for &mut T { + fn get(&mut self) -> &mut AsyncPgConnection { + DatabaseConnection::get(*self) + } +} + +pub struct DatabaseConnectionRef { + inner: T, + _db: PhantomData, +} + +impl DatabaseConnection for DatabaseConnectionRef +where + T: AsMut, + D: DatabaseConnection, +{ + fn get(&mut self) -> &mut AsyncPgConnection { + self.inner.as_mut().get() + } +} diff --git a/server/src/github/installation.rs b/server/src/github/installation.rs index a7cb3a8..e11e269 100644 --- a/server/src/github/installation.rs +++ b/server/src/github/installation.rs @@ -1,11 +1,11 @@ use std::collections::hash_map::Entry; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Duration; use anyhow::Context; use axum::http; -use futures::TryStreamExt; +use futures::{StreamExt, TryStreamExt}; use moka::future::Cache; use octocrab::models::{InstallationRepositories, RepositoryId, UserId}; use octocrab::{GitHubError, Octocrab}; @@ -14,13 +14,15 @@ use parking_lot::Mutex; use super::config::GitHubBrawlRepoConfig; use super::merge_workflow::DefaultMergeWorkflow; use super::models::{Installation, Repository, User}; -use super::repo::RepoClient; +use super::repo::{GitHubRepoClient, RepoClient, RepoClientRef}; +use super::repo_lock::{LockGuard, RepoLock}; pub struct InstallationClient { client: Octocrab, installation: Mutex, repositories: Mutex>>, - user_cache: UserCache, + state: OrgState, + repo_lock: RepoLock, } #[derive(Debug, thiserror::Error)] @@ -36,19 +38,19 @@ pub enum GitHubBrawlRepoConfigError { } #[derive(Debug, Clone)] -pub struct UserCache { +pub struct OrgState { users: Cache>, users_by_name: Cache>, teams: Cache<(String, String), Vec>, } -impl Default for UserCache { +impl Default for OrgState { fn default() -> Self { Self::new(Duration::from_secs(60), 100) } } -impl UserCache { +impl OrgState { pub fn new(ttl: Duration, capacity: u64) -> Self { Self { users: Cache::builder().max_capacity(capacity).time_to_live(ttl).build(), @@ -58,13 +60,25 @@ impl UserCache { } } +struct LockedRepoClient { + client: Arc, + _guard: LockGuard, +} + +impl AsRef for LockedRepoClient { + fn as_ref(&self) -> &RepoClient { + &self.client + } +} + impl InstallationClient { pub fn new(client: Octocrab, installation: Installation) -> Self { Self { client, installation: Mutex::new(installation), repositories: Mutex::new(HashMap::new()), - user_cache: UserCache::default(), + state: OrgState::default(), + repo_lock: RepoLock::new(), } } @@ -120,7 +134,7 @@ impl InstallationClient { repo, config, self.client.clone(), - self.user_cache.clone(), + self.state.clone(), DefaultMergeWorkflow, ))); } @@ -129,8 +143,12 @@ impl InstallationClient { } pub async fn fetch_repositories(self: &Arc) -> anyhow::Result<()> { - let mut repositories = Vec::new(); let mut page = 1; + + let mut futures = futures::stream::FuturesUnordered::new(); + + let mut ids = HashSet::new(); + loop { let resp: InstallationRepositories = self .client @@ -138,31 +156,32 @@ impl InstallationClient { .await .context("get installation repositories")?; - repositories.extend(resp.repositories); - - if repositories.len() >= resp.total_count as usize { + ids.extend(resp.repositories.iter().map(|repo| repo.id)); + + futures.extend(resp.repositories.into_iter().map(|repo| async move { + let repo: Repository = repo.into(); + let r = self.set_repository(repo.clone()).await; + if let Err(e) = r { + tracing::error!( + id = %repo.id, + name = %repo.name, + owner = %repo.owner.login, + "failed to set repository: {:#}", + e + ); + } + })); + + if futures.len() >= resp.total_count as usize { break; } page += 1; } - let mut repos = HashMap::new(); - for repo in repositories { - let config = self.get_repo_config(repo.id).await?; - repos.insert( - repo.id, - Arc::new(RepoClient::new( - repo.into(), - config, - self.client.clone(), - self.user_cache.clone(), - DefaultMergeWorkflow, - )), - ); - } + futures.collect::>().await; - *self.repositories.lock() = repos; + self.repositories.lock().retain(|id, _| ids.contains(id)); Ok(()) } @@ -175,8 +194,12 @@ impl InstallationClient { self.repositories.lock().contains_key(&repo_id) } - pub fn get_repo_client(&self, repo_id: RepositoryId) -> Option> { - self.repositories.lock().get(&repo_id).cloned() + pub async fn get_repo_client(&self, repo_id: RepositoryId) -> Option { + let client = self.repositories.lock().get(&repo_id).cloned()?; + Some(RepoClientRef::new(LockedRepoClient { + client, + _guard: self.repo_lock.lock(repo_id).await, + })) } pub async fn fetch_repository(self: &Arc, id: RepositoryId) -> anyhow::Result<()> { @@ -198,12 +221,16 @@ impl InstallationClient { } pub fn update_installation(&self, installation: Installation) { - tracing::info!("updated installation: {} ({})", installation.account.login, installation.id); + tracing::info!( + id = %installation.id, + owner = %installation.account.login, + "updated installation", + ); *self.installation.lock() = installation; } } -impl UserCache { +impl OrgState { pub async fn get_user(&self, client: &Octocrab, user_id: UserId) -> anyhow::Result> { self.users .try_get_with::<_, octocrab::Error>(user_id, async { @@ -482,7 +509,7 @@ mod tests { let installation_client = task.await.unwrap(); assert_eq!(installation_client.repositories(), vec![RepositoryId(1296269)]); - assert!(installation_client.get_repo_client(RepositoryId(1296269)).is_some()); + assert!(installation_client.get_repo_client(RepositoryId(1296269)).await.is_some()); } #[tokio::test] @@ -539,7 +566,7 @@ mod tests { let installation_client = task.await.unwrap(); assert!(installation_client.has_repository(RepositoryId(899726767))); - assert!(installation_client.get_repo_client(RepositoryId(899726767)).is_some()); + assert!(installation_client.get_repo_client(RepositoryId(899726767)).await.is_some()); } #[tokio::test] @@ -603,7 +630,7 @@ mod tests { #[tokio::test] async fn test_user_cache() { let (client, mut handle) = mock_octocrab(); - let cache = UserCache::new(std::time::Duration::from_millis(100), 100); + let cache = OrgState::new(std::time::Duration::from_millis(100), 100); let task = tokio::spawn(async move { cache.get_user(&client, UserId(49777269)).await.unwrap(); // cache miss diff --git a/server/src/github/merge_workflow.rs b/server/src/github/merge_workflow.rs index d096d4a..244a6a8 100644 --- a/server/src/github/merge_workflow.rs +++ b/server/src/github/merge_workflow.rs @@ -180,6 +180,7 @@ impl GitHubMergeWorkflow for &T { } } +#[derive(Debug, Clone, Copy, Default)] pub struct DefaultMergeWorkflow; impl DefaultMergeWorkflow { @@ -244,7 +245,8 @@ impl DefaultMergeWorkflow { .context("update")? == 0 { - anyhow::bail!("run already completed"); + tracing::warn!("run already completed"); + return Ok(()); } let checks_message = format_fn(|f| { diff --git a/server/src/github/mod.rs b/server/src/github/mod.rs index a702c7a..2a287f9 100644 --- a/server/src/github/mod.rs +++ b/server/src/github/mod.rs @@ -14,6 +14,7 @@ pub mod merge_workflow; pub mod messages; pub mod models; pub mod repo; +pub mod repo_lock; pub struct GitHubService { client: Octocrab, @@ -149,7 +150,7 @@ mod test_utils { use octocrab::{AuthState, Octocrab, OctocrabBuilder}; use super::config::GitHubBrawlRepoConfig; - use super::installation::UserCache; + use super::installation::OrgState; use super::merge_workflow::GitHubMergeWorkflow; use super::models::{Repository, User}; use super::repo::RepoClient; @@ -227,13 +228,13 @@ mod test_utils { (client, handle) } - pub fn mock_repo_client( + pub async fn mock_repo_client( octocrab: Octocrab, repo: Repository, config: GitHubBrawlRepoConfig, merge_workflow: MockMergeWorkflow, ) -> RepoClient { - RepoClient::new(repo, config, octocrab, UserCache::default(), merge_workflow) + RepoClient::new(repo, config, octocrab, OrgState::default(), merge_workflow) } pub fn default_repo() -> Repository { diff --git a/server/src/github/repo.rs b/server/src/github/repo.rs index 9e360c4..587ff95 100644 --- a/server/src/github/repo.rs +++ b/server/src/github/repo.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; +use std::marker::PhantomData; use std::sync::Arc; use std::time::Duration; @@ -12,7 +14,7 @@ use octocrab::params::{self}; use octocrab::{GitHubError, Octocrab}; use super::config::{GitHubBrawlRepoConfig, Permission, Role}; -use super::installation::UserCache; +use super::installation::OrgState; use super::merge_workflow::{DefaultMergeWorkflow, GitHubMergeWorkflow}; use super::messages::{CommitMessage, IssueMessage}; use super::models::{Commit, Label, PullRequest, Repository, Review, User, WorkflowRun}; @@ -21,11 +23,80 @@ pub struct RepoClient { pub(super) repo: ArcSwap, pub(super) config: ArcSwap, client: Octocrab, - user_cache: UserCache, + org_state: OrgState, merge_workflow: W, role_users: Cache>, } +pub struct RepoClientRef(T, PhantomData); + +impl RepoClientRef +where + T: AsRef, + C: GitHubRepoClient, +{ + pub fn new(client: T) -> Self { + Self(client, PhantomData) + } +} + +// Helper macro to forward all methods from the inner client +macro_rules! forward_fns { + ( + $(fn $fn:ident(&self $(,$arg:ident: $arg_ty:ty)*$(,)?) -> $ret:ty;)* + ) => { + $( + #[cfg_attr(all(test, coverage_nightly), coverage(off))] + fn $fn(&self $(,$arg: $arg_ty)*) -> $ret { + self.0.as_ref().$fn($($arg),*) + } + )* + }; +} + +impl GitHubRepoClient for RepoClientRef +where + T: AsRef + Send + Sync, + C: GitHubRepoClient, +{ + type MergeWorkflow<'a> + = C::MergeWorkflow<'a> + where + C: 'a, + T: 'a; + + // Forward all methods from the inner client + forward_fns! { + fn id(&self) -> RepositoryId; + fn add_labels(&self, issue_number: u64, labels: &[String]) -> impl std::future::Future>> + Send; + fn remove_label(&self, issue_number: u64, labels: &str) -> impl std::future::Future>> + Send; + fn branch_workflows(&self, branch: &str) -> impl std::future::Future>> + Send; + fn can_merge(&self, user_id: UserId) -> impl std::future::Future> + Send; + fn can_try(&self, user_id: UserId) -> impl std::future::Future> + Send; + fn can_review(&self, user_id: UserId) -> impl std::future::Future> + Send; + fn config(&self) -> Arc; + fn owner(&self) -> String; + fn name(&self) -> String; + fn merge_workflow(&self) -> Self::MergeWorkflow<'_>; + fn cancel_workflow_run(&self, run_id: RunId) -> impl std::future::Future> + Send; + fn has_permission(&self, user_id: UserId, permissions: &[Permission]) -> impl std::future::Future> + Send; + fn get_user(&self, user_id: UserId) -> impl std::future::Future>> + Send; + fn create_commit(&self, message: String, parents: Vec, tree: String) -> impl std::future::Future> + Send; + fn get_ref_latest_commit(&self, gh_ref: ¶ms::repos::Reference) -> impl std::future::Future>> + Send; + fn push_branch(&self, branch: &str, sha: &str, force: bool) -> impl std::future::Future> + Send; + fn delete_branch(&self, branch: &str) -> impl std::future::Future> + Send; + fn get_pull_request(&self, number: u64) -> impl std::future::Future> + Send; + fn get_role_members(&self, role: Role) -> impl std::future::Future>> + Send; + fn get_reviewers(&self, pr_number: u64) -> impl std::future::Future>> + Send; + fn send_message(&self, issue_number: u64, message: &IssueMessage) -> impl std::future::Future> + Send; + fn get_commit(&self, sha: &str) -> impl std::future::Future>> + Send; + fn create_merge(&self, message: &CommitMessage, base_sha: &str, head_sha: &str) -> impl std::future::Future> + Send; + fn commit_link(&self, sha: &str) -> String; + fn workflow_run_link(&self, run_id: RunId) -> String; + fn pr_link(&self, pr_number: u64) -> String; + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum MergeResult { Success(Commit), @@ -188,7 +259,7 @@ impl RepoClient { repo: Repository, config: GitHubBrawlRepoConfig, client: Octocrab, - user_cache: UserCache, + user_cache: OrgState, merge_workflow: W, ) -> Self { tracing::info!( @@ -202,7 +273,7 @@ impl RepoClient { repo: ArcSwap::from_pointee(repo), config: ArcSwap::from_pointee(config), client, - user_cache, + org_state: user_cache, merge_workflow, role_users: Cache::builder() .max_capacity(50) @@ -239,7 +310,7 @@ impl GitHubRepoClient for RepoClient { } async fn get_user(&self, user_id: UserId) -> anyhow::Result> { - self.user_cache.get_user(&self.client, user_id).await + self.org_state.get_user(&self.client, user_id).await } async fn get_pull_request(&self, number: u64) -> anyhow::Result { @@ -458,13 +529,13 @@ impl GitHubRepoClient for RepoClient { } } Permission::Team(team) => { - let users = self.user_cache.get_team_users(&self.client, &owner, team).await?; + let users = self.org_state.get_team_users(&self.client, &owner, team).await?; if users.contains(&user_id) { return Ok(true); } } Permission::User(user) => { - if let Some(user) = self.user_cache.get_user_by_name(&self.client, user).await? { + if let Some(user) = self.org_state.get_user_by_name(&self.client, user).await? { if user.id == user_id { return Ok(true); } @@ -528,6 +599,7 @@ impl GitHubRepoClient for RepoClient { pub mod test_utils { use super::*; + #[derive(Debug)] pub struct MockRepoClient { pub id: RepositoryId, pub owner: String, @@ -890,7 +962,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; assert_eq!(repo_client.id(), RepositoryId(899726767)); assert_eq!(repo_client.owner(), "ScuffleCloud"); @@ -907,7 +980,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { let user = repo_client.get_user(UserId(49777269)).await.unwrap().unwrap(); @@ -955,7 +1029,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { let pull_request = repo_client.get_pull_request(22).await.unwrap(); @@ -998,7 +1073,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { let members = repo_client.get_role_members(Role::Admin).await.unwrap(); @@ -1146,7 +1222,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo_client @@ -1192,7 +1269,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo_client @@ -1234,7 +1312,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo_client @@ -1292,7 +1371,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo_client.delete_branch("feature/1").await.unwrap(); @@ -1356,7 +1436,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo_client @@ -1451,7 +1532,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { assert!(repo_client @@ -1614,7 +1696,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo_client.get_reviewers(22).await.unwrap(); @@ -1653,7 +1736,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo_client @@ -1766,7 +1850,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { assert!(matches!( @@ -1927,7 +2012,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo_client.add_labels(1, &["queued".to_string()]).await.unwrap(); @@ -1972,7 +2058,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo_client.remove_label(1, "some_label").await.unwrap(); @@ -2011,7 +2098,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; assert_eq!( repo.workflow_run_link(RunId(1)), "https://github.com/ScuffleCloud/ci-testing/actions/runs/1" @@ -2026,7 +2114,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; assert_eq!(repo.pr_link(1), "https://github.com/ScuffleCloud/ci-testing/pull/1"); } @@ -2038,7 +2127,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; assert_eq!(repo.owner(), "ScuffleCloud"); } @@ -2050,7 +2140,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; assert_eq!(repo.name(), "ci-testing"); } @@ -2062,7 +2153,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; assert_eq!( repo.commit_link("b7f8cd1bd474d5be1802377c9a0baea5eb59fcb6"), "https://github.com/ScuffleCloud/ci-testing/commit/b7f8cd1bd474d5be1802377c9a0baea5eb59fcb6" @@ -2078,7 +2170,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo.branch_workflows("main").await.unwrap(); }); @@ -2114,7 +2207,8 @@ mod tests { default_repo(), GitHubBrawlRepoConfig::default(), MockMergeWorkflow(1), - ); + ) + .await; let task = tokio::spawn(async move { repo.cancel_workflow_run(RunId(1)).await.unwrap(); }); diff --git a/server/src/github/repo_lock.rs b/server/src/github/repo_lock.rs new file mode 100644 index 0000000..e409abe --- /dev/null +++ b/server/src/github/repo_lock.rs @@ -0,0 +1,91 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use octocrab::models::RepositoryId; +use parking_lot::Mutex; +use tokio::sync::OwnedSemaphorePermit; + +#[derive(Debug, Clone)] +pub struct RepoLock { + active_repos: Arc>>, +} + +#[derive(Debug, Clone)] +struct Lock { + count: Arc, + semaphore: Arc, +} + +pub struct LockGuard { + _inner: LockGuardInner, + _permit: OwnedSemaphorePermit, +} + +struct LockGuardInner { + repo_id: RepositoryId, + count: Arc, + repo_lock: RepoLock, +} + +impl LockGuardInner { + fn new(repo_id: RepositoryId, count: Arc, repo_lock: RepoLock) -> Self { + count.fetch_add(1, Ordering::Relaxed); + Self { + repo_id, + count, + repo_lock, + } + } + + fn into_guard(self, permit: OwnedSemaphorePermit) -> LockGuard { + LockGuard { + _inner: self, + _permit: permit, + } + } +} + +impl Drop for LockGuardInner { + fn drop(&mut self) { + if self.count.fetch_sub(1, Ordering::Relaxed) == 1 { + let mut lock = self.repo_lock.active_repos.lock(); + // The reason we do an extra load here is that the mutex is required to do a + // increment, therefore its possible that the count might have been + // incremented by another thread after we decremented it to 0, but + // before we could aquire the lock. So the additional check is to ensure we + // the count is 0 before deleting the entry. + if self.count.load(Ordering::Relaxed) == 0 { + lock.remove(&self.repo_id); + } + } + } +} + +impl Lock { + fn new() -> Self { + Self { + count: Arc::new(AtomicUsize::new(0)), + semaphore: Arc::new(tokio::sync::Semaphore::new(1)), + } + } +} + +impl RepoLock { + pub fn new() -> Self { + Self { + active_repos: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub async fn lock(&self, repo_id: RepositoryId) -> LockGuard { + let (inner, semaphore) = { + let mut lock = self.active_repos.lock(); + let lock = lock.entry(repo_id).or_insert_with(Lock::new); + let inner = LockGuardInner::new(repo_id, lock.count.clone(), self.clone()); + (inner, lock.semaphore.clone()) + }; + + inner.into_guard(semaphore.acquire_owned().await.expect("semaphore has no permits")) + } +} diff --git a/server/src/lib.rs b/server/src/lib.rs index 1865461..153ed88 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -1,5 +1,9 @@ #![cfg_attr(coverage_nightly, feature(coverage_attribute))] +use database::DatabaseConnection; +use github::repo::GitHubRepoClient; +use octocrab::models::{InstallationId, RepositoryId}; + pub mod auto_start; pub mod command; pub mod database; @@ -7,3 +11,13 @@ pub mod github; pub mod webhook; mod utils; + +pub trait BrawlState: Send + Sync + 'static { + fn get_repo( + &self, + installation_id: Option, + repo_id: RepositoryId, + ) -> impl std::future::Future> + Send; + + fn database(&self) -> impl std::future::Future> + Send; +} diff --git a/server/src/webhook/mod.rs b/server/src/webhook/mod.rs index 55b4529..42943e7 100644 --- a/server/src/webhook/mod.rs +++ b/server/src/webhook/mod.rs @@ -1,5 +1,4 @@ use std::net::SocketAddr; -use std::ops::DerefMut; use std::sync::Arc; use anyhow::Context; @@ -17,21 +16,17 @@ mod check_event; mod parse; mod pull_request; -use crate::command::BrawlCommandContext; -use crate::github::models::Installation; +use crate::command::{BrawlCommand, BrawlCommandContext}; +use crate::database::DatabaseConnection; +use crate::github::models::{Installation, User}; use crate::github::repo::GitHubRepoClient; +use crate::BrawlState; -pub trait WebhookConfig: Send + Sync + 'static { +pub trait WebhookConfig: BrawlState { fn webhook_secret(&self) -> &str; fn bind_address(&self) -> Option; - fn get_repo( - &self, - installation_id: InstallationId, - repo_id: RepositoryId, - ) -> Option>; - fn add_repo( &self, installation_id: InstallationId, @@ -49,11 +44,44 @@ pub trait WebhookConfig: Send + Sync + 'static { installation: Installation, ) -> impl std::future::Future> + Send; - fn delete_installation(&self, installation_id: InstallationId) -> anyhow::Result<()>; + fn delete_installation( + &self, + installation_id: InstallationId, + ) -> impl std::future::Future> + Send; + + #[doc(hidden)] + #[cfg_attr(all(test, coverage_nightly), coverage(off))] + fn handle_command( + &self, + conn: &mut AsyncPgConnection, + command: BrawlCommand, + context: BrawlCommandContext<'_, R>, + ) -> impl std::future::Future> + Send { + async move { command.handle(conn, context).await } + } + + #[doc(hidden)] + #[cfg_attr(all(test, coverage_nightly), coverage(off))] + fn handle_pull_request( + &self, + conn: &mut AsyncPgConnection, + client: &R, + pr_number: u64, + user: User, + ) -> impl std::future::Future> + Send { + async move { pull_request::handle(client, conn, pr_number, user).await } + } - fn database( + #[doc(hidden)] + #[cfg_attr(all(test, coverage_nightly), coverage(off))] + fn handle_check_run( &self, - ) -> impl std::future::Future + Send>> + Send; + conn: &mut AsyncPgConnection, + client: &R, + check_run: serde_json::Value, + ) -> impl std::future::Future> + Send { + async move { check_event::handle(client, conn, check_run).await } + } } fn router(global: Arc) -> axum::Router { @@ -148,20 +176,22 @@ async fn handle_webhook_action(global: &impl WebhookConfig, action: WebhookEvent repo_id, user, } => { - let Some(repo_client) = global.get_repo(installation_id, repo_id) else { + let Some(repo_client) = global.get_repo(Some(installation_id), repo_id).await else { return Err(anyhow::anyhow!("repo client not found")); }; global .database() .await? + .get() .transaction(|conn| { Box::pin(async move { - command - .handle( + global + .handle_command( conn, + command, BrawlCommandContext { - repo: repo_client.as_ref(), + repo: &repo_client, user, pr_number, }, @@ -169,7 +199,8 @@ async fn handle_webhook_action(global: &impl WebhookConfig, action: WebhookEvent .await }) }) - .await?; + .await + .context("command")?; Ok(()) } @@ -179,17 +210,19 @@ async fn handle_webhook_action(global: &impl WebhookConfig, action: WebhookEvent pr_number, user, } => { - let Some(repo_client) = global.get_repo(installation_id, repo_id) else { + let Some(repo_client) = global.get_repo(Some(installation_id), repo_id).await else { return Err(anyhow::anyhow!("repo client not found")); }; global .database() .await? + .get() .transaction(|conn| { - Box::pin(async move { pull_request::handle(repo_client.as_ref(), conn, pr_number, user).await }) + Box::pin(async move { global.handle_pull_request(conn, &repo_client, pr_number, user).await }) }) - .await?; + .await + .context("pull request")?; Ok(()) } @@ -198,41 +231,502 @@ async fn handle_webhook_action(global: &impl WebhookConfig, action: WebhookEvent installation_id, repo_id, } => { - let Some(repo_client) = global.get_repo(installation_id, repo_id) else { + let Some(repo_client) = global.get_repo(Some(installation_id), repo_id).await else { return Err(anyhow::anyhow!("repo client not found")); }; global .database() .await? - .transaction(|conn| { - Box::pin(async move { check_event::handle(repo_client.as_ref(), conn, check_run).await }) - }) - .await?; + .get() + .transaction(|conn| Box::pin(async move { global.handle_check_run(conn, &repo_client, check_run).await })) + .await + .context("check run")?; Ok(()) } WebhookEventAction::DeleteInstallation { installation_id } => { - global.delete_installation(installation_id)?; + global + .delete_installation(installation_id) + .await + .context("delete installation")?; Ok(()) } WebhookEventAction::AddRepository { installation_id, repo_id, } => { - global.add_repo(installation_id, repo_id).await?; + global.add_repo(installation_id, repo_id).await.context("add repository")?; Ok(()) } WebhookEventAction::RemoveRepository { installation_id, repo_id, } => { - global.remove_repo(installation_id, repo_id).await?; + global + .remove_repo(installation_id, repo_id) + .await + .context("remove repository")?; Ok(()) } WebhookEventAction::UpdateInstallation { installation } => { - global.update_installation(installation).await?; + global + .update_installation(installation) + .await + .context("update installation")?; Ok(()) } } } + +#[cfg(test)] +#[cfg_attr(all(test, coverage_nightly), coverage(off))] +mod tests { + use std::sync::Arc; + + use octocrab::models::UserId; + use tokio::sync::{mpsc, oneshot}; + + use super::*; + use crate::database::get_test_connection; + use crate::github::merge_workflow::DefaultMergeWorkflow; + use crate::github::repo::test_utils::MockRepoClient; + use crate::github::repo::RepoClientRef; + + #[derive(Debug)] + enum Action { + Command { + command: BrawlCommand, + pr_number: u64, + tx: oneshot::Sender>, + }, + PullRequest { + pr_number: u64, + tx: oneshot::Sender>, + }, + CheckRun { + check_run: serde_json::Value, + tx: oneshot::Sender>, + }, + DeleteInstallation { + installation_id: InstallationId, + tx: oneshot::Sender>, + }, + AddRepository { + installation_id: InstallationId, + repo_id: RepositoryId, + tx: oneshot::Sender>, + }, + RemoveRepository { + installation_id: InstallationId, + repo_id: RepositoryId, + tx: oneshot::Sender>, + }, + UpdateInstallation { + installation: Installation, + tx: oneshot::Sender>, + }, + GetRepo { + installation_id: Option, + repo_id: RepositoryId, + tx: oneshot::Sender>>>, + }, + } + + struct MockWebhookConfig { + tx: mpsc::Sender, + } + + impl MockWebhookConfig { + fn new() -> (Self, mpsc::Receiver) { + let (tx, rx) = mpsc::channel(1); + (Self { tx }, rx) + } + } + + impl BrawlState for MockWebhookConfig { + async fn database(&self) -> anyhow::Result { + Ok(get_test_connection().await) + } + + async fn get_repo( + &self, + _installation_id: Option, + repo_id: RepositoryId, + ) -> Option { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Action::GetRepo { + installation_id: _installation_id, + repo_id, + tx, + }) + .await + .expect("send get repo action"); + let repo = rx.await.expect("get repo"); + + repo.map(RepoClientRef::new) + } + } + + impl WebhookConfig for MockWebhookConfig { + async fn add_repo(&self, _installation_id: InstallationId, repo_id: RepositoryId) -> anyhow::Result<()> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Action::AddRepository { + installation_id: _installation_id, + repo_id, + tx, + }) + .await + .context("send add repository action")?; + rx.await.context("add repository")? + } + + async fn remove_repo(&self, _installation_id: InstallationId, repo_id: RepositoryId) -> anyhow::Result<()> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Action::RemoveRepository { + installation_id: _installation_id, + repo_id, + tx, + }) + .await + .context("send remove repository action")?; + rx.await.context("remove repository")? + } + + async fn handle_check_run( + &self, + _conn: &mut AsyncPgConnection, + _client: &R, + check_run: serde_json::Value, + ) -> anyhow::Result<()> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Action::CheckRun { check_run, tx }) + .await + .context("send check run action")?; + rx.await.context("check run")? + } + + async fn handle_command( + &self, + _conn: &mut AsyncPgConnection, + command: BrawlCommand, + context: BrawlCommandContext<'_, R>, + ) -> anyhow::Result<()> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Action::Command { + command, + pr_number: context.pr_number, + tx, + }) + .await + .context("send command action")?; + rx.await.context("command")? + } + + async fn handle_pull_request( + &self, + _conn: &mut AsyncPgConnection, + _client: &R, + pr_number: u64, + _user: User, + ) -> anyhow::Result<()> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Action::PullRequest { pr_number, tx }) + .await + .context("send pull request action")?; + rx.await.context("pull request")? + } + + async fn update_installation(&self, installation: Installation) -> anyhow::Result<()> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Action::UpdateInstallation { installation, tx }) + .await + .context("send update installation action")?; + rx.await.context("update installation")? + } + + async fn delete_installation(&self, installation_id: InstallationId) -> anyhow::Result<()> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Action::DeleteInstallation { installation_id, tx }) + .await + .context("send delete installation action")?; + rx.await.context("delete installation")? + } + + fn bind_address(&self) -> Option { + unimplemented!("bind address") + } + + fn webhook_secret(&self) -> &str { + unimplemented!("webhook secret") + } + } + + #[tokio::test] + async fn test_handle_webhook_action_add_repository() { + let (config, mut rx) = MockWebhookConfig::new(); + + let task = tokio::spawn(async move { + handle_webhook_action( + &config, + WebhookEventAction::AddRepository { + installation_id: InstallationId(1), + repo_id: RepositoryId(1), + }, + ) + .await + .expect("handle webhook action"); + }); + + match rx.recv().await.expect("action") { + Action::AddRepository { + installation_id, + repo_id, + tx, + } => { + assert_eq!(installation_id, InstallationId(1)); + assert_eq!(repo_id, RepositoryId(1)); + tx.send(Ok(())).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + task.await.expect("task"); + } + + #[tokio::test] + async fn test_handle_webhook_action_remove_repository() { + let (config, mut rx) = MockWebhookConfig::new(); + + let task = tokio::spawn(async move { + handle_webhook_action( + &config, + WebhookEventAction::RemoveRepository { + installation_id: InstallationId(1), + repo_id: RepositoryId(1), + }, + ) + .await + .expect("handle webhook action"); + }); + + match rx.recv().await.expect("action") { + Action::RemoveRepository { + installation_id, + repo_id, + tx, + } => { + assert_eq!(installation_id, InstallationId(1)); + assert_eq!(repo_id, RepositoryId(1)); + tx.send(Ok(())).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + task.await.expect("task"); + } + + #[tokio::test] + async fn test_handle_webhook_action_delete_installation() { + let (config, mut rx) = MockWebhookConfig::new(); + + let task = tokio::spawn(async move { + handle_webhook_action( + &config, + WebhookEventAction::DeleteInstallation { + installation_id: InstallationId(1), + }, + ) + .await + .expect("handle webhook action"); + }); + + match rx.recv().await.expect("action") { + Action::DeleteInstallation { installation_id, tx } => { + assert_eq!(installation_id, InstallationId(1)); + tx.send(Ok(())).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + task.await.expect("task"); + } + + #[tokio::test] + async fn test_handle_webhook_action_update_installation() { + let (config, mut rx) = MockWebhookConfig::new(); + + let task = tokio::spawn(async move { + handle_webhook_action( + &config, + WebhookEventAction::UpdateInstallation { + installation: Installation { + id: InstallationId(1), + account: User { + id: UserId(1), + login: "test".to_string(), + }, + }, + }, + ) + .await + .expect("handle webhook action"); + }); + + match rx.recv().await.expect("action") { + Action::UpdateInstallation { installation, tx } => { + assert_eq!(installation.id, InstallationId(1)); + assert_eq!(installation.account.id, UserId(1)); + assert_eq!(installation.account.login, "test"); + tx.send(Ok(())).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + task.await.expect("task"); + } + + #[tokio::test] + async fn test_handle_webhook_action_check_run() { + let (config, mut rx) = MockWebhookConfig::new(); + + let task = tokio::spawn(async move { + handle_webhook_action( + &config, + WebhookEventAction::CheckRun { + check_run: serde_json::Value::default(), + installation_id: InstallationId(1), + repo_id: RepositoryId(1), + }, + ) + .await + .expect("handle webhook action"); + }); + + match rx.recv().await.expect("action") { + Action::GetRepo { + installation_id, + repo_id, + tx, + } => { + assert_eq!(installation_id, Some(InstallationId(1))); + assert_eq!(repo_id, RepositoryId(1)); + let (client, _) = MockRepoClient::new(DefaultMergeWorkflow); + tx.send(Some(Arc::new(client))).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + match rx.recv().await.expect("action") { + Action::CheckRun { check_run, tx } => { + assert_eq!(check_run, serde_json::Value::default()); + tx.send(Ok(())).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + task.await.expect("task"); + } + + #[tokio::test] + async fn test_handle_webhook_action_pull_request() { + let (config, mut rx) = MockWebhookConfig::new(); + + let task = tokio::spawn(async move { + handle_webhook_action( + &config, + WebhookEventAction::PullRequest { + installation_id: InstallationId(1), + repo_id: RepositoryId(1), + pr_number: 1, + user: User { + id: UserId(1), + login: "test".to_string(), + }, + }, + ) + .await + .expect("handle webhook action"); + }); + + match rx.recv().await.expect("action") { + Action::GetRepo { + installation_id, + repo_id, + tx, + } => { + assert_eq!(installation_id, Some(InstallationId(1))); + assert_eq!(repo_id, RepositoryId(1)); + let (client, _) = MockRepoClient::new(DefaultMergeWorkflow); + tx.send(Some(Arc::new(client))).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + match rx.recv().await.expect("action") { + Action::PullRequest { pr_number, tx } => { + assert_eq!(pr_number, 1); + tx.send(Ok(())).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + task.await.expect("task"); + } + + #[tokio::test] + async fn test_handle_webhook_action_command() { + let (config, mut rx) = MockWebhookConfig::new(); + + let task = tokio::spawn(async move { + handle_webhook_action( + &config, + WebhookEventAction::Command { + installation_id: InstallationId(1), + command: BrawlCommand::Cancel, + repo_id: RepositoryId(1), + pr_number: 1, + user: User { + id: UserId(1), + login: "test".to_string(), + }, + }, + ) + .await + .expect("handle webhook action"); + }); + + match rx.recv().await.expect("action") { + Action::GetRepo { + installation_id, + repo_id, + tx, + } => { + assert_eq!(installation_id, Some(InstallationId(1))); + assert_eq!(repo_id, RepositoryId(1)); + let (client, _) = MockRepoClient::new(DefaultMergeWorkflow); + tx.send(Some(Arc::new(client))).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + match rx.recv().await.expect("action") { + Action::Command { command, pr_number, tx } => { + assert_eq!(command, BrawlCommand::Cancel); + assert_eq!(pr_number, 1); + tx.send(Ok(())).expect("send result"); + } + r => panic!("unexpected action: {:?}", r), + } + + task.await.expect("task"); + } +}