diff --git a/manager.go b/manager.go index b2041d0..a24334b 100644 --- a/manager.go +++ b/manager.go @@ -15,6 +15,15 @@ const ( idLen = 40 ) +var ( + // ErrUnauthorized is returned when no valid session is found. + ErrUnauthorized = errors.New("unauthorized") + + // ErrNotOwner is returned when session's status is being modified + // not by its owner. + ErrNotOwner = errors.New("session can be managed only by its owner") +) + // Manager holds the data needed to properly create sessions // and set them in http responses, extract them from http requests, // validate them and directly communicate with the store. @@ -211,13 +220,16 @@ func (m *Manager) Clone(opts ...setter) *Manager { // the store and sets the proper values of the cookie. func (m *Manager) Init(w http.ResponseWriter, r *http.Request, key string) error { s := m.newSession(r, key) - if s.ExpiresAt.After(time.Time{}) { - if err := m.store.Create(r.Context(), s); err != nil { - return err - } + exp := s.ExpiresAt + if s.ExpiresAt.IsZero() { + s.ExpiresAt = time.Now().Add(time.Hour * 24) // for temporary sessions + } + + if err := m.store.Create(r.Context(), s); err != nil { + return err } - m.setCookie(w, s.ExpiresAt, s.ID) + m.setCookie(w, exp, s.ID) return nil } @@ -265,12 +277,12 @@ func (m *Manager) wrap(rej func(error) http.Handler, next http.Handler) http.Han } if !ok { - rej(errors.New("unauthorized")).ServeHTTP(w, r) + rej(ErrUnauthorized).ServeHTTP(w, r) return } if m.validate && !s.IsValid(r) { - rej(errors.New("unauthorized")).ServeHTTP(w, r) + rej(ErrUnauthorized).ServeHTTP(w, r) return } @@ -320,7 +332,7 @@ func (m *Manager) RevokeByIDExt(ctx context.Context, id string) error { } if s2.UserKey != s1.UserKey { - return errors.New("session can be revoked only by its owner") + return ErrNotOwner } return m.store.DeleteByID(ctx, id) diff --git a/manager_test.go b/manager_test.go index 335dc04..9954122 100644 --- a/manager_test.go +++ b/manager_test.go @@ -244,16 +244,28 @@ func TestInit(t *testing.T) { } } - wasCreateCalled := func(count int, key string) check { + wasCreateCalled := func(count int, key string, t1, t2 time.Time) check { return func(t *testing.T, s *StoreMock, _ *httptest.ResponseRecorder, _ error) { ff := s.CreateCalls() if len(ff) != count { t.Errorf("want %d, got %d", count, len(ff)) } - if len(ff) > 0 && ff[0].S.UserKey != key { + if len(ff) == 0 { + return + } + + if ff[0].S.UserKey != key { t.Errorf("want %q, got %q", key, ff[0].S.UserKey) } + + if !ff[0].S.ExpiresAt.After(t1) { + t.Errorf("want after %s, got %s", t1.String(), ff[0].S.ExpiresAt.String()) + } + + if !ff[0].S.ExpiresAt.Before(t2) { + t.Errorf("want before %s, got %s", t2.String(), ff[0].S.ExpiresAt.String()) + } } } @@ -274,28 +286,30 @@ func TestInit(t *testing.T) { }{ "Error returned by store.Create": { Store: storeStub(errors.New("error")), - ExpiresIn: time.Hour, + ExpiresIn: time.Hour * 24 * 30, Checks: checks( hasErr(true), hasCookie(false), - wasCreateCalled(1, key), + wasCreateCalled(1, key, time.Now().Add(time.Hour*24), + time.Now().Add(time.Hour*24*30+time.Second)), ), }, - "Successful init without expiration field": { + "Successful temporary session init": { Store: storeStub(nil), Checks: checks( hasErr(false), hasCookie(true), - wasCreateCalled(0, ""), + wasCreateCalled(1, key, time.Time{}, time.Now().Add(time.Hour*24+time.Second)), ), }, - "Successful init": { + "Successful permanent session init": { Store: storeStub(nil), - ExpiresIn: time.Hour, + ExpiresIn: time.Hour * 24 * 30, Checks: checks( hasErr(false), hasCookie(true), - wasCreateCalled(1, key), + wasCreateCalled(1, key, time.Now().Add(time.Hour*24), + time.Now().Add(time.Hour*24*30+time.Second)), ), }, }