Skip to content

Commit

Permalink
Auto merge of #25 - troy/bug-fix, r=<dry>
Browse files Browse the repository at this point in the history
refactor traits and implement a lock per repo
Adds a lock per repo so that only a single repo request can be processed at a time. This makes it impossible for webhooks to be run concurrently when for a single repo. Previously it was possible for webhooks to run together which sometimes caused issues if multiple events updated the same PR.

Requested-by: TroyKomodo <[email protected]>
  • Loading branch information
scuffle-brawl[bot] authored Dec 31, 2024
2 parents 8ce1c62 + 750ec92 commit beaa4de
Show file tree
Hide file tree
Showing 10 changed files with 512 additions and 155 deletions.
46 changes: 19 additions & 27 deletions server/src/auto_start.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::collections::HashMap;
use std::ops::DerefMut;
use std::sync::Arc;

use anyhow::Context;
use diesel_async::{AsyncPgConnection, RunQueryDsl};
Expand All @@ -9,20 +7,14 @@ use scuffle_context::ContextFutExt;

use crate::database::ci_run::CiRun;
use crate::database::pr::Pr;
use crate::database::DatabaseConnection;
use crate::github::merge_workflow::GitHubMergeWorkflow;
use crate::github::repo::GitHubRepoClient;
use crate::BrawlState;

pub struct AutoStartSvc;

pub trait AutoStartConfig: Send + Sync + 'static {
type RepoClient: GitHubRepoClient;

fn repo_client(&self, repo_id: RepositoryId) -> Option<Arc<Self::RepoClient>>;

fn database(
&self,
) -> impl std::future::Future<Output = anyhow::Result<impl DerefMut<Target = AsyncPgConnection> + Send>> + Send;

pub trait AutoStartConfig: BrawlState {
fn interval(&self) -> std::time::Duration;
}

Expand All @@ -36,8 +28,8 @@ impl<C: AutoStartConfig> scuffle_bootstrap::Service<C> for AutoStartSvc {

while tokio::time::sleep(global.interval()).with_context(&ctx).await.is_some() {
let mut conn = global.database().await?;
let runs = get_pending_runs(&mut conn).await?;
process_pending_runs(&mut conn, runs, global.as_ref()).await?;
let runs = get_pending_runs(conn.get()).await?;
process_pending_runs(conn.get(), runs, global.as_ref()).await?;
}

Ok(())
Expand Down Expand Up @@ -79,12 +71,12 @@ async fn process_pending_runs(
global: &impl AutoStartConfig,
) -> anyhow::Result<()> {
for run in runs {
let Some(repo_client) = global.repo_client(RepositoryId(run.github_repo_id as u64)) else {
let Some(repo_client) = global.get_repo(None, RepositoryId(run.github_repo_id as u64)).await else {
tracing::error!("no installation client found for repo {}", run.github_repo_id);
continue;
};

if let Err(e) = handle_run(conn, &run, repo_client.as_ref()).await {
if let Err(e) = handle_run(conn, &run, &repo_client).await {
tracing::error!(
"error handling run (repo id: {}, pr number: {}, run id: {}): {}",
run.github_repo_id,
Expand Down Expand Up @@ -128,15 +120,18 @@ async fn handle_run(
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use std::sync::Arc;

use chrono::Utc;
use octocrab::models::UserId;
use octocrab::models::{InstallationId, UserId};

use super::*;
use crate::database::ci_run::Base;
use crate::database::enums::GithubCiRunStatus;
use crate::database::get_test_connection;
use crate::database::{get_test_connection, DatabaseConnection};
use crate::github::models::PullRequest;
use crate::github::repo::test_utils::MockRepoClient;
use crate::github::repo::RepoClientRef;

fn run() -> CiRun<'static> {
CiRun {
Expand Down Expand Up @@ -453,26 +448,23 @@ mod tests {

struct AutoStartCfg(Arc<MockRepoClient<MockWorkflow>>);

impl AutoStartConfig for AutoStartCfg {
type RepoClient = MockRepoClient<MockWorkflow>;

fn repo_client(&self, _: RepositoryId) -> Option<Arc<Self::RepoClient>> {
Some(self.0.clone())
impl BrawlState for AutoStartCfg {
async fn get_repo(&self, _: Option<InstallationId>, _: RepositoryId) -> Option<impl GitHubRepoClient + 'static> {
Some(RepoClientRef::new(self.0.clone()))
}

fn database(
&self,
) -> impl std::future::Future<Output = anyhow::Result<impl DerefMut<Target = AsyncPgConnection> + Send>> + Send
{
fn database(&self) -> impl std::future::Future<Output = anyhow::Result<impl DatabaseConnection + Send>> + Send {
fn get_conn() -> diesel_async::pooled_connection::bb8::PooledConnection<'static, AsyncPgConnection> {
unreachable!()
}

std::future::ready(Ok(get_conn()))
}
}

impl AutoStartConfig for AutoStartCfg {
fn interval(&self) -> std::time::Duration {
todo!()
unreachable!("this method should not be called")
}
}

Expand Down
43 changes: 19 additions & 24 deletions server/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#![cfg_attr(all(coverage_nightly, test), coverage(off))]

use std::net::SocketAddr;
use std::ops::DerefMut;
use std::sync::Arc;

use anyhow::Context;
Expand All @@ -16,8 +15,9 @@ use scuffle_bootstrap_telemetry::opentelemetry_sdk::metrics::SdkMeterProvider;
use scuffle_bootstrap_telemetry::opentelemetry_sdk::Resource;
use scuffle_bootstrap_telemetry::prometheus_client::registry::Registry;
use scuffle_brawl::database::schema::health_check;
use scuffle_brawl::database::DatabaseConnection;
use scuffle_brawl::github::models::Installation;
use scuffle_brawl::github::repo::{GitHubRepoClient, RepoClient};
use scuffle_brawl::github::repo::GitHubRepoClient;
use scuffle_brawl::github::GitHubService;
use scuffle_metrics::opentelemetry::KeyValue;
use tracing_subscriber::layer::SubscriberExt;
Expand Down Expand Up @@ -171,14 +171,6 @@ impl scuffle_brawl::webhook::WebhookConfig for Global {
Some(self.config.github.webhook_bind)
}

fn get_repo(
&self,
installation_id: InstallationId,
repo_id: RepositoryId,
) -> Option<Arc<impl GitHubRepoClient + 'static>> {
self.github_service.get_client(installation_id)?.get_repo_client(repo_id)
}

async fn add_repo(&self, installation_id: InstallationId, repo_id: RepositoryId) -> anyhow::Result<()> {
let client = self
.github_service
Expand Down Expand Up @@ -206,29 +198,32 @@ impl scuffle_brawl::webhook::WebhookConfig for Global {
self.github_service.delete_installation(installation_id);
Ok(())
}

async fn database(&self) -> anyhow::Result<impl DerefMut<Target = AsyncPgConnection> + Send> {
let conn = self.database.get().await.context("get database connection")?;
Ok(conn)
}
}

impl scuffle_brawl::auto_start::AutoStartConfig for Global {
type RepoClient = RepoClient;

fn interval(&self) -> std::time::Duration {
std::time::Duration::from_secs(self.config.interval_seconds)
}
}

async fn database(&self) -> anyhow::Result<impl DerefMut<Target = AsyncPgConnection> + Send> {
let conn = self.database.get().await.context("get database connection")?;
Ok(conn)
impl scuffle_brawl::BrawlState for Global {
async fn get_repo(
&self,
installation_id: Option<InstallationId>,
repo_id: RepositoryId,
) -> Option<impl GitHubRepoClient + 'static> {
let installation = if let Some(installation_id) = installation_id {
self.github_service.get_client(installation_id)
} else {
self.github_service.get_client_by_repo(repo_id)
}?;

installation.get_repo_client(repo_id).await
}

fn repo_client(&self, repo_id: octocrab::models::RepositoryId) -> Option<Arc<RepoClient>> {
let client = self.github_service.get_client_by_repo(repo_id)?;
let repo = client.get_repo_client(repo_id)?;
Some(repo)
async fn database(&self) -> anyhow::Result<impl DatabaseConnection + Send> {
let conn = self.database.get().await.context("get database connection")?;
Ok(conn)
}
}

Expand Down
13 changes: 7 additions & 6 deletions server/src/command/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::str::FromStr;

use anyhow::Context;
use auto_try::AutoTryCommand;
use diesel_async::AsyncPgConnection;
use dry_run::DryRunCommand;
Expand Down Expand Up @@ -38,12 +39,12 @@ impl BrawlCommand {
context: BrawlCommandContext<'_, R>,
) -> anyhow::Result<()> {
match self {
BrawlCommand::DryRun(command) => dry_run::handle(conn, context, command).await,
BrawlCommand::Merge(command) => merge::handle(conn, context, command).await,
BrawlCommand::Retry => retry::handle(conn, context).await,
BrawlCommand::Cancel => cancel::handle(conn, context).await,
BrawlCommand::Ping => ping::handle(conn, context).await,
BrawlCommand::AutoTry(command) => auto_try::handle(conn, context, command).await,
BrawlCommand::DryRun(command) => dry_run::handle(conn, context, command).await.context("dry run"),
BrawlCommand::Merge(command) => merge::handle(conn, context, command).await.context("merge"),
BrawlCommand::Retry => retry::handle(conn, context).await.context("retry"),
BrawlCommand::Cancel => cancel::handle(conn, context).await.context("cancel"),
BrawlCommand::Ping => ping::handle(conn, context).await.context("ping"),
BrawlCommand::AutoTry(command) => auto_try::handle(conn, context, command).await.context("auto try"),
}
}
}
Expand Down
34 changes: 34 additions & 0 deletions server/src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ macro_rules! test_query {
};
}

use std::marker::PhantomData;

use diesel_async::AsyncPgConnection;
#[cfg(test)]
use test_query;

Expand Down Expand Up @@ -51,3 +54,34 @@ pub async fn get_test_connection() -> diesel_async::AsyncPgConnection {

conn
}

pub trait DatabaseConnection {
fn get(&mut self) -> &mut AsyncPgConnection;
}

impl DatabaseConnection for AsyncPgConnection {
fn get(&mut self) -> &mut AsyncPgConnection {
self
}
}

impl DatabaseConnection for diesel_async::pooled_connection::bb8::PooledConnection<'_, AsyncPgConnection> {
fn get(&mut self) -> &mut AsyncPgConnection {
self
}
}

pub struct DatabaseConnectionRef<T, D> {
inner: T,
_db: PhantomData<D>,
}

impl<T, D> DatabaseConnection for DatabaseConnectionRef<T, D>
where
T: AsMut<D>,
D: DatabaseConnection,
{
fn get(&mut self) -> &mut AsyncPgConnection {
self.inner.as_mut().get()
}
}
Loading

0 comments on commit beaa4de

Please sign in to comment.