diff --git a/.github/workflows/check-generated-files.yml b/.github/workflows/check-generated-files.yml index d520358915..046bb6fbb2 100644 --- a/.github/workflows/check-generated-files.yml +++ b/.github/workflows/check-generated-files.yml @@ -61,7 +61,6 @@ jobs: - "invitations/invitations.go" - "users/emailer.go" - "users/hasher.go" - - "mqtt/events/streams.go" - "certs/certs.go" - "certs/pki/vault.go" - "certs/service.go" @@ -149,7 +148,6 @@ jobs: mv ./clients/mocks/clients_client.go ./clients/mocks/clients_client.go.tmp mv ./clients/mocks/cache.go ./clients/mocks/cache.go.tmp mv ./clients/mocks/service.go ./clients/mocks/service.go.tmp - mv ./mqtt/mocks/events.go ./mqtt/mocks/events.go.tmp mv ./readers/mocks/messages.go ./readers/mocks/messages.go.tmp mv ./pkg/sdk/mocks/sdk.go ./pkg/sdk/mocks/sdk.go.tmp mv ./pkg/messaging/mocks/pubsub.go ./pkg/messaging/mocks/pubsub.go.tmp @@ -208,7 +206,6 @@ jobs: check_mock_changes ./clients/mocks/clients_client.go " ./clients/mocks/clients_client.go" check_mock_changes ./clients/mocks/cache.go " ./clients/mocks/cache.go" check_mock_changes ./clients/mocks/service.go " ./clients/mocks/service.go" - check_mock_changes ./mqtt/mocks/events.go " ./mqtt/mocks/events.go" check_mock_changes ./readers/mocks/messages.go " ./readers/mocks/messages.go" check_mock_changes ./pkg/sdk/mocks/sdk.go " ./pkg/sdk/mocks/sdk.go" check_mock_changes ./pkg/messaging/mocks/pubsub.go " ./pkg/messaging/mocks/pubsub.go" diff --git a/cmd/mqtt/main.go b/cmd/mqtt/main.go index 8f527daf04..32bd0353c7 100644 --- a/cmd/mqtt/main.go +++ b/cmd/mqtt/main.go @@ -135,13 +135,6 @@ func main() { defer bsub.Close() bsub = brokerstracing.NewPubSub(serverConfig, tracer, bsub) - bsub, err = msgevents.NewPubSubMiddleware(ctx, bsub, cfg.ESURL) - if err != nil { - logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) - exitCode = 1 - return - } - mpub, err := mqttpub.NewPublisher(fmt.Sprintf("mqtt://%s:%s", cfg.MQTTTargetHost, cfg.MQTTTargetPort), cfg.MQTTQoS, cfg.MQTTForwarderTimeout) if err != nil { logger.Error(fmt.Sprintf("failed to create MQTT publisher: %s", err)) @@ -181,13 +174,6 @@ func main() { return } - es, err := events.NewEventStore(ctx, cfg.ESURL, cfg.Instance) - if err != nil { - logger.Error(fmt.Sprintf("failed to create %s event store : %s", svcName, err)) - exitCode = 1 - return - } - clientsClientCfg := grpcclient.Config{} if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil { logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err)) @@ -220,7 +206,15 @@ func main() { defer channelsHandler.Close() logger.Info("Channels service gRPC client successfully connected to channels gRPC server " + channelsHandler.Secure()) - h := mqtt.NewHandler(np, es, logger, clientsClient, channelsClient) + h := mqtt.NewHandler(np, logger, clientsClient, channelsClient) + + h, err = events.NewEventStoreMiddleware(ctx, h, cfg.ESURL, cfg.Instance) + if err != nil { + logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) + exitCode = 1 + return + } + h = handler.NewTracing(tracer, h) if cfg.SendTelemetry { diff --git a/coap/adapter.go b/coap/adapter.go index 806f888262..9d6e21f6d1 100644 --- a/coap/adapter.go +++ b/coap/adapter.go @@ -121,9 +121,10 @@ func (svc *adapterService) Subscribe(ctx context.Context, key, chanID, subtopic authzc := newAuthzClient(clientID, chanID, subtopic, svc.channels, c) subCfg := messaging.SubscriberConfig{ - ID: c.Token(), - Topic: subject, - Handler: authzc, + ID: c.Token(), + ClientID: clientID, + Topic: subject, + Handler: authzc, } return svc.pubsub.Subscribe(ctx, subCfg) } diff --git a/journal/api/responses.go b/journal/api/responses.go index b4e0f9496a..4ac44db041 100644 --- a/journal/api/responses.go +++ b/journal/api/responses.go @@ -10,7 +10,10 @@ import ( "github.com/absmach/supermq/journal" ) -var _ supermq.Response = (*pageRes)(nil) +var ( + _ supermq.Response = (*pageRes)(nil) + _ supermq.Response = (*clientTelemetryRes)(nil) +) type pageRes struct { journal.JournalsPage `json:",inline"` @@ -31,3 +34,15 @@ func (res pageRes) Empty() bool { type clientTelemetryRes struct { journal.ClientTelemetry `json:",inline"` } + +func (res clientTelemetryRes) Headers() map[string]string { + return map[string]string{} +} + +func (res clientTelemetryRes) Code() int { + return http.StatusOK +} + +func (res clientTelemetryRes) Empty() bool { + return false +} diff --git a/journal/journal.go b/journal/journal.go index 54d9b96f9a..df4e0e0d25 100644 --- a/journal/journal.go +++ b/journal/journal.go @@ -140,13 +140,21 @@ func (page JournalsPage) MarshalJSON() ([]byte, error) { type ClientTelemetry struct { ClientID string `json:"client_id"` DomainID string `json:"domain_id"` - Subscriptions []string `json:"subscriptions"` + Subscriptions uint64 `json:"subscriptions"` InboundMessages uint64 `json:"inbound_messages"` OutboundMessages uint64 `json:"outbound_messages"` FirstSeen time.Time `json:"first_seen"` LastSeen time.Time `json:"last_seen"` } +type ClientSubscription struct { + ID string `json:"id" db:"id"` + SubscriberID string `json:"subscriber_id" db:"subscriber_id"` + ChannelID string `json:"channel_id" db:"channel_id"` + Subtopic string `json:"subtopic" db:"subtopic"` + ClientID string `json:"client_id" db:"client_id"` +} + // Service provides access to the journal log service. // //go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" @@ -179,4 +187,19 @@ type Repository interface { // DeleteClientTelemetry removes telemetry data for a client from the database. DeleteClientTelemetry(ctx context.Context, clientID, domainID string) error + + // AddSubscription adds a subscription to the client telemetry. + AddSubscription(ctx context.Context, sub ClientSubscription) error + + // CountSubscriptions returns the number of subscriptions for a client. + CountSubscriptions(ctx context.Context, clientID string) (uint64, error) + + // RemoveSubscription removes a subscription from the client telemetry. + RemoveSubscription(ctx context.Context, subscriberID string) error + + // IncrementInboundMessages increments the inbound messages count for a client. + IncrementInboundMessages(ctx context.Context, clientID string) error + + // IncrementOutboundMessages increments the outbound messages count for a client. + IncrementOutboundMessages(ctx context.Context, channelID, subtopic string) error } diff --git a/journal/middleware/authorization.go b/journal/middleware/authorization.go index 2cf9edccff..8bb0e102a2 100644 --- a/journal/middleware/authorization.go +++ b/journal/middleware/authorization.go @@ -70,7 +70,7 @@ func (am *authorizationMiddleware) RetrieveClientTelemetry(ctx context.Context, Domain: session.DomainID, SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Subject: session.UserID, + Subject: session.DomainUserID, Permission: readPermission, ObjectType: policies.ClientType, Object: clientID, diff --git a/journal/mocks/repository.go b/journal/mocks/repository.go index 32abe3ce4c..10170dc2d7 100644 --- a/journal/mocks/repository.go +++ b/journal/mocks/repository.go @@ -16,6 +16,52 @@ type Repository struct { mock.Mock } +// AddSubscription provides a mock function with given fields: ctx, sub +func (_m *Repository) AddSubscription(ctx context.Context, sub journal.ClientSubscription) error { + ret := _m.Called(ctx, sub) + + if len(ret) == 0 { + panic("no return value specified for AddSubscription") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, journal.ClientSubscription) error); ok { + r0 = rf(ctx, sub) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CountSubscriptions provides a mock function with given fields: ctx, clientID +func (_m *Repository) CountSubscriptions(ctx context.Context, clientID string) (uint64, error) { + ret := _m.Called(ctx, clientID) + + if len(ret) == 0 { + panic("no return value specified for CountSubscriptions") + } + + var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (uint64, error)); ok { + return rf(ctx, clientID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) uint64); ok { + r0 = rf(ctx, clientID) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, clientID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // DeleteClientTelemetry provides a mock function with given fields: ctx, clientID, domainID func (_m *Repository) DeleteClientTelemetry(ctx context.Context, clientID string, domainID string) error { ret := _m.Called(ctx, clientID, domainID) @@ -34,6 +80,60 @@ func (_m *Repository) DeleteClientTelemetry(ctx context.Context, clientID string return r0 } +// IncrementInboundMessages provides a mock function with given fields: ctx, clientID +func (_m *Repository) IncrementInboundMessages(ctx context.Context, clientID string) error { + ret := _m.Called(ctx, clientID) + + if len(ret) == 0 { + panic("no return value specified for IncrementInboundMessages") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, clientID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// IncrementOutboundMessages provides a mock function with given fields: ctx, channelID, subtopic +func (_m *Repository) IncrementOutboundMessages(ctx context.Context, channelID string, subtopic string) error { + ret := _m.Called(ctx, channelID, subtopic) + + if len(ret) == 0 { + panic("no return value specified for IncrementOutboundMessages") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, channelID, subtopic) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveSubscription provides a mock function with given fields: ctx, subscriberID +func (_m *Repository) RemoveSubscription(ctx context.Context, subscriberID string) error { + ret := _m.Called(ctx, subscriberID) + + if len(ret) == 0 { + panic("no return value specified for RemoveSubscription") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, subscriberID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // RetrieveAll provides a mock function with given fields: ctx, page func (_m *Repository) RetrieveAll(ctx context.Context, page journal.Page) (journal.JournalsPage, error) { ret := _m.Called(ctx, page) diff --git a/journal/postgres/init.go b/journal/postgres/init.go index 29ffb23565..00d22cf5e9 100644 --- a/journal/postgres/init.go +++ b/journal/postgres/init.go @@ -28,18 +28,25 @@ func Migration() *migrate.MemoryMigrationSource { `CREATE INDEX idx_journal_default_client_filter ON journal(operation, (attributes->>'id'), (attributes->>'client_id'), occurred_at DESC);`, `CREATE INDEX idx_journal_default_channel_filter ON journal(operation, (attributes->>'id'), (attributes->>'channel_id'), occurred_at DESC);`, `CREATE TABLE IF NOT EXISTS clients_telemetry ( - client_id VARCHAR(36) NOT NULL, + client_id VARCHAR(36) PRIMARY KEY, domain_id VARCHAR(36) NOT NULL, - subscriptions TEXT[], inbound_messages BIGINT DEFAULT 0, outbound_messages BIGINT DEFAULT 0, first_seen TIMESTAMP, - last_seen TIMESTAMP, - PRIMARY KEY (client_id, domain_id) + last_seen TIMESTAMP + )`, + `CREATE TABLE IF NOT EXISTS subscriptions ( + id VARCHAR(36) PRIMARY KEY, + subscriber_id VARCHAR(1024) NOT NULL, + channel_id VARCHAR(36) NOT NULL, + subtopic VARCHAR(1024), + client_id VARCHAR(36), + FOREIGN KEY (client_id) REFERENCES clients_telemetry(client_id) ON DELETE CASCADE ON UPDATE CASCADE )`, }, Down: []string{ `DROP TABLE IF EXISTS clients_telemetry`, + `DROP TABLE IF EXISTS subscriptions`, `DROP TABLE IF EXISTS journal`, }, }, diff --git a/journal/postgres/telemetry.go b/journal/postgres/telemetry.go index f231327319..f46f42d039 100644 --- a/journal/postgres/telemetry.go +++ b/journal/postgres/telemetry.go @@ -16,8 +16,8 @@ import ( ) func (repo *repository) SaveClientTelemetry(ctx context.Context, ct journal.ClientTelemetry) error { - q := `INSERT INTO clients_telemetry (client_id, domain_id, messages, subscriptions, first_seen, last_seen) - VALUES (:client_id, :domain_id, :messages, :subscriptions, :first_seen, :last_seen);` + q := `INSERT INTO clients_telemetry (client_id, domain_id, inbound_messages, outbound_messages, first_seen, last_seen) + VALUES (:client_id, :domain_id, :inbound_messages, :outbound_messages, :first_seen, :last_seen);` dbct, err := toDBClientsTelemetry(ct) if err != nil { @@ -32,7 +32,7 @@ func (repo *repository) SaveClientTelemetry(ctx context.Context, ct journal.Clie } func (repo *repository) DeleteClientTelemetry(ctx context.Context, clientID, domainID string) error { - q := "DELETE FROM clients_telemetry AS ct WHERE ct.client_id = :client_id AND ct.domain_id = :domain_id;" + q := `DELETE FROM clients_telemetry AS ct WHERE ct.client_id = :client_id AND ct.domain_id = :domain_id;` dbct := dbClientTelemetry{ ClientID: clientID, @@ -50,7 +50,7 @@ func (repo *repository) DeleteClientTelemetry(ctx context.Context, clientID, dom } func (repo *repository) RetrieveClientTelemetry(ctx context.Context, clientID, domainID string) (journal.ClientTelemetry, error) { - q := "SELECT * FROM clients_telemetry WHERE client_id = :client_id AND domain_id = :domain_id;" + q := `SELECT * FROM clients_telemetry WHERE client_id = :client_id AND domain_id = :domain_id;` dbct := dbClientTelemetry{ ClientID: clientID, @@ -80,14 +80,142 @@ func (repo *repository) RetrieveClientTelemetry(ctx context.Context, clientID, d return journal.ClientTelemetry{}, repoerr.ErrNotFound } +func (repo *repository) AddSubscription(ctx context.Context, sub journal.ClientSubscription) error { + q := `INSERT INTO subscriptions (id, subscriber_id, channel_id, subtopic, client_id) + VALUES (:id, :subscriber_id, :channel_id, :subtopic, :client_id); + ` + + result, err := repo.db.NamedExecContext(ctx, q, sub) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + + if rows, _ := result.RowsAffected(); rows == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +func (repo *repository) CountSubscriptions(ctx context.Context, clientID string) (uint64, error) { + q := `SELECT COUNT(*) FROM subscriptions WHERE client_id = :client_id;` + + sb := journal.ClientSubscription{ + ClientID: clientID, + } + + total, err := postgres.Total(ctx, repo.db, q, sb) + if err != nil { + return 0, postgres.HandleError(repoerr.ErrViewEntity, err) + } + + return total, nil +} + +func (repo *repository) RemoveSubscription(ctx context.Context, subscriberID string) error { + q := `DELETE FROM subscriptions WHERE subscriber_id = :subscriber_id;` + + sb := journal.ClientSubscription{ + SubscriberID: subscriberID, + } + + _, err := repo.db.NamedExecContext(ctx, q, sb) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + + return nil +} + +func (repo *repository) IncrementInboundMessages(ctx context.Context, clientID string) error { + q := ` + UPDATE clients_telemetry + SET inbound_messages = inbound_messages + 1, + last_seen = :last_seen + WHERE client_id = :client_id; + ` + + ct := journal.ClientTelemetry{ + ClientID: clientID, + LastSeen: time.Now(), + } + dbct, err := toDBClientsTelemetry(ct) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + result, err := repo.db.NamedExecContext(ctx, q, dbct) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + + if rows, _ := result.RowsAffected(); rows == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +func (repo *repository) IncrementOutboundMessages(ctx context.Context, channelID, subtopic string) error { + query := ` + SELECT client_id, COUNT(*) AS match_count + FROM subscriptions + WHERE channel_id = :channel_id AND subtopic = :subtopic + GROUP BY client_id + ` + sb := journal.ClientSubscription{ + ChannelID: channelID, + Subtopic: subtopic, + } + + rows, err := repo.db.NamedQueryContext(ctx, query, sb) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer rows.Close() + + tx, err := repo.db.BeginTxx(ctx, nil) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + + q := `UPDATE clients_telemetry + SET outbound_messages = outbound_messages + $1 + WHERE client_id = $2; + ` + + for rows.Next() { + var clientID string + var count uint64 + if err = rows.Scan(&clientID, &count); err != nil { + if err := tx.Rollback(); err != nil { + return errors.Wrap(errors.ErrRollbackTx, err) + } + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + + if _, err = repo.db.ExecContext(ctx, q, count, clientID); err != nil { + if err := tx.Rollback(); err != nil { + return errors.Wrap(errors.ErrRollbackTx, err) + } + return errors.Wrap(errors.ErrRollbackTx, err) + } + } + + if err = tx.Commit(); err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + + return nil +} + type dbClientTelemetry struct { - ClientID string `db:"client_id"` - DomainID string `db:"domain_id"` - Subscriptions pgtype.TextArray `db:"subscriptions"` - InboundMessages uint64 `db:"inbound_messages"` - OutboundMessages uint64 `db:"outbound_messages"` - FirstSeen time.Time `db:"first_seen"` - LastSeen sql.NullTime `db:"last_seen"` + ClientID string `db:"client_id"` + DomainID string `db:"domain_id"` + InboundMessages uint64 `db:"inbound_messages"` + OutboundMessages uint64 `db:"outbound_messages"` + FirstSeen time.Time `db:"first_seen"` + LastSeen sql.NullTime `db:"last_seen"` } func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error) { @@ -104,7 +232,6 @@ func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error) return dbClientTelemetry{ ClientID: ct.ClientID, DomainID: ct.DomainID, - Subscriptions: subs, InboundMessages: ct.InboundMessages, OutboundMessages: ct.OutboundMessages, FirstSeen: ct.FirstSeen, @@ -113,11 +240,6 @@ func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error) } func toClientsTelemetry(dbct dbClientTelemetry) (journal.ClientTelemetry, error) { - var subs []string - for _, e := range dbct.Subscriptions.Elements { - subs = append(subs, e.String) - } - var lastSeen time.Time if dbct.LastSeen.Valid { lastSeen = dbct.LastSeen.Time @@ -126,7 +248,6 @@ func toClientsTelemetry(dbct dbClientTelemetry) (journal.ClientTelemetry, error) return journal.ClientTelemetry{ ClientID: dbct.ClientID, DomainID: dbct.DomainID, - Subscriptions: subs, InboundMessages: dbct.InboundMessages, OutboundMessages: dbct.OutboundMessages, FirstSeen: dbct.FirstSeen, diff --git a/journal/service.go b/journal/service.go index 81f9ed61bf..9367392b7c 100644 --- a/journal/service.go +++ b/journal/service.go @@ -5,6 +5,9 @@ package journal import ( "context" + "fmt" + "strings" + "time" "github.com/absmach/supermq" smqauthn "github.com/absmach/supermq/pkg/authn" @@ -12,6 +15,16 @@ import ( svcerr "github.com/absmach/supermq/pkg/errors/service" ) +const ( + clientCreate = "client.create" + clientRemove = "client.remove" + mqttSubscribe = "mqtt.client_subscribe" + mqttDisconnect = "mqtt.client_disconnect" + messagingPublish = "messaging.client_publish" + messagingSubscribe = "messaging.client_subscribe" + messagingUnsubscribe = "messaging.client_unsubscribe" +) + type service struct { idProvider supermq.IDProvider repository Repository @@ -31,7 +44,14 @@ func (svc *service) Save(ctx context.Context, journal Journal) error { } journal.ID = id - return svc.repository.Save(ctx, journal) + if err := svc.repository.Save(ctx, journal); err != nil { + return err + } + if err := svc.handleTelemetry(ctx, journal); err != nil { + return err + } + + return nil } func (svc *service) RetrieveAll(ctx context.Context, session smqauthn.Session, page Page) (JournalsPage, error) { @@ -49,5 +69,292 @@ func (svc *service) RetrieveClientTelemetry(ctx context.Context, session smqauth return ClientTelemetry{}, errors.Wrap(svcerr.ErrViewEntity, err) } + subs, err := svc.repository.CountSubscriptions(ctx, clientID) + if err != nil { + return ClientTelemetry{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + + ct.Subscriptions = subs + return ct, nil } + +func (svc *service) handleTelemetry(ctx context.Context, journal Journal) error { + switch journal.Operation { + case clientCreate: + return svc.addClientTelemetry(ctx, journal) + + case clientRemove: + return svc.removeClientTelemetry(ctx, journal) + + case mqttSubscribe: + return svc.addMqttSubscription(ctx, journal) + + case messagingSubscribe: + return svc.addSubscription(ctx, journal) + + case messagingUnsubscribe: + return svc.removeSubscription(ctx, journal) + + case messagingPublish: + return svc.updateMessageCount(ctx, journal) + + case mqttDisconnect: + return svc.removeMqttSubscription(ctx, journal) + + default: + return nil + } +} + +func (svc *service) addClientTelemetry(ctx context.Context, journal Journal) error { + ce, err := toClientEvent(journal) + if err != nil { + return err + } + ct := ClientTelemetry{ + ClientID: ce.id, + DomainID: ce.domain, + FirstSeen: ce.createdAt, + LastSeen: ce.createdAt, + } + return svc.repository.SaveClientTelemetry(ctx, ct) +} + +func (svc *service) removeClientTelemetry(ctx context.Context, journal Journal) error { + ce, err := toClientEvent(journal) + if err != nil { + return err + } + return svc.repository.DeleteClientTelemetry(ctx, ce.id, ce.domain) +} + +func (svc *service) addSubscription(ctx context.Context, journal Journal) error { + ae, err := toSubscribeEvent(journal) + if err != nil { + return err + } + var subtopic string + topics := strings.Split(ae.topic, ".") + if len(topics) > 2 { + subtopic = topics[2] + } + + id, err := svc.idProvider.ID() + if err != nil { + return err + } + + sub := ClientSubscription{ + ID: id, + SubscriberID: ae.subscriberID, + ChannelID: topics[1], + Subtopic: subtopic, + ClientID: ae.clientID, + } + + return svc.repository.AddSubscription(ctx, sub) +} + +func (svc *service) addMqttSubscription(ctx context.Context, journal Journal) error { + ae, err := toMqttSubscribeEvent(journal) + if err != nil { + return err + } + + id, err := svc.idProvider.ID() + if err != nil { + return err + } + + sub := ClientSubscription{ + ID: id, + SubscriberID: ae.subscriberID, + ChannelID: ae.channelID, + Subtopic: ae.subtopic, + ClientID: ae.clientID, + } + + return svc.repository.AddSubscription(ctx, sub) +} + +func (svc *service) removeSubscription(ctx context.Context, journal Journal) error { + ae, err := toUnsubscribeEvent(journal) + if err != nil { + return err + } + + return svc.repository.RemoveSubscription(ctx, ae.subscriberID) +} + +func (svc *service) removeMqttSubscription(ctx context.Context, journal Journal) error { + ae, err := toMqttDisconnectEvent(journal) + if err != nil { + return err + } + + return svc.repository.RemoveSubscription(ctx, ae.subscriberID) +} + +func (svc *service) updateMessageCount(ctx context.Context, journal Journal) error { + ae, err := toPublishEvent(journal) + if err != nil { + return err + } + if err := svc.repository.IncrementInboundMessages(ctx, ae.clientID); err != nil { + return err + } + if err := svc.repository.IncrementOutboundMessages(ctx, ae.channelID, ae.subtopic); err != nil { + return err + } + return nil +} + +type clientEvent struct { + id string + domain string + createdAt time.Time +} + +func toClientEvent(journal Journal) (clientEvent, error) { + var createdAt time.Time + var err error + id, err := getStringAttribute(journal, "id") + if err != nil { + return clientEvent{}, err + } + domain, err := getStringAttribute(journal, "domain") + if err != nil { + return clientEvent{}, err + } + + createdAtStr := journal.Attributes["created_at"].(string) + if createdAtStr != "" { + createdAt, err = time.Parse(time.RFC3339, createdAtStr) + if err != nil { + return clientEvent{}, fmt.Errorf("invalid created_at format") + } + } + return clientEvent{ + id: id, + domain: domain, + createdAt: createdAt, + }, nil +} + +type adapterEvent struct { + clientID string + channelID string + subscriberID string + topic string + subtopic string +} + +func toPublishEvent(journal Journal) (adapterEvent, error) { + clientID, err := getStringAttribute(journal, "client_id") + if err != nil { + return adapterEvent{}, err + } + channelID, err := getStringAttribute(journal, "channel_id") + if err != nil { + return adapterEvent{}, err + } + subtopic, err := getStringAttribute(journal, "subtopic") + if err != nil { + return adapterEvent{}, err + } + + return adapterEvent{ + clientID: clientID, + channelID: channelID, + subtopic: subtopic, + }, nil +} + +func toSubscribeEvent(journal Journal) (adapterEvent, error) { + subscriberID, err := getStringAttribute(journal, "subscriber_id") + if err != nil { + return adapterEvent{}, err + } + topic, err := getStringAttribute(journal, "topic") + if err != nil { + return adapterEvent{}, err + } + var clientID string + clientID, err = getStringAttribute(journal, "client_id") + if err != nil { + clientID = "" + } + + return adapterEvent{ + clientID: clientID, + subscriberID: subscriberID, + topic: topic, + }, nil +} + +func toUnsubscribeEvent(journal Journal) (adapterEvent, error) { + subscriberID, err := getStringAttribute(journal, "subscriber_id") + if err != nil { + return adapterEvent{}, err + } + topic, err := getStringAttribute(journal, "topic") + if err != nil { + return adapterEvent{}, err + } + + return adapterEvent{ + subscriberID: subscriberID, + topic: topic, + }, nil +} + +func toMqttSubscribeEvent(journal Journal) (adapterEvent, error) { + clientID, err := getStringAttribute(journal, "client_id") + if err != nil { + return adapterEvent{}, err + } + subscriberID, err := getStringAttribute(journal, "subscriber_id") + if err != nil { + return adapterEvent{}, err + } + channelID, err := getStringAttribute(journal, "channel_id") + if err != nil { + return adapterEvent{}, err + } + subtopic, err := getStringAttribute(journal, "subtopic") + if err != nil { + return adapterEvent{}, err + } + + return adapterEvent{ + clientID: clientID, + subscriberID: subscriberID, + channelID: channelID, + subtopic: subtopic, + }, nil +} + +func toMqttDisconnectEvent(journal Journal) (adapterEvent, error) { + subscriberID, err := getStringAttribute(journal, "subscriber_id") + if err != nil { + return adapterEvent{}, err + } + clientID, err := getStringAttribute(journal, "client_id") + if err != nil { + return adapterEvent{}, err + } + + return adapterEvent{ + subscriberID: subscriberID, + channelID: clientID, + }, nil +} + +func getStringAttribute(journal Journal, key string) (string, error) { + value, ok := journal.Attributes[key].(string) + if !ok { + return "", fmt.Errorf("missing or invalid %s attribute", key) + } + return value, nil +} diff --git a/mqtt/events/streams.go b/mqtt/events/streams.go index b81316e437..fa3a3097fc 100644 --- a/mqtt/events/streams.go +++ b/mqtt/events/streams.go @@ -5,73 +5,170 @@ package events import ( "context" + "net/url" + "regexp" + "strings" + "github.com/absmach/mgate/pkg/session" + "github.com/absmach/supermq/pkg/errors" "github.com/absmach/supermq/pkg/events" "github.com/absmach/supermq/pkg/events/store" ) const streamID = "supermq.mqtt" -//go:generate mockery --name EventStore --output=../mocks --filename events.go --quiet --note "Copyright (c) Abstract Machines" -type EventStore interface { - Connect(ctx context.Context, clientID, subscriberID string) error - Disconnect(ctx context.Context, clientID, subscriberID string) error - Subscribe(ctx context.Context, clientID, channelID, subscriberID, subtopic string) error -} +var ( + errFailedSession = errors.New("failed to obtain session from context") + errMalformedTopic = errors.New("malformed topic") + channelRegExp = regexp.MustCompile(`^\/?channels\/([\w\-]+)\/messages(\/[^?]*)?(\?.*)?$`) +) // EventStore is a struct used to store event streams in Redis. type eventStore struct { ep events.Publisher + handler session.Handler instance string } -// NewEventStore returns wrapper around mProxy service that sends +// NewEventStoreMiddleware returns middleware around mGate service that sends // events to event store. -func NewEventStore(ctx context.Context, url, instance string) (EventStore, error) { +func NewEventStoreMiddleware(ctx context.Context, handler session.Handler, url, instance string) (session.Handler, error) { publisher, err := store.NewPublisher(ctx, url, streamID) if err != nil { return nil, err } return &eventStore{ - instance: instance, ep: publisher, + handler: handler, + instance: instance, }, nil } -// Connect issues event on MQTT CONNECT. -func (es *eventStore) Connect(ctx context.Context, clientID, subscriberID string) error { +func (es *eventStore) AuthConnect(ctx context.Context) error { + if err := es.handler.AuthConnect(ctx); err != nil { + return err + } + s, ok := session.FromContext(ctx) + if !ok { + return errFailedSession + } + ev := connectEvent{ - clientID: clientID, operation: clientConnect, - subscriberID: subscriberID, + clientID: s.Username, + subscriberID: s.ID, instance: es.instance, } return es.ep.Publish(ctx, ev) } -// Disconnect issues event on MQTT CONNECT. -func (es *eventStore) Disconnect(ctx context.Context, clientID, subscriberID string) error { +func (es *eventStore) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { + return es.handler.AuthPublish(ctx, topic, payload) +} + +func (es *eventStore) AuthSubscribe(ctx context.Context, topics *[]string) error { + return es.handler.AuthSubscribe(ctx, topics) +} + +func (es *eventStore) Connect(ctx context.Context) error { + return es.handler.Connect(ctx) +} + +func (es *eventStore) Publish(ctx context.Context, topic *string, payload *[]byte) error { + return es.handler.Publish(ctx, topic, payload) +} + +func (es *eventStore) Subscribe(ctx context.Context, topics *[]string) error { + if err := es.handler.Subscribe(ctx, topics); err != nil { + return err + } + + s, ok := session.FromContext(ctx) + if !ok { + return errFailedSession + } + + for _, topic := range *topics { + channelID, subtopic, err := parseTopic(topic) + if err != nil { + return err + } + ev := subscribeEvent{ + operation: clientSubscribe, + clientID: s.Username, + channelID: channelID, + subscriberID: s.ID, + subtopic: subtopic, + } + + if err := es.ep.Publish(ctx, ev); err != nil { + return err + } + } + + return nil +} + +func (es *eventStore) Unsubscribe(ctx context.Context, topics *[]string) error { + return es.handler.Unsubscribe(ctx, topics) +} + +func (es *eventStore) Disconnect(ctx context.Context) error { + if err := es.handler.Disconnect(ctx); err != nil { + return err + } + + s, ok := session.FromContext(ctx) + if !ok { + return errFailedSession + } + ev := connectEvent{ - clientID: clientID, operation: clientDisconnect, - subscriberID: subscriberID, + clientID: s.Username, + subscriberID: s.ID, instance: es.instance, } return es.ep.Publish(ctx, ev) } -// Subscribe issues event on MQTT SUBSCRIBE. -func (es *eventStore) Subscribe(ctx context.Context, clientID, channelID, subscriberID, subtopic string) error { - ev := subscribeEvent{ - operation: clientSubscribe, - clientID: clientID, - channelID: channelID, - subscriberID: subscriberID, - subtopic: subtopic, +func parseTopic(topic string) (string, string, error) { + channelParts := channelRegExp.FindStringSubmatch(topic) + if len(channelParts) < 2 { + return "", "", errMalformedTopic } - return es.ep.Publish(ctx, ev) + chanID := channelParts[1] + subtopic := channelParts[2] + + if subtopic == "" { + return subtopic, chanID, nil + } + + subtopic, err := url.QueryUnescape(subtopic) + if err != nil { + return "", "", errMalformedTopic + } + subtopic = strings.ReplaceAll(subtopic, "/", ".") + + elems := strings.Split(subtopic, ".") + filteredElems := []string{} + for _, elem := range elems { + if elem == "" { + continue + } + + if len(elem) > 1 && (strings.Contains(elem, "*") || strings.Contains(elem, ">")) { + return "", "", errMalformedTopic + } + + filteredElems = append(filteredElems, elem) + } + + subtopic = strings.Join(filteredElems, ".") + + return chanID, subtopic, nil } diff --git a/mqtt/handler.go b/mqtt/handler.go index 47436331c3..53b10acd68 100644 --- a/mqtt/handler.go +++ b/mqtt/handler.go @@ -15,7 +15,6 @@ import ( "github.com/absmach/mgate/pkg/session" grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - "github.com/absmach/supermq/mqtt/events" "github.com/absmach/supermq/pkg/connections" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" @@ -67,13 +66,11 @@ type handler struct { clients grpcClientsV1.ClientsServiceClient channels grpcChannelsV1.ChannelsServiceClient logger *slog.Logger - es events.EventStore } // NewHandler creates new Handler entity. -func NewHandler(publisher messaging.Publisher, es events.EventStore, logger *slog.Logger, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) session.Handler { +func NewHandler(publisher messaging.Publisher, logger *slog.Logger, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) session.Handler { return &handler{ - es: es, logger: logger, publisher: publisher, clients: clients, @@ -107,10 +104,6 @@ func (h *handler) AuthConnect(ctx context.Context) error { return errInvalidUserId } - if err := h.es.Connect(ctx, s.Username, s.ID); err != nil { - h.logger.Error(errors.Wrap(ErrFailedPublishConnectEvent, err).Error()) - } - return nil } @@ -203,18 +196,8 @@ func (h *handler) Subscribe(ctx context.Context, topics *[]string) error { if !ok { return errors.Wrap(ErrFailedSubscribe, ErrClientNotInitialized) } - - for _, topic := range *topics { - channelID, subTopic, err := parseTopic(topic) - if err != nil { - return err - } - if err := h.es.Subscribe(ctx, s.Username, channelID, s.ID, subTopic); err != nil { - return errors.Wrap(ErrFailedSubscribeEvent, err) - } - } - h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ","))) + return nil } @@ -225,6 +208,7 @@ func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error { return errors.Wrap(ErrFailedUnsubscribe, ErrClientNotInitialized) } h.logger.Info(fmt.Sprintf(LogInfoUnsubscribed, s.ID, strings.Join(*topics, ","))) + return nil } @@ -235,9 +219,7 @@ func (h *handler) Disconnect(ctx context.Context) error { return errors.Wrap(ErrFailedDisconnect, ErrClientNotInitialized) } h.logger.Error(fmt.Sprintf(LogInfoDisconnected, s.ID, s.Password)) - if err := h.es.Disconnect(ctx, s.Username, s.ID); err != nil { - return errors.Wrap(ErrFailedPublishDisconnectEvent, err) - } + return nil } @@ -272,23 +254,6 @@ func (h *handler) authAccess(ctx context.Context, clientID, topic string, msgTyp return nil } -func parseTopic(topic string) (string, string, error) { - channelParts := channelRegExp.FindStringSubmatch(topic) - if len(channelParts) < 2 { - return "", "", errors.Wrap(ErrFailedPublish, ErrMalformedTopic) - } - - chanID := channelParts[1] - subtopic := channelParts[2] - - subtopic, err := parseSubtopic(subtopic) - if err != nil { - return "", "", errors.Wrap(ErrFailedParseSubtopic, err) - } - - return chanID, subtopic, nil -} - func parseSubtopic(subtopic string) (string, error) { if subtopic == "" { return subtopic, nil diff --git a/mqtt/handler_test.go b/mqtt/handler_test.go index c86ba9fb00..0df3a5baad 100644 --- a/mqtt/handler_test.go +++ b/mqtt/handler_test.go @@ -68,9 +68,8 @@ var ( ) var ( - clients = new(climocks.ClientsServiceClient) - channels = new(chmocks.ChannelsServiceClient) - eventStore = new(mocks.EventStore) + clients = new(climocks.ClientsServiceClient) + channels = new(chmocks.ChannelsServiceClient) ) func TestAuthConnect(t *testing.T) { @@ -147,10 +146,8 @@ func TestAuthConnect(t *testing.T) { password = string(tc.session.Password) } clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: password}).Return(tc.authNRes, tc.authNErr) - svcCall := eventStore.On("Connect", mock.Anything, clientID, mock.Anything).Return(tc.err) err := handler.AuthConnect(ctx) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - svcCall.Unset() clientsCall.Unset() }) } @@ -445,11 +442,9 @@ func TestSubscribe(t *testing.T) { if tc.session != nil { ctx = session.NewContext(ctx, tc.session) } - eventsCall := eventStore.On("Subscribe", mock.Anything, clientID, chanID, clientID, mock.Anything).Return(nil) err := handler.Subscribe(ctx, &tc.topic) assert.Contains(t, logBuffer.String(), tc.logMsg) assert.Equal(t, tc.err, err) - eventsCall.Unset() } } @@ -519,11 +514,9 @@ func TestDisconnect(t *testing.T) { if tc.session != nil { ctx = session.NewContext(ctx, tc.session) } - svcCall := eventStore.On("Disconnect", mock.Anything, clientID, mock.Anything).Return(tc.err) err := handler.Disconnect(ctx) assert.Contains(t, logBuffer.String(), tc.logMsg) assert.Equal(t, tc.err, err) - svcCall.Unset() } } @@ -534,6 +527,5 @@ func newHandler() session.Handler { } clients = new(climocks.ClientsServiceClient) channels = new(chmocks.ChannelsServiceClient) - eventStore = new(mocks.EventStore) - return mqtt.NewHandler(mocks.NewPublisher(), eventStore, logger, clients, channels) + return mqtt.NewHandler(mocks.NewPublisher(), logger, clients, channels) } diff --git a/mqtt/mocks/events.go b/mqtt/mocks/events.go deleted file mode 100644 index 30c64e42d8..0000000000 --- a/mqtt/mocks/events.go +++ /dev/null @@ -1,84 +0,0 @@ -// Code generated by mockery v2.43.2. DO NOT EDIT. - -// Copyright (c) Abstract Machines - -package mocks - -import ( - context "context" - - mock "github.com/stretchr/testify/mock" -) - -// EventStore is an autogenerated mock type for the EventStore type -type EventStore struct { - mock.Mock -} - -// Connect provides a mock function with given fields: ctx, clientID, subscriberID -func (_m *EventStore) Connect(ctx context.Context, clientID string, subscriberID string) error { - ret := _m.Called(ctx, clientID, subscriberID) - - if len(ret) == 0 { - panic("no return value specified for Connect") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, clientID, subscriberID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Disconnect provides a mock function with given fields: ctx, clientID, subscriberID -func (_m *EventStore) Disconnect(ctx context.Context, clientID string, subscriberID string) error { - ret := _m.Called(ctx, clientID, subscriberID) - - if len(ret) == 0 { - panic("no return value specified for Disconnect") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, clientID, subscriberID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Subscribe provides a mock function with given fields: ctx, clientID, channelID, subscriberID, subtopic -func (_m *EventStore) Subscribe(ctx context.Context, clientID string, channelID string, subscriberID string, subtopic string) error { - ret := _m.Called(ctx, clientID, channelID, subscriberID, subtopic) - - if len(ret) == 0 { - panic("no return value specified for Subscribe") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { - r0 = rf(ctx, clientID, channelID, subscriberID, subtopic) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// NewEventStore creates a new instance of EventStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewEventStore(t interface { - mock.TestingT - Cleanup(func()) -}) *EventStore { - mock := &EventStore{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/pkg/messaging/events/events.go b/pkg/messaging/events/events.go index 12f6ce3df4..ceead22212 100644 --- a/pkg/messaging/events/events.go +++ b/pkg/messaging/events/events.go @@ -35,13 +35,15 @@ func (pe publishEvent) Encode() (map[string]interface{}, error) { type subscribeEvent struct { operation string subscriberID string - subtopic string + clientID string + topic string } func (se subscribeEvent) Encode() (map[string]interface{}, error) { return map[string]interface{}{ "operation": se.operation, "subscriber_id": se.subscriberID, - "subtopic": se.subtopic, + "client_id": se.clientID, + "topic": se.topic, }, nil } diff --git a/pkg/messaging/events/pubsub.go b/pkg/messaging/events/pubsub.go index 8e792ae891..657f24908b 100644 --- a/pkg/messaging/events/pubsub.go +++ b/pkg/messaging/events/pubsub.go @@ -54,7 +54,8 @@ func (es *pubsubES) Subscribe(ctx context.Context, cfg messaging.SubscriberConfi se := subscribeEvent{ operation: clientSubscribe, subscriberID: cfg.ID, - subtopic: cfg.Topic, + clientID: cfg.ClientID, + topic: cfg.Topic, } return es.ep.Publish(ctx, se) @@ -68,7 +69,7 @@ func (es *pubsubES) Unsubscribe(ctx context.Context, id string, topic string) er se := subscribeEvent{ operation: clientUnsubscribe, subscriberID: id, - subtopic: topic, + topic: topic, } return es.ep.Publish(ctx, se) diff --git a/pkg/messaging/pubsub.go b/pkg/messaging/pubsub.go index 393de64fef..acdc0e146e 100644 --- a/pkg/messaging/pubsub.go +++ b/pkg/messaging/pubsub.go @@ -36,6 +36,7 @@ type MessageHandler interface { type SubscriberConfig struct { ID string + ClientID string Topic string Handler MessageHandler DeliveryPolicy DeliveryPolicy diff --git a/ws/adapter.go b/ws/adapter.go index f92fe15074..02c4cfe39e 100644 --- a/ws/adapter.go +++ b/ws/adapter.go @@ -75,9 +75,10 @@ func (svc *adapterService) Subscribe(ctx context.Context, clientKey, chanID, sub } subCfg := messaging.SubscriberConfig{ - ID: clientID, - Topic: subject, - Handler: c, + ID: clientID, + ClientID: clientID, + Topic: subject, + Handler: c, } if err := svc.pubsub.Subscribe(ctx, subCfg); err != nil { return ErrFailedSubscription diff --git a/ws/adapter_test.go b/ws/adapter_test.go index 0bb69a8ff9..9348201026 100644 --- a/ws/adapter_test.go +++ b/ws/adapter_test.go @@ -158,9 +158,10 @@ func TestSubscribe(t *testing.T) { for _, tc := range cases { subConfig := messaging.SubscriberConfig{ - ID: clientID, - Topic: "channels." + tc.chanID + "." + subTopic, - Handler: c, + ID: clientID, + Topic: "channels." + tc.chanID + "." + subTopic, + ClientID: clientID, + Handler: c, } clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: tc.clientKey}).Return(tc.authNRes, tc.authNErr) channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{