diff --git a/migrations/2024-12-25-001112_label_state/down.sql b/migrations/2024-12-25-001112_label_state/down.sql new file mode 100644 index 0000000..d50fbf9 --- /dev/null +++ b/migrations/2024-12-25-001112_label_state/down.sql @@ -0,0 +1 @@ +ALTER TABLE github_pr DROP COLUMN added_label; diff --git a/migrations/2024-12-25-001112_label_state/up.sql b/migrations/2024-12-25-001112_label_state/up.sql new file mode 100644 index 0000000..eca2373 --- /dev/null +++ b/migrations/2024-12-25-001112_label_state/up.sql @@ -0,0 +1 @@ +ALTER TABLE github_pr ADD COLUMN added_label TEXT; diff --git a/migrations/schema.patch b/migrations/schema.patch index 29b668d..2e39a9f 100644 --- a/migrations/schema.patch +++ b/migrations/schema.patch @@ -1,5 +1,5 @@ ---- migrations/schema.unpatched.rs 2024-12-21 01:00:50.435153015 +0000 -+++ server/src/database/schema.rs 2024-12-21 00:55:41.850491800 +0000 +--- migrations/schema.unpatched.rs 2024-12-25 01:43:39.324295195 +0000 ++++ server/src/database/schema.rs 2024-12-25 01:43:35.592373517 +0000 @@ -1,38 +1,44 @@ +#![cfg_attr(coverage_nightly, coverage(off))] // @generated automatically by Diesel CLI. @@ -85,7 +85,7 @@ /// /// (Automatically generated by Diesel.) status -> GithubPrStatus, -@@ -320,12 +326,13 @@ +@@ -326,12 +332,13 @@ updated_at -> Timestamptz, } } diff --git a/migrations/schema.unpatched.rs b/migrations/schema.unpatched.rs index ecb201f..8b82e12 100644 --- a/migrations/schema.unpatched.rs +++ b/migrations/schema.unpatched.rs @@ -298,6 +298,12 @@ diesel::table! { /// /// (Automatically generated by Diesel.) default_priority -> Nullable, + /// The `added_label` column of the `github_pr` table. + /// + /// Its SQL type is `Nullable`. + /// + /// (Automatically generated by Diesel.) + added_label -> Nullable, } } diff --git a/server/src/command/cancel.rs b/server/src/command/cancel.rs index 4646106..bc5633e 100644 --- a/server/src/command/cancel.rs +++ b/server/src/command/cancel.rs @@ -23,7 +23,7 @@ async fn handle_with_pr( pr: PullRequest, context: BrawlCommandContext<'_, R>, ) -> anyhow::Result<()> { - Pr::new(&pr, context.user.id, context.repo.id()) + let db_pr = Pr::new(&pr, context.user.id, context.repo.id()) .upsert() .get_result(conn) .await @@ -49,7 +49,7 @@ async fn handle_with_pr( context .repo .merge_workflow() - .cancel(&run, context.repo, conn) + .cancel(&run, context.repo, conn, &db_pr) .await .context("cancel ci run")?; @@ -89,8 +89,9 @@ pub mod tests { run: &CiRun<'_>, repo: &impl GitHubRepoClient, conn: &mut AsyncPgConnection, + pr: &Pr<'_>, ) -> anyhow::Result<()> { - let _ = (run, repo, conn); + let _ = (run, repo, conn, pr); if self .cancelled .compare_exchange( diff --git a/server/src/command/dry_run.rs b/server/src/command/dry_run.rs index 347b99f..d97130c 100644 --- a/server/src/command/dry_run.rs +++ b/server/src/command/dry_run.rs @@ -80,6 +80,12 @@ async fn handle_with_pr( let branch = context.repo.config().try_branch(pr.number); + let db_pr = Pr::new(&pr, context.user.id, context.repo.id()) + .upsert() + .get_result(conn) + .await + .context("update pr")?; + if let Some(run) = CiRun::active(context.repo.id(), pr.number) .get_result(conn) .await @@ -90,7 +96,7 @@ async fn handle_with_pr( context .repo .merge_workflow() - .cancel(&run, context.repo, conn) + .cancel(&run, context.repo, conn, &db_pr) .await .context("cancel ci run")?; } else { @@ -115,12 +121,6 @@ async fn handle_with_pr( } } - let db_pr = Pr::new(&pr, context.user.id, context.repo.id()) - .upsert() - .get_result(conn) - .await - .context("update pr")?; - let run = CiRun::insert(context.repo.id(), pr.number) .base_ref(base) .head_commit_sha(command.head_sha.as_deref().unwrap_or_else(|| pr.head.sha.as_ref()).into()) @@ -199,6 +199,7 @@ mod tests { run: &CiRun<'_>, _: &impl GitHubRepoClient, conn: &mut AsyncPgConnection, + _: &Pr<'_>, ) -> anyhow::Result<()> { if self .cancelled diff --git a/server/src/command/merge.rs b/server/src/command/merge.rs index 7c056b1..9a547ee 100644 --- a/server/src/command/merge.rs +++ b/server/src/command/merge.rs @@ -104,7 +104,7 @@ async fn handle_with_pr( .get_result(conn) .await?; - context.repo.merge_workflow().queued(&run, context.repo).await?; + context.repo.merge_workflow().queued(&run, context.repo, conn, &db_pr).await?; Ok(()) } @@ -131,7 +131,13 @@ mod tests { } impl GitHubMergeWorkflow for MockMergeWorkFlow { - async fn queued(&self, _: &CiRun<'_>, _: &impl GitHubRepoClient) -> anyhow::Result<()> { + async fn queued( + &self, + _: &CiRun<'_>, + _: &impl GitHubRepoClient, + _: &mut AsyncPgConnection, + _: &Pr<'_>, + ) -> anyhow::Result<()> { self.queued .compare_exchange( false, diff --git a/server/src/command/retry.rs b/server/src/command/retry.rs index 87182b0..fd5ff9a 100644 --- a/server/src/command/retry.rs +++ b/server/src/command/retry.rs @@ -83,7 +83,7 @@ async fn handle_with_pr( if run.is_dry_run { context.repo.merge_workflow().start(&run, context.repo, conn, &db_pr).await?; } else { - context.repo.merge_workflow().queued(&run, context.repo).await?; + context.repo.merge_workflow().queued(&run, context.repo, conn, &db_pr).await?; } Ok(()) @@ -132,7 +132,13 @@ mod tests { Ok(true) } - async fn queued(&self, _: &CiRun<'_>, _: &impl GitHubRepoClient) -> anyhow::Result<()> { + async fn queued( + &self, + _: &CiRun<'_>, + _: &impl GitHubRepoClient, + _: &mut AsyncPgConnection, + _: &Pr<'_>, + ) -> anyhow::Result<()> { self.queued .compare_exchange( false, diff --git a/server/src/database/pr.rs b/server/src/database/pr.rs index 7fe8894..f860a89 100644 --- a/server/src/database/pr.rs +++ b/server/src/database/pr.rs @@ -30,6 +30,7 @@ pub struct Pr<'a> { pub target_branch: Cow<'a, str>, pub source_branch: Cow<'a, str>, pub latest_commit_sha: Cow<'a, str>, + pub added_label: Option>, pub created_at: chrono::DateTime, pub updated_at: chrono::DateTime, } @@ -42,6 +43,7 @@ pub struct UpdatePr<'a> { pub github_repo_id: i64, #[builder(start_fn)] pub github_pr_number: i32, + pub added_label: Option>>, pub title: Option>, pub body: Option>, pub merge_status: Option, @@ -110,6 +112,7 @@ impl<'a> Pr<'a> { target_branch: Cow::Borrowed(&pr.base.ref_field), source_branch: Cow::Borrowed(&pr.head.ref_field), latest_commit_sha: Cow::Borrowed(&pr.head.sha), + added_label: None, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), } @@ -138,6 +141,7 @@ impl<'a> Pr<'a> { title: Some(Cow::Borrowed(self.title.as_ref())), body: Some(Cow::Borrowed(self.body.as_ref())), merge_status: Some(self.merge_status), + added_label: Some(self.added_label.as_deref().map(Cow::Borrowed)), assigned_ids: Some(self.assigned_ids.clone()), status: Some(self.status), default_priority: Some(self.default_priority), @@ -236,6 +240,7 @@ mod tests { "github_pr"."target_branch", "github_pr"."source_branch", "github_pr"."latest_commit_sha", + "github_pr"."added_label", "github_pr"."created_at", "github_pr"."updated_at" FROM @@ -263,6 +268,7 @@ mod tests { latest_commit_sha: Cow::Borrowed("test"), source_branch: Cow::Borrowed("test"), target_branch: Cow::Borrowed("test"), + added_label: None, merge_status: GithubPrMergeStatus::NotReady, status: GithubPrStatus::Open, merge_commit_sha: None, @@ -283,6 +289,7 @@ mod tests { "target_branch", "source_branch", "latest_commit_sha", + "added_label", "created_at", "updated_at" ) @@ -301,6 +308,7 @@ mod tests { $9, $10, $11, + DEFAULT, $12, $13 ) -- binds: [1, 1, "test", "test", NotReady, 0, [], Open, "test", "test", "test", 2024-06-20T02:40:00Z, 2024-06-20T02:40:00Z] @@ -360,6 +368,7 @@ mod tests { updated_at: chrono::DateTime::from_timestamp_nanos(1718851200000000000), assigned_ids: vec![], author_id: 0, + added_label: None, default_priority: None, latest_commit_sha: Cow::Borrowed("test"), source_branch: Cow::Borrowed("test"), @@ -384,6 +393,7 @@ mod tests { "target_branch", "source_branch", "latest_commit_sha", + "added_label", "created_at", "updated_at" ) @@ -402,21 +412,23 @@ mod tests { $9, $10, $11, + DEFAULT, $12, $13 ) ON CONFLICT ("github_repo_id", "github_pr_number") DO UPDATE SET - "title" = $14, - "body" = $15, - "merge_status" = $16, - "assigned_ids" = $17, - "status" = $18, - "default_priority" = $19, - "merge_commit_sha" = $20, - "target_branch" = $21, - "latest_commit_sha" = $22, - "updated_at" = $23 + "added_label" = $14, + "title" = $15, + "body" = $16, + "merge_status" = $17, + "assigned_ids" = $18, + "status" = $19, + "default_priority" = $20, + "merge_commit_sha" = $21, + "target_branch" = $22, + "latest_commit_sha" = $23, + "updated_at" = $24 RETURNING "github_pr"."github_repo_id", "github_pr"."github_pr_number", @@ -431,8 +443,9 @@ mod tests { "github_pr"."target_branch", "github_pr"."source_branch", "github_pr"."latest_commit_sha", + "github_pr"."added_label", "github_pr"."created_at", - "github_pr"."updated_at" -- binds: [1, 1, "test", "test", NotReady, 0, [], Open, "test", "test", "test", 2024-06-20T02:40:00Z, 2024-06-20T02:40:00Z, "test", "test", NotReady, [], Open, None, None, "test", "test", 2024-06-20T02:40:00Z] + "github_pr"."updated_at" -- binds: [1, 1, "test", "test", NotReady, 0, [], Open, "test", "test", "test", 2024-06-20T02:40:00Z, 2024-06-20T02:40:00Z, None, "test", "test", NotReady, [], Open, None, None, "test", "test", 2024-06-20T02:40:00Z] "#, } @@ -516,6 +529,7 @@ mod tests { title: Cow::Borrowed("test"), assigned_ids: vec![], author_id: 0, + added_label: None, default_priority: None, merge_status: GithubPrMergeStatus::NotReady, status: GithubPrStatus::Open, @@ -547,6 +561,7 @@ mod tests { assigned_ids: vec![], author_id: 0, default_priority: None, + added_label: None, merge_status: GithubPrMergeStatus::NotReady, status: GithubPrStatus::Open, merge_commit_sha: None, @@ -584,6 +599,7 @@ mod tests { title: Cow::Borrowed("test"), assigned_ids: vec![], author_id: 0, + added_label: None, default_priority: None, merge_status: GithubPrMergeStatus::NotReady, status: GithubPrStatus::Open, diff --git a/server/src/database/schema.rs b/server/src/database/schema.rs index 7a72f8c..5666cc4 100644 --- a/server/src/database/schema.rs +++ b/server/src/database/schema.rs @@ -304,6 +304,12 @@ diesel::table! { /// /// (Automatically generated by Diesel.) default_priority -> Nullable, + /// The `added_label` column of the `github_pr` table. + /// + /// Its SQL type is `Nullable`. + /// + /// (Automatically generated by Diesel.) + added_label -> Nullable, } } diff --git a/server/src/github/label_state.rs b/server/src/github/label_state.rs new file mode 100644 index 0000000..6f1f1ad --- /dev/null +++ b/server/src/github/label_state.rs @@ -0,0 +1,99 @@ +use std::borrow::Cow; + +use diesel_async::{AsyncPgConnection, RunQueryDsl}; + +use super::config::GitHubBrawlLabelsConfig; +use super::repo::GitHubRepoClient; +use crate::database::enums::GithubCiRunStatus; +use crate::database::pr::Pr; + +fn desired_labels<'a>(status: GithubCiRunStatus, is_dry_run: bool, config: &'a GitHubBrawlLabelsConfig) -> Option<&'a str> { + match (status, is_dry_run) { + (GithubCiRunStatus::Queued, false) => { + if let Some(on_merge_queued) = config.on_merge_queued.as_deref() { + return Some(on_merge_queued); + } + } + (GithubCiRunStatus::InProgress, true) => { + if let Some(on_try_in_progress) = config.on_try_in_progress.as_deref() { + return Some(on_try_in_progress); + } + } + (GithubCiRunStatus::InProgress, false) => { + if let Some(on_merge_in_progress) = config.on_merge_in_progress.as_deref() { + return Some(on_merge_in_progress); + } + } + (GithubCiRunStatus::Failure, true) => { + if let Some(on_try_failure) = config.on_try_failure.as_deref() { + return Some(on_try_failure); + } + } + (GithubCiRunStatus::Failure, false) => { + if let Some(on_merge_failure) = config.on_merge_failure.as_deref() { + return Some(on_merge_failure); + } + } + (GithubCiRunStatus::Success, false) => { + if let Some(on_merge_success) = config.on_merge_success.as_deref() { + return Some(on_merge_success); + } + } + _ => {} + } + + None +} + +pub async fn update_labels( + conn: &mut AsyncPgConnection, + pr: &Pr<'_>, + status: GithubCiRunStatus, + is_dry_run: bool, + repo_client: &impl GitHubRepoClient, +) -> anyhow::Result<()> { + let config = repo_client.config(); + let desired_label = desired_labels(status, is_dry_run, &config.labels); + + if desired_label == pr.added_label.as_deref() { + return Ok(()); + } + + if let Some(desired_label) = desired_label { + if let Err(e) = repo_client + .add_labels(pr.github_pr_number as u64, &[desired_label.to_string()]) + .await + { + tracing::error!( + repo_id = pr.github_repo_id, + owner = repo_client.owner(), + repo_name = repo_client.name(), + pr_number = pr.github_pr_number, + "failed to add label: {:#}", + e + ) + } + } + + if let Some(added_label) = pr.added_label.as_deref() { + if let Err(e) = repo_client.remove_label(pr.github_pr_number as u64, added_label).await { + tracing::error!( + repo_id = pr.github_repo_id, + owner = repo_client.owner(), + repo_name = repo_client.name(), + pr_number = pr.github_pr_number, + "failed to remove label: {:#}", + e + ) + } + } + + pr.update() + .added_label(desired_label.map(Cow::Borrowed)) + .build() + .query() + .execute(conn) + .await?; + + Ok(()) +} diff --git a/server/src/github/merge_workflow.rs b/server/src/github/merge_workflow.rs index 0b56655..e2fa68d 100644 --- a/server/src/github/merge_workflow.rs +++ b/server/src/github/merge_workflow.rs @@ -13,6 +13,7 @@ use crate::database::ci_run::{Base, CiRun}; use crate::database::ci_run_check::CiCheck; use crate::database::enums::{GithubCiRunStatus, GithubCiRunStatusCheckStatus}; use crate::database::pr::Pr; +use crate::github::label_state::update_labels; use crate::github::repo::MergeResult; use crate::utils::format_fn; @@ -96,8 +97,9 @@ pub trait GitHubMergeWorkflow: Send + Sync { run: &CiRun<'_>, repo: &impl GitHubRepoClient, conn: &mut AsyncPgConnection, + pr: &Pr<'_>, ) -> impl std::future::Future> + Send { - let _ = (run, repo, conn); + let _ = (run, repo, conn, pr); unimplemented!("GitHubMergeWorkflow::cancel"); #[allow(unreachable_code)] std::future::pending() @@ -108,8 +110,10 @@ pub trait GitHubMergeWorkflow: Send + Sync { &self, run: &CiRun<'_>, repo: &impl GitHubRepoClient, + conn: &mut AsyncPgConnection, + pr: &Pr<'_>, ) -> impl std::future::Future> + Send { - let _ = (run, repo); + let _ = (run, repo, conn, pr); unimplemented!("GitHubMergeWorkflow::queued"); #[allow(unreachable_code)] std::future::pending() @@ -148,8 +152,9 @@ impl GitHubMergeWorkflow for &T { run: &CiRun<'_>, repo: &impl GitHubRepoClient, conn: &mut AsyncPgConnection, + pr: &Pr<'_>, ) -> impl std::future::Future> + Send { - (*self).cancel(run, repo, conn) + (*self).cancel(run, repo, conn, pr) } #[cfg_attr(all(test, coverage_nightly), coverage(off))] @@ -157,8 +162,10 @@ impl GitHubMergeWorkflow for &T { &self, run: &CiRun<'_>, repo: &impl GitHubRepoClient, + conn: &mut AsyncPgConnection, + pr: &Pr<'_>, ) -> impl std::future::Future> + Send { - (*self).queued(run, repo) + (*self).queued(run, repo, conn, pr) } #[cfg_attr(all(test, coverage_nightly), coverage(off))] @@ -179,9 +186,9 @@ impl DefaultMergeWorkflow { async fn fail( &self, run: &CiRun<'_>, + pr: &Pr<'_>, repo: &impl GitHubRepoClient, conn: &mut AsyncPgConnection, - pr_number: i32, message: impl Borrow, ) -> anyhow::Result<()> { if CiRun::update(run.id) @@ -197,7 +204,7 @@ impl DefaultMergeWorkflow { anyhow::bail!("run already completed"); } - repo.send_message(pr_number as u64, message.borrow()) + repo.send_message(pr.github_pr_number as u64, message.borrow()) .await .context("send message")?; @@ -214,6 +221,8 @@ impl DefaultMergeWorkflow { repo.delete_branch(&run.ci_branch).await?; + update_labels(conn, pr, GithubCiRunStatus::Failure, run.is_dry_run, repo).await?; + Ok(()) } @@ -225,8 +234,6 @@ impl DefaultMergeWorkflow { pr: &Pr<'_>, checks: &[&CiCheck<'_>], ) -> anyhow::Result<()> { - println!("success"); - if CiRun::update(run.id) .status(GithubCiRunStatus::Success) .completed_at(chrono::Utc::now()) @@ -292,9 +299,9 @@ impl DefaultMergeWorkflow { Err(e) => { self.fail( run, + pr, repo, conn, - run.github_pr_number, &messages::error( format_fn(|f| writeln!(f, "Tests passed but failed to push to branch {}", pr.target_branch)), e, @@ -332,6 +339,8 @@ impl DefaultMergeWorkflow { } } + update_labels(conn, pr, GithubCiRunStatus::Success, run.is_dry_run, repo).await?; + Ok(()) } } @@ -374,9 +383,9 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { if check.status_check_status == GithubCiRunStatusCheckStatus::Failure { self.fail( run, + pr, repo, conn, - run.github_pr_number, &messages::tests_failed(&check.status_check_name, &check.url), ) .await?; @@ -396,9 +405,9 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { { self.fail( run, + pr, repo, conn, - run.github_pr_number, &messages::tests_timeout( repo.config().timeout_minutes, format_fn(|f| { @@ -437,9 +446,9 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { let Some(branch) = repo.get_ref_latest_commit(&Reference::Branch(branch.to_string())).await? else { self.fail( run, + pr, repo, conn, - run.github_pr_number, &messages::error( format_fn(|f| { writeln!(f, "Failed to find branch {branch}") @@ -500,9 +509,9 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { Ok(MergeResult::Conflict) => { self.fail( run, + pr, repo, conn, - run.github_pr_number, messages::merge_conflict( &pr.source_branch, &pr.target_branch, @@ -517,9 +526,9 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { Err(e) => { self.fail( run, + pr, repo, conn, - run.github_pr_number, &messages::error(format_fn(|f| write!(f, "Failed to create merge commit")), e), ) .await?; @@ -541,9 +550,9 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { Err(e) => { self.fail( run, + pr, repo, conn, - run.github_pr_number, &messages::error( format_fn(|f| write!(f, "Failed to push branch {branch}", branch = run.ci_branch)), e, @@ -572,6 +581,8 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { "ci run started", ); + update_labels(conn, pr, GithubCiRunStatus::InProgress, run.is_dry_run, repo).await?; + Ok(true) } @@ -580,6 +591,7 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { run: &CiRun<'_>, repo: &impl GitHubRepoClient, conn: &mut AsyncPgConnection, + pr: &Pr<'_>, ) -> anyhow::Result<()> { if CiRun::update(run.id) .status(GithubCiRunStatus::Cancelled) @@ -651,10 +663,18 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { repo.delete_branch(&run.ci_branch).await?; } + update_labels(conn, pr, GithubCiRunStatus::Cancelled, run.is_dry_run, repo).await?; + Ok(()) } - async fn queued(&self, run: &CiRun<'_>, repo: &impl GitHubRepoClient) -> anyhow::Result<()> { + async fn queued( + &self, + run: &CiRun<'_>, + repo: &impl GitHubRepoClient, + conn: &mut AsyncPgConnection, + pr: &Pr<'_>, + ) -> anyhow::Result<()> { let requested_by = repo.get_user(UserId(run.requested_by_id as u64)).await?; repo.send_message( @@ -670,6 +690,8 @@ impl GitHubMergeWorkflow for DefaultMergeWorkflow { ) .await?; + update_labels(conn, pr, GithubCiRunStatus::Queued, run.is_dry_run, repo).await?; + Ok(()) } } @@ -810,6 +832,7 @@ mod tests { latest_commit_sha: "latest_commit_sha".into(), created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), + added_label: None, }; diesel::insert_into(crate::database::schema::github_pr::table) @@ -2191,7 +2214,7 @@ mod tests { #[tokio::test] async fn test_ci_queued() { - let (_, client, _, run, mut rx) = ci_run_test_boilerplate( + let (mut conn, client, pr, run, mut rx) = ci_run_test_boilerplate( InsertCiRun::builder(1, 1) .base_ref(Base::Commit(Cow::Borrowed("sha"))) .head_commit_sha(Cow::Borrowed("head_commit_sha")) @@ -2204,7 +2227,7 @@ mod tests { .await; let task = tokio::spawn(async move { - DefaultMergeWorkflow.queued(&run, &client).await.unwrap(); + DefaultMergeWorkflow.queued(&run, &client, &mut conn, &pr).await.unwrap(); }); match rx.recv().await.unwrap() { @@ -2257,7 +2280,7 @@ mod tests { #[tokio::test] async fn test_ci_run_cancel_not_started() { - let (mut conn, client, _, run, _) = ci_run_test_boilerplate( + let (mut conn, client, pr, run, _) = ci_run_test_boilerplate( InsertCiRun::builder(1, 1) .base_ref(Base::Commit(Cow::Borrowed("sha"))) .head_commit_sha(Cow::Borrowed("head_commit_sha")) @@ -2270,7 +2293,7 @@ mod tests { .await; let task = tokio::spawn(async move { - client.merge_workflow().cancel(&run, &client, &mut conn).await.unwrap(); + client.merge_workflow().cancel(&run, &client, &mut conn, &pr).await.unwrap(); conn }); @@ -2286,7 +2309,7 @@ mod tests { #[tokio::test] async fn test_ci_run_cancel_started() { - let (mut conn, client, _, run, mut rx) = ci_run_test_boilerplate( + let (mut conn, client, pr, run, mut rx) = ci_run_test_boilerplate( InsertCiRun::builder(1, 1) .base_ref(Base::Commit(Cow::Borrowed("sha"))) .head_commit_sha(Cow::Borrowed("head_commit_sha")) @@ -2311,7 +2334,7 @@ mod tests { let run = CiRun::latest(RepositoryId(1), 1).get_result(&mut conn).await.unwrap(); let task = tokio::spawn(async move { - client.merge_workflow().cancel(&run, &client, &mut conn).await.unwrap(); + client.merge_workflow().cancel(&run, &client, &mut conn, &pr).await.unwrap(); conn }); diff --git a/server/src/github/mod.rs b/server/src/github/mod.rs index fdb8639..a702c7a 100644 --- a/server/src/github/mod.rs +++ b/server/src/github/mod.rs @@ -9,6 +9,7 @@ use octocrab::Octocrab; pub mod config; pub mod installation; +pub mod label_state; pub mod merge_workflow; pub mod messages; pub mod models; diff --git a/server/src/github/repo.rs b/server/src/github/repo.rs index 233ed61..e98b479 100644 --- a/server/src/github/repo.rs +++ b/server/src/github/repo.rs @@ -15,7 +15,7 @@ use super::config::{GitHubBrawlRepoConfig, Permission, Role}; use super::installation::UserCache; use super::merge_workflow::{DefaultMergeWorkflow, GitHubMergeWorkflow}; use super::messages::{CommitMessage, IssueMessage}; -use super::models::{Commit, PullRequest, Repository, Review, User}; +use super::models::{Commit, Label, PullRequest, Repository, Review, User}; pub struct RepoClient { pub(super) repo: ArcSwap, @@ -127,6 +127,20 @@ pub trait GitHubRepoClient: Send + Sync { /// Delete a branch from the repository fn delete_branch(&self, branch: &str) -> impl std::future::Future> + Send; + /// Add labels to a pull request + fn add_labels( + &self, + issue_number: u64, + labels: &[String], + ) -> impl std::future::Future>> + Send; + + /// Remove a label from a pull request + fn remove_label( + &self, + issue_number: u64, + labels: &str, + ) -> impl std::future::Future>> + Send; + /// Check if a user has a permission fn has_permission( &self, @@ -353,6 +367,24 @@ impl GitHubRepoClient for RepoClient { } } + async fn add_labels(&self, issue_number: u64, labels: &[String]) -> anyhow::Result> { + self.client + .issues_by_id(self.id()) + .add_labels(issue_number, labels) + .await + .context("add labels") + .map(|labels| labels.into_iter().map(Label::from).collect()) + } + + async fn remove_label(&self, issue_number: u64, label: &str) -> anyhow::Result> { + self.client + .issues_by_id(self.id()) + .remove_label(issue_number, label) + .await + .context("remove label") + .map(|labels| labels.into_iter().map(Label::from).collect()) + } + async fn get_commit(&self, sha: &str) -> anyhow::Result> { match self .client @@ -552,6 +584,16 @@ pub mod test_utils { permissions: Vec, result: tokio::sync::oneshot::Sender>, }, + AddLabels { + issue_number: u64, + labels: Vec, + result: tokio::sync::oneshot::Sender>>, + }, + RemoveLabel { + issue_number: u64, + label: String, + result: tokio::sync::oneshot::Sender>>, + }, } impl GitHubRepoClient for MockRepoClient { @@ -724,6 +766,32 @@ pub mod test_utils { .expect("send has permission"); rx.await.expect("recv has permission") } + + async fn add_labels(&self, issue_number: u64, labels: &[String]) -> anyhow::Result> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.actions + .send(MockRepoAction::AddLabels { + issue_number, + labels: labels.to_vec(), + result: tx, + }) + .await + .expect("send add labels"); + rx.await.expect("recv add labels") + } + + async fn remove_label(&self, issue_number: u64, label: &str) -> anyhow::Result> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.actions + .send(MockRepoAction::RemoveLabel { + issue_number, + label: label.to_owned(), + result: tx, + }) + .await + .expect("send remove label"); + rx.await.expect("recv remove label") + } } } diff --git a/server/src/webhook/pull_request.rs b/server/src/webhook/pull_request.rs index 7705198..1269a86 100644 --- a/server/src/webhook/pull_request.rs +++ b/server/src/webhook/pull_request.rs @@ -45,7 +45,7 @@ pub async fn handle_with_pr( match run { Some(run) if !run.is_dry_run => { - repo.merge_workflow().cancel(&run, repo, conn).await?; + repo.merge_workflow().cancel(&run, repo, conn, ¤t).await?; repo.send_message( run.github_pr_number as u64, &messages::error_no_body(format!( @@ -99,6 +99,7 @@ mod tests { run: &CiRun<'_>, _: &impl GitHubRepoClient, conn: &mut AsyncPgConnection, + _: &Pr<'_>, ) -> anyhow::Result<()> { self.cancel.store(true, std::sync::atomic::Ordering::Relaxed);