Skip to content

Commit

Permalink
Implement signing in
Browse files Browse the repository at this point in the history
  • Loading branch information
GrantGryczan committed Jan 1, 2025
1 parent 6b07ea7 commit 599848f
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 9 deletions.
29 changes: 29 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ strum_macros = "0.26"
thiserror = "2"
tokio = { version = "1", features = ["full"] }
tower = { version = "0.5", features = ["util"] }
tower-cookies = { version = "0.10" }

[lints]
# Last updated for Clippy version: 1.83
Expand Down Expand Up @@ -108,7 +109,6 @@ clippy.manual_instant_elapsed = "warn"
clippy.manual_is_variant_and = "warn"
clippy.manual_let_else = "warn"
clippy.manual_ok_or = "warn"
clippy.map_unwrap_or = "warn"
clippy.match_bool = "warn"
clippy.mixed_read_write_in_expression = "warn"
clippy.mod_module_files = "warn"
Expand Down
5 changes: 5 additions & 0 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ pub enum Error {
/// The requested API route doesn't exist.
#[error("The requested API route doesn't exist.")]
RouteNotFound,

/// Credentials specified in the request (such as email and password) don't match any user.
#[error("The specified user credentials are incorrect.")]
UserCredentialsWrong,
}

impl Error {
Expand All @@ -84,6 +88,7 @@ impl Error {
Self::JsonSyntax(_) => StatusCode::BAD_REQUEST,
Self::ResourceNotFound => StatusCode::NOT_FOUND,
Self::RouteNotFound => StatusCode::NOT_FOUND,
Self::UserCredentialsWrong => StatusCode::FORBIDDEN,
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/api/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use axum::{
routing::{get, post},
Router,
};
use tower_cookies::CookieManagerLayer;

use crate::api;

Expand All @@ -14,6 +15,7 @@ pub mod v1 {
pub mod email_verification;
pub mod password_reset;
pub mod sessions;
pub mod users;
}

Expand All @@ -36,6 +38,8 @@ pub(super) static ROUTER: LazyLock<Router> = LazyLock::new(|| {
"/api/v1/password-reset/password",
post(v1::password_reset::password::post),
)
.route("/api/v1/sessions", post(v1::sessions::post))
.route("/api/v1/users", post(v1::users::post))
.fallback(|| async { api::Error::RouteNotFound })
.layer(CookieManagerLayer::new())
});
4 changes: 2 additions & 2 deletions src/api/routes/v1/password_reset/password.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use axum_macros::debug_handler;
use serde::{Deserialize, Serialize};

use crate::{
api::{self, validation::UserPassword, Json, Query, Response},
api::{self, validation::NewUserPassword, Json, Query, Response},
crypto::{hash_with_salt, hash_without_salt},
db::{self, TxError, TxResult},
id::Token,
Expand All @@ -24,7 +24,7 @@ pub struct PostQuery {
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub struct PostRequest {
/// The user's new password in plain text.
pub password: UserPassword,
pub password: NewUserPassword,
}

/// Sets a new password to fulfill a user's password reset request.
Expand Down
134 changes: 134 additions & 0 deletions src/api/routes/v1/sessions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//! The set of users' sign-in sessions.
use std::sync::LazyLock;

use axum::http::StatusCode;
use axum_macros::debug_handler;
use serde::{Deserialize, Serialize};
use sqlx::Acquire;
use tower_cookies::{
cookie::{time::Duration, SameSite},
Cookie, Cookies,
};

use crate::{
api::{
self,
validation::{UserEmail, UserPassword},
Json, Response,
},
crypto::{hash_without_salt, verify_hash},
db::{self, TxResult},
id::Token,
WEBSITE_ORIGIN,
};

/// The domain for the website.
static WEBSITE_DOMAIN: LazyLock<&str> = LazyLock::new(|| domain_from_origin(&WEBSITE_ORIGIN));

/// How long a session takes to expire after its creation.
const SESSION_MAX_AGE: Duration = Duration::days(60);

/// A `POST` request body for this API route.
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub struct PostRequest {
/// The email address of the user signing in.
pub email: UserEmail,

/// The user's password in plain text.
pub password: UserPassword,
}

/// Signs a user in, creating a sign-in session and returning a session cookie.
///
/// # Errors
///
/// See [`crate::api::Error`].
#[debug_handler]
pub async fn post(cookies: Cookies, Json(body): Json<PostRequest>) -> Response<PostResponse> {
let token = db::transaction!(async |tx| -> TxResult<_, api::Error> {
let Some(user) = sqlx::query!(
"SELECT id, password_hash FROM users
WHERE email = $1",
body.email.as_str(),
)
.fetch_optional(tx.as_mut())
.await?
.filter(|user| verify_hash(&body.password, &user.password_hash)) else {
// To prevent user enumeration, send this same error response whether or not the email
// is correct.
return Err(db::TxError::Abort(api::Error::UserCredentialsWrong));
};

let mut token = Token::generate()?;

loop {
// If this loop's query fails from a token conflict, this savepoint is rolled back to
// rather than aborting the entire transaction.
let mut savepoint = tx.begin().await?;

let token_hash = hash_without_salt(&token);

match sqlx::query!(
"INSERT INTO sessions (token_hash, user_id)
VALUES ($1, $2)",
token_hash.as_ref(),
user.id,
)
.execute(savepoint.as_mut())
.await
{
Err(sqlx::Error::Database(error))
if error.constraint() == Some("sessions_pkey") =>
{
token.reroll()?;
continue;
}
result => result?,
};

savepoint.commit().await?;
break;
}

Ok(token)
})
.await?;

cookies.add(
Cookie::build(("token", token.to_string()))
.domain(*WEBSITE_DOMAIN)
.http_only(true)
.max_age(SESSION_MAX_AGE)
.path("/")
.same_site(SameSite::Lax)
.secure(WEBSITE_ORIGIN.starts_with("https:"))
.into(),
);

Ok((StatusCode::OK, Json(PostResponse {})))
}

/// A `POST` response body for this API route.
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct PostResponse {
// To reduce the session token's attack surface, it isn't included in the response. It's set as
// an `HttpOnly` cookie instead so browser scripts can't access it.
}

/// Returns the domain from an origin URI string.
///
/// # Panics
///
/// Panics if the origin string doesn't contain "//".
fn domain_from_origin(origin: &str) -> &str {
let start = origin.find("//").expect("origin should contain \"//\"") + 2;
let end = origin[start..]
.find(":")
.map(|index| index + start)
.unwrap_or(origin.len());

&origin[start..end]
}
6 changes: 3 additions & 3 deletions src/api/routes/v1/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use sqlx::Acquire;
use crate::{
api::{
self,
validation::{EmailVerificationCode, UserEmail, UserName, UserPassword},
validation::{EmailVerificationCode, NewUserPassword, UserEmail, UserName},
Json, Response,
},
crypto::{hash_with_salt, verify_hash},
Expand All @@ -29,8 +29,8 @@ pub struct PostRequest {
/// The user's name.
pub name: UserName,

/// The user's password in plain text.
pub password: UserPassword,
/// The user's new password in plain text.
pub password: NewUserPassword,
}

/// Creates a new user.
Expand Down
5 changes: 4 additions & 1 deletion src/api/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ use thiserror::Error;
/// A user's name.
pub type UserName = BoundedString<1, 64>;

/// A user's new password in plain text.
pub type NewUserPassword = BoundedString<8, 256>;

/// A user's password in plain text.
pub type UserPassword = BoundedString<8, 256>;
pub type UserPassword = BoundedString<0, 256>;

/// An unverified email's verification code in plain text.
pub type EmailVerificationCode = BoundedString<6, 6>;
Expand Down
4 changes: 2 additions & 2 deletions src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub(super) async fn handle(request: Request) -> Response {
///
/// Panics if the origin string doesn't contain "//".
fn host_from_origin(origin: &str) -> &str {
let host_index = origin.find("//").expect("origin should contain \"//\"") + 2;
let start = origin.find("//").expect("origin should contain \"//\"") + 2;

&origin[host_index..]
&origin[start..]
}

0 comments on commit 599848f

Please sign in to comment.