Skip to content

Commit

Permalink
Use tx for all db mutations
Browse files Browse the repository at this point in the history
  • Loading branch information
tarkah committed Nov 26, 2024
1 parent 88bb2aa commit bf96268
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 77 deletions.
7 changes: 5 additions & 2 deletions crates/avalanche/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn build(request: api::Request<api::v1::avalanche::Build>, context: Contex
.sub
.parse::<endpoint::Id>()
.map_err(Error::InvalidEndpoint)?;
let endpoint = Endpoint::get(&context.state.db, endpoint_id)
let endpoint = Endpoint::get(context.state.db.acquire().await?.as_mut(), endpoint_id)
.await
.map_err(Error::LoadEndpoint)?;

Expand Down Expand Up @@ -84,14 +84,17 @@ pub enum Error {
/// Failed to load endpoint from DB
#[error("load endpoint")]
LoadEndpoint(#[source] database::Error),
/// Database error
#[error("database")]
Database(#[from] database::Error),
}

impl From<&Error> for http::StatusCode {
fn from(error: &Error) -> Self {
match error {
Error::MissingRequestToken => http::StatusCode::UNAUTHORIZED,
Error::MissingRemotes | Error::InvalidEndpoint(_) => http::StatusCode::BAD_REQUEST,
Error::LoadEndpoint(_) => http::StatusCode::INTERNAL_SERVER_ERROR,
Error::LoadEndpoint(_) | Error::Database(_) => http::StatusCode::INTERNAL_SERVER_ERROR,
Error::BuildInProgress => http::StatusCode::SERVICE_UNAVAILABLE,
}
}
Expand Down
52 changes: 33 additions & 19 deletions crates/service/src/account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ impl Account {
}

/// Get the account for [`Id`] from the provided [`Database`]
pub async fn get(db: &Database, id: Id) -> Result<Self, Error> {
pub async fn get<'a, T>(conn: &'a mut T, id: Id) -> Result<Self, Error>
where
&'a mut T: database::Executor<'a>,
{
let account: Account = sqlx::query_as(
"
SELECT
Expand All @@ -74,18 +77,21 @@ impl Account {
",
)
.bind(id.0)
.fetch_one(&db.pool)
.fetch_one(conn)
.await?;

Ok(account)
}

/// Lookup an account using `username` and `publickey` from the provided [`Database`]
pub async fn lookup_with_credentials(
db: &Database,
pub async fn lookup_with_credentials<'a, T>(
conn: &'a mut T,
username: &str,
public_key: &EncodedPublicKey,
) -> Result<Self, Error> {
) -> Result<Self, Error>
where
&'a mut T: database::Executor<'a>,
{
let account: Account = sqlx::query_as(
"
SELECT
Expand All @@ -104,14 +110,14 @@ impl Account {
)
.bind(username)
.bind(public_key.to_string())
.fetch_one(&db.pool)
.fetch_one(conn)
.await?;

Ok(account)
}

/// Create / update this account to the provided [`Database`]
pub async fn save<'c>(&self, conn: impl sqlx::Executor<'c, Database = sqlx::Sqlite>) -> Result<(), Error> {
pub async fn save<'a>(&self, tx: &mut database::Transaction<'a>) -> Result<(), Error> {
sqlx::query(
"
INSERT INTO account
Expand All @@ -138,7 +144,7 @@ impl Account {
.bind(&self.email)
.bind(&self.name)
.bind(self.public_key.to_string())
.execute(conn)
.execute(tx.as_mut())
.await?;

Ok(())
Expand Down Expand Up @@ -197,7 +203,12 @@ pub struct Token {

impl Token {
/// Set the account's bearer token & expiration
pub async fn set(db: &Database, id: Id, encoded: impl ToString, expiration: DateTime<Utc>) -> Result<(), Error> {
pub async fn set<'a>(
tx: &mut database::Transaction<'a>,
id: Id,
encoded: impl ToString,
expiration: DateTime<Utc>,
) -> Result<(), Error> {
sqlx::query(
"
INSERT INTO account_token
Expand All @@ -215,14 +226,17 @@ impl Token {
.bind(id.0)
.bind(encoded.to_string())
.bind(expiration)
.execute(&db.pool)
.execute(tx.as_mut())
.await?;

Ok(())
}

/// Get the account token for [`Id`] from the provided [`Database`]
pub async fn get(db: &Database, id: Id) -> Result<Token, Error> {
pub async fn get<'a, T>(conn: &'a mut T, id: Id) -> Result<Token, Error>
where
&'a mut T: database::Executor<'a>,
{
let token: Token = sqlx::query_as(
"
SELECT
Expand All @@ -233,7 +247,7 @@ impl Token {
",
)
.bind(id.0)
.fetch_one(&db.pool)
.fetch_one(conn)
.await?;

Ok(token)
Expand Down Expand Up @@ -262,6 +276,8 @@ pub struct Admin {
)
)]
pub(crate) async fn sync_admin(db: &Database, admin: Admin) -> Result<(), Error> {
let mut tx = db.begin().await?;

let account: Option<Id> = sqlx::query_as(
"
SELECT
Expand All @@ -279,22 +295,20 @@ pub(crate) async fn sync_admin(db: &Database, admin: Admin) -> Result<(), Error>
.bind(&admin.name)
.bind(&admin.email)
.bind(admin.public_key.to_string())
.fetch_optional(&db.pool)
.fetch_optional(tx.as_mut())
.await?;

if account.is_some() {
return Ok(());
}

let mut transaction = db.transaction().await?;

sqlx::query(
"
DELETE FROM account
WHERE type = 'admin';
",
)
.execute(transaction.as_mut())
.execute(tx.as_mut())
.await?;

Account {
Expand All @@ -305,10 +319,10 @@ pub(crate) async fn sync_admin(db: &Database, admin: Admin) -> Result<(), Error>
email: Some(admin.email.clone()),
public_key: admin.public_key.clone(),
}
.save(transaction.as_mut())
.save(&mut tx)
.await?;

transaction.commit().await?;
tx.commit().await?;

debug!("Admin account synced");

Expand All @@ -325,6 +339,6 @@ pub enum Error {

impl From<sqlx::Error> for Error {
fn from(error: sqlx::Error) -> Self {
Error::Database(database::Error::from(error))
Error::Database(database::Error::Execute(error))
}
}
36 changes: 25 additions & 11 deletions crates/service/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ pub struct EndpointAuth {

impl EndpointAuth {
async fn verified_tokens(&self, public_key: &PublicKey) -> Result<Tokens, EndpointAuthError> {
let tokens = endpoint::Tokens::get(&self.db, self.endpoint).await?;
let tokens = endpoint::Tokens::get(self.db.acquire().await?.as_mut(), self.endpoint).await?;

Ok(Tokens {
bearer_token: tokens
Expand All @@ -323,8 +323,10 @@ impl AuthStorage for EndpointAuth {
const REFRESH_ENABLED: bool = true;

async fn tokens(&self) -> Result<Tokens, EndpointAuthError> {
let endpoint = Endpoint::get(&self.db, self.endpoint).await?;
let account = Account::get(&self.db, endpoint.account).await?;
let mut conn = self.db.acquire().await?;

let endpoint = Endpoint::get(conn.as_mut(), self.endpoint).await?;
let account = Account::get(conn.as_mut(), endpoint.account).await?;

let public_key = account.public_key.decoded()?;

Expand All @@ -339,8 +341,10 @@ impl AuthStorage for EndpointAuth {
)
)]
async fn token_refreshed(&self, purpose: token::Purpose, token: &str) -> Result<Tokens, Self::Error> {
let mut endpoint = Endpoint::get(&self.db, self.endpoint).await?;
let account = Account::get(&self.db, endpoint.account).await?;
let mut tx = self.db.begin().await?;

let mut endpoint = Endpoint::get(tx.as_mut(), self.endpoint).await?;
let account = Account::get(tx.as_mut(), endpoint.account).await?;

let public_key = account.public_key.decoded()?;

Expand All @@ -360,9 +364,11 @@ impl AuthStorage for EndpointAuth {
bearer_token: tokens.bearer_token.as_ref().map(|token| token.encoded.clone()),
access_token: tokens.access_token.as_ref().map(|token| token.encoded.clone()),
}
.save(&self.db, self.endpoint)
.save(&mut tx, self.endpoint)
.await?;
endpoint.save(&self.db).await?;
endpoint.save(&mut tx).await?;

tx.commit().await?;

info!("Token refreshed, endpoint operational");

Expand All @@ -374,7 +380,9 @@ impl AuthStorage for EndpointAuth {

error!("Invalid signature");

endpoint.save(&self.db).await?;
endpoint.save(&mut tx).await?;

tx.commit().await?;

Err(EndpointAuthError::InvalidRefreshToken)
}
Expand All @@ -384,7 +392,9 @@ impl AuthStorage for EndpointAuth {

error!("Invalid token");

endpoint.save(&self.db).await?;
endpoint.save(&mut tx).await?;

tx.commit().await?;

Err(EndpointAuthError::InvalidRefreshToken)
}
Expand All @@ -399,7 +409,9 @@ impl AuthStorage for EndpointAuth {
)
)]
async fn token_refresh_failed(&self, purpose: token::Purpose, error: &reqwest::Error) -> Result<(), Self::Error> {
let mut endpoint = Endpoint::get(&self.db, self.endpoint).await?;
let mut tx = self.db.begin().await?;

let mut endpoint = Endpoint::get(tx.as_mut(), self.endpoint).await?;

endpoint.status = endpoint::Status::Unreachable;

Expand All @@ -414,7 +426,9 @@ impl AuthStorage for EndpointAuth {
}
}

endpoint.save(&self.db).await?;
endpoint.save(&mut tx).await?;

tx.commit().await?;

Ok(())
}
Expand Down
65 changes: 53 additions & 12 deletions crates/service/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
use std::path::Path;

use sqlx::{sqlite::SqliteConnectOptions, Pool, Sqlite, Transaction};
use sqlx::{pool::PoolConnection, sqlite::SqliteConnectOptions, Pool, Sqlite, SqliteConnection};
use thiserror::Error;

/// Service database
#[derive(Debug, Clone)]
pub struct Database {
/// Connection pool to the underlying SQLITE database
pub pool: Pool<Sqlite>,
pool: Pool<Sqlite>,
}

impl Database {
Expand All @@ -25,26 +25,67 @@ impl Database {
}

async fn connect(options: SqliteConnectOptions) -> Result<Self, Error> {
let pool = sqlx::SqlitePool::connect_with(options).await?;
let pool = sqlx::SqlitePool::connect_with(options).await.map_err(Error::Connect)?;

sqlx::migrate!("src/database/migrations").run(&pool).await?;
sqlx::migrate!("src/database/migrations")
.run(&pool)
.await
.map_err(Error::Migrate)?;

Ok(Self { pool })
}

/// Acquire a database connection
pub async fn acquire(&self) -> Result<PoolConnection<Sqlite>, Error> {
self.pool.acquire().await.map_err(Error::Acquire)
}

/// Begin a database transaction
pub async fn transaction(&self) -> Result<Transaction<'static, Sqlite>, Error> {
Ok(self.pool.begin().await?)
pub async fn begin(&self) -> Result<Transaction, Error> {
Ok(Transaction(self.pool.begin().await.map_err(Error::Commit)?))
}
}

/// A database transaction
pub struct Transaction<'a>(sqlx::Transaction<'a, Sqlite>);

impl<'a> Transaction<'a> {
/// Commit the transaction
pub async fn commit(self) -> Result<(), Error> {
self.0.commit().await.map_err(Error::Commit)
}
}

impl<'a> AsMut<SqliteConnection> for Transaction<'a> {
fn as_mut(&mut self) -> &mut SqliteConnection {
self.0.as_mut()
}
}

/// Provides a database connection for executing queries
pub trait Executor<'a>: sqlx::Executor<'a, Database = Sqlite> {}

impl<'a, T> Executor<'a> for &'a mut T where &'a mut T: sqlx::Executor<'a, Database = Sqlite> {}

/// A database error
#[derive(Debug, Error)]
pub enum Error {
/// Sqlx error
#[error("sqlx")]
Sqlx(#[from] sqlx::Error),
/// Sqlx migration error
#[error("sqlx migration")]
Migrate(#[from] sqlx::migrate::MigrateError),
/// Failed to connect
#[error("failed to connect")]
Connect(#[source] sqlx::Error),
/// Migrations failed
#[error("migrations failed")]
Migrate(#[source] sqlx::migrate::MigrateError),
/// Acquire connection
#[error("acquire connection")]
Acquire(#[source] sqlx::Error),
/// Begin transaction
#[error("begin transaction")]
Begin(#[source] sqlx::Error),
/// Commit transaction
#[error("commit transaction")]
Commit(#[source] sqlx::Error),
/// Execute query
#[error("execute query")]
Execute(#[source] sqlx::Error),
}
Loading

0 comments on commit bf96268

Please sign in to comment.