Skip to content

Commit

Permalink
apiserver: ensure at most one session per device (#398)
Browse files Browse the repository at this point in the history
* Database migrations to ensure at most one session per device
* fix: don't sync Kolide devices continuously if integration is disabled
* Inserts to session table will overwrite existing session with matching device id
* Delete invalid sessions from session store
* Write tests for session cache store
* Bump Go version in asdf

co-authored-by: @sechmann
  • Loading branch information
kimtore authored Jul 3, 2024
1 parent bbbae2b commit 0535aab
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .tool-versions
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
golang 1.22.4
golang 1.22.5
protoc 23.4
protoc-gen-go 1.31.0
protoc-gen-go-grpc 1.3.0
28 changes: 15 additions & 13 deletions cmd/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,23 +358,25 @@ func run(log *logrus.Entry, cfg config.Config) error {
cancel()
}()

untilContextDone := func(ctx context.Context, interval time.Duration, f func(context.Context) error) {
ticker := time.NewTicker(interval)
for {
if err := f(ctx); err != nil {
log.WithError(err).Error("run until program done wrapper")
}
if cfg.KolideIntegrationEnabled {
untilContextDone := func(ctx context.Context, interval time.Duration, f func(context.Context) error) {
ticker := time.NewTicker(interval)
for {
if err := f(ctx); err != nil {
log.WithError(err).Error("run until program done wrapper")
}

select {
case <-ticker.C:
case <-ctx.Done():
return
select {
case <-ticker.C:
case <-ctx.Done():
return
}
}
}
}

// sync all devices continuously
go untilContextDone(ctx, 1*time.Minute, grpcHandler.UpdateAllDevices)
// sync all devices continuously
go untilContextDone(ctx, 1*time.Minute, grpcHandler.UpdateAllDevices)
}

// initialize gateway metrics
gateways, err := db.ReadGateways(ctx)
Expand Down
16 changes: 16 additions & 0 deletions internal/apiserver/auth/sessionstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,26 @@ func (store *sessionStore) Get(ctx context.Context, key string) (*pb.Session, er
return session, nil
}

// Delete all sessions belonging to a specific device.
// The cache store MUST be locked before calling this function.
func (store *sessionStore) deleteSessionsForDeviceIDWithAssumedLock(deviceID int64) {
for key, session := range store.cache {
if session.GetDevice().GetId() == deviceID {
delete(store.cache, key)
}
}
}

func (store *sessionStore) Set(ctx context.Context, session *pb.Session) error {
if session.GetDevice() == nil {
return fmt.Errorf("store session in database: device info not given")
}

store.lock.Lock()
defer store.lock.Unlock()

store.deleteSessionsForDeviceIDWithAssumedLock(session.GetDevice().GetId())

err := store.db.AddSessionInfo(ctx, session)
if err != nil {
return fmt.Errorf("store session in database: %w", err)
Expand Down
66 changes: 64 additions & 2 deletions internal/apiserver/auth/sessionstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth_test
import (
"context"
"errors"
"fmt"
"strconv"
"testing"
"time"
Expand All @@ -21,7 +22,8 @@ func TestSessionStore_SetAndGetFromCache(t *testing.T) {
store := auth.NewSessionStore(db)

session := &pb.Session{
Key: "abc",
Key: "abc",
Device: &pb.Device{},
}

db.On("AddSessionInfo", mock.Anything, session).Return(nil).Once()
Expand Down Expand Up @@ -63,7 +65,8 @@ func TestSessionStore_Errors(t *testing.T) {
store := auth.NewSessionStore(db)

session := &pb.Session{
Key: "abc",
Key: "abc",
Device: &pb.Device{},
}
dbError := errors.New("error from database")

Expand Down Expand Up @@ -153,3 +156,62 @@ func TestSessionStore_UpdateDevice(t *testing.T) {
assert.NoError(t, err)
assert.True(t, sess.GetDevice().GetLastSeen().AsTime().Equal(updatedDevice.GetLastSeen().AsTime()))
}

// Test that existing sessions with the same device id are removed.
func TestSessionStore_ReplaceOnSet(t *testing.T) {
ctx := context.Background()
db := database.NewMockDatabase(t)
store := auth.NewSessionStore(db)

now := time.Now()

// Return from database layer
db.EXPECT().AddSessionInfo(mock.Anything, mock.Anything).Return(nil).Times(3)
db.EXPECT().ReadSessionInfo(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("oops")).Once()

assert.NoError(t, store.Set(ctx, &pb.Session{
Key: "old_key_1",
Device: &pb.Device{
Id: 123,
LastSeen: timestamppb.New(now),
Serial: "old",
},
}))

assert.NoError(t, store.Set(ctx, &pb.Session{
Key: "old_key_2",
Device: &pb.Device{
Id: 456,
LastSeen: timestamppb.New(now),
Serial: "old",
},
}))

// Assert that the device is stored
session, _ := store.Get(ctx, "old_key_1")
assert.Equal(t, session.GetDevice().GetId(), int64(123))
assert.Equal(t, session.GetDevice().GetSerial(), "old")

assert.NoError(t, store.Set(ctx, &pb.Session{
Key: "new_key_1",
Device: &pb.Device{
Id: 123,
LastSeen: timestamppb.New(now),
Serial: "new",
},
}))

// Assert that the old key is deleted and the new key refers to the correct device
_, err := store.Get(ctx, "old_key_1")
assert.Error(t, err)
session, err = store.Get(ctx, "new_key_1")
assert.NoError(t, err)
assert.Equal(t, session.GetDevice().GetId(), int64(123))
assert.Equal(t, session.GetDevice().GetSerial(), "new")

// Assert that the other device still exists in its original state
session, _ = store.Get(ctx, "old_key_2")
assert.Equal(t, session.GetDevice().GetId(), int64(456))
assert.Equal(t, session.GetDevice().GetSerial(), "old")

}
6 changes: 5 additions & 1 deletion internal/apiserver/database/queries/sessions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ LIMIT 1;

-- name: AddSession :exec
INSERT INTO sessions (key, expiry, device_id, object_id)
VALUES (@key, @expiry, @device_id, @object_id);
VALUES (@key, @expiry, @device_id, @object_id)
ON CONFLICT (device_id) DO UPDATE
SET key = excluded.key,
expiry = excluded.expiry,
object_id = excluded.object_id;

-- name: AddSessionAccessGroupID :exec
INSERT INTO session_access_group_ids (session_key, group_id)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- No changes necessary

DROP INDEX sessions_device_id_unique;
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- Delete all sessions for all device_id's, except the most recent session for each user.
DELETE FROM sessions AS s
WHERE key !=
(SELECT key FROM sessions
WHERE device_id = s.device_id
ORDER BY expiry DESC
LIMIT 1)
;

CREATE UNIQUE INDEX sessions_device_id_unique ON sessions (device_id);
4 changes: 4 additions & 0 deletions internal/apiserver/sqlc/sessions.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0535aab

Please sign in to comment.