diff --git a/router/user.go b/router/user.go index a14e59d4..f41965fd 100644 --- a/router/user.go +++ b/router/user.go @@ -8,6 +8,7 @@ import ( "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" "github.com/samber/lo" + "github.com/traPtitech/Jomon/ent" "github.com/traPtitech/Jomon/model" "go.uber.org/zap" ) @@ -79,17 +80,39 @@ func (h Handlers) UpdateUserInfo(c echo.Context) error { }) } +func userFromModelUser(u model.User) User { + return User{ + ID: u.ID, + Name: u.Name, + DisplayName: u.DisplayName, + Admin: u.Admin, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, + DeletedAt: u.DeletedAt, + } +} + func (h Handlers) GetMe(c echo.Context) error { sess, err := session.Get(h.SessionName, c) if err != nil { h.Logger.Error("failed to get session", zap.Error(err)) return echo.NewHTTPError(http.StatusInternalServerError, err) } - user, ok := sess.Values[sessionUserKey].(User) + userInSession, ok := sess.Values[sessionUserKey].(User) if !ok { h.Logger.Error("failed to parse stored session as user info") return echo.NewHTTPError(http.StatusInternalServerError, "failed to get user info") } + modelUser, err := h.Repository.GetUserByID(c.Request().Context(), userInSession.ID) + if err != nil { + if ent.IsNotFound(err) { + h.Logger.Error("failed to find user from DB by ID") + return c.JSON(http.StatusNotFound, err) + } + h.Logger.Error("failed to get user by ID") + return c.JSON(http.StatusInternalServerError, err) + } + user := userFromModelUser(*modelUser) return c.JSON(http.StatusOK, user) } diff --git a/router/user_test.go b/router/user_test.go index b5f28adb..b08047ac 100644 --- a/router/user_test.go +++ b/router/user_test.go @@ -289,7 +289,7 @@ func TestHandlers_GetMe(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - accessUser := makeUser(t, false) + accessUser := makeUser(t, random.Numeric(t, 2) == 1) user := User{ ID: accessUser.ID, Name: accessUser.Name, @@ -321,6 +321,11 @@ func TestHandlers_GetMe(t *testing.T) { sess.Values[sessionUserKey] = user require.NoError(t, sess.Save(c.Request(), c.Response())) + h.Repository.MockUserRepository. + EXPECT(). + GetUserByID(c.Request().Context(), user.ID). + Return(accessUser, nil) + err = h.Handlers.GetMe(c) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code)