diff --git a/docs/chapter2/section1/0_setup.md b/docs/chapter2/section1/0_setup.md index e1343094..a94e239b 100644 --- a/docs/chapter2/section1/0_setup.md +++ b/docs/chapter2/section1/0_setup.md @@ -23,7 +23,7 @@ DB_DATABASE="world" 4. 以下のコマンドを実行し、クレートの依存関係を追加しましょう。 ```sh -$ cargo add axum anyhow serde serde_json tokio bcrypt --features tokio/full,serde/derive,axum/macros +$ cargo add axum axum-extra anyhow serde serde_json tokio bcrypt --features tokio/full,serde/derive,axum/macros,axum-extra/typed-header $ cargo add async-session tracing tracing-subscriber --features tracing-subscriber/env-filter,tracing-subscriber/fmt $ cargo add tower-http --features add-extension,trace,fs ``` diff --git a/docs/chapter2/section1/1_account.md b/docs/chapter2/section1/1_account.md index 97c8b362..0f91725b 100644 --- a/docs/chapter2/section1/1_account.md +++ b/docs/chapter2/section1/1_account.md @@ -374,6 +374,8 @@ pub async fn sign_up( ::: code-group <<<@/chapter2/section1/src/1_account/auth.rs{rs:line-numbers}[auth.rs] <<<@/chapter2/section1/src/1_account/users.rs{rs:line-numbers}[users.rs] +<<<@/chapter2/section1/src/1_account/main.rs{rs:line-numbers}[main.rs] +<<<@/chapter2/section1/src/1_account/repository.rs{rs:line-numbers}[repository.rs] ::: 最後に、`handler.rs` に、先ほど書いたハンドラーを追加しましょう。 @@ -413,6 +415,7 @@ $ task db ```sql mysql> USE world; mysql> SELECT * FROM users; +mysql> SELECT * FROM user_passwords; ``` ![](images/3/database1-user.png) diff --git a/docs/chapter2/section1/2_session.md b/docs/chapter2/section1/2_session.md index 46bab42d..b132f724 100644 --- a/docs/chapter2/section1/2_session.md +++ b/docs/chapter2/section1/2_session.md @@ -1,277 +1,528 @@ # セッション管理機構の実装 ## セッションストアを設定する -`main.go`に以下を追加しましょう。 -```go -func main() { - (省略) - // usersテーブルが存在しなかったら、usersテーブルを作成する - _, err = db.Exec("CREATE TABLE IF NOT EXISTS users (Username VARCHAR(255) PRIMARY KEY, HashedPass VARCHAR(255))") - if err != nil { - log.Fatal(err) - } - - // セッションの情報を記憶するための場所をデータベース上に設定 // [!code ++] - store, err := mysqlstore.NewMySQLStoreFromConnection(db.DB, "sessions", "/", 60*60*24*14, []byte("secret-token")) // [!code ++] - if err != nil { // [!code ++] - log.Fatal(err) // [!code ++] - } // [!code ++] - - h := handler.NewHandler(db) - e := echo.New() - e.Use(middleware.Logger()) // ログを取るミドルウェアを追加 // [!code ++] - e.Use(session.Middleware(store)) // セッション管理のためのミドルウェアを追加 // [!code ++] - - e.POST("/signup", h.SignUpHandler) - (省略) +`repository.rs`に以下を追加しましょう。 +```rs +use async_sqlx_session::MySqlSessionStore; // [!code ++] +use sqlx::mysql::MySqlConnectOptions; +use sqlx::mysql::MySqlPool; +use std::env; + +pub mod country; +pub mod users; + +#[derive(Clone)] +pub struct Repository { + pool: MySqlPool, + session_store: MySqlSessionStore, // [!code ++] } -``` -これらはセッションストアの設定です。 -最初に、セッションの情報を記憶するための場所をデータベース上に設定します。 +impl Repository { + pub async fn connect() -> anyhow::Result { + let options = get_options()?; + let pool = sqlx::MySqlPool::connect_with(options).await?; -この仕組みを使用するために、 `e.Use(session.Middleware(store))` を含めてセッションストアを使ってね〜、って echo に命令しています。 + let session_store = // [!code ++] + MySqlSessionStore::from_client(pool.clone()).with_table_name("user_sessions"); // [!code ++] -`e.Use(middleware.Logger())` は文字通りログを取るものです。ついでに入れましょう。 + Ok(Self { + pool, + session_store, // [!code ++] + }) + } -:::tip -`"secret-token"`は、暗号化/復号化の際に使われる秘密鍵です。 -実際に運用するときはこの"secret-token"を独自の値にしてください。環境変数などで管理するのが良いでしょう。 -::: + pub async fn migrate(&self) -> anyhow::Result<()> { + sqlx::migrate!("./migrations").run(&self.pool).await?; + Ok(()) + } +} +...(省略) +``` -## LoginHandler の実装 -続いて、`LoginHandler` を `handler.go` に実装していきましょう。 +これらはセッションストアの設定です。 +セッションの情報を記憶するための場所をデータベース上に設定して、`session_store` からアクセスできるようにしています。 + +## `login` ハンドラの実装 +続いて、`login` ハンドラを `handler/auth.rs` に実装していきましょう。 -```go -func (h *Handler) LoginHandler(c echo.Context) error { // [!code ++] +```rs +pub async fn login( // [!code ++] + State(state): State, // [!code ++] + Json(body): Json, // [!code ++] +) -> Result { // [!code ++] } // [!code ++] ``` -`LoginHandler` の外に以下の構造体を追加します。 -```go -type User struct { // [!code ++] - Username string `json:"username,omitempty" db:"Username"` // [!code ++] - HashedPass string `json:"-" db:"HashedPass"` // [!code ++] + +`login` ハンドラの外に以下の構造体を追加します。 +```rs +#[derive(Deserialize)] // [!code ++] +pub struct Login { // [!code ++] + pub username: String, // [!code ++] + pub password: String, // [!code ++] } // [!code ++] ``` -`LoginHandler` を実装していきます。 -```go -func (h *Handler) LoginHandler(c echo.Context) error { - // リクエストを受け取り、reqに格納する // [!code ++] - var req LoginRequestBody // [!code ++] - err := c.Bind(&req) // [!code ++] - if err != nil { // [!code ++] - return c.String(http.StatusBadRequest, "bad request body") // [!code ++] - } // [!code ++] - - // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) // [!code ++] - if req.Password == "" || req.Username == "" { // [!code ++] - return c.String(http.StatusBadRequest, "Username or Password is empty") // [!code ++] - } // [!code ++] - - // データベースからユーザーを取得する // [!code ++] - user := User{} // [!code ++] - err = h.db.Get(&user, "SELECT * FROM users WHERE username=?", req.Username) // [!code ++] - if err != nil { // [!code ++] - if errors.Is(err, sql.ErrNoRows) { // [!code ++] - return c.NoContent(http.StatusUnauthorized) // [!code ++] - } else { // [!code ++] - log.Println(err) // [!code ++] - return c.NoContent(http.StatusInternalServerError) // [!code ++] - } // [!code ++] - } // [!code ++] + +`login` ハンドラの中身を実装する前に、必要になるデータベース操作のメソッドを追加します。ここで必要になるのは以下の 2 つです。 + +- `username` から `id` を取得するメソッド +- `id` と `password` の組が登録されているものと一致するかを確認するメソッド + +この 2 つを `repository/users.rs` に追加します。 +```rs +use super::Repository; + +impl Repository { + pub async fn is_exist_username(&self, username: String) -> sqlx::Result { + ...(省略) + } + + pub async fn create_user(&self, username: String) -> sqlx::Result { + ...(省略) + } + + pub async fn get_user_id_by_name(&self, username: String) -> sqlx::Result { // [!code ++] + let result = sqlx::query_scalar("SELECT id FROM users WHERE username = ?") // [!code ++] + .bind(&username) // [!code ++] + .fetch_one(&self.pool) // [!code ++] + .await?; // [!code ++] + Ok(result) // [!code ++] + } // [!code ++] + + pub async fn save_user_password(&self, id: i32, password: String) -> anyhow::Result<()> { + ...(省略) + } + + pub async fn verify_user_password(&self, id: u64, password: String) -> anyhow::Result { // [!code ++] + let hash = // [!code ++] + sqlx::query_scalar::<_, String>("SELECT hashed_pass FROM user_passwords WHERE id = ?") // [!code ++] + .bind(id) // [!code ++] + .fetch_one(&self.pool) // [!code ++] + .await?; // [!code ++] + + Ok(bcrypt::verify(password, &hash)?) // [!code ++] + } // [!code ++] } ``` -req への代入は signUpHandler と同じです。UserName と Password が入っているかも確認しましょう。 -パスワードの一致チェックをするために、データベースからユーザーを取得してきましょう。 +データベースに保存されているパスワードはハッシュ化されています。 -ユーザーが存在しなかった場合は `sql.ErrNoRows` というエラーが返ってきます。 +ハッシュ化は不可逆な処理なので、ハッシュ化されたものから原文を調べることはできません。確認する際はもらったパスワードをハッシュ化することで行います。 +`bcrypt::verify` によってパスワードの検証ができます。 + +`handler/auth.rs` に戻り、`login` ハンドラを実装していきます。 + +```rs +pub async fn login( // [!code ++] + State(state): State, // [!code ++] + Json(body): Json, // [!code ++] +) -> Result { // [!code ++] + // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) // [!code ++] + if body.username.is_empty() || body.password.is_empty() { // [!code ++] + return Err(StatusCode::BAD_REQUEST); // [!code ++] + } // [!code ++] + + // データベースからユーザーを取得する // [!code ++] + let id = state // [!code ++] + .get_user_id_by_name(body.username.clone()) // [!code ++] + .await // [!code ++] + .map_err(|e| match e { // [!code ++] + sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED, // [!code ++] + _ => StatusCode::INTERNAL_SERVER_ERROR, // [!code ++] + })?; // [!code ++] +} // [!code ++] +``` + +ユーザーが存在しなかった場合は `sqlx::Error::RowNotFound` というエラーが返ってきます。 もしそのエラーなら 401 (Unauthorized)、そうでなければ 500 (Internal Server Error) です。 もし 404 (Not Found) とすると、「このユーザーはパスワードが違うのではなく存在しないんだ」という事がわかってしまい(このユーザーは存在していてパスワードは違う事も分かります)、セキュリティ上のリスクに繋がります。 -:::tip -ここで、エラーチェックは基本的に errors.Is を使いましょう。 -参考: -::: -```go -func (h *Handler) LoginHandler(c echo.Context) error { - (省略) - // データベースからユーザーを取得する - user := User{} - err = h.db.Get(&user, "SELECT * FROM users WHERE username=?", req.Username) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return c.NoContent(http.StatusUnauthorized) - } else { - log.Println(err) - return c.NoContent(http.StatusInternalServerError) - } - } - // パスワードが一致しているかを確かめる // [!code ++] - err = bcrypt.CompareHashAndPassword([]byte(user.HashedPass), []byte(req.Password)) // [!code ++] - if err != nil { // [!code ++] - if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { // [!code ++] - return c.NoContent(http.StatusUnauthorized) // [!code ++] - } else { // [!code ++] - return c.NoContent(http.StatusInternalServerError) // [!code ++] - } // [!code ++] - } // [!code ++] +```rs +pub async fn login( + State(state): State, + Json(body): Json, +) -> Result { + // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) + if body.username.is_empty() || body.password.is_empty() { + return Err(StatusCode::BAD_REQUEST); + } + + // データベースからユーザーを取得する + let id = state + .get_user_id_by_name(body.username.clone()) + .await + .map_err(|e| match e { + sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + })?; + + // パスワードが一致しているかを確かめる // [!code ++] + if !state // [!code ++] + .verify_user_password(id, body.password.clone()) // [!code ++] + .await // [!code ++] + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? // [!code ++] + { // [!code ++] + return Err(StatusCode::UNAUTHORIZED); // [!code ++] + } // [!code ++] } ``` -データベースに保存されているパスワードはハッシュ化されています。 +データベースでエラーが起きた場合や、検証の操作に失敗した場合は 500 (Internal Server Error), パスワードが間違っていた場合 401 (Unauthorized) を返却しています。 + +```rs +pub async fn login( + State(state): State, + Json(body): Json, +) -> Result { + ...(省略) + + // パスワードが一致しているかを確かめる + if !state + .verify_user_password(id, body.password.clone()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + { + return Err(StatusCode::UNAUTHORIZED); + } + + // セッションストアに登録する // [!code ++] + let session_id = state // [!code ++] + .create_user_session(id.to_string()) // [!code ++] + .await // [!code ++] + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; // [!code ++] + +} +``` -ハッシュ化は不可逆な処理なので、ハッシュ化されたものから原文を調べることはできません。確認する際はもらったパスワードをハッシュ化することで行います。 +`id` をセッションストアに登録して、セッション id を取得します。 -これは `bcrypt.CompareHashAndPassword` が行ってくれるのでそれに乗っかりましょう。 - -- この関数はハッシュが一致すれば返り値が `nil` となります -- 一致しない場合、 `bcrypt.ErrMismatchedHashAndPassword` が返ってきます -- 処理中にこれ以外の問題が発生した場合は、返り値はエラー型の何かです - -従って、これらのエラーの内容に応じて、 500 (Internal Server Error), 401 (Unauthorized) を返却するか、処理を続行するかを選択していきます。 -```go -func (h *Handler) LoginHandler(c echo.Context) error { - (省略) - // パスワードが一致しているかを確かめる - err = bcrypt.CompareHashAndPassword([]byte(user.HashedPass), []byte(req.Password)) - if err != nil { - if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { - return c.NoContent(http.StatusUnauthorized) - } else { - return c.NoContent(http.StatusInternalServerError) - } - } - - // セッションストアに登録する // [!code ++] - sess, err := session.Get("sessions", c) // [!code ++] - if err != nil { // [!code ++] - log.Println(err) // [!code ++] - return c.String(http.StatusInternalServerError, "something wrong in getting session") // [!code ++] - } // [!code ++] - sess.Values["userName"] = req.Username // [!code ++] - sess.Save(c.Request(), c.Response()) // [!code ++] - - return c.NoContent(http.StatusOK) // [!code ++] +ここで用いる、セッションストアに登録するメソッド `create_user_session` を実装していきます。 + +ファイル `repository/users_session.rs` を作成し、以下を記述してください。 + +```rs +use anyhow::Context; // [!code ++] +use async_session::{Session, SessionStore}; // [!code ++] + +use super::Repository; // [!code ++] + +impl Repository { // [!code ++] + pub async fn create_user_session(&self, user_id: String) -> anyhow::Result { // [!code ++] + let mut session = Session::new(); // [!code ++] + + session // [!code ++] + .insert("user_id", user_id) // [!code ++] + .with_context(|| "Failed to insert user_id")?; // [!code ++] + + let session_id = self // [!code ++] + .session_store // [!code ++] + .store_session(session) // [!code ++] + .await // [!code ++] + .with_context(|| "Failed to store session")? // [!code ++] + .with_context(|| "Failed to create session")?; // [!code ++] + + Ok(session_id) // [!code ++] + } // [!code ++] +} // [!code ++] +``` + +セッションに `user_id` を登録し、セッションストアに保存します。 +セッション id を返却します。 + +`handler/auth.rs` に戻り、ヘッダーにセッション id を設定する処理を追加します。 + +```rs +pub async fn login( + State(state): State, + Json(body): Json, +) -> Result { + ...(省略) + + // セッションストアに登録する + let session_id = state + .create_user_session(id.to_string()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // クッキーをセットする // [!code ++] + let mut headers = header::HeaderMap::new(); // [!code ++] + + headers.insert( // [!code ++] + header::SET_COOKIE, // [!code ++] + format!("session_id={}; HttpOnly; SameSite=Strict", session_id) // [!code ++] + .parse() // [!code ++] + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, // [!code ++] + ); // [!code ++] + + Ok((StatusCode::OK, headers)) // [!code ++] } ``` -セッションストアに登録します。 -セッションの `userName` という値にそのユーザーの名前を格納していることは覚えておきましょう。 -ここまで書いたら、 `LoginHandler` を使えるようにしましょう。 +ここまで書いたら、 `login` ハンドラを使えるようにしましょう。 +`handler.rs` に以下を追加してください。 -```go -func main() { - (省略) - e.Use(session.Middleware(store)) // セッション管理のためのミドルウェアを追加 +```rs +pub fn make_router(app_state: Repository) -> Router { + let city_router = Router::new() + .route("/cities/:city_name", get(country::get_city_handler)) + .route("/cities", post(country::post_city_handler)); - e.POST("/signup", h.SignUpHandler) - e.POST("/login", h.LoginHandler) // [!code ++] + let auth_router = Router::new() + .route("/signup", post(auth::sign_up)) + .route("/login", post(auth::login)); // [!code ++] - e.GET("/cities/:cityName", h.GetCityInfoHandler) - (省略) + Router::new() + .nest("/", city_router) + .nest("/", auth_router) + .with_state(app_state) } ``` :::details ここまでの全体像 ::: code-group -<<<@/chapter2/section1/src/2_session/main.go{go:line-numbers}[main.go] -<<<@/chapter2/section1/src/2_session/handler.go{go:line-numbers}[handler.go] +<<<@/chapter2/section1/src/2_session/auth.rs{rs:line-numbers}[auth.rs] +<<<@/chapter2/section1/src/2_session/users.rs{rs:line-numbers}[users.rs] +<<<@/chapter2/section1/src/2_session/users_session.rs{rs:line-numbers}[users_session.rs] +<<<@/chapter2/section1/src/2_session/repository.rs{rs:line-numbers}[repository.rs] ::: -## userAuthMiddleware の実装 +## Middleware の実装 -続いて、`userAuthMiddleware` を実装します。 +続いて、`auth_middleware` を実装します。 まず、これは Handler ではなく Middleware と呼ばれます。 送られてくるリクエストは、Middleware を経由して、 Handler に流れていきます。 -Middleware から次の Middleware/Handler を呼び出す際は `next(c)` と記述します。 Middleware の実装は難しいので、なんとなく理解できれば十分です。 - -以下を`handler.go`に追加しましょう。 -```go -func UserAuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { // [!code ++] - return func(c echo.Context) error { // [!code ++] - sess, err := session.Get("sessions", c) // [!code ++] - if err != nil { // [!code ++] - log.Println(err) // [!code ++] - return c.String(http.StatusInternalServerError, "something wrong in getting session") // [!code ++] - } // [!code ++] - if sess.Values["userName"] == nil { // [!code ++] - return c.String(http.StatusUnauthorized, "please login") // [!code ++] - } // [!code ++] - c.Set("userName", sess.Values["userName"].(string)) // [!code ++] - return next(c) // [!code ++] - } // [!code ++] +Middleware から次の Middleware/Handler を呼び出す際は `next.run(req)` と記述します。 + +以下を`handler/auth.rs`に追加してください。 + +```rs +pub async fn auth_middleware( // [!code ++] + State(state): State, // [!code ++] + TypedHeader(cookie): TypedHeader, // [!code ++] + mut req: Request, // [!code ++] + next: Next, // [!code ++] +) -> Result { // [!code ++] + + // セッションIDを取得する // [!code ++] + let session_id = cookie // [!code ++] + .get("session_id") // [!code ++] + .ok_or(StatusCode::UNAUTHORIZED)? // [!code ++] + .to_string(); // [!code ++] + + // セッションストアからユーザーIDを取得する // [!code ++] + let user_id = state // [!code ++] + .get_user_id_by_session_id(&session_id) // [!code ++] + .await // [!code ++] + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? // [!code ++] + .ok_or(StatusCode::UNAUTHORIZED)?; // [!code ++] + + // リクエストにユーザーIDを追加する // [!code ++] + req.extensions_mut().insert(user_id); // [!code ++] + + // 次のミドルウェアを呼び出す // [!code ++] + Ok(next.run(req).await) // [!code ++] } // [!code ++] ``` -関数が関数を呼び出していて混乱しそうですが、 2 行目から 13 行目が本質で、外側はおまじないと考えて良いです。 - この Middleware はリクエストを送ったユーザーがログインしているのかをチェックし、 -ログインしているなら Context (`c`) にそのユーザーの UserName を設定します。 +ログインしているならリクエスト(`req`) に `user_id` を追加します。 + +Cookie からセッション id を取得し、セッションストアからユーザー id を取得します。 +ここで、セッション id がなかった場合や、不正なセッション id だった場合は 401 (Unauthorized) を返却します。 +正しくログインされていれば、次の Middleware/Handler を呼び出します。 + +ここで使用した、 `get_user_id_by_session_id` メソッドを `repository/users_session.rs` に追加します。 + +```rs +pub async fn get_user_id_by_session_id( // [!code ++] + &self, // [!code ++] + session_id: &String, // [!code ++] +) -> anyhow::Result> { // [!code ++] + let session = self // [!code ++] + .session_store // [!code ++] + .load_session(session_id.clone()) // [!code ++] + .await // [!code ++] + .with_context(|| "Failed to load session")?; // [!code ++] + + Ok(session.and_then(|s| s.get::("user_id"))) // [!code ++] +} // [!code ++] +``` -セッションを取得し、ログイン時に設定した `userName` の値を確認しに行きます。 +最後に、Middleware を設定しましょう。 +`handler.rs` に以下を追加してください。 -ここで名前が入っていればリクエストの送信者はログイン済みで、そうでなければログインをしていないことが分かります。 +```rs +use axum::{ + middleware::from_fn_with_state, // [!code ++] + routing::{get, post}, + Router, +}; -これを利用して、ログインしていない場合には処理をここで止めて 401 (Unauthorized) を返却し、していれば次の処理 (`next(c)`) -に進みます。 +use crate::repository::Repository; -最後に、Middleware を設定しましょう。 -グループ機能を利用して、 `withAuth` に設定されてるエンドポイントは `userAuthMiddleware` を処理してから処理する、という設定をします。 - -```go -func main() { - (省略) - e.POST("/login", h.LoginHandler) - - e.GET("/cities/:cityName", h.GetCityInfoHandler) // [!code --] - e.POST("/cities", h.PostCityHandler) // [!code --] - withAuth := e.Group("") // [!code ++] - withAuth.Use(handler.UserAuthMiddleware) // [!code ++] - withAuth.GET("/cities/:cityName", h.GetCityInfoHandler) // [!code ++] - withAuth.POST("/cities", h.PostCityHandler) // [!code ++] - - err = e.Start(":8080") - (省略) +mod auth; +mod country; + +pub fn make_router(app_state: Repository) -> Router { + let city_router = Router::new() + .route("/cities/:city_name", get(country::get_city_handler)) + .route("/cities", post(country::post_city_handler)); + .route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware)); // [!code ++] + + let auth_router = Router::new() + .route("/signup", post(auth::sign_up)) + .route("/login", post(auth::login)) + .route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware)); // [!code ++] + + ...(省略) } ``` これで、この章の目標である「ログインしないと利用できないようにする」が達成されました。 -## GetMeHandler の実装 +## logout ハンドラの実装 + +ログアウト機能をまだ実装していなかったので、 `logout` ハンドラを実装していきます。 + +まず、`handler/auth.rs` に以下を追加してください。 + +```rs +pub async fn logout( // [!code ++] + State(state): State, // [!code ++] + TypedHeader(cookie): TypedHeader, // [!code ++] +) -> Result { // [!code ++] + // セッションIDを取得する // [!code ++] + let session_id = cookie // [!code ++] + .get("session_id") // [!code ++] + .ok_or(StatusCode::UNAUTHORIZED)? // [!code ++] + .to_string(); // [!code ++] + + // セッションストアから削除する // [!code ++] + state // [!code ++] + .delete_user_session(session_id) // [!code ++] + .await // [!code ++] + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; // [!code ++] + + // クッキーを削除する // [!code ++] + let mut headers = header::HeaderMap::new(); // [!code ++] + headers.insert( // [!code ++] + header::SET_COOKIE, // [!code ++] + "session_id=; HttpOnly; SameSite=Strict; Max-Age=0" // [!code ++] + .parse() // [!code ++] + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, // [!code ++] + ); // [!code ++] + + Ok((StatusCode::OK, headers)) // [!code ++] +} // [!code ++] +``` + +ログアウトするときは、ログインするときとは逆にセッションと Cookie を削除します。 -最後に、 `GetMeHandler` を実装します。叩いたときに自分の情報が返ってくるエンドポイントです。 +ここで呼び出す `delete_user_session` メソッドを `repository/users_session.rs` に追加します。 -以下を `handler.go` に追加しましょう。 -```go -type Me struct { // [!code ++] - Username string `json:"username,omitempty" db:"username"` // [!code ++] +```rs +pub async fn delete_user_session(&self, session_id: String) -> anyhow::Result<()> { // [!code ++] + let session = self // [!code ++] + .session_store // [!code ++] + .load_session(session_id.clone()) // [!code ++] + .await // [!code ++] + .with_context(|| "Failed to load session")? // [!code ++] + .with_context(|| "Failed to find session")?; // [!code ++] + + self.session_store // [!code ++] + .destroy_session(session) // [!code ++] + .await // [!code ++] + .with_context(|| "Failed to destroy session")?; // [!code ++] + + Ok(()) // [!code ++] } // [!code ++] ``` -```go -func GetMeHandler(c echo.Context) error { // [!code ++] - return c.JSON(http.StatusOK, Me{ // [!code ++] - Username: c.Get("userName").(string), // [!code ++] - }) // [!code ++] + +セッション ID からセッションを取得し、セッションストアから削除します。 + +最後に、`handler.rs` に `logout` ハンドラを追加します。 + +```rs +let auth_router = Router::new() + .route("/signup", post(auth::sign_up)) + .route("/login", post(auth::login)) + .route("/logout", post(auth::logout)) // [!code ++] + .route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware)); +``` + + +## me ハンドラの実装 + +最後に、 `me` ハンドラを実装します。叩いたときに自分の情報が返ってくるエンドポイントです。 + +以下を `handler/auth.rs` に追加してください。 + +```rs +#[derive(Serialize)] // [!code ++] +pub struct Me { // [!code ++] + pub username: String, // [!code ++] +}// [!code ++] +``` + +```rs +pub async fn me(State(state): State, req: Request) -> Result, StatusCode> { // [!code ++] + // リクエストからユーザーIDを取得する // [!code ++] + let user_id = req // [!code ++] + .extensions() // [!code ++] + .get::() // [!code ++] + .ok_or(StatusCode::UNAUTHORIZED)? // [!code ++] + .to_string(); // [!code ++] + + // データベースからユーザー名を取得する // [!code ++] + let username = state // [!code ++] + .get_user_name_by_id( // [!code ++] + user_id // [!code ++] + .parse() // [!code ++] + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, // [!code ++] + ) // [!code ++] + .await // [!code ++] + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; // [!code ++] + + Ok(Json(Me { username })) // [!code ++] } // [!code ++] ``` -アクセスしているユーザーの`userName`をセッションから取得して返しています。 -`userAuthMiddleware` を実行したあとなので、`c.Get("userName").(string)` によって userName を取得できます。 - -`main.go`に`withAuth.GET("/me", handler.GetMeHandler)`を追加しましょう。 -```go -func main() { - (省略) - withAuth := e.Group("") - withAuth.Use(handler.UserAuthMiddleware) - withAuth.GET("/me", handler.GetMeHandler) // [!code ++] - withAuth.GET("/cities/:cityName", h.GetCityInfoHandler) - withAuth.POST("/cities", h.PostCityHandler) - - err = e.Start(":8080") - (省略) +リクエストからユーザー ID を取得し、データベースからユーザー名を取得します。 + +ここで呼び出す `get_user_name_by_id` メソッドを `repository/users.rs` に追加します。 + +```rs +impl Repository { + ...(省略) + + pub async fn delete_user_session(&self, session_id: String) -> anyhow::Result<()> { // [!code ++] + let session = self// [!code ++] + .session_store// [!code ++] + .load_session(session_id.clone()) // [!code ++] + .await// [!code ++] + .with_context(|| "Failed to load session")? // [!code ++] + .with_context(|| "Failed to find session")?; // [!code ++] + + self.session_store // [!code ++] + .destroy_session(session) // [!code ++] + .await // [!code ++] + .with_context(|| "Failed to destroy session")?; // [!code ++] + + Ok(()) // [!code ++] + } // [!code ++] + + ...(省略) } ``` + +最後に、`handler.rs` に `me` ハンドラを追加します。 + +```rs +let auth_router = Router::new() + .route("/signup", post(auth::sign_up)) + .route("/login", post(auth::login)) + .route("/logout", post(auth::logout)) + .route("/me", get(auth::me)) // [!code ++] + .route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware)); +``` diff --git a/docs/chapter2/section1/3_verify.md b/docs/chapter2/section1/3_verify.md index 6da8b6e3..e0b2eeaf 100644 --- a/docs/chapter2/section1/3_verify.md +++ b/docs/chapter2/section1/3_verify.md @@ -6,8 +6,14 @@ :::details 完成形 ::: code-group -<<<@/chapter2/section1/src/final/main.go{go:line-numbers}[main.go] -<<<@/chapter2/section1/src/final/handler.go{go:line-numbers}[handler.go] +<<<@/chapter2/section1/src/final/main.rs{rs:line-numbers}[main.rs] +<<<@/chapter2/section1/src/final/handler.rs{rs:line-numbers}[handler.rs] +<<<@/chapter2/section1/src/final/repository.rs{rs:line-numbers}[repository.rs] +<<<@/chapter2/section1/src/final/handler/country.rs{rs:line-numbers}[handler/country.rs] +<<<@/chapter2/section1/src/final/repository/country.rs{rs:line-numbers}[repository/country.rs] +<<<@/chapter2/section1/src/final/handler/auth.rs{rs:line-numbers}[handler/auth.rs] +<<<@/chapter2/section1/src/final/repository/users.rs{rs:line-numbers}[repository/users.rs] +<<<@/chapter2/section1/src/final/repository/users_session.rs{rs:line-numbers}[repository/users_session.rs] ::: ## 検証 diff --git a/docs/chapter2/section1/src/1_account/main.rs b/docs/chapter2/section1/src/1_account/main.rs new file mode 100644 index 00000000..8823e79b --- /dev/null +++ b/docs/chapter2/section1/src/1_account/main.rs @@ -0,0 +1,21 @@ +use tower_http::trace::TraceLayer; +use tracing_subscriber::EnvFilter; + +mod handler; +mod repository; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::try_from_default_env().unwrap_or("info".into())) + .init(); + + let app_state = repository::Repository::connect().await?; + app_state.migrate().await?; + let app = handler::make_router(app_state).layer(TraceLayer::new_for_http()); + let listener = tokio::net::TcpListener::bind("127.0.0.1:8080").await?; + + tracing::info!("listening on {}", listener.local_addr()?); + axum::serve(listener, app).await.unwrap(); + Ok(()) +} diff --git a/docs/chapter2/section1/src/1_account/repository.rs b/docs/chapter2/section1/src/1_account/repository.rs new file mode 100644 index 00000000..f7d595d0 --- /dev/null +++ b/docs/chapter2/section1/src/1_account/repository.rs @@ -0,0 +1,44 @@ +use sqlx::mysql::MySqlConnectOptions; +use sqlx::mysql::MySqlPool; +use std::env; + +pub mod country; +pub mod users; + +#[derive(Clone)] +pub struct Repository { + pool: MySqlPool, +} + +impl Repository { + pub async fn connect() -> anyhow::Result { + let options = get_options()?; + let pool = sqlx::MySqlPool::connect_with(options).await?; + Ok(Self { + pool, + }) + } + pub async fn migrate(&self) -> anyhow::Result<()> { + sqlx::migrate!("./migrations").run(&self.pool).await?; + Ok(()) + } +} + +fn get_options() -> anyhow::Result { + let host = env::var("DB_HOSTNAME")?; + let port = env::var("DB_PORT")?.parse()?; + let username = env::var("DB_USERNAME")?; + let password = env::var("DB_PASSWORD")?; + let database = env::var("DB_DATABASE")?; + let timezone = Some(String::from("Asia/Tokyo")); + let collation = String::from("utf8mb4_unicode_ci"); + + Ok(MySqlConnectOptions::new() + .host(&host) + .port(port) + .username(&username) + .password(&password) + .database(&database) + .timezone(timezone) + .collation(&collation)) +} diff --git a/docs/chapter2/section1/src/2_session/auth.rs b/docs/chapter2/section1/src/2_session/auth.rs new file mode 100644 index 00000000..f9b71c25 --- /dev/null +++ b/docs/chapter2/section1/src/2_session/auth.rs @@ -0,0 +1,96 @@ +use axum::{ + extract::State, + http::{header, StatusCode}, + response::IntoResponse, + Json, +}; +use serde::Deserialize; + +use crate::repository::Repository; + +#[derive(Deserialize)] +pub struct SignUp { + pub username: String, + pub password: String, +} + +pub async fn sign_up( + State(state): State, + Json(body): Json, +) -> Result { + // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) + if body.username.is_empty() || body.password.is_empty() { + return Err(StatusCode::BAD_REQUEST); + } + + // 登録しようとしているユーザーが既にデータベース内に存在したら409 Conflictを返す + if let Ok(true) = state.is_exist_username(body.username.clone()).await { + return Err(StatusCode::CONFLICT); + } + + // ユーザーを作成する + let id = state + .create_user(body.username.clone()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // パスワードを保存する + state + .save_user_password(id as i32, body.password.clone()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(StatusCode::CREATED) +} + +#[derive(Deserialize)] +pub struct Login { + pub username: String, + pub password: String, +} + +pub async fn login( + State(state): State, + Json(body): Json, +) -> Result { + // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) + if body.username.is_empty() || body.password.is_empty() { + return Err(StatusCode::BAD_REQUEST); + } + + // データベースからユーザーを取得する + let id = state + .get_user_id_by_name(body.username.clone()) + .await + .map_err(|e| match e { + sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + })?; + + // パスワードが一致しているかを確かめる + if !state + .verify_user_password(id, body.password.clone()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + { + return Err(StatusCode::UNAUTHORIZED); + } + + // セッションストアに登録する + let session_id = state + .create_user_session(id.to_string()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // クッキーをセットする + let mut headers = header::HeaderMap::new(); + + headers.insert( + header::SET_COOKIE, + format!("session_id={}; HttpOnly; SameSite=Strict", session_id) + .parse() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, + ); + + Ok((StatusCode::OK, headers)) +} diff --git a/docs/chapter2/section1/src/2_session/handler.go b/docs/chapter2/section1/src/2_session/handler.go deleted file mode 100644 index ad0cbbe8..00000000 --- a/docs/chapter2/section1/src/2_session/handler.go +++ /dev/null @@ -1,167 +0,0 @@ -package handler - -import ( - "database/sql" - "errors" - "github.com/jmoiron/sqlx" - "github.com/labstack/echo-contrib/session" - "github.com/labstack/echo/v4" - "golang.org/x/crypto/bcrypt" - "log" - "net/http" -) - -type Handler struct { - db *sqlx.DB -} - -func NewHandler(db *sqlx.DB) *Handler { - return &Handler{db: db} -} - -type City struct { - ID int `json:"id,omitempty" db:"ID"` - Name sql.NullString `json:"name,omitempty" db:"Name"` - CountryCode sql.NullString `json:"countryCode,omitempty" db:"CountryCode"` - District sql.NullString `json:"district,omitempty" db:"District"` - Population sql.NullInt64 `json:"population,omitempty" db:"Population"` -} - -type LoginRequestBody struct { - Username string `json:"username,omitempty" form:"username"` - Password string `json:"password,omitempty" form:"password"` -} - -type User struct { - Username string `json:"username,omitempty" db:"Username"` - HashedPass string `json:"-" db:"HashedPass"` -} - -func (h *Handler) SignUpHandler(c echo.Context) error { - // リクエストを受け取り、reqに格納する - req := LoginRequestBody{} - err := c.Bind(&req) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "bad request body") - } - - // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) - if req.Password == "" || req.Username == "" { - return c.String(http.StatusBadRequest, "Username or Password is empty") - } - - // 登録しようとしているユーザーが既にデータベース内に存在するかチェック - var count int - err = h.db.Get(&count, "SELECT COUNT(*) FROM users WHERE Username=?", req.Username) - if err != nil { - log.Println(err) - return c.NoContent(http.StatusInternalServerError) - } - // 存在したら409 Conflictを返す - if count > 0 { - return c.String(http.StatusConflict, "Username is already used") - } - - // パスワードをハッシュ化する - hashedPass, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) - // ハッシュ化に失敗したら500 InternalServerErrorを返す - if err != nil { - log.Println(err) - return c.NoContent(http.StatusInternalServerError) - } - - // ユーザーを登録する - _, err = h.db.Exec("INSERT INTO users (Username, HashedPass) VALUES (?, ?)", req.Username, hashedPass) - // 登録に失敗したら500 InternalServerErrorを返す - if err != nil { - log.Println(err) - return c.NoContent(http.StatusInternalServerError) - } - // 登録に成功したら201 Createdを返す - return c.NoContent(http.StatusCreated) -} - -func (h *Handler) LoginHandler(c echo.Context) error { - // リクエストを受け取り、reqに格納する - var req LoginRequestBody - err := c.Bind(&req) - if err != nil { - return c.String(http.StatusBadRequest, "bad request body") - } - - // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) - if req.Password == "" || req.Username == "" { - return c.String(http.StatusBadRequest, "Username or Password is empty") - } - - // データベースからユーザーを取得する - user := User{} - err = h.db.Get(&user, "SELECT * FROM users WHERE username=?", req.Username) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return c.NoContent(http.StatusUnauthorized) - } else { - log.Println(err) - return c.NoContent(http.StatusInternalServerError) - } - } - // パスワードが一致しているかを確かめる - err = bcrypt.CompareHashAndPassword([]byte(user.HashedPass), []byte(req.Password)) - if err != nil { - if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { - return c.NoContent(http.StatusUnauthorized) - } else { - return c.NoContent(http.StatusInternalServerError) - } - } - - // セッションストアに登録する - sess, err := session.Get("sessions", c) - if err != nil { - log.Println(err) - return c.String(http.StatusInternalServerError, "something wrong in getting session") - } - sess.Values["userName"] = req.Username - sess.Save(c.Request(), c.Response()) - - return c.NoContent(http.StatusOK) -} - -func (h *Handler) GetCityInfoHandler(c echo.Context) error { - cityName := c.Param("cityName") - - var city City - err := h.db.Get(&city, "SELECT * FROM city WHERE Name=?", cityName) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return c.NoContent(http.StatusNotFound) - } - log.Printf("failed to get city data: %s\n", err) - return c.NoContent(http.StatusInternalServerError) - } - - return c.JSON(http.StatusOK, city) -} - -func (h *Handler) PostCityHandler(c echo.Context) error { - var city City - err := c.Bind(&city) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "bad request body") - } - - result, err := h.db.Exec("INSERT INTO city (Name, CountryCode, District, Population) VALUES (?, ?, ?, ?)", city.Name, city.CountryCode, city.District, city.Population) - if err != nil { - log.Printf("failed to insert city data: %s\n", err) - return c.NoContent(http.StatusInternalServerError) - } - - id, err := result.LastInsertId() - if err != nil { - log.Printf("failed to get last insert id: %s\n", err) - return c.NoContent(http.StatusInternalServerError) - } - city.ID = int(id) - - return c.JSON(http.StatusCreated, city) -} diff --git a/docs/chapter2/section1/src/2_session/main.go b/docs/chapter2/section1/src/2_session/main.go deleted file mode 100644 index ca0cc067..00000000 --- a/docs/chapter2/section1/src/2_session/main.go +++ /dev/null @@ -1,75 +0,0 @@ -package main - -import ( - "github.com/labstack/echo-contrib/session" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - "github.com/srinathgs/mysqlstore" - "github.com/traPtitech/naro-template-backend/handler" - "log" - "os" - "time" - - "github.com/go-sql-driver/mysql" - - "github.com/jmoiron/sqlx" - "github.com/joho/godotenv" -) - -func main() { - // .envファイルから環境変数を読み込み - err := godotenv.Load(".env") - if err != nil { - log.Fatal(err) - } - - // データーベースの設定 - jst, err := time.LoadLocation("Asia/Tokyo") - if err != nil { - log.Fatal(err) - } - conf := mysql.Config{ - User: os.Getenv("DB_USERNAME"), - Passwd: os.Getenv("DB_PASSWORD"), - Net: "tcp", - Addr: os.Getenv("DB_HOSTNAME") + ":" + os.Getenv("DB_PORT"), - DBName: os.Getenv("DB_DATABASE"), - ParseTime: true, - Collation: "utf8mb4_unicode_ci", - Loc: jst, - } - - // データベースに接続 - db, err := sqlx.Open("mysql", conf.FormatDSN()) - if err != nil { - log.Fatal(err) - } - - // usersテーブルが存在しなかったら、usersテーブルを作成する - _, err = db.Exec("CREATE TABLE IF NOT EXISTS users (Username VARCHAR(255) PRIMARY KEY, HashedPass VARCHAR(255))") - if err != nil { - log.Fatal(err) - } - - // セッションの情報を記憶するための場所をデータベース上に設定 - store, err := mysqlstore.NewMySQLStoreFromConnection(db.DB, "sessions", "/", 60*60*24*14, []byte("secret-token")) - if err != nil { - log.Fatal(err) - } - - h := handler.NewHandler(db) - e := echo.New() - e.Use(middleware.Logger()) // ログを取るミドルウェアを追加 - e.Use(session.Middleware(store)) // セッション管理のためのミドルウェアを追加 - - e.POST("/signup", h.SignUpHandler) - e.POST("/login", h.LoginHandler) - - e.GET("/cities/:cityName", h.GetCityInfoHandler) - e.POST("/cities", h.PostCityHandler) - - err = e.Start(":8080") - if err != nil { - log.Fatal(err) - } -} diff --git a/docs/chapter2/section1/src/2_session/repository.rs b/docs/chapter2/section1/src/2_session/repository.rs new file mode 100644 index 00000000..1e8baf69 --- /dev/null +++ b/docs/chapter2/section1/src/2_session/repository.rs @@ -0,0 +1,53 @@ +use async_sqlx_session::MySqlSessionStore; +use sqlx::mysql::MySqlConnectOptions; +use sqlx::mysql::MySqlPool; +use std::env; + +pub mod country; +pub mod users; +pub mod users_session; + +#[derive(Clone)] +pub struct Repository { + pool: MySqlPool, + session_store: MySqlSessionStore, +} + +impl Repository { + pub async fn connect() -> anyhow::Result { + let options = get_options()?; + let pool = sqlx::MySqlPool::connect_with(options).await?; + + let session_store = + MySqlSessionStore::from_client(pool.clone()).with_table_name("user_sessions"); + + Ok(Self { + pool, + session_store, + }) + } + + pub async fn migrate(&self) -> anyhow::Result<()> { + sqlx::migrate!("./migrations").run(&self.pool).await?; + Ok(()) + } +} + +fn get_options() -> anyhow::Result { + let host = env::var("DB_HOSTNAME")?; + let port = env::var("DB_PORT")?.parse()?; + let username = env::var("DB_USERNAME")?; + let password = env::var("DB_PASSWORD")?; + let database = env::var("DB_DATABASE")?; + let timezone = Some(String::from("Asia/Tokyo")); + let collation = String::from("utf8mb4_unicode_ci"); + + Ok(MySqlConnectOptions::new() + .host(&host) + .port(port) + .username(&username) + .password(&password) + .database(&database) + .timezone(timezone) + .collation(&collation)) +} diff --git a/docs/chapter2/section1/src/2_session/users.rs b/docs/chapter2/section1/src/2_session/users.rs new file mode 100644 index 00000000..55f75768 --- /dev/null +++ b/docs/chapter2/section1/src/2_session/users.rs @@ -0,0 +1,49 @@ +use super::Repository; + +impl Repository { + pub async fn is_exist_username(&self, username: String) -> sqlx::Result { + let result = sqlx::query("SELECT * FROM users WHERE username = ?") + .bind(&username) + .fetch_optional(&self.pool) + .await?; + Ok(result.is_some()) + } + + pub async fn create_user(&self, username: String) -> sqlx::Result { + let result = sqlx::query("INSERT INTO users (username) VALUES (?)") + .bind(&username) + .execute(&self.pool) + .await?; + Ok(result.last_insert_id()) + } + + pub async fn get_user_id_by_name(&self, username: String) -> sqlx::Result { + let result = sqlx::query_scalar("SELECT id FROM users WHERE username = ?") + .bind(&username) + .fetch_one(&self.pool) + .await?; + Ok(result) + } + + pub async fn save_user_password(&self, id: i32, password: String) -> anyhow::Result<()> { + let hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?; + + sqlx::query("INSERT INTO user_passwords (id, hashed_pass) VALUES (?, ?)") + .bind(id) + .bind(hash) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn verify_user_password(&self, id: u64, password: String) -> anyhow::Result { + let hash = + sqlx::query_scalar::<_, String>("SELECT hashed_pass FROM user_passwords WHERE id = ?") + .bind(id) + .fetch_one(&self.pool) + .await?; + + Ok(bcrypt::verify(password, &hash)?) + } +} diff --git a/docs/chapter2/section1/src/2_session/users_session.rs b/docs/chapter2/section1/src/2_session/users_session.rs new file mode 100644 index 00000000..346a684d --- /dev/null +++ b/docs/chapter2/section1/src/2_session/users_session.rs @@ -0,0 +1,23 @@ +use anyhow::Context; +use async_session::{Session, SessionStore}; + +use super::Repository; + +impl Repository { + pub async fn create_user_session(&self, user_id: String) -> anyhow::Result { + let mut session = Session::new(); + + session + .insert("user_id", user_id) + .with_context(|| "Failed to insert user_id")?; + + let session_id = self + .session_store + .store_session(session) + .await + .with_context(|| "Failed to store session")? + .with_context(|| "Failed to create session")?; + + Ok(session_id) + } +} diff --git a/docs/chapter2/section1/src/final/handler.go b/docs/chapter2/section1/src/final/handler.go deleted file mode 100644 index 7913ae31..00000000 --- a/docs/chapter2/section1/src/final/handler.go +++ /dev/null @@ -1,192 +0,0 @@ -package handler - -import ( - "database/sql" - "errors" - "github.com/jmoiron/sqlx" - "github.com/labstack/echo-contrib/session" - "github.com/labstack/echo/v4" - "golang.org/x/crypto/bcrypt" - "log" - "net/http" -) - -type Handler struct { - db *sqlx.DB -} - -func NewHandler(db *sqlx.DB) *Handler { - return &Handler{db: db} -} - -type City struct { - ID int `json:"id,omitempty" db:"ID"` - Name sql.NullString `json:"name,omitempty" db:"Name"` - CountryCode sql.NullString `json:"countryCode,omitempty" db:"CountryCode"` - District sql.NullString `json:"district,omitempty" db:"District"` - Population sql.NullInt64 `json:"population,omitempty" db:"Population"` -} - -type LoginRequestBody struct { - Username string `json:"username,omitempty" form:"username"` - Password string `json:"password,omitempty" form:"password"` -} - -type User struct { - Username string `json:"username,omitempty" db:"Username"` - HashedPass string `json:"-" db:"HashedPass"` -} - -type Me struct { - Username string `json:"username,omitempty" db:"username"` -} - -func (h *Handler) SignUpHandler(c echo.Context) error { - // リクエストを受け取り、reqに格納する - req := LoginRequestBody{} - err := c.Bind(&req) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "bad request body") - } - - // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) - if req.Password == "" || req.Username == "" { - return c.String(http.StatusBadRequest, "Username or Password is empty") - } - - // 登録しようとしているユーザーが既にデータベース内に存在するかチェック - var count int - err = h.db.Get(&count, "SELECT COUNT(*) FROM users WHERE Username=?", req.Username) - if err != nil { - log.Println(err) - return c.NoContent(http.StatusInternalServerError) - } - // 存在したら409 Conflictを返す - if count > 0 { - return c.String(http.StatusConflict, "Username is already used") - } - - // パスワードをハッシュ化する - hashedPass, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) - // ハッシュ化に失敗したら500 InternalServerErrorを返す - if err != nil { - log.Println(err) - return c.NoContent(http.StatusInternalServerError) - } - - // ユーザーを登録する - _, err = h.db.Exec("INSERT INTO users (Username, HashedPass) VALUES (?, ?)", req.Username, hashedPass) - // 登録に失敗したら500 InternalServerErrorを返す - if err != nil { - log.Println(err) - return c.NoContent(http.StatusInternalServerError) - } - // 登録に成功したら201 Createdを返す - return c.NoContent(http.StatusCreated) -} - -func (h *Handler) LoginHandler(c echo.Context) error { - // リクエストを受け取り、reqに格納する - var req LoginRequestBody - err := c.Bind(&req) - if err != nil { - return c.String(http.StatusBadRequest, "bad request body") - } - - // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) - if req.Password == "" || req.Username == "" { - return c.String(http.StatusBadRequest, "Username or Password is empty") - } - - // データベースからユーザーを取得する - user := User{} - err = h.db.Get(&user, "SELECT * FROM users WHERE username=?", req.Username) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return c.NoContent(http.StatusUnauthorized) - } else { - log.Println(err) - return c.NoContent(http.StatusInternalServerError) - } - } - // パスワードが一致しているかを確かめる - err = bcrypt.CompareHashAndPassword([]byte(user.HashedPass), []byte(req.Password)) - if err != nil { - if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { - return c.NoContent(http.StatusUnauthorized) - } else { - return c.NoContent(http.StatusInternalServerError) - } - } - - // セッションストアに登録する - sess, err := session.Get("sessions", c) - if err != nil { - log.Println(err) - return c.String(http.StatusInternalServerError, "something wrong in getting session") - } - sess.Values["userName"] = req.Username - sess.Save(c.Request(), c.Response()) - - return c.NoContent(http.StatusOK) -} - -func UserAuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - sess, err := session.Get("sessions", c) - if err != nil { - log.Println(err) - return c.String(http.StatusInternalServerError, "something wrong in getting session") - } - if sess.Values["userName"] == nil { - return c.String(http.StatusUnauthorized, "please login") - } - c.Set("userName", sess.Values["userName"].(string)) - return next(c) - } -} - -func GetMeHandler(c echo.Context) error { - return c.JSON(http.StatusOK, Me{ - Username: c.Get("userName").(string), - }) -} - -func (h *Handler) GetCityInfoHandler(c echo.Context) error { - cityName := c.Param("cityName") - - var city City - err := h.db.Get(&city, "SELECT * FROM city WHERE Name=?", cityName) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return c.NoContent(http.StatusNotFound) - } - log.Printf("failed to get city data: %s\n", err) - return c.NoContent(http.StatusInternalServerError) - } - - return c.JSON(http.StatusOK, city) -} - -func (h *Handler) PostCityHandler(c echo.Context) error { - var city City - err := c.Bind(&city) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "bad request body") - } - - result, err := h.db.Exec("INSERT INTO city (Name, CountryCode, District, Population) VALUES (?, ?, ?, ?)", city.Name, city.CountryCode, city.District, city.Population) - if err != nil { - log.Printf("failed to insert city data: %s\n", err) - return c.NoContent(http.StatusInternalServerError) - } - - id, err := result.LastInsertId() - if err != nil { - log.Printf("failed to get last insert id: %s\n", err) - return c.NoContent(http.StatusInternalServerError) - } - city.ID = int(id) - - return c.JSON(http.StatusCreated, city) -} diff --git a/docs/chapter2/section1/src/final/handler.rs b/docs/chapter2/section1/src/final/handler.rs new file mode 100644 index 00000000..f22842c3 --- /dev/null +++ b/docs/chapter2/section1/src/final/handler.rs @@ -0,0 +1,32 @@ +use axum::{ + middleware::from_fn_with_state, + routing::{get, post}, + Router, +}; + +use crate::repository::Repository; + +mod auth; +mod country; + +pub fn make_router(app_state: Repository) -> Router { + let city_router = Router::new() + .route("/cities/:city_name", get(country::get_city_handler)) + .route("/cities", post(country::post_city_handler)) + .route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware)); + + let auth_router = Router::new() + .route("/signup", post(auth::sign_up)) + .route("/login", post(auth::login)) + .route("/logout", post(auth::logout)) + .route("/me", get(auth::me)) + .route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware)); + + let ping_router = Router::new().route("/ping", get(|| async { "pong" })); + + Router::new() + .nest("/", city_router) + .nest("/", auth_router) + .nest("/", ping_router) + .with_state(app_state) +} diff --git a/docs/chapter2/section1/src/final/handler/auth.rs b/docs/chapter2/section1/src/final/handler/auth.rs new file mode 100644 index 00000000..5c86e993 --- /dev/null +++ b/docs/chapter2/section1/src/final/handler/auth.rs @@ -0,0 +1,178 @@ +use axum::{ + extract::{Request, State}, + http::{header, StatusCode}, + middleware::Next, + response::IntoResponse, + Json, +}; +use axum_extra::{headers::Cookie, TypedHeader}; +use serde::{Deserialize, Serialize}; + +use crate::repository::Repository; + +pub async fn auth_middleware( + State(state): State, + TypedHeader(cookie): TypedHeader, + mut req: Request, + next: Next, +) -> Result { + // セッションIDを取得する + let session_id = cookie + .get("session_id") + .ok_or(StatusCode::UNAUTHORIZED)? + .to_string(); + + // セッションストアからユーザーIDを取得する + let user_id = state + .get_user_id_by_session_id(&session_id) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::UNAUTHORIZED)?; + + // リクエストにユーザーIDを追加する + req.extensions_mut().insert(user_id); + + // 次のミドルウェアを呼び出す + Ok(next.run(req).await) +} + +#[derive(Deserialize)] +pub struct SignUp { + pub username: String, + pub password: String, +} + +pub async fn sign_up( + State(state): State, + Json(body): Json, +) -> Result { + // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) + if body.username.is_empty() || body.password.is_empty() { + return Err(StatusCode::BAD_REQUEST); + } + + // 登録しようとしているユーザーが既にデータベース内に存在したら409 Conflictを返す + if let Ok(true) = state.is_exist_username(body.username.clone()).await { + return Err(StatusCode::CONFLICT); + } + + // ユーザーを作成する + let id = state + .create_user(body.username.clone()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // パスワードを保存する + state + .save_user_password(id as i32, body.password.clone()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(StatusCode::CREATED) +} + +#[derive(Deserialize)] +pub struct Login { + pub username: String, + pub password: String, +} + +pub async fn login( + State(state): State, + Json(body): Json, +) -> Result { + // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す) + if body.username.is_empty() || body.password.is_empty() { + return Err(StatusCode::BAD_REQUEST); + } + + // データベースからユーザーを取得する + let id = state + .get_user_id_by_name(body.username.clone()) + .await + .map_err(|e| match e { + sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + })?; + + // パスワードが一致しているかを確かめる + if !state + .verify_user_password(id, body.password.clone()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + { + return Err(StatusCode::UNAUTHORIZED); + } + + // セッションストアに登録する + let session_id = state + .create_user_session(id.to_string()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // クッキーをセットする + let mut headers = header::HeaderMap::new(); + + headers.insert( + header::SET_COOKIE, + format!("session_id={}; HttpOnly; SameSite=Strict", session_id) + .parse() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, + ); + + Ok((StatusCode::OK, headers)) +} + +pub async fn logout( + State(state): State, + TypedHeader(cookie): TypedHeader, +) -> Result { + // セッションIDを取得する + let session_id = cookie + .get("session_id") + .ok_or(StatusCode::UNAUTHORIZED)? + .to_string(); + + // セッションストアから削除する + state + .delete_user_session(session_id) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // クッキーを削除する + let mut headers = header::HeaderMap::new(); + headers.insert( + header::SET_COOKIE, + "session_id=; HttpOnly; SameSite=Strict; Max-Age=0" + .parse() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, + ); + + Ok((StatusCode::OK, headers)) +} + +#[derive(Serialize)] +pub struct Me { + pub username: String, +} + +pub async fn me(State(state): State, req: Request) -> Result, StatusCode> { + // リクエストからユーザーIDを取得する + let user_id = req + .extensions() + .get::() + .ok_or(StatusCode::UNAUTHORIZED)? + .to_string(); + + // データベースからユーザー名を取得する + let username = state + .get_user_name_by_id( + user_id + .parse() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, + ) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json(Me { username })) +} diff --git a/docs/chapter2/section1/src/final/handler/country.rs b/docs/chapter2/section1/src/final/handler/country.rs new file mode 100644 index 00000000..b6cc22eb --- /dev/null +++ b/docs/chapter2/section1/src/final/handler/country.rs @@ -0,0 +1,162 @@ +use crate::repository::{country::City, Repository}; +use axum::{ + extract::rejection::JsonRejection, + extract::{Path, State}, + http::StatusCode, + Json, +}; + +pub async fn get_city_handler( + State(state): State, + Path(city_name): Path, +) -> Result, StatusCode> { + let city = Repository::get_city_by_name(&state, city_name).await; + match city { + Ok(city) => Ok(Json(city)), + Err(sqlx::Error::RowNotFound) => Err(StatusCode::NOT_FOUND), + Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), + } +} + +pub async fn post_city_handler( + State(state): State, + query: Result, JsonRejection>, +) -> Result, StatusCode> { + match query { + Ok(Json(city)) => { + let result = Repository::create_city(&state, city).await; + match result { + Ok(city) => Ok(Json(city)), + Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), + } + } + Err(_) => Err(StatusCode::BAD_REQUEST), + } +} + +use std::collections::HashMap; + +//与えられた City のリストから国ごとの人口の和を計算する +#[allow(dead_code)] +pub fn sum_population_by_country(cities: Vec) -> HashMap { + let mut map = HashMap::new(); + for city in cities { + if city.country_code.is_empty() { + continue; + } + let entry = map.entry(city.country_code).or_insert(0); + *entry += city.population; + } + map +} + +// #[cfg(test)] 属性を追加したモジュールはテストモジュールとして扱われる +#[cfg(test)] +mod tests { + use super::{sum_population_by_country, City}; + use std::collections::HashMap; + + #[test] + fn test_sum_population_by_country_empty() { + // ここにテストを追加する + let cities = vec![]; + let result = sum_population_by_country(cities); + assert!(result.is_empty()); + } + + #[test] + fn test_sum_population_by_country_single() { + let cities = vec![ + City { + id: Some(1), + name: "Tokyo".to_string(), + country_code: "JPN".to_string(), + district: "Tokyo".to_string(), + population: 100, + }, + City { + id: Some(2), + name: "Osaka".to_string(), + country_code: "JPN".to_string(), + district: "Osaka".to_string(), + population: 200, + }, + ]; + + let mut expected = HashMap::new(); + expected.insert("JPN".to_string(), 300); + + let result = sum_population_by_country(cities); + + assert_eq!(result, expected); + } + + #[test] + fn test_sum_population_by_country_multiple() { + let cities = vec![ + City { + id: Some(1), + name: "Tokyo".to_string(), + country_code: "JPN".to_string(), + district: "Tokyo".to_string(), + population: 100, + }, + City { + id: Some(2), + name: "Osaka".to_string(), + country_code: "JPN".to_string(), + district: "Osaka".to_string(), + population: 200, + }, + City { + id: Some(3), + name: "New York".to_string(), + country_code: "USA".to_string(), + district: "New York".to_string(), + population: 300, + }, + City { + id: Some(4), + name: "Los Angeles".to_string(), + country_code: "USA".to_string(), + district: "California".to_string(), + population: 400, + }, + ]; + + let mut expected = HashMap::new(); + expected.insert("JPN".to_string(), 300); + expected.insert("USA".to_string(), 700); + + let result = sum_population_by_country(cities); + + assert_eq!(result, expected); + } + + #[test] + fn test_sum_population_by_country_empty_country_code() { + let cities = vec![ + City { + id: Some(1), + name: "Tokyo".to_string(), + country_code: "JPN".to_string(), + district: "Tokyo".to_string(), + population: 100, + }, + City { + id: Some(2), + name: "Osaka".to_string(), + country_code: "".to_string(), + district: "Osaka".to_string(), + population: 200, + }, + ]; + + let mut expected = HashMap::new(); + expected.insert("JPN".to_string(), 100); + + let result = sum_population_by_country(cities); + + assert_eq!(result, expected); + } +} diff --git a/docs/chapter2/section1/src/final/main.go b/docs/chapter2/section1/src/final/main.go deleted file mode 100644 index de157b5c..00000000 --- a/docs/chapter2/section1/src/final/main.go +++ /dev/null @@ -1,78 +0,0 @@ -package main - -import ( - "github.com/labstack/echo-contrib/session" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - "github.com/srinathgs/mysqlstore" - "github.com/traPtitech/naro-template-backend/handler" - "log" - "os" - "time" - - "github.com/go-sql-driver/mysql" - - "github.com/jmoiron/sqlx" - "github.com/joho/godotenv" -) - -func main() { - // .envファイルから環境変数を読み込み - err := godotenv.Load(".env") - if err != nil { - log.Fatal(err) - } - - // データーベースの設定 - jst, err := time.LoadLocation("Asia/Tokyo") - if err != nil { - log.Fatal(err) - } - conf := mysql.Config{ - User: os.Getenv("DB_USERNAME"), - Passwd: os.Getenv("DB_PASSWORD"), - Net: "tcp", - Addr: os.Getenv("DB_HOSTNAME") + ":" + os.Getenv("DB_PORT"), - DBName: os.Getenv("DB_DATABASE"), - ParseTime: true, - Collation: "utf8mb4_unicode_ci", - Loc: jst, - } - - // データベースに接続 - db, err := sqlx.Open("mysql", conf.FormatDSN()) - if err != nil { - log.Fatal(err) - } - - // usersテーブルが存在しなかったら、usersテーブルを作成する - _, err = db.Exec("CREATE TABLE IF NOT EXISTS users (Username VARCHAR(255) PRIMARY KEY, HashedPass VARCHAR(255))") - if err != nil { - log.Fatal(err) - } - - // セッションの情報を記憶するための場所をデータベース上に設定 - store, err := mysqlstore.NewMySQLStoreFromConnection(db.DB, "sessions", "/", 60*60*24*14, []byte("secret-token")) - if err != nil { - log.Fatal(err) - } - - h := handler.NewHandler(db) - e := echo.New() - e.Use(middleware.Logger()) // ログを取るミドルウェアを追加 - e.Use(session.Middleware(store)) // セッション管理のためのミドルウェアを追加 - - e.POST("/signup", h.SignUpHandler) - e.POST("/login", h.LoginHandler) - - withAuth := e.Group("") - withAuth.Use(handler.UserAuthMiddleware) - withAuth.GET("/me", handler.GetMeHandler) - withAuth.GET("/cities/:cityName", h.GetCityInfoHandler) - withAuth.POST("/cities", h.PostCityHandler) - - err = e.Start(":8080") - if err != nil { - log.Fatal(err) - } -} diff --git a/docs/chapter2/section1/src/final/main.rs b/docs/chapter2/section1/src/final/main.rs new file mode 100644 index 00000000..8823e79b --- /dev/null +++ b/docs/chapter2/section1/src/final/main.rs @@ -0,0 +1,21 @@ +use tower_http::trace::TraceLayer; +use tracing_subscriber::EnvFilter; + +mod handler; +mod repository; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::try_from_default_env().unwrap_or("info".into())) + .init(); + + let app_state = repository::Repository::connect().await?; + app_state.migrate().await?; + let app = handler::make_router(app_state).layer(TraceLayer::new_for_http()); + let listener = tokio::net::TcpListener::bind("127.0.0.1:8080").await?; + + tracing::info!("listening on {}", listener.local_addr()?); + axum::serve(listener, app).await.unwrap(); + Ok(()) +} diff --git a/docs/chapter2/section1/src/final/repository.rs b/docs/chapter2/section1/src/final/repository.rs new file mode 100644 index 00000000..1e8baf69 --- /dev/null +++ b/docs/chapter2/section1/src/final/repository.rs @@ -0,0 +1,53 @@ +use async_sqlx_session::MySqlSessionStore; +use sqlx::mysql::MySqlConnectOptions; +use sqlx::mysql::MySqlPool; +use std::env; + +pub mod country; +pub mod users; +pub mod users_session; + +#[derive(Clone)] +pub struct Repository { + pool: MySqlPool, + session_store: MySqlSessionStore, +} + +impl Repository { + pub async fn connect() -> anyhow::Result { + let options = get_options()?; + let pool = sqlx::MySqlPool::connect_with(options).await?; + + let session_store = + MySqlSessionStore::from_client(pool.clone()).with_table_name("user_sessions"); + + Ok(Self { + pool, + session_store, + }) + } + + pub async fn migrate(&self) -> anyhow::Result<()> { + sqlx::migrate!("./migrations").run(&self.pool).await?; + Ok(()) + } +} + +fn get_options() -> anyhow::Result { + let host = env::var("DB_HOSTNAME")?; + let port = env::var("DB_PORT")?.parse()?; + let username = env::var("DB_USERNAME")?; + let password = env::var("DB_PASSWORD")?; + let database = env::var("DB_DATABASE")?; + let timezone = Some(String::from("Asia/Tokyo")); + let collation = String::from("utf8mb4_unicode_ci"); + + Ok(MySqlConnectOptions::new() + .host(&host) + .port(port) + .username(&username) + .password(&password) + .database(&database) + .timezone(timezone) + .collation(&collation)) +} diff --git a/docs/chapter2/section1/src/final/repository/country.rs b/docs/chapter2/section1/src/final/repository/country.rs new file mode 100644 index 00000000..cd3812c4 --- /dev/null +++ b/docs/chapter2/section1/src/final/repository/country.rs @@ -0,0 +1,43 @@ +use super::Repository; + +#[derive(sqlx::FromRow, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct City { + #[sqlx(rename = "ID")] + pub id: Option, + #[sqlx(rename = "Name")] + pub name: String, + #[sqlx(rename = "CountryCode")] + pub country_code: String, + #[sqlx(rename = "District")] + pub district: String, + #[sqlx(rename = "Population")] + pub population: i32, +} + +impl Repository { + pub async fn get_city_by_name(&self, city_name: String) -> sqlx::Result { + sqlx::query_as::<_, City>("SELECT * FROM city WHERE Name = ?") + .bind(&city_name) + .fetch_one(&self.pool) + .await + } + + pub async fn create_city(&self, city: City) -> sqlx::Result { + let result = sqlx::query( + "INSERT INTO city (Name, CountryCode, District, Population) VALUES (?, ?, ?, ?)", + ) + .bind(&city.name) + .bind(&city.country_code) + .bind(&city.district) + .bind(city.population) + .execute(&self.pool) + .await?; + + let id = result.last_insert_id() as i32; + Ok(City { + id: Some(id), + ..city + }) + } +} diff --git a/docs/chapter2/section1/src/final/repository/users.rs b/docs/chapter2/section1/src/final/repository/users.rs new file mode 100644 index 00000000..6652b2b7 --- /dev/null +++ b/docs/chapter2/section1/src/final/repository/users.rs @@ -0,0 +1,57 @@ +use super::Repository; + +impl Repository { + pub async fn is_exist_username(&self, username: String) -> sqlx::Result { + let result = sqlx::query("SELECT * FROM users WHERE username = ?") + .bind(&username) + .fetch_optional(&self.pool) + .await?; + Ok(result.is_some()) + } + + pub async fn create_user(&self, username: String) -> sqlx::Result { + let result = sqlx::query("INSERT INTO users (username) VALUES (?)") + .bind(&username) + .execute(&self.pool) + .await?; + Ok(result.last_insert_id()) + } + + pub async fn get_user_id_by_name(&self, username: String) -> sqlx::Result { + let result = sqlx::query_scalar("SELECT id FROM users WHERE username = ?") + .bind(&username) + .fetch_one(&self.pool) + .await?; + Ok(result) + } + + pub async fn get_user_name_by_id(&self, id: u64) -> sqlx::Result { + let result = sqlx::query_scalar("SELECT username FROM users WHERE id = ?") + .bind(id) + .fetch_one(&self.pool) + .await?; + Ok(result) + } + + pub async fn save_user_password(&self, id: i32, password: String) -> anyhow::Result<()> { + let hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?; + + sqlx::query("INSERT INTO user_passwords (id, hashed_pass) VALUES (?, ?)") + .bind(id) + .bind(hash) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn verify_user_password(&self, id: u64, password: String) -> anyhow::Result { + let hash = + sqlx::query_scalar::<_, String>("SELECT hashed_pass FROM user_passwords WHERE id = ?") + .bind(id) + .fetch_one(&self.pool) + .await?; + + Ok(bcrypt::verify(password, &hash)?) + } +} diff --git a/docs/chapter2/section1/src/final/repository/users_session.rs b/docs/chapter2/section1/src/final/repository/users_session.rs new file mode 100644 index 00000000..d5e408dd --- /dev/null +++ b/docs/chapter2/section1/src/final/repository/users_session.rs @@ -0,0 +1,52 @@ +use anyhow::Context; +use async_session::{Session, SessionStore}; + +use super::Repository; + +impl Repository { + pub async fn create_user_session(&self, user_id: String) -> anyhow::Result { + let mut session = Session::new(); + + session + .insert("user_id", user_id) + .with_context(|| "Failed to insert user_id")?; + + let session_id = self + .session_store + .store_session(session) + .await + .with_context(|| "Failed to store session")? + .with_context(|| "Failed to create session")?; + + Ok(session_id) + } + + pub async fn delete_user_session(&self, session_id: String) -> anyhow::Result<()> { + let session = self + .session_store + .load_session(session_id.clone()) + .await + .with_context(|| "Failed to load session")? + .with_context(|| "Failed to find session")?; + + self.session_store + .destroy_session(session) + .await + .with_context(|| "Failed to destroy session")?; + + Ok(()) + } + + pub async fn get_user_id_by_session_id( + &self, + session_id: &String, + ) -> anyhow::Result> { + let session = self + .session_store + .load_session(session_id.clone()) + .await + .with_context(|| "Failed to load session")?; + + Ok(session.and_then(|s| s.get::("user_id"))) + } +}