diff --git a/cmd/apiserver/main.go b/cmd/apiserver/main.go index c9df1cd9..96f99177 100644 --- a/cmd/apiserver/main.go +++ b/cmd/apiserver/main.go @@ -318,9 +318,9 @@ func run(log *logrus.Entry, cfg config.Config) error { } updateDevice := func(event *kolidepb.DeviceEvent) error { - device, err := kolide.LookupDevice(ctx, db, event) + device, err := db.ReadDeviceByExternalID(ctx, event.GetExternalID()) if err != nil { - return err + return fmt.Errorf("read device with external_id=%v: %w", event.GetExternalID(), err) } changed := false diff --git a/go.mod b/go.mod index b431c8cd..c1806c63 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/jackmordaunt/icns/v2 v2.2.6 github.com/kelseyhightower/envconfig v1.4.0 github.com/lestrrat-go/jwx v1.2.29 - github.com/nais/kolide-event-handler v0.0.0-20240614075259-de023eff2206 + github.com/nais/kolide-event-handler v0.0.0-20240614084216-95ac8998fd8f github.com/nirasan/go-oauth-pkce-code-verifier v0.0.0-20170819232839-0fbfe93532da github.com/prometheus/client_golang v1.17.0 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index 2f1c6e3b..565d75ee 100644 --- a/go.sum +++ b/go.sum @@ -512,10 +512,8 @@ github.com/moricho/tparallel v0.3.1 h1:fQKD4U1wRMAYNngDonW5XupoB/ZGJHdpzrWqgyg9k github.com/moricho/tparallel v0.3.1/go.mod h1:leENX2cUv7Sv2qDgdi0D0fCftN8fRC67Bcn8pqzeYNI= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/nais/kolide-event-handler v0.0.0-20240613124908-c26ee6800776 h1:W5G0sRZnAOvr7Gt6Su3BUpf0dqVotuokB1Z1sKUPYQI= -github.com/nais/kolide-event-handler v0.0.0-20240613124908-c26ee6800776/go.mod h1:1Ta1n1Q+EtH7EIHuU3/svAkCQ8AWk9/qfs6nBj0lhxE= -github.com/nais/kolide-event-handler v0.0.0-20240614075259-de023eff2206 h1:AT0hiUcqyJyfP48JVgtiGXeIbsuKOdpyKBn1jgSpsFA= -github.com/nais/kolide-event-handler v0.0.0-20240614075259-de023eff2206/go.mod h1:7Dl7mqto/Jb4Ng8NntCAsnnYCX2clIuOUd6X7b5s+7o= +github.com/nais/kolide-event-handler v0.0.0-20240614084216-95ac8998fd8f h1:AuoUs1nEs0XUyQjwgwDygPC4zE5+EvQGkMNr5TfmPlw= +github.com/nais/kolide-event-handler v0.0.0-20240614084216-95ac8998fd8f/go.mod h1:1Ta1n1Q+EtH7EIHuU3/svAkCQ8AWk9/qfs6nBj0lhxE= github.com/nakabonne/nestif v0.3.1 h1:wm28nZjhQY5HyYPx+weN3Q65k6ilSBxDb8v5S81B81U= github.com/nakabonne/nestif v0.3.1/go.mod h1:9EtoZochLn5iUprVDmDjqGKPofoUEBL8U4Ngq6aY7OE= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= diff --git a/internal/apiserver/database/database.go b/internal/apiserver/database/database.go index 4368cc13..977effcc 100644 --- a/internal/apiserver/database/database.go +++ b/internal/apiserver/database/database.go @@ -397,6 +397,19 @@ func (db *ApiServerDB) ReadDeviceById(ctx context.Context, deviceID int64) (*pb. return sqlcDeviceToPbDevice(*device) } +func (db *ApiServerDB) ReadDeviceByExternalID(ctx context.Context, externalID string) (*pb.Device, error) { + id := sql.NullString{ + String: externalID, + Valid: true, + } + device, err := db.queries.GetDeviceByExternalID(ctx, id) + if err != nil { + return nil, err + } + + return sqlcDeviceToPbDevice(*device) +} + func (db *ApiServerDB) ReadGateways(ctx context.Context) ([]*pb.Gateway, error) { rows, err := db.queries.GetGateways(ctx) if err != nil { diff --git a/internal/apiserver/database/interface.go b/internal/apiserver/database/interface.go index bc0e1a9b..9aab1352 100644 --- a/internal/apiserver/database/interface.go +++ b/internal/apiserver/database/interface.go @@ -16,6 +16,7 @@ type APIServer interface { AddDevice(ctx context.Context, device *pb.Device) error ReadDevice(ctx context.Context, publicKey string) (*pb.Device, error) ReadDeviceById(ctx context.Context, deviceID int64) (*pb.Device, error) + ReadDeviceByExternalID(ctx context.Context, externalID string) (*pb.Device, error) ReadGateways(ctx context.Context) ([]*pb.Gateway, error) ReadGateway(ctx context.Context, name string) (*pb.Gateway, error) ReadDeviceBySerialPlatform(ctx context.Context, serial string, platform string) (*pb.Device, error) diff --git a/internal/apiserver/database/mock_api_server.go b/internal/apiserver/database/mock_api_server.go index c9046c31..261e6786 100644 --- a/internal/apiserver/database/mock_api_server.go +++ b/internal/apiserver/database/mock_api_server.go @@ -271,6 +271,65 @@ func (_c *MockAPIServer_ReadDevice_Call) RunAndReturn(run func(context.Context, return _c } +// ReadDeviceByExternalID provides a mock function with given fields: ctx, externalID +func (_m *MockAPIServer) ReadDeviceByExternalID(ctx context.Context, externalID string) (*pb.Device, error) { + ret := _m.Called(ctx, externalID) + + if len(ret) == 0 { + panic("no return value specified for ReadDeviceByExternalID") + } + + var r0 *pb.Device + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*pb.Device, error)); ok { + return rf(ctx, externalID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *pb.Device); ok { + r0 = rf(ctx, externalID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*pb.Device) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, externalID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockAPIServer_ReadDeviceByExternalID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadDeviceByExternalID' +type MockAPIServer_ReadDeviceByExternalID_Call struct { + *mock.Call +} + +// ReadDeviceByExternalID is a helper method to define mock.On call +// - ctx context.Context +// - externalID string +func (_e *MockAPIServer_Expecter) ReadDeviceByExternalID(ctx interface{}, externalID interface{}) *MockAPIServer_ReadDeviceByExternalID_Call { + return &MockAPIServer_ReadDeviceByExternalID_Call{Call: _e.mock.On("ReadDeviceByExternalID", ctx, externalID)} +} + +func (_c *MockAPIServer_ReadDeviceByExternalID_Call) Run(run func(ctx context.Context, externalID string)) *MockAPIServer_ReadDeviceByExternalID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockAPIServer_ReadDeviceByExternalID_Call) Return(_a0 *pb.Device, _a1 error) *MockAPIServer_ReadDeviceByExternalID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockAPIServer_ReadDeviceByExternalID_Call) RunAndReturn(run func(context.Context, string) (*pb.Device, error)) *MockAPIServer_ReadDeviceByExternalID_Call { + _c.Call.Return(run) + return _c +} + // ReadDeviceById provides a mock function with given fields: ctx, deviceID func (_m *MockAPIServer) ReadDeviceById(ctx context.Context, deviceID int64) (*pb.Device, error) { ret := _m.Called(ctx, deviceID) diff --git a/internal/apiserver/database/queries/devices.sql b/internal/apiserver/database/queries/devices.sql index 253c04db..b7fa63a2 100644 --- a/internal/apiserver/database/queries/devices.sql +++ b/internal/apiserver/database/queries/devices.sql @@ -4,6 +4,9 @@ SELECT * FROM devices ORDER BY id; -- name: GetDeviceByPublicKey :one SELECT * FROM devices WHERE public_key = @public_key; +-- name: GetDeviceByExternalID :one +SELECT * FROM devices WHERE external_id = @external_id; + -- name: GetDeviceByID :one SELECT * FROM devices WHERE id = @id; diff --git a/internal/apiserver/database/schema/0005_device_external_id_idx.down.sql b/internal/apiserver/database/schema/0005_device_external_id_idx.down.sql new file mode 100644 index 00000000..43596d31 --- /dev/null +++ b/internal/apiserver/database/schema/0005_device_external_id_idx.down.sql @@ -0,0 +1 @@ +DROP INDEX devies_external_id; diff --git a/internal/apiserver/database/schema/0005_device_external_id_idx.up.sql b/internal/apiserver/database/schema/0005_device_external_id_idx.up.sql new file mode 100644 index 00000000..2c95300e --- /dev/null +++ b/internal/apiserver/database/schema/0005_device_external_id_idx.up.sql @@ -0,0 +1 @@ +CREATE UNIQUE INDEX devies_external_id ON devices ( external_id ); diff --git a/internal/apiserver/kolide/check_test.go b/internal/apiserver/kolide/check_test.go index adc47807..29c59757 100644 --- a/internal/apiserver/kolide/check_test.go +++ b/internal/apiserver/kolide/check_test.go @@ -5,37 +5,37 @@ import ( "testing" "time" - kolideclient "github.com/nais/kolide-event-handler/pkg/kolide" - + "github.com/nais/device/internal/apiserver/kolide" + "github.com/nais/device/internal/pb" "github.com/stretchr/testify/assert" ) func TestCheck(t *testing.T) { tagTests := []struct { tags []string - severity kolideclient.Severity + severity pb.Severity duration time.Duration }{ - {[]string{}, kolideclient.SeverityWarning, kolideclient.DurationWarning}, - {[]string{"foo", "bar"}, kolideclient.SeverityWarning, kolideclient.DurationWarning}, - {[]string{"foo", "notice"}, kolideclient.SeverityNotice, kolideclient.DurationNotice}, - {[]string{"warning", "notice", "danger"}, kolideclient.SeverityDanger, kolideclient.DurationDanger}, - {[]string{"notice"}, kolideclient.SeverityNotice, kolideclient.DurationNotice}, - {[]string{"warning"}, kolideclient.SeverityWarning, kolideclient.DurationWarning}, - {[]string{"danger"}, kolideclient.SeverityDanger, kolideclient.DurationDanger}, - {[]string{"critical"}, kolideclient.SeverityCritical, kolideclient.DurationCritical}, + {[]string{}, pb.Severity_Warning, kolide.DurationWarning}, + {[]string{"foo", "bar"}, pb.Severity_Warning, kolide.DurationWarning}, + {[]string{"foo", "notice"}, pb.Severity_Notice, kolide.DurationNotice}, + {[]string{"warning", "notice", "danger"}, pb.Severity_Danger, kolide.DurationDanger}, + {[]string{"notice"}, pb.Severity_Notice, kolide.DurationNotice}, + {[]string{"warning"}, pb.Severity_Warning, kolide.DurationWarning}, + {[]string{"danger"}, pb.Severity_Danger, kolide.DurationDanger}, + {[]string{"critical"}, pb.Severity_Critical, kolide.DurationCritical}, } for _, tt := range tagTests { t.Run(strings.Join(tt.tags, ", "), func(t *testing.T) { - check := kolideclient.Check{ + check := kolide.Check{ Tags: tt.tags, } severity := check.Severity() assert.Equal(t, tt.severity, severity) - assert.Equal(t, tt.duration, severity.GraceTime()) + assert.Equal(t, tt.duration, kolide.GraceTime(severity)) }) } } diff --git a/internal/apiserver/kolide/event_handler.go b/internal/apiserver/kolide/event_handler.go index 59abe2e1..5b50f932 100644 --- a/internal/apiserver/kolide/event_handler.go +++ b/internal/apiserver/kolide/event_handler.go @@ -3,13 +3,8 @@ package kolide import ( "context" "crypto/tls" - "fmt" - "strings" "time" - "github.com/nais/device/internal/apiserver/database" - "github.com/nais/device/internal/pb" - kolidepb "github.com/nais/kolide-event-handler/pkg/pb" "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -90,29 +85,3 @@ func DeviceEventStreamer(ctx context.Context, log *logrus.Entry, grpcAddress, gr return ctx.Err() } - -func LookupDevice(ctx context.Context, db database.APIServer, event *kolidepb.DeviceEvent) (*pb.Device, error) { - platform := func(platform string) string { - switch strings.ToLower(platform) { - case "darwin": - return "darwin" - case "windows": - return "windows" - default: - return "linux" - } - } - - p := platform(event.GetPlatform()) - - device, err := db.ReadDeviceBySerialPlatform(ctx, event.GetSerial(), p) - if err != nil { - return nil, fmt.Errorf("read device with serial=%s platform=%s: %w", event.GetSerial(), p, err) - } - - if device.ExternalID == "" { - device.ExternalID = event.GetExternalID() - } - - return device, nil -} diff --git a/internal/apiserver/sqlc/db.go b/internal/apiserver/sqlc/db.go index 5b000188..b66cd49f 100644 --- a/internal/apiserver/sqlc/db.go +++ b/internal/apiserver/sqlc/db.go @@ -51,6 +51,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.deleteGatewayRoutesStmt, err = db.PrepareContext(ctx, deleteGatewayRoutes); err != nil { return nil, fmt.Errorf("error preparing query DeleteGatewayRoutes: %w", err) } + if q.getDeviceByExternalIDStmt, err = db.PrepareContext(ctx, getDeviceByExternalID); err != nil { + return nil, fmt.Errorf("error preparing query GetDeviceByExternalID: %w", err) + } if q.getDeviceByIDStmt, err = db.PrepareContext(ctx, getDeviceByID); err != nil { return nil, fmt.Errorf("error preparing query GetDeviceByID: %w", err) } @@ -152,6 +155,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing deleteGatewayRoutesStmt: %w", cerr) } } + if q.getDeviceByExternalIDStmt != nil { + if cerr := q.getDeviceByExternalIDStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getDeviceByExternalIDStmt: %w", cerr) + } + } if q.getDeviceByIDStmt != nil { if cerr := q.getDeviceByIDStmt.Close(); cerr != nil { err = fmt.Errorf("error closing getDeviceByIDStmt: %w", cerr) @@ -285,6 +293,7 @@ type Queries struct { clearDeviceIssuesExceptForStmt *sql.Stmt deleteGatewayAccessGroupIDsStmt *sql.Stmt deleteGatewayRoutesStmt *sql.Stmt + getDeviceByExternalIDStmt *sql.Stmt getDeviceByIDStmt *sql.Stmt getDeviceByPublicKeyStmt *sql.Stmt getDeviceBySerialAndPlatformStmt *sql.Stmt @@ -317,6 +326,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { clearDeviceIssuesExceptForStmt: q.clearDeviceIssuesExceptForStmt, deleteGatewayAccessGroupIDsStmt: q.deleteGatewayAccessGroupIDsStmt, deleteGatewayRoutesStmt: q.deleteGatewayRoutesStmt, + getDeviceByExternalIDStmt: q.getDeviceByExternalIDStmt, getDeviceByIDStmt: q.getDeviceByIDStmt, getDeviceByPublicKeyStmt: q.getDeviceByPublicKeyStmt, getDeviceBySerialAndPlatformStmt: q.getDeviceBySerialAndPlatformStmt, diff --git a/internal/apiserver/sqlc/devices.sql.go b/internal/apiserver/sqlc/devices.sql.go index e5007066..145feb77 100644 --- a/internal/apiserver/sqlc/devices.sql.go +++ b/internal/apiserver/sqlc/devices.sql.go @@ -53,6 +53,30 @@ func (q *Queries) ClearDeviceIssuesExceptFor(ctx context.Context, unhealthyDevic return err } +const getDeviceByExternalID = `-- name: GetDeviceByExternalID :one +SELECT id, username, serial, platform, healthy, last_updated, public_key, ipv4, ipv6, last_seen, issues, external_id FROM devices WHERE external_id = ?1 +` + +func (q *Queries) GetDeviceByExternalID(ctx context.Context, externalID sql.NullString) (*Device, error) { + row := q.queryRow(ctx, q.getDeviceByExternalIDStmt, getDeviceByExternalID, externalID) + var i Device + err := row.Scan( + &i.ID, + &i.Username, + &i.Serial, + &i.Platform, + &i.Healthy, + &i.LastUpdated, + &i.PublicKey, + &i.Ipv4, + &i.Ipv6, + &i.LastSeen, + &i.Issues, + &i.ExternalID, + ) + return &i, err +} + const getDeviceByID = `-- name: GetDeviceByID :one SELECT id, username, serial, platform, healthy, last_updated, public_key, ipv4, ipv6, last_seen, issues, external_id FROM devices WHERE id = ?1 ` diff --git a/internal/apiserver/sqlc/querier.go b/internal/apiserver/sqlc/querier.go index e7d79fbf..974ad9f7 100644 --- a/internal/apiserver/sqlc/querier.go +++ b/internal/apiserver/sqlc/querier.go @@ -6,6 +6,7 @@ package sqlc import ( "context" + "database/sql" ) type Querier interface { @@ -18,6 +19,7 @@ type Querier interface { ClearDeviceIssuesExceptFor(ctx context.Context, unhealthyDeviceIds interface{}) error DeleteGatewayAccessGroupIDs(ctx context.Context, gatewayName string) error DeleteGatewayRoutes(ctx context.Context, gatewayName string) error + GetDeviceByExternalID(ctx context.Context, externalID sql.NullString) (*Device, error) GetDeviceByID(ctx context.Context, id int64) (*Device, error) GetDeviceByPublicKey(ctx context.Context, publicKey string) (*Device, error) GetDeviceBySerialAndPlatform(ctx context.Context, arg GetDeviceBySerialAndPlatformParams) (*Device, error)